mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-12-15 11:43:21 +00:00
flux kontext
This commit is contained in:
parent
597d26b7e0
commit
eb92f0c11c
13
defaults/ReadMe.txt
Normal file
13
defaults/ReadMe.txt
Normal file
@ -0,0 +1,13 @@
|
||||
Please dot not modify any file in this Folder.
|
||||
|
||||
If you want to change a property of a default model, copy the corrresponding model file in the ./finetunes folder and modify the properties you want to change in the new file.
|
||||
If a property is not in the new file, it will be inherited automatically from the default file that matches the same name file.
|
||||
|
||||
For instance to hide a model:
|
||||
|
||||
{
|
||||
"model":
|
||||
{
|
||||
"visible": false
|
||||
}
|
||||
}
|
||||
@ -1,14 +1,14 @@
|
||||
{
|
||||
"model":
|
||||
{
|
||||
"name": "First Last Frame to Video 720p (FLF2V)14B",
|
||||
"name": "First Last Frame to Video 720p (FLF2V) 14B",
|
||||
"architecture" : "flf2v_720p",
|
||||
"visible" : false,
|
||||
"visible" : true,
|
||||
"description": "The First Last Frame 2 Video model is the official model Image 2 Video model that supports Start and End frames.",
|
||||
"URLs": [
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_FLF2V_720p_14B_bf16.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_FLF2V_720p_14B_quanto_int8.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_FLF2V_720p_14B_quanto_fp16_int8.safetensors"
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_FLF2V_720p_14B_mbf16.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_FLF2V_720p_14B_quanto_mbf16_int8.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_FLF2V_720p_14B_quanto_mfp16_int8.safetensors"
|
||||
],
|
||||
"auto_quantize": true
|
||||
},
|
||||
19
defaults/flux_dev_kontext.json
Normal file
19
defaults/flux_dev_kontext.json
Normal file
@ -0,0 +1,19 @@
|
||||
{
|
||||
"model": {
|
||||
"name": "Flux Dev Kontext 12B",
|
||||
"architecture": "flux_dev_kontext",
|
||||
"description": "FLUX.1 Kontext is a 12 billion parameter rectified flow transformer capable of editing images based on instructions stored in the Prompt. Please be aware that the output resolution is modified by Flux Kontext and may not be what you requested.",
|
||||
"URLs": [
|
||||
"c:/temp/kontext/flux1_kontext_dev_bf16.safetensors",
|
||||
"c:/temp/kontext/flux1_kontext_dev_quanto_bf16_int8.safetensors"
|
||||
],
|
||||
"URLs2": [
|
||||
"https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_bf16.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_quanto_bf16_int8.safetensors"
|
||||
]
|
||||
},
|
||||
"resolution": "1280x720",
|
||||
"video_length": "1"
|
||||
}
|
||||
|
||||
|
||||
13
defaults/fun_inp.json
Normal file
13
defaults/fun_inp.json
Normal file
@ -0,0 +1,13 @@
|
||||
{
|
||||
"model":
|
||||
{
|
||||
"name": "Fun InP image2video 14B",
|
||||
"architecture" : "fun_inp",
|
||||
"description": "The Fun model is an alternative image 2 video that supports out the box End Image fixing (contrary to the original Wan image 2 video model).",
|
||||
"URLs": [
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_Fun_InP_14B_bf16.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_Fun_InP_14B_quanto_int8.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_Fun_InP_14B_quanto_fp16_int8.safetensors"
|
||||
]
|
||||
}
|
||||
}
|
||||
11
defaults/fun_inp_1.3B.json
Normal file
11
defaults/fun_inp_1.3B.json
Normal file
@ -0,0 +1,11 @@
|
||||
{
|
||||
"model":
|
||||
{
|
||||
"name": "Fun InP image2video 1.3B",
|
||||
"architecture" : "fun_inp_1.3B",
|
||||
"description": "The Fun model is an alternative image 2 video that supports out the box End Image fixing (contrary to the original Wan image 2 video model). The 1.3B adds also image 2 to video capability to the 1.3B model.",
|
||||
"URLs": [
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_Fun_InP_1.3B_bf16.safetensors"
|
||||
]
|
||||
}
|
||||
}
|
||||
12
defaults/hunyuan.json
Normal file
12
defaults/hunyuan.json
Normal file
@ -0,0 +1,12 @@
|
||||
{
|
||||
"model":
|
||||
{
|
||||
"name": "Hunyuan Video text2video 720p 13B",
|
||||
"architecture" : "hunyuan",
|
||||
"description": "Probably the best text 2 video model available.",
|
||||
"URLs": [
|
||||
"https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_720_bf16.safetensors.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_720_quanto_int8.safetensors"
|
||||
]
|
||||
}
|
||||
}
|
||||
12
defaults/hunyuan_avatar.json
Normal file
12
defaults/hunyuan_avatar.json
Normal file
@ -0,0 +1,12 @@
|
||||
{
|
||||
"model":
|
||||
{
|
||||
"name": "Hunyuan Video Avatar 720p 13B",
|
||||
"architecture" : "hunyuan_avatar",
|
||||
"description": "With the Hunyuan Video Avatar model you can animate a person based on the content of an audio input. Please note that the video generator works by processing 128 frames segment at a time (even if you ask less). The good news is that it will concatenate multiple segments for long video generation (max 3 segments recommended as the quality will get worse).",
|
||||
"URLs": [
|
||||
"https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_avatar_720_bf16.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_avatar_720_quanto_bf16_int8.safetensors"
|
||||
]
|
||||
}
|
||||
}
|
||||
12
defaults/hunyuan_custom.json
Normal file
12
defaults/hunyuan_custom.json
Normal file
@ -0,0 +1,12 @@
|
||||
{
|
||||
"model":
|
||||
{
|
||||
"name": "Hunyuan Video Custom 720p 13B",
|
||||
"architecture" : "hunyuan_custom",
|
||||
"description": "The Hunyuan Video Custom model is probably the best model to transfer people (only people for the moment) as it is quite good to keep their identity. However it is slow as to get good results, you need to generate 720p videos with 30 steps.",
|
||||
"URLs": [
|
||||
"https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_custom_720_bf16.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_custom_720_quanto_bf16_int8.safetensors"
|
||||
]
|
||||
}
|
||||
}
|
||||
12
defaults/hunyuan_custom_audio.json
Normal file
12
defaults/hunyuan_custom_audio.json
Normal file
@ -0,0 +1,12 @@
|
||||
{
|
||||
"model":
|
||||
{
|
||||
"name": "Hunyuan Video Custom Audio 720p 13B",
|
||||
"architecture" : "hunyuan_custom_audio",
|
||||
"description": "The Hunyuan Video Custom Audio model can be used to generate scenes of a person speaking given a Reference Image and a Recorded Voice or Song. The reference image is not a start image and therefore one can represent the person in a different context.The video length can be anything up to 10s. It is also quite good to generate no sound Video based on a person.",
|
||||
"URLs": [
|
||||
"https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_custom_audio_720_bf16.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_custom_audio_720_quanto_bf16_int8.safetensors"
|
||||
]
|
||||
}
|
||||
}
|
||||
12
defaults/hunyuan_custom_edit.json
Normal file
12
defaults/hunyuan_custom_edit.json
Normal file
@ -0,0 +1,12 @@
|
||||
{
|
||||
"model":
|
||||
{
|
||||
"name": "Hunyuan Video Custom Edit 720p 13B",
|
||||
"architecture" : "hunyuan_custom_edit",
|
||||
"description": "The Hunyuan Video Custom Edit model can be used to do Video inpainting on a person (add accessories or completely replace the person). You will need in any case to define a Video Mask which will indicate which area of the Video should be edited.",
|
||||
"URLs": [
|
||||
"https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_custom_edit_720_bf16.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_custom_edit_720_quanto_bf16_int8.safetensors"
|
||||
]
|
||||
}
|
||||
}
|
||||
12
defaults/hunyuan_i2v.json
Normal file
12
defaults/hunyuan_i2v.json
Normal file
@ -0,0 +1,12 @@
|
||||
{
|
||||
"model":
|
||||
{
|
||||
"name": "Hunyuan Video image2video 720p 13B",
|
||||
"architecture" : "hunyuan_i2v",
|
||||
"description": "A good looking image 2 video model, but not so good in prompt adherence.",
|
||||
"URLs": [
|
||||
"https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_i2v_720_bf16v2.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_i2v_720_quanto_int8v2.safetensors"
|
||||
]
|
||||
}
|
||||
}
|
||||
13
defaults/i2v.json
Normal file
13
defaults/i2v.json
Normal file
@ -0,0 +1,13 @@
|
||||
{
|
||||
"model":
|
||||
{
|
||||
"name": "Wan2.1 image2video 480p 14B",
|
||||
"architecture" : "i2v",
|
||||
"description": "The standard Wan Image 2 Video specialized to generate 480p images. It also offers Start and End Image support (End Image is not supported in the original model but seems to work well)",
|
||||
"URLs": [
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_image2video_480p_14B_mbf16.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_image2video_480p_14B_quanto_mbf16_int8.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_image2video_480p_14B_quanto_mfp16_int8.safetensors"
|
||||
]
|
||||
}
|
||||
}
|
||||
14
defaults/i2v_720p.json
Normal file
14
defaults/i2v_720p.json
Normal file
@ -0,0 +1,14 @@
|
||||
{
|
||||
"model":
|
||||
{
|
||||
"name": "Wan2.1 image2video 720p 14B",
|
||||
"architecture" : "i2v",
|
||||
"description": "The standard Wan Image 2 Video specialized to generate 720p images. It also offers Start and End Image support (End Image is not supported in the original model but seems to work well).",
|
||||
"URLs": [
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_image2video_720p_14B_mbf16.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_image2video_720p_14B_quanto_mfp16_int8.safetensors"
|
||||
]
|
||||
},
|
||||
"resolution": "1280x720"
|
||||
}
|
||||
10
defaults/i2v_fusionix.json
Normal file
10
defaults/i2v_fusionix.json
Normal file
@ -0,0 +1,10 @@
|
||||
{
|
||||
"model":
|
||||
{
|
||||
"name": "Wan2.1 image2video 480p FusioniX 14B",
|
||||
"architecture" : "i2v",
|
||||
"description": "A powerful merged image-to-video model based on the original WAN 2.1 I2V model, enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail.",
|
||||
"URLs": "i2v",
|
||||
"loras": ["https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan2.1_I2V_14B_FusionX_LoRA.safetensors"]
|
||||
}
|
||||
}
|
||||
14
defaults/ltxv_13B.json
Normal file
14
defaults/ltxv_13B.json
Normal file
@ -0,0 +1,14 @@
|
||||
{
|
||||
"model":
|
||||
{
|
||||
"name": "LTX Video 0.9.7 13B",
|
||||
"architecture" : "ltxv_13B",
|
||||
"description": "LTX Video is a fast model that can be used to generate long videos (up to 260 frames).It is recommended to keep the number of steps to 30 or you will need to update the file 'ltxv_video/configs/ltxv-13b-0.9.7-dev.yaml'.The LTX Video model expects very long prompts, so don't hesitate to use the Prompt Enhancer.",
|
||||
"URLs": [
|
||||
"https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv_0.9.7_13B_dev_bf16.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv_0.9.7_13B_dev_quanto_bf16_int8.safetensors"
|
||||
],
|
||||
"LTXV_config": "ltx_video/configs/ltxv-13b-0.9.7-dev.yaml"
|
||||
},
|
||||
"num_inference_steps": 30
|
||||
}
|
||||
14
defaults/ltxv_distilled.json
Normal file
14
defaults/ltxv_distilled.json
Normal file
@ -0,0 +1,14 @@
|
||||
{
|
||||
"model":
|
||||
{
|
||||
"name": "LTX Video 0.9.7 Distilled 13B",
|
||||
"architecture" : "ltxv_13B",
|
||||
"description": "LTX Video is a fast model that can be used to generate long videos (up to 260 frames).This distilled version is a very fast version and retains a high level of quality. The LTX Video model expects very long prompts, so don't hesitate to use the Prompt Enhancer.",
|
||||
"URLs": "ltxv_13B",
|
||||
"loras": ["https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv_0.9.7_13B_distilled_lora128_bf16.safetensors"],
|
||||
"loras_multipliers": [ 1 ],
|
||||
"lock_inference_steps": true,
|
||||
"LTXV_config": "ltx_video/configs/ltxv-13b-0.9.7-distilled.yaml"
|
||||
},
|
||||
"num_inference_steps": 6
|
||||
}
|
||||
11
defaults/phantom_1.3B.json
Normal file
11
defaults/phantom_1.3B.json
Normal file
@ -0,0 +1,11 @@
|
||||
{
|
||||
"model":
|
||||
{
|
||||
"name": "Phantom 1.3B",
|
||||
"architecture" : "phantom_1.3B",
|
||||
"description": "The Phantom model is specialized in transferring people or objects of your choice into a generated Video. It produces very nice results when used at 720p.",
|
||||
"URLs": [
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2_1_phantom_1.3B_mbf16.safetensors"
|
||||
]
|
||||
}
|
||||
}
|
||||
13
defaults/phantom_14B.json
Normal file
13
defaults/phantom_14B.json
Normal file
@ -0,0 +1,13 @@
|
||||
{
|
||||
"model":
|
||||
{
|
||||
"name": "Phantom 14B",
|
||||
"architecture" : "phantom_14B",
|
||||
"description": "The Phantom model is specialized in transferring people or objects of your choice into a generated Video. It produces very nice results when used at 720p.",
|
||||
"URLs": [
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_phantom_14B_mbf16.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_phantom_14B_quanto_mbf16_int8.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_phantom_14B_quanto_mfp16_int8.safetensors"
|
||||
]
|
||||
}
|
||||
}
|
||||
11
defaults/recam_1.3B.json
Normal file
11
defaults/recam_1.3B.json
Normal file
@ -0,0 +1,11 @@
|
||||
{
|
||||
"model":
|
||||
{
|
||||
"name": "ReCamMaster 1.3B",
|
||||
"architecture" : "recam_1.3B",
|
||||
"description": "The Recam Master in theory should allow you to replay a video by applying a different camera movement. The model supports only video that are at least 81 frames long (any frame beyond will be ignored)",
|
||||
"URLs": [
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_recammaster_1.3B_bf16.safetensors"
|
||||
]
|
||||
}
|
||||
}
|
||||
11
defaults/sky_df_1.3B.json
Normal file
11
defaults/sky_df_1.3B.json
Normal file
@ -0,0 +1,11 @@
|
||||
{
|
||||
"model":
|
||||
{
|
||||
"name": "SkyReels2 Diffusion Forcing 1.3B",
|
||||
"architecture" : "sky_df_1.3B",
|
||||
"description": "The SkyReels 2 Diffusion Forcing model has been designed to generate very long videos that exceeds the usual 5s limit. You can also use this model to extend any existing video.",
|
||||
"URLs": [
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors"
|
||||
]
|
||||
}
|
||||
}
|
||||
13
defaults/sky_df_14B.json
Normal file
13
defaults/sky_df_14B.json
Normal file
@ -0,0 +1,13 @@
|
||||
{
|
||||
"model":
|
||||
{
|
||||
"name": "SkyReels2 Diffusion Forcing 540p 14B",
|
||||
"architecture" : "sky_df_14B",
|
||||
"description": "The SkyReels 2 Diffusion Forcing model has been designed to generate very long videos that exceeds the usual 5s limit. You can also use this model to extend any existing video.",
|
||||
"URLs": [
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/sky_reels2_diffusion_forcing_14B_bf16.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/sky_reels2_diffusion_forcing_14B_quanto_fp16_int8.safetensors"
|
||||
]
|
||||
}
|
||||
}
|
||||
14
defaults/sky_df_720p_14B.json
Normal file
14
defaults/sky_df_720p_14B.json
Normal file
@ -0,0 +1,14 @@
|
||||
{
|
||||
"model":
|
||||
{
|
||||
"name": "SkyReels2 Diffusion Forcing 720p 14B",
|
||||
"architecture" : "sky_df_14B",
|
||||
"description": "The SkyReels 2 Diffusion Forcing model has been designed to generate very long videos that exceeds the usual 5s limit. You can also use this model to extend any existing video.",
|
||||
"URLs": [
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/sky_reels2_diffusion_forcing_720p_14B_mbf16.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/sky_reels2_diffusion_forcing_720p_14B_quanto_mbf16_int8.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/sky_reels2_diffusion_forcing_720p_14B_quanto_mfp16_int8.safetensors"
|
||||
]
|
||||
},
|
||||
"resolution": "1280x720"
|
||||
}
|
||||
13
defaults/t2v.json
Normal file
13
defaults/t2v.json
Normal file
@ -0,0 +1,13 @@
|
||||
{
|
||||
"model":
|
||||
{
|
||||
"name": "Wan2.1 text2video 14B",
|
||||
"architecture" : "t2v",
|
||||
"description": "The original Wan Text 2 Video model. Most other models have been built on top of it",
|
||||
"URLs": [
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_text2video_14B_mbf16.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_text2video_14B_quanto_mbf16_int8.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_text2video_14B_quanto_mfp16_int8.safetensors"
|
||||
]
|
||||
}
|
||||
}
|
||||
11
defaults/t2v_1.3B.json
Normal file
11
defaults/t2v_1.3B.json
Normal file
@ -0,0 +1,11 @@
|
||||
{
|
||||
"model":
|
||||
{
|
||||
"name": "Wan2.1 text2video 1.3B",
|
||||
"architecture" : "t2v_1.3B",
|
||||
"description": "The light version of the original Wan Text 2 Video model. Most other models have been built on top of it",
|
||||
"URLs": [
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_text2video_1.3B_bf16.safetensors"
|
||||
]
|
||||
}
|
||||
}
|
||||
11
defaults/vace_1.3B.json
Normal file
11
defaults/vace_1.3B.json
Normal file
@ -0,0 +1,11 @@
|
||||
{
|
||||
"model":
|
||||
{
|
||||
"name": "Vace ControlNet 1.3B",
|
||||
"architecture" : "vace_1.3B",
|
||||
"description": "The Vace ControlNet model is a powerful model that allows you to control the content of the generated video based of additional custom data : pose or depth video, images or objects you want to see in the video.",
|
||||
"URLs": [
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_Vace_1.3B_mbf16.safetensors"
|
||||
]
|
||||
}
|
||||
}
|
||||
13
flux/__init__.py
Normal file
13
flux/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
try:
|
||||
from ._version import (
|
||||
version as __version__, # type: ignore
|
||||
version_tuple,
|
||||
)
|
||||
except ImportError:
|
||||
__version__ = "unknown (no version information available)"
|
||||
version_tuple = (0, 0, "unknown", "noinfo")
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
PACKAGE = __package__.replace("_", "-")
|
||||
PACKAGE_ROOT = Path(__file__).parent
|
||||
18
flux/__main__.py
Normal file
18
flux/__main__.py
Normal file
@ -0,0 +1,18 @@
|
||||
from fire import Fire
|
||||
|
||||
from .cli import main as cli_main
|
||||
from .cli_control import main as control_main
|
||||
from .cli_fill import main as fill_main
|
||||
from .cli_kontext import main as kontext_main
|
||||
from .cli_redux import main as redux_main
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(
|
||||
{
|
||||
"t2i": cli_main,
|
||||
"control": control_main,
|
||||
"fill": fill_main,
|
||||
"kontext": kontext_main,
|
||||
"redux": redux_main,
|
||||
}
|
||||
)
|
||||
21
flux/_version.py
Normal file
21
flux/_version.py
Normal file
@ -0,0 +1,21 @@
|
||||
# file generated by setuptools-scm
|
||||
# don't change, don't track in version control
|
||||
|
||||
__all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
|
||||
|
||||
TYPE_CHECKING = False
|
||||
if TYPE_CHECKING:
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
||||
else:
|
||||
VERSION_TUPLE = object
|
||||
|
||||
version: str
|
||||
__version__: str
|
||||
__version_tuple__: VERSION_TUPLE
|
||||
version_tuple: VERSION_TUPLE
|
||||
|
||||
__version__ = version = '0.0.post58+g1371b2b'
|
||||
__version_tuple__ = version_tuple = (0, 0, 'post58', 'g1371b2b')
|
||||
112
flux/flux_main.py
Normal file
112
flux/flux_main.py
Normal file
@ -0,0 +1,112 @@
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from glob import iglob
|
||||
from mmgp import offload as offload
|
||||
import torch
|
||||
from wan.utils.utils import calculate_new_dimensions
|
||||
from flux.sampling import denoise, get_schedule, prepare_kontext, unpack
|
||||
from flux.modules.layers import get_linear_split_map
|
||||
from flux.util import (
|
||||
aspect_ratio_to_height_width,
|
||||
load_ae,
|
||||
load_clip,
|
||||
load_flow_model,
|
||||
load_t5,
|
||||
save_image,
|
||||
)
|
||||
|
||||
class model_factory:
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_dir,
|
||||
model_filename = None,
|
||||
model_type = None,
|
||||
base_model_type = None,
|
||||
text_encoder_filename = None,
|
||||
quantizeTransformer = False,
|
||||
save_quantized = False,
|
||||
dtype = torch.bfloat16,
|
||||
VAE_dtype = torch.float32,
|
||||
mixed_precision_transformer = False
|
||||
):
|
||||
self.device = torch.device(f"cuda")
|
||||
self.VAE_dtype = VAE_dtype
|
||||
self.dtype = dtype
|
||||
torch_device = "cpu"
|
||||
|
||||
self.t5 = load_t5(torch_device, text_encoder_filename, max_length=512)
|
||||
self.clip = load_clip(torch_device)
|
||||
self.name= "flux-dev-kontext"
|
||||
self.model = load_flow_model(self.name, model_filename[0], torch_device)
|
||||
|
||||
self.vae = load_ae(self.name, device=torch_device)
|
||||
|
||||
# offload.change_dtype(self.model, dtype, True)
|
||||
if save_quantized:
|
||||
from wgp import save_quantized_model
|
||||
save_quantized_model(self.model, model_type, model_filename[0], dtype, None)
|
||||
|
||||
split_linear_modules_map = get_linear_split_map()
|
||||
self.model.split_linear_modules_map = split_linear_modules_map
|
||||
offload.split_linear_modules(self.model, split_linear_modules_map )
|
||||
|
||||
|
||||
def generate(
|
||||
self,
|
||||
seed: int | None = None,
|
||||
input_prompt: str = "replace the logo with the text 'Black Forest Labs'",
|
||||
sampling_steps: int = 20,
|
||||
input_ref_images = None,
|
||||
width= 832,
|
||||
height=480,
|
||||
guide_scale: float = 2.5,
|
||||
fit_into_canvas = None,
|
||||
callback = None,
|
||||
loras_slists = None,
|
||||
frame_num = 1,
|
||||
**bbargs
|
||||
):
|
||||
|
||||
if self._interrupt:
|
||||
return None
|
||||
|
||||
rng = torch.Generator(device="cuda")
|
||||
if seed is None:
|
||||
seed = rng.seed()
|
||||
|
||||
if input_ref_images != None and len(input_ref_images) > 0:
|
||||
image_ref = input_ref_images[0]
|
||||
w, h = image_ref.size
|
||||
height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
|
||||
|
||||
inp, height, width = prepare_kontext(
|
||||
t5=self.t5,
|
||||
clip=self.clip,
|
||||
prompt=input_prompt,
|
||||
ae=self.vae,
|
||||
img_cond=image_ref,
|
||||
target_width=width,
|
||||
target_height=height,
|
||||
bs=frame_num,
|
||||
seed=seed,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
inp.pop("img_cond_orig")
|
||||
timesteps = get_schedule(sampling_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell"))
|
||||
def unpack_latent(x):
|
||||
return unpack(x.float(), height, width)
|
||||
# denoise initial noise
|
||||
x = denoise(self.model, **inp, timesteps=timesteps, guidance=guide_scale, callback=callback, pipeline=self, loras_slists= loras_slists, unpack_latent = unpack_latent)
|
||||
if x==None: return None
|
||||
# decode latents to pixel space
|
||||
x = unpack_latent(x)
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
x = self.vae.decode(x)
|
||||
|
||||
x = x.clamp(-1, 1)
|
||||
x = x.transpose(0, 1)
|
||||
return x
|
||||
|
||||
54
flux/math.py
Normal file
54
flux/math.py
Normal file
@ -0,0 +1,54 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
from wan.modules.attention import pay_attention
|
||||
|
||||
|
||||
def attention(qkv_list, pe: Tensor) -> Tensor:
|
||||
q, k, v = qkv_list
|
||||
qkv_list.clear()
|
||||
q_list = [q]
|
||||
q = None
|
||||
q = apply_rope_(q_list, pe)
|
||||
k_list = [k]
|
||||
k = None
|
||||
k = apply_rope_(k_list, pe)
|
||||
qkv_list = [q.transpose(1,2), k.transpose(1,2) ,v.transpose(1,2)]
|
||||
del q,k, v
|
||||
x = pay_attention(qkv_list).transpose(1,2)
|
||||
# x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
x = rearrange(x, "B H L D -> B L (H D)")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
assert dim % 2 == 0
|
||||
scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
|
||||
omega = 1.0 / (theta**scale)
|
||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
return out.float()
|
||||
|
||||
|
||||
def apply_rope_(q_list, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
||||
xq= q_list[0]
|
||||
xqshape = xq.shape
|
||||
xqdtype= xq.dtype
|
||||
q_list.clear()
|
||||
xq = xq.float().reshape(*xqshape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq[..., 0]
|
||||
xq = freqs_cis[..., 1] * xq[..., 1]
|
||||
|
||||
xq_out.add_(xq)
|
||||
# xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
|
||||
return xq_out.reshape(*xqshape).to(xqdtype)
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
168
flux/model.py
Normal file
168
flux/model.py
Normal file
@ -0,0 +1,168 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from flux.modules.layers import (
|
||||
DoubleStreamBlock,
|
||||
EmbedND,
|
||||
LastLayer,
|
||||
MLPEmbedder,
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
)
|
||||
from flux.modules.lora import LinearLora, replace_linear_with_lora
|
||||
|
||||
|
||||
@dataclass
|
||||
class FluxParams:
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
vec_in_dim: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
mlp_ratio: float
|
||||
num_heads: int
|
||||
depth: int
|
||||
depth_single_blocks: int
|
||||
axes_dim: list[int]
|
||||
theta: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
|
||||
|
||||
class Flux(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(self, params: FluxParams):
|
||||
super().__init__()
|
||||
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = params.out_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
)
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
||||
)
|
||||
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
||||
|
||||
def preprocess_loras(self, model_type, sd):
|
||||
new_sd = {}
|
||||
if len(sd) == 0: return sd
|
||||
|
||||
first_key= next(iter(sd))
|
||||
if first_key.startswith("transformer."):
|
||||
src_list = [".attn.to_q.", ".attn.to_k.", ".attn.to_v."]
|
||||
tgt_list = [".linear1_attn_q.", ".linear1_attn_k.", ".linear1_attn_v."]
|
||||
for k,v in sd.items():
|
||||
k = k.replace("transformer.single_transformer_blocks", "diffusion_model.single_blocks")
|
||||
k = k.replace("transformer.double_transformer_blocks", "diffusion_model.double_blocks")
|
||||
for src, tgt in zip(src_list, tgt_list):
|
||||
k = k.replace(src, tgt)
|
||||
|
||||
new_sd[k] = v
|
||||
|
||||
return new_sd
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor | None = None,
|
||||
callback= None,
|
||||
pipeline =None,
|
||||
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||
if self.params.guidance_embed:
|
||||
if guidance is None:
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec += self.guidance_in(timestep_embedding(guidance, 256))
|
||||
vec += self.vector_in(y)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
for block in self.double_blocks:
|
||||
if callback != None:
|
||||
callback(-1, None, False, True)
|
||||
if pipeline._interrupt:
|
||||
return None
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
for block in self.single_blocks:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
|
||||
class FluxLoraWrapper(Flux):
|
||||
def __init__(
|
||||
self,
|
||||
lora_rank: int = 128,
|
||||
lora_scale: float = 1.0,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.lora_rank = lora_rank
|
||||
|
||||
replace_linear_with_lora(
|
||||
self,
|
||||
max_rank=lora_rank,
|
||||
scale=lora_scale,
|
||||
)
|
||||
|
||||
def set_lora_scale(self, scale: float) -> None:
|
||||
for module in self.modules():
|
||||
if isinstance(module, LinearLora):
|
||||
module.set_scale(scale=scale)
|
||||
320
flux/modules/autoencoder.py
Normal file
320
flux/modules/autoencoder.py
Normal file
@ -0,0 +1,320 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoEncoderParams:
|
||||
resolution: int
|
||||
in_channels: int
|
||||
ch: int
|
||||
out_ch: int
|
||||
ch_mult: list[int]
|
||||
num_res_blocks: int
|
||||
z_channels: int
|
||||
scale_factor: float
|
||||
shift_factor: float
|
||||
|
||||
|
||||
def swish(x: Tensor) -> Tensor:
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
|
||||
def attention(self, h_: Tensor) -> Tensor:
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
b, c, h, w = q.shape
|
||||
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
||||
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
||||
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
||||
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x + self.proj_out(self.attention(x))
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = swish(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
h = self.norm2(h)
|
||||
h = swish(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
pad = (0, 1, 0, 1)
|
||||
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
resolution: int,
|
||||
in_channels: int,
|
||||
ch: int,
|
||||
ch_mult: list[int],
|
||||
num_res_blocks: int,
|
||||
z_channels: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
# downsampling
|
||||
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
block_in = self.ch
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for _ in range(self.num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||
block_in = block_out
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
|
||||
# end
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
||||
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1])
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = swish(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ch: int,
|
||||
out_ch: int,
|
||||
ch_mult: list[int],
|
||||
num_res_blocks: int,
|
||||
in_channels: int,
|
||||
resolution: int,
|
||||
z_channels: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.ffactor = 2 ** (self.num_resolutions - 1)
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for _ in range(self.num_res_blocks + 1):
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||
block_in = block_out
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
||||
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, z: Tensor) -> Tensor:
|
||||
# get dtype for proper tracing
|
||||
upscale_dtype = next(self.up.parameters()).dtype
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
|
||||
# cast to proper dtype
|
||||
h = h.to(upscale_dtype)
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = swish(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class DiagonalGaussian(nn.Module):
|
||||
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
||||
super().__init__()
|
||||
self.sample = sample
|
||||
self.chunk_dim = chunk_dim
|
||||
|
||||
def forward(self, z: Tensor) -> Tensor:
|
||||
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
||||
if self.sample:
|
||||
std = torch.exp(0.5 * logvar)
|
||||
return mean + std * torch.randn_like(mean)
|
||||
else:
|
||||
return mean
|
||||
|
||||
|
||||
class AutoEncoder(nn.Module):
|
||||
def __init__(self, params: AutoEncoderParams, sample_z: bool = False):
|
||||
super().__init__()
|
||||
self.params = params
|
||||
self.encoder = Encoder(
|
||||
resolution=params.resolution,
|
||||
in_channels=params.in_channels,
|
||||
ch=params.ch,
|
||||
ch_mult=params.ch_mult,
|
||||
num_res_blocks=params.num_res_blocks,
|
||||
z_channels=params.z_channels,
|
||||
)
|
||||
self.decoder = Decoder(
|
||||
resolution=params.resolution,
|
||||
in_channels=params.in_channels,
|
||||
ch=params.ch,
|
||||
out_ch=params.out_ch,
|
||||
ch_mult=params.ch_mult,
|
||||
num_res_blocks=params.num_res_blocks,
|
||||
z_channels=params.z_channels,
|
||||
)
|
||||
self.reg = DiagonalGaussian(sample=sample_z)
|
||||
|
||||
self.scale_factor = params.scale_factor
|
||||
self.shift_factor = params.shift_factor
|
||||
|
||||
def get_VAE_tile_size(*args, **kwargs):
|
||||
return []
|
||||
def encode(self, x: Tensor) -> Tensor:
|
||||
z = self.reg(self.encoder(x))
|
||||
z = self.scale_factor * (z - self.shift_factor)
|
||||
return z
|
||||
|
||||
def decode(self, z: Tensor) -> Tensor:
|
||||
z = z / self.scale_factor + self.shift_factor
|
||||
return self.decoder(z)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.decode(self.encode(x))
|
||||
38
flux/modules/conditioner.py
Normal file
38
flux/modules/conditioner.py
Normal file
@ -0,0 +1,38 @@
|
||||
from torch import Tensor, nn
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||
import os
|
||||
|
||||
class HFEmbedder(nn.Module):
|
||||
def __init__(self, version: str, text_encoder_filename, max_length: int, is_clip = False, **hf_kwargs):
|
||||
super().__init__()
|
||||
self.is_clip = is_clip
|
||||
self.max_length = max_length
|
||||
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
||||
|
||||
if is_clip:
|
||||
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
|
||||
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
|
||||
else:
|
||||
from mmgp import offload as offloadobj
|
||||
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(os.path.dirname(text_encoder_filename), max_length=max_length)
|
||||
self.hf_module: T5EncoderModel = offloadobj.fast_load_transformers_model(text_encoder_filename)
|
||||
|
||||
self.hf_module = self.hf_module.eval().requires_grad_(False)
|
||||
|
||||
def forward(self, text: list[str]) -> Tensor:
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=False,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
outputs = self.hf_module(
|
||||
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
|
||||
attention_mask=None,
|
||||
output_hidden_states=False,
|
||||
)
|
||||
return outputs[self.output_key].bfloat16()
|
||||
99
flux/modules/image_embedders.py
Normal file
99
flux/modules/image_embedders.py
Normal file
@ -0,0 +1,99 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file as load_sft
|
||||
from torch import nn
|
||||
from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel
|
||||
|
||||
from flux.util import print_load_warning
|
||||
|
||||
|
||||
class DepthImageEncoder:
|
||||
depth_model_name = "LiheYoung/depth-anything-large-hf"
|
||||
|
||||
def __init__(self, device):
|
||||
self.device = device
|
||||
self.depth_model = AutoModelForDepthEstimation.from_pretrained(self.depth_model_name).to(device)
|
||||
self.processor = AutoProcessor.from_pretrained(self.depth_model_name)
|
||||
|
||||
def __call__(self, img: torch.Tensor) -> torch.Tensor:
|
||||
hw = img.shape[-2:]
|
||||
|
||||
img = torch.clamp(img, -1.0, 1.0)
|
||||
img_byte = ((img + 1.0) * 127.5).byte()
|
||||
|
||||
img = self.processor(img_byte, return_tensors="pt")["pixel_values"]
|
||||
depth = self.depth_model(img.to(self.device)).predicted_depth
|
||||
depth = repeat(depth, "b h w -> b 3 h w")
|
||||
depth = torch.nn.functional.interpolate(depth, hw, mode="bicubic", antialias=True)
|
||||
|
||||
depth = depth / 127.5 - 1.0
|
||||
return depth
|
||||
|
||||
|
||||
class CannyImageEncoder:
|
||||
def __init__(
|
||||
self,
|
||||
device,
|
||||
min_t: int = 50,
|
||||
max_t: int = 200,
|
||||
):
|
||||
self.device = device
|
||||
self.min_t = min_t
|
||||
self.max_t = max_t
|
||||
|
||||
def __call__(self, img: torch.Tensor) -> torch.Tensor:
|
||||
assert img.shape[0] == 1, "Only batch size 1 is supported"
|
||||
|
||||
img = rearrange(img[0], "c h w -> h w c")
|
||||
img = torch.clamp(img, -1.0, 1.0)
|
||||
img_np = ((img + 1.0) * 127.5).numpy().astype(np.uint8)
|
||||
|
||||
# Apply Canny edge detection
|
||||
canny = cv2.Canny(img_np, self.min_t, self.max_t)
|
||||
|
||||
# Convert back to torch tensor and reshape
|
||||
canny = torch.from_numpy(canny).float() / 127.5 - 1.0
|
||||
canny = rearrange(canny, "h w -> 1 1 h w")
|
||||
canny = repeat(canny, "b 1 ... -> b 3 ...")
|
||||
return canny.to(self.device)
|
||||
|
||||
|
||||
class ReduxImageEncoder(nn.Module):
|
||||
siglip_model_name = "google/siglip-so400m-patch14-384"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device,
|
||||
redux_path: str,
|
||||
redux_dim: int = 1152,
|
||||
txt_in_features: int = 4096,
|
||||
dtype=torch.bfloat16,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.redux_dim = redux_dim
|
||||
self.device = device if isinstance(device, torch.device) else torch.device(device)
|
||||
self.dtype = dtype
|
||||
|
||||
with self.device:
|
||||
self.redux_up = nn.Linear(redux_dim, txt_in_features * 3, dtype=dtype)
|
||||
self.redux_down = nn.Linear(txt_in_features * 3, txt_in_features, dtype=dtype)
|
||||
|
||||
sd = load_sft(redux_path, device=str(device))
|
||||
missing, unexpected = self.load_state_dict(sd, strict=False, assign=True)
|
||||
print_load_warning(missing, unexpected)
|
||||
|
||||
self.siglip = SiglipVisionModel.from_pretrained(self.siglip_model_name).to(dtype=dtype)
|
||||
self.normalize = SiglipImageProcessor.from_pretrained(self.siglip_model_name)
|
||||
|
||||
def __call__(self, x: Image.Image) -> torch.Tensor:
|
||||
imgs = self.normalize.preprocess(images=[x], do_resize=True, return_tensors="pt", do_convert_rgb=True)
|
||||
|
||||
_encoded_x = self.siglip(**imgs.to(device=self.device, dtype=self.dtype)).last_hidden_state
|
||||
|
||||
projected_x = self.redux_down(nn.functional.silu(self.redux_up(_encoded_x)))
|
||||
|
||||
return projected_x
|
||||
327
flux/modules/layers copy.py
Normal file
327
flux/modules/layers copy.py
Normal file
@ -0,0 +1,327 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor, nn
|
||||
|
||||
from flux.math import attention, rope
|
||||
|
||||
def get_linear_split_map():
|
||||
hidden_size = 3072
|
||||
_modules_map = {
|
||||
"qkv" : {"mapped_modules" : ["q", "k", "v"] , "split_sizes": [hidden_size, hidden_size, hidden_size]},
|
||||
"linear1" : {"mapped_modules" : ["linear1_attn_q", "linear1_attn_k", "linear1_attn_v", "linear1_mlp"] , "split_sizes": [hidden_size, hidden_size, hidden_size, 7*hidden_size- 3*hidden_size]}
|
||||
}
|
||||
return split_linear_modules_map
|
||||
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: Tensor) -> Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
emb = torch.cat(
|
||||
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
||||
dim=-3,
|
||||
)
|
||||
|
||||
return emb.unsqueeze(1)
|
||||
|
||||
|
||||
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
t = time_factor * t
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
||||
t.device
|
||||
)
|
||||
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
if torch.is_floating_point(t):
|
||||
embedding = embedding.to(t)
|
||||
return embedding
|
||||
|
||||
|
||||
class MLPEmbedder(nn.Module):
|
||||
def __init__(self, in_dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
||||
self.silu = nn.SiLU()
|
||||
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.out_layer(self.silu(self.in_layer(x)))
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.scale = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
x_dtype = x.dtype
|
||||
x = x.float()
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||
return (x * rrms).to(dtype=x_dtype) * self.scale
|
||||
|
||||
|
||||
class QKNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.query_norm = RMSNorm(dim)
|
||||
self.key_norm = RMSNorm(dim)
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
||||
q = self.query_norm(q)
|
||||
k = self.key_norm(k)
|
||||
return q.to(v), k.to(v)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.norm = QKNorm(head_dim)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k, v)
|
||||
x = attention(q, k, v, pe=pe)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModulationOut:
|
||||
shift: Tensor
|
||||
scale: Tensor
|
||||
gate: Tensor
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, dim: int, double: bool):
|
||||
super().__init__()
|
||||
self.is_double = double
|
||||
self.multiplier = 6 if double else 3
|
||||
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
||||
|
||||
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
||||
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
||||
|
||||
return (
|
||||
ModulationOut(*out[:3]),
|
||||
ModulationOut(*out[3:]) if self.is_double else None,
|
||||
)
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_mod = Modulation(hidden_size, double=True)
|
||||
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||
|
||||
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
self.txt_mod = Modulation(hidden_size, double=True)
|
||||
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||
|
||||
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated.mul_(1 + img_mod1.scale)
|
||||
img_modulated.add_(img_mod1.shift)
|
||||
|
||||
shape = (*img_modulated.shape[:2], self.num_heads, int(img_modulated.shape[-1] / self.num_heads) )
|
||||
img_q = self.img_attn.q(img_modulated).view(*shape).transpose(1,2)
|
||||
img_k = self.img_attn.k(img_modulated).view(*shape).transpose(1,2)
|
||||
img_v = self.img_attn.v(img_modulated).view(*shape).transpose(1,2)
|
||||
del img_modulated
|
||||
|
||||
# img_qkv = self.img_attn.qkv(img_modulated)
|
||||
# img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated.mul_(1 + txt_mod1.scale)
|
||||
txt_modulated.add_(txt_mod1.shift)
|
||||
# txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
|
||||
shape = (*txt_modulated.shape[:2], self.num_heads, int(txt_modulated.shape[-1] / self.num_heads) )
|
||||
txt_q = self.txt_attn.q(txt_modulated).view(*shape).transpose(1,2)
|
||||
txt_k = self.txt_attn.k(txt_modulated).view(*shape).transpose(1,2)
|
||||
txt_v = self.txt_attn.v(txt_modulated).view(*shape).transpose(1,2)
|
||||
del txt_modulated
|
||||
|
||||
# txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
# txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
# run actual attention
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
|
||||
qkv_list = [q, k, v]
|
||||
del q, k, v
|
||||
attn = attention(qkv_list, pe=pe)
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img blocks
|
||||
img.addcmul_(self.img_attn.proj(img_attn), img_mod1.gate)
|
||||
img.addcmul_(self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift), img_mod2.gate)
|
||||
|
||||
# img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
# img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
|
||||
# calculate the txt blocks
|
||||
txt.addcmul_(self.txt_attn.proj(txt_attn), txt_mod1.gate)
|
||||
txt.addcmul_(self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift), txt_mod2.gate)
|
||||
# txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
# txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
return img, txt
|
||||
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
"""
|
||||
A DiT block with parallel linear layers as described in
|
||||
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: float | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
head_dim = hidden_size // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
# qkv and mlp_in
|
||||
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
||||
# proj and mlp_out
|
||||
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
||||
|
||||
self.norm = QKNorm(head_dim)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False)
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
x_mod = self.pre_norm(x)
|
||||
x_mod.mul_(1 + mod.scale)
|
||||
x_mod.add_(mod.shift)
|
||||
|
||||
##### More spagheti VRAM optimizations done by DeepBeepMeep !
|
||||
# I am sure you are a nice person and as you copy this code, you will give me proper credits:
|
||||
# Please link to https://github.com/deepbeepmeep/Wan2GP and @deepbeepmeep on twitter
|
||||
|
||||
# x_mod = (1 + mod.scale) * x + mod.shift
|
||||
|
||||
shape = (*x_mod.shape[:2], self.num_heads, int(x_mod.shape[-1] / self.num_heads) )
|
||||
q = self.linear1_attn_q(x_mod).view(*shape).transpose(1,2)
|
||||
k = self.linear1_attn_k(x_mod).view(*shape).transpose(1,2)
|
||||
v = self.linear1_attn_v(x_mod).view(*shape).transpose(1,2)
|
||||
|
||||
# shape = (*txt_mod.shape[:2], self.heads_num, int(txt_mod.shape[-1] / self.heads_num) )
|
||||
# txt_q = self.linear1_attn_q(txt_mod).view(*shape)
|
||||
# txt_k = self.linear1_attn_k(txt_mod).view(*shape)
|
||||
# txt_v = self.linear1_attn_v(txt_mod).view(*shape)
|
||||
|
||||
# qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
# q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
qkv_list = [q, k, v]
|
||||
del q, k, v
|
||||
attn = attention(qkv_list, pe=pe)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
|
||||
x_mod_shape = x_mod.shape
|
||||
x_mod = x_mod.view(-1, x_mod.shape[-1])
|
||||
chunk_size = int(x_mod_shape[1]/6)
|
||||
x_chunks = torch.split(x_mod, chunk_size)
|
||||
attn = attn.view(-1, attn.shape[-1])
|
||||
attn_chunks =torch.split(attn, chunk_size)
|
||||
for x_chunk, attn_chunk in zip(x_chunks, attn_chunks):
|
||||
mlp_chunk = self.linear1_mlp(x_chunk)
|
||||
mlp_chunk = self.mlp_act(mlp_chunk)
|
||||
attn_mlp_chunk = torch.cat((attn_chunk, mlp_chunk), -1)
|
||||
del attn_chunk, mlp_chunk
|
||||
x_chunk[...] = self.linear2(attn_mlp_chunk)
|
||||
del attn_mlp_chunk
|
||||
x_mod = x_mod.view(x_mod_shape)
|
||||
x.addcmul_(x_mod, mod.gate)
|
||||
return x
|
||||
|
||||
# output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
# return x + mod.gate * output
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
||||
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
x = self.linear(x)
|
||||
return x
|
||||
328
flux/modules/layers.py
Normal file
328
flux/modules/layers.py
Normal file
@ -0,0 +1,328 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor, nn
|
||||
|
||||
from flux.math import attention, rope
|
||||
|
||||
def get_linear_split_map():
|
||||
hidden_size = 3072
|
||||
split_linear_modules_map = {
|
||||
"qkv" : {"mapped_modules" : ["q", "k", "v"] , "split_sizes": [hidden_size, hidden_size, hidden_size]},
|
||||
"linear1" : {"mapped_modules" : ["linear1_attn_q", "linear1_attn_k", "linear1_attn_v", "linear1_mlp"] , "split_sizes": [hidden_size, hidden_size, hidden_size, 7*hidden_size- 3*hidden_size]}
|
||||
}
|
||||
return split_linear_modules_map
|
||||
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: Tensor) -> Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
emb = torch.cat(
|
||||
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
||||
dim=-3,
|
||||
)
|
||||
|
||||
return emb.unsqueeze(1)
|
||||
|
||||
|
||||
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
t = time_factor * t
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
||||
t.device
|
||||
)
|
||||
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
if torch.is_floating_point(t):
|
||||
embedding = embedding.to(t)
|
||||
return embedding
|
||||
|
||||
|
||||
class MLPEmbedder(nn.Module):
|
||||
def __init__(self, in_dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
||||
self.silu = nn.SiLU()
|
||||
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.out_layer(self.silu(self.in_layer(x)))
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.scale = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
x_dtype = x.dtype
|
||||
x = x.float()
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||
return (x * rrms).to(dtype=x_dtype) * self.scale
|
||||
|
||||
|
||||
|
||||
class QKNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.query_norm = RMSNorm(dim)
|
||||
self.key_norm = RMSNorm(dim)
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
||||
if k != None:
|
||||
return self.key_norm(k).to(v)
|
||||
else:
|
||||
return self.query_norm(q).to(v)
|
||||
# q = self.query_norm(q)
|
||||
# k = self.key_norm(k)
|
||||
# return q.to(v), k.to(v)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.norm = QKNorm(head_dim)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
||||
raise Exception("not implemented")
|
||||
|
||||
@dataclass
|
||||
class ModulationOut:
|
||||
shift: Tensor
|
||||
scale: Tensor
|
||||
gate: Tensor
|
||||
|
||||
|
||||
def split_mlp(mlp, x, divide = 4):
|
||||
x_shape = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
chunk_size = int(x_shape[1]/divide)
|
||||
x_chunks = torch.split(x, chunk_size)
|
||||
for i, x_chunk in enumerate(x_chunks):
|
||||
mlp_chunk = mlp[0](x_chunk)
|
||||
mlp_chunk = mlp[1](mlp_chunk)
|
||||
x_chunk[...] = mlp[2](mlp_chunk)
|
||||
return x.reshape(x_shape)
|
||||
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, dim: int, double: bool):
|
||||
super().__init__()
|
||||
self.is_double = double
|
||||
self.multiplier = 6 if double else 3
|
||||
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
||||
|
||||
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
||||
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
||||
|
||||
return (
|
||||
ModulationOut(*out[:3]),
|
||||
ModulationOut(*out[3:]) if self.is_double else None,
|
||||
)
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_mod = Modulation(hidden_size, double=True)
|
||||
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||
|
||||
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
self.txt_mod = Modulation(hidden_size, double=True)
|
||||
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||
|
||||
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated.mul_(1 + img_mod1.scale)
|
||||
img_modulated.add_(img_mod1.shift)
|
||||
|
||||
shape = (*img_modulated.shape[:2], self.num_heads, int(img_modulated.shape[-1] / self.num_heads) )
|
||||
img_q = self.img_attn.q(img_modulated).view(*shape).transpose(1,2)
|
||||
img_k = self.img_attn.k(img_modulated).view(*shape).transpose(1,2)
|
||||
img_v = self.img_attn.v(img_modulated).view(*shape).transpose(1,2)
|
||||
del img_modulated
|
||||
|
||||
|
||||
img_q= self.img_attn.norm(img_q, None, img_v)
|
||||
img_k = self.img_attn.norm(None, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated.mul_(1 + txt_mod1.scale)
|
||||
txt_modulated.add_(txt_mod1.shift)
|
||||
|
||||
shape = (*txt_modulated.shape[:2], self.num_heads, int(txt_modulated.shape[-1] / self.num_heads) )
|
||||
txt_q = self.txt_attn.q(txt_modulated).view(*shape).transpose(1,2)
|
||||
txt_k = self.txt_attn.k(txt_modulated).view(*shape).transpose(1,2)
|
||||
txt_v = self.txt_attn.v(txt_modulated).view(*shape).transpose(1,2)
|
||||
del txt_modulated
|
||||
|
||||
|
||||
txt_q = self.txt_attn.norm(txt_q, None, txt_v)
|
||||
txt_k = self.txt_attn.norm(None, txt_k, txt_v)
|
||||
|
||||
# run actual attention
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
del txt_q, img_q
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
del txt_k, img_k
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
del txt_v, img_v
|
||||
|
||||
qkv_list = [q, k, v]
|
||||
del q, k, v
|
||||
attn = attention(qkv_list, pe=pe)
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img blocks
|
||||
img.addcmul_(self.img_attn.proj(img_attn), img_mod1.gate)
|
||||
mod_img = self.img_norm2(img)
|
||||
mod_img.mul_(1 + img_mod2.scale)
|
||||
mod_img.add_(img_mod2.shift)
|
||||
mod_img = split_mlp(self.img_mlp, mod_img)
|
||||
# mod_img = self.img_mlp(mod_img)
|
||||
img.addcmul_( mod_img, img_mod2.gate)
|
||||
mod_img = None
|
||||
|
||||
# calculate the txt blocks
|
||||
txt.addcmul_(self.txt_attn.proj(txt_attn), txt_mod1.gate)
|
||||
txt.addcmul_(self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift), txt_mod2.gate)
|
||||
return img, txt
|
||||
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
"""
|
||||
A DiT block with parallel linear layers as described in
|
||||
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: float | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
head_dim = hidden_size // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
# qkv and mlp_in
|
||||
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
||||
# proj and mlp_out
|
||||
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
||||
|
||||
self.norm = QKNorm(head_dim)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False)
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
x_mod = self.pre_norm(x)
|
||||
x_mod.mul_(1 + mod.scale)
|
||||
x_mod.add_(mod.shift)
|
||||
|
||||
##### More spagheti VRAM optimizations done by DeepBeepMeep !
|
||||
# I am sure you are a nice person and as you copy this code, you will give me proper credits:
|
||||
# Please link to https://github.com/deepbeepmeep/Wan2GP and @deepbeepmeep on twitter
|
||||
|
||||
# x_mod = (1 + mod.scale) * x + mod.shift
|
||||
|
||||
shape = (*x_mod.shape[:2], self.num_heads, int(x_mod.shape[-1] / self.num_heads) )
|
||||
q = self.linear1_attn_q(x_mod).view(*shape).transpose(1,2)
|
||||
k = self.linear1_attn_k(x_mod).view(*shape).transpose(1,2)
|
||||
v = self.linear1_attn_v(x_mod).view(*shape).transpose(1,2)
|
||||
|
||||
q = self.norm(q, None, v)
|
||||
k = self.norm(None, k, v)
|
||||
|
||||
# compute attention
|
||||
qkv_list = [q, k, v]
|
||||
del q, k, v
|
||||
attn = attention(qkv_list, pe=pe)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
|
||||
x_mod_shape = x_mod.shape
|
||||
x_mod = x_mod.view(-1, x_mod.shape[-1])
|
||||
chunk_size = int(x_mod_shape[1]/6)
|
||||
x_chunks = torch.split(x_mod, chunk_size)
|
||||
attn = attn.view(-1, attn.shape[-1])
|
||||
attn_chunks =torch.split(attn, chunk_size)
|
||||
for x_chunk, attn_chunk in zip(x_chunks, attn_chunks):
|
||||
mlp_chunk = self.linear1_mlp(x_chunk)
|
||||
mlp_chunk = self.mlp_act(mlp_chunk)
|
||||
attn_mlp_chunk = torch.cat((attn_chunk, mlp_chunk), -1)
|
||||
del attn_chunk, mlp_chunk
|
||||
x_chunk[...] = self.linear2(attn_mlp_chunk)
|
||||
del attn_mlp_chunk
|
||||
x_mod = x_mod.view(x_mod_shape)
|
||||
x.addcmul_(x_mod, mod.gate)
|
||||
return x
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
||||
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
x = self.linear(x)
|
||||
return x
|
||||
94
flux/modules/lora.py
Normal file
94
flux/modules/lora.py
Normal file
@ -0,0 +1,94 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def replace_linear_with_lora(
|
||||
module: nn.Module,
|
||||
max_rank: int,
|
||||
scale: float = 1.0,
|
||||
) -> None:
|
||||
for name, child in module.named_children():
|
||||
if isinstance(child, nn.Linear):
|
||||
new_lora = LinearLora(
|
||||
in_features=child.in_features,
|
||||
out_features=child.out_features,
|
||||
bias=child.bias,
|
||||
rank=max_rank,
|
||||
scale=scale,
|
||||
dtype=child.weight.dtype,
|
||||
device=child.weight.device,
|
||||
)
|
||||
|
||||
new_lora.weight = child.weight
|
||||
new_lora.bias = child.bias if child.bias is not None else None
|
||||
|
||||
setattr(module, name, new_lora)
|
||||
else:
|
||||
replace_linear_with_lora(
|
||||
module=child,
|
||||
max_rank=max_rank,
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
|
||||
class LinearLora(nn.Linear):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool,
|
||||
rank: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
lora_bias: bool = True,
|
||||
scale: float = 1.0,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias is not None,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
assert isinstance(scale, float), "scale must be a float"
|
||||
|
||||
self.scale = scale
|
||||
self.rank = rank
|
||||
self.lora_bias = lora_bias
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
if rank > (new_rank := min(self.out_features, self.in_features)):
|
||||
self.rank = new_rank
|
||||
|
||||
self.lora_A = nn.Linear(
|
||||
in_features=in_features,
|
||||
out_features=self.rank,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.lora_B = nn.Linear(
|
||||
in_features=self.rank,
|
||||
out_features=out_features,
|
||||
bias=self.lora_bias,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def set_scale(self, scale: float) -> None:
|
||||
assert isinstance(scale, float), "scalar value must be a float"
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
base_out = super().forward(input)
|
||||
|
||||
_lora_out_B = self.lora_B(self.lora_A(input))
|
||||
lora_update = _lora_out_B * self.scale
|
||||
|
||||
return base_out + lora_update
|
||||
392
flux/sampling.py
Normal file
392
flux/sampling.py
Normal file
@ -0,0 +1,392 @@
|
||||
import math
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
|
||||
from .model import Flux
|
||||
from .modules.autoencoder import AutoEncoder
|
||||
from .modules.conditioner import HFEmbedder
|
||||
from .modules.image_embedders import CannyImageEncoder, DepthImageEncoder, ReduxImageEncoder
|
||||
from .util import PREFERED_KONTEXT_RESOLUTIONS
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
def get_noise(
|
||||
num_samples: int,
|
||||
height: int,
|
||||
width: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
):
|
||||
return torch.randn(
|
||||
num_samples,
|
||||
16,
|
||||
# allow for packing
|
||||
2 * math.ceil(height / 16),
|
||||
2 * math.ceil(width / 16),
|
||||
dtype=dtype,
|
||||
generator=torch.Generator(device="cuda").manual_seed(seed),
|
||||
).to(device)
|
||||
|
||||
|
||||
def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
|
||||
bs, c, h, w = img.shape
|
||||
if bs == 1 and not isinstance(prompt, str):
|
||||
bs = len(prompt)
|
||||
|
||||
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if img.shape[0] == 1 and bs > 1:
|
||||
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
txt = t5(prompt)
|
||||
if txt.shape[0] == 1 and bs > 1:
|
||||
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
|
||||
txt_ids = torch.zeros(bs, txt.shape[1], 3)
|
||||
|
||||
vec = clip(prompt)
|
||||
if vec.shape[0] == 1 and bs > 1:
|
||||
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
return {
|
||||
"img": img,
|
||||
"img_ids": img_ids.to(img.device),
|
||||
"txt": txt.to(img.device),
|
||||
"txt_ids": txt_ids.to(img.device),
|
||||
"vec": vec.to(img.device),
|
||||
}
|
||||
|
||||
|
||||
def prepare_control(
|
||||
t5: HFEmbedder,
|
||||
clip: HFEmbedder,
|
||||
img: Tensor,
|
||||
prompt: str | list[str],
|
||||
ae: AutoEncoder,
|
||||
encoder: DepthImageEncoder | CannyImageEncoder,
|
||||
img_cond_path: str,
|
||||
) -> dict[str, Tensor]:
|
||||
# load and encode the conditioning image
|
||||
bs, _, h, w = img.shape
|
||||
if bs == 1 and not isinstance(prompt, str):
|
||||
bs = len(prompt)
|
||||
|
||||
img_cond = Image.open(img_cond_path).convert("RGB")
|
||||
|
||||
width = w * 8
|
||||
height = h * 8
|
||||
img_cond = img_cond.resize((width, height), Image.Resampling.LANCZOS)
|
||||
img_cond = np.array(img_cond)
|
||||
img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0
|
||||
img_cond = rearrange(img_cond, "h w c -> 1 c h w")
|
||||
|
||||
with torch.no_grad():
|
||||
img_cond = encoder(img_cond)
|
||||
img_cond = ae.encode(img_cond)
|
||||
|
||||
img_cond = img_cond.to(torch.bfloat16)
|
||||
img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if img_cond.shape[0] == 1 and bs > 1:
|
||||
img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
return_dict = prepare(t5, clip, img, prompt)
|
||||
return_dict["img_cond"] = img_cond
|
||||
return return_dict
|
||||
|
||||
|
||||
def prepare_fill(
|
||||
t5: HFEmbedder,
|
||||
clip: HFEmbedder,
|
||||
img: Tensor,
|
||||
prompt: str | list[str],
|
||||
ae: AutoEncoder,
|
||||
img_cond_path: str,
|
||||
mask_path: str,
|
||||
) -> dict[str, Tensor]:
|
||||
# load and encode the conditioning image and the mask
|
||||
bs, _, _, _ = img.shape
|
||||
if bs == 1 and not isinstance(prompt, str):
|
||||
bs = len(prompt)
|
||||
|
||||
img_cond = Image.open(img_cond_path).convert("RGB")
|
||||
img_cond = np.array(img_cond)
|
||||
img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0
|
||||
img_cond = rearrange(img_cond, "h w c -> 1 c h w")
|
||||
|
||||
mask = Image.open(mask_path).convert("L")
|
||||
mask = np.array(mask)
|
||||
mask = torch.from_numpy(mask).float() / 255.0
|
||||
mask = rearrange(mask, "h w -> 1 1 h w")
|
||||
|
||||
with torch.no_grad():
|
||||
img_cond = img_cond.to(img.device)
|
||||
mask = mask.to(img.device)
|
||||
img_cond = img_cond * (1 - mask)
|
||||
img_cond = ae.encode(img_cond)
|
||||
mask = mask[:, 0, :, :]
|
||||
mask = mask.to(torch.bfloat16)
|
||||
mask = rearrange(
|
||||
mask,
|
||||
"b (h ph) (w pw) -> b (ph pw) h w",
|
||||
ph=8,
|
||||
pw=8,
|
||||
)
|
||||
mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if mask.shape[0] == 1 and bs > 1:
|
||||
mask = repeat(mask, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
img_cond = img_cond.to(torch.bfloat16)
|
||||
img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if img_cond.shape[0] == 1 and bs > 1:
|
||||
img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
img_cond = torch.cat((img_cond, mask), dim=-1)
|
||||
|
||||
return_dict = prepare(t5, clip, img, prompt)
|
||||
return_dict["img_cond"] = img_cond.to(img.device)
|
||||
return return_dict
|
||||
|
||||
|
||||
def prepare_redux(
|
||||
t5: HFEmbedder,
|
||||
clip: HFEmbedder,
|
||||
img: Tensor,
|
||||
prompt: str | list[str],
|
||||
encoder: ReduxImageEncoder,
|
||||
img_cond_path: str,
|
||||
) -> dict[str, Tensor]:
|
||||
bs, _, h, w = img.shape
|
||||
if bs == 1 and not isinstance(prompt, str):
|
||||
bs = len(prompt)
|
||||
|
||||
img_cond = Image.open(img_cond_path).convert("RGB")
|
||||
with torch.no_grad():
|
||||
img_cond = encoder(img_cond)
|
||||
|
||||
img_cond = img_cond.to(torch.bfloat16)
|
||||
if img_cond.shape[0] == 1 and bs > 1:
|
||||
img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if img.shape[0] == 1 and bs > 1:
|
||||
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
txt = t5(prompt)
|
||||
txt = torch.cat((txt, img_cond.to(txt)), dim=-2)
|
||||
if txt.shape[0] == 1 and bs > 1:
|
||||
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
|
||||
txt_ids = torch.zeros(bs, txt.shape[1], 3)
|
||||
|
||||
vec = clip(prompt)
|
||||
if vec.shape[0] == 1 and bs > 1:
|
||||
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
return {
|
||||
"img": img,
|
||||
"img_ids": img_ids.to(img.device),
|
||||
"txt": txt.to(img.device),
|
||||
"txt_ids": txt_ids.to(img.device),
|
||||
"vec": vec.to(img.device),
|
||||
}
|
||||
|
||||
|
||||
def prepare_kontext(
|
||||
t5: HFEmbedder,
|
||||
clip: HFEmbedder,
|
||||
prompt: str | list[str],
|
||||
ae: AutoEncoder,
|
||||
img_cond: str,
|
||||
seed: int,
|
||||
device: torch.device,
|
||||
target_width: int | None = None,
|
||||
target_height: int | None = None,
|
||||
bs: int = 1,
|
||||
) -> tuple[dict[str, Tensor], int, int]:
|
||||
# load and encode the conditioning image
|
||||
if bs == 1 and not isinstance(prompt, str):
|
||||
bs = len(prompt)
|
||||
|
||||
width, height = img_cond.size
|
||||
aspect_ratio = width / height
|
||||
|
||||
# Kontext is trained on specific resolutions, using one of them is recommended
|
||||
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
|
||||
|
||||
width = 2 * int(width / 16)
|
||||
height = 2 * int(height / 16)
|
||||
|
||||
img_cond = img_cond.resize((8 * width, 8 * height), Image.Resampling.LANCZOS)
|
||||
img_cond = np.array(img_cond)
|
||||
img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0
|
||||
img_cond = rearrange(img_cond, "h w c -> 1 c h w")
|
||||
img_cond_orig = img_cond.clone()
|
||||
|
||||
with torch.no_grad():
|
||||
img_cond = ae.encode(img_cond.to(device))
|
||||
|
||||
img_cond = img_cond.to(torch.bfloat16)
|
||||
img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if img_cond.shape[0] == 1 and bs > 1:
|
||||
img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
# image ids are the same as base image with the first dimension set to 1
|
||||
# instead of 0
|
||||
img_cond_ids = torch.zeros(height // 2, width // 2, 3)
|
||||
img_cond_ids[..., 0] = 1
|
||||
img_cond_ids[..., 1] = img_cond_ids[..., 1] + torch.arange(height // 2)[:, None]
|
||||
img_cond_ids[..., 2] = img_cond_ids[..., 2] + torch.arange(width // 2)[None, :]
|
||||
img_cond_ids = repeat(img_cond_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
if target_width is None:
|
||||
target_width = 8 * width
|
||||
if target_height is None:
|
||||
target_height = 8 * height
|
||||
|
||||
img = get_noise(
|
||||
bs,
|
||||
target_height,
|
||||
target_width,
|
||||
device=device,
|
||||
dtype=torch.bfloat16,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
return_dict = prepare(t5, clip, img, prompt)
|
||||
return_dict["img_cond_seq"] = img_cond
|
||||
return_dict["img_cond_seq_ids"] = img_cond_ids.to(device)
|
||||
return_dict["img_cond_orig"] = img_cond_orig
|
||||
return return_dict, target_height, target_width
|
||||
|
||||
|
||||
def time_shift(mu: float, sigma: float, t: Tensor):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
|
||||
def get_lin_function(
|
||||
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
|
||||
) -> Callable[[float], float]:
|
||||
m = (y2 - y1) / (x2 - x1)
|
||||
b = y1 - m * x1
|
||||
return lambda x: m * x + b
|
||||
|
||||
|
||||
def get_schedule(
|
||||
num_steps: int,
|
||||
image_seq_len: int,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
shift: bool = True,
|
||||
) -> list[float]:
|
||||
# extra step for zero
|
||||
timesteps = torch.linspace(1, 0, num_steps + 1)
|
||||
|
||||
# shifting the schedule to favor high timesteps for higher signal images
|
||||
if shift:
|
||||
# estimate mu based on linear estimation between two points
|
||||
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
||||
timesteps = time_shift(mu, 1.0, timesteps)
|
||||
|
||||
return timesteps.tolist()
|
||||
|
||||
|
||||
def denoise(
|
||||
model: Flux,
|
||||
# model input
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
vec: Tensor,
|
||||
# sampling parameters
|
||||
timesteps: list[float],
|
||||
guidance: float = 4.0,
|
||||
# extra img tokens (channel-wise)
|
||||
img_cond: Tensor | None = None,
|
||||
# extra img tokens (sequence-wise)
|
||||
img_cond_seq: Tensor | None = None,
|
||||
img_cond_seq_ids: Tensor | None = None,
|
||||
callback=None,
|
||||
pipeline=None,
|
||||
loras_slists=None,
|
||||
unpack_latent = None,
|
||||
):
|
||||
|
||||
kwargs = {'pipeline': pipeline, 'callback': callback}
|
||||
if callback != None:
|
||||
callback(-1, None, True)
|
||||
|
||||
updated_num_steps= len(timesteps) -1
|
||||
if callback != None:
|
||||
from wgp import update_loras_slists
|
||||
update_loras_slists(model, loras_slists, updated_num_steps)
|
||||
callback(-1, None, True, override_num_inference_steps = updated_num_steps)
|
||||
from mmgp import offload
|
||||
# this is ignored for schnell
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
|
||||
offload.set_step_no_for_lora(model, i)
|
||||
if pipeline._interrupt:
|
||||
return None
|
||||
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
img_input = img
|
||||
img_input_ids = img_ids
|
||||
if img_cond is not None:
|
||||
img_input = torch.cat((img, img_cond), dim=-1)
|
||||
if img_cond_seq is not None:
|
||||
assert (
|
||||
img_cond_seq_ids is not None
|
||||
), "You need to provide either both or neither of the sequence conditioning"
|
||||
img_input = torch.cat((img_input, img_cond_seq), dim=1)
|
||||
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
|
||||
pred = model(
|
||||
img=img_input,
|
||||
img_ids=img_input_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
**kwargs
|
||||
)
|
||||
if pred == None: return None
|
||||
|
||||
if img_input_ids is not None:
|
||||
pred = pred[:, : img.shape[1]]
|
||||
|
||||
img += (t_prev - t_curr) * pred
|
||||
if callback is not None:
|
||||
preview = unpack_latent(img).transpose(0,1)
|
||||
callback(i, preview, False)
|
||||
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
||||
return rearrange(
|
||||
x,
|
||||
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
||||
h=math.ceil(height / 16),
|
||||
w=math.ceil(width / 16),
|
||||
ph=2,
|
||||
pw=2,
|
||||
)
|
||||
302
flux/to_remove/cli.py
Normal file
302
flux/to_remove/cli.py
Normal file
@ -0,0 +1,302 @@
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from glob import iglob
|
||||
|
||||
import torch
|
||||
from fire import Fire
|
||||
from transformers import pipeline
|
||||
|
||||
from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
|
||||
from flux.util import (
|
||||
check_onnx_access_for_trt,
|
||||
configs,
|
||||
load_ae,
|
||||
load_clip,
|
||||
load_flow_model,
|
||||
load_t5,
|
||||
save_image,
|
||||
)
|
||||
|
||||
NSFW_THRESHOLD = 0.85
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingOptions:
|
||||
prompt: str
|
||||
width: int
|
||||
height: int
|
||||
num_steps: int
|
||||
guidance: float
|
||||
seed: int | None
|
||||
|
||||
|
||||
def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
|
||||
user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
|
||||
usage = (
|
||||
"Usage: Either write your prompt directly, leave this field empty "
|
||||
"to repeat the prompt or write a command starting with a slash:\n"
|
||||
"- '/w <width>' will set the width of the generated image\n"
|
||||
"- '/h <height>' will set the height of the generated image\n"
|
||||
"- '/s <seed>' sets the next seed\n"
|
||||
"- '/g <guidance>' sets the guidance (flux-dev only)\n"
|
||||
"- '/n <steps>' sets the number of steps\n"
|
||||
"- '/q' to quit"
|
||||
)
|
||||
|
||||
while (prompt := input(user_question)).startswith("/"):
|
||||
if prompt.startswith("/w"):
|
||||
if prompt.count(" ") != 1:
|
||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
||||
continue
|
||||
_, width = prompt.split()
|
||||
options.width = 16 * (int(width) // 16)
|
||||
print(
|
||||
f"Setting resolution to {options.width} x {options.height} "
|
||||
f"({options.height * options.width / 1e6:.2f}MP)"
|
||||
)
|
||||
elif prompt.startswith("/h"):
|
||||
if prompt.count(" ") != 1:
|
||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
||||
continue
|
||||
_, height = prompt.split()
|
||||
options.height = 16 * (int(height) // 16)
|
||||
print(
|
||||
f"Setting resolution to {options.width} x {options.height} "
|
||||
f"({options.height * options.width / 1e6:.2f}MP)"
|
||||
)
|
||||
elif prompt.startswith("/g"):
|
||||
if prompt.count(" ") != 1:
|
||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
||||
continue
|
||||
_, guidance = prompt.split()
|
||||
options.guidance = float(guidance)
|
||||
print(f"Setting guidance to {options.guidance}")
|
||||
elif prompt.startswith("/s"):
|
||||
if prompt.count(" ") != 1:
|
||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
||||
continue
|
||||
_, seed = prompt.split()
|
||||
options.seed = int(seed)
|
||||
print(f"Setting seed to {options.seed}")
|
||||
elif prompt.startswith("/n"):
|
||||
if prompt.count(" ") != 1:
|
||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
||||
continue
|
||||
_, steps = prompt.split()
|
||||
options.num_steps = int(steps)
|
||||
print(f"Setting number of steps to {options.num_steps}")
|
||||
elif prompt.startswith("/q"):
|
||||
print("Quitting")
|
||||
return None
|
||||
else:
|
||||
if not prompt.startswith("/h"):
|
||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
||||
print(usage)
|
||||
if prompt != "":
|
||||
options.prompt = prompt
|
||||
return options
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main(
|
||||
name: str = "flux-schnell",
|
||||
width: int = 1360,
|
||||
height: int = 768,
|
||||
seed: int | None = None,
|
||||
prompt: str = (
|
||||
"a photo of a forest with mist swirling around the tree trunks. The word "
|
||||
'"FLUX" is painted over it in big, red brush strokes with visible texture'
|
||||
),
|
||||
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
||||
num_steps: int | None = None,
|
||||
loop: bool = False,
|
||||
guidance: float = 2.5,
|
||||
offload: bool = False,
|
||||
output_dir: str = "output",
|
||||
add_sampling_metadata: bool = True,
|
||||
trt: bool = False,
|
||||
trt_transformer_precision: str = "bf16",
|
||||
track_usage: bool = False,
|
||||
):
|
||||
"""
|
||||
Sample the flux model. Either interactively (set `--loop`) or run for a
|
||||
single image.
|
||||
|
||||
Args:
|
||||
name: Name of the model to load
|
||||
height: height of the sample in pixels (should be a multiple of 16)
|
||||
width: width of the sample in pixels (should be a multiple of 16)
|
||||
seed: Set a seed for sampling
|
||||
output_name: where to save the output image, `{idx}` will be replaced
|
||||
by the index of the sample
|
||||
prompt: Prompt used for sampling
|
||||
device: Pytorch device
|
||||
num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
|
||||
loop: start an interactive session and sample multiple times
|
||||
guidance: guidance value used for guidance distillation
|
||||
add_sampling_metadata: Add the prompt to the image Exif metadata
|
||||
trt: use TensorRT backend for optimized inference
|
||||
trt_transformer_precision: specify transformer precision for inference
|
||||
track_usage: track usage of the model for licensing purposes
|
||||
"""
|
||||
|
||||
prompt = prompt.split("|")
|
||||
if len(prompt) == 1:
|
||||
prompt = prompt[0]
|
||||
additional_prompts = None
|
||||
else:
|
||||
additional_prompts = prompt[1:]
|
||||
prompt = prompt[0]
|
||||
|
||||
assert not (
|
||||
(additional_prompts is not None) and loop
|
||||
), "Do not provide additional prompts and set loop to True"
|
||||
|
||||
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
|
||||
|
||||
if name not in configs:
|
||||
available = ", ".join(configs.keys())
|
||||
raise ValueError(f"Got unknown model name: {name}, chose from {available}")
|
||||
|
||||
torch_device = torch.device(device)
|
||||
if num_steps is None:
|
||||
num_steps = 4 if name == "flux-schnell" else 50
|
||||
|
||||
# allow for packing and conversion to latent space
|
||||
height = 16 * (height // 16)
|
||||
width = 16 * (width // 16)
|
||||
|
||||
output_name = os.path.join(output_dir, "img_{idx}.jpg")
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
idx = 0
|
||||
else:
|
||||
fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
|
||||
if len(fns) > 0:
|
||||
idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
|
||||
else:
|
||||
idx = 0
|
||||
|
||||
if not trt:
|
||||
t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512)
|
||||
clip = load_clip(torch_device)
|
||||
model = load_flow_model(name, device="cpu" if offload else torch_device)
|
||||
ae = load_ae(name, device="cpu" if offload else torch_device)
|
||||
else:
|
||||
# lazy import to make install optional
|
||||
from flux.trt.trt_manager import ModuleName, TRTManager
|
||||
|
||||
# Check if we need ONNX model access (which requires authentication for FLUX models)
|
||||
onnx_dir = check_onnx_access_for_trt(name, trt_transformer_precision)
|
||||
|
||||
trt_ctx_manager = TRTManager(
|
||||
trt_transformer_precision=trt_transformer_precision,
|
||||
trt_t5_precision=os.getenv("TRT_T5_PRECISION", "bf16"),
|
||||
)
|
||||
engines = trt_ctx_manager.load_engines(
|
||||
model_name=name,
|
||||
module_names={
|
||||
ModuleName.CLIP,
|
||||
ModuleName.TRANSFORMER,
|
||||
ModuleName.T5,
|
||||
ModuleName.VAE,
|
||||
},
|
||||
engine_dir=os.environ.get("TRT_ENGINE_DIR", "./engines"),
|
||||
custom_onnx_paths=onnx_dir or os.environ.get("CUSTOM_ONNX_PATHS", ""),
|
||||
trt_image_height=height,
|
||||
trt_image_width=width,
|
||||
trt_batch_size=1,
|
||||
trt_timing_cache=os.getenv("TRT_TIMING_CACHE_FILE", None),
|
||||
trt_static_batch=False,
|
||||
trt_static_shape=False,
|
||||
)
|
||||
|
||||
ae = engines[ModuleName.VAE].to(device="cpu" if offload else torch_device)
|
||||
model = engines[ModuleName.TRANSFORMER].to(device="cpu" if offload else torch_device)
|
||||
clip = engines[ModuleName.CLIP].to(torch_device)
|
||||
t5 = engines[ModuleName.T5].to(device="cpu" if offload else torch_device)
|
||||
|
||||
rng = torch.Generator(device="cpu")
|
||||
opts = SamplingOptions(
|
||||
prompt=prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
num_steps=num_steps,
|
||||
guidance=guidance,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
if loop:
|
||||
opts = parse_prompt(opts)
|
||||
|
||||
while opts is not None:
|
||||
if opts.seed is None:
|
||||
opts.seed = rng.seed()
|
||||
print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# prepare input
|
||||
x = get_noise(
|
||||
1,
|
||||
opts.height,
|
||||
opts.width,
|
||||
device=torch_device,
|
||||
dtype=torch.bfloat16,
|
||||
seed=opts.seed,
|
||||
)
|
||||
opts.seed = None
|
||||
if offload:
|
||||
ae = ae.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
t5, clip = t5.to(torch_device), clip.to(torch_device)
|
||||
inp = prepare(t5, clip, x, prompt=opts.prompt)
|
||||
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
|
||||
|
||||
# offload TEs to CPU, load model to gpu
|
||||
if offload:
|
||||
t5, clip = t5.cpu(), clip.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
model = model.to(torch_device)
|
||||
|
||||
# denoise initial noise
|
||||
x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
|
||||
|
||||
# offload model, load autoencoder to gpu
|
||||
if offload:
|
||||
model.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
ae.decoder.to(x.device)
|
||||
|
||||
# decode latents to pixel space
|
||||
x = unpack(x.float(), opts.height, opts.width)
|
||||
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
|
||||
x = ae.decode(x)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.perf_counter()
|
||||
|
||||
fn = output_name.format(idx=idx)
|
||||
print(f"Done in {t1 - t0:.1f}s. Saving {fn}")
|
||||
|
||||
idx = save_image(
|
||||
nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage
|
||||
)
|
||||
|
||||
if loop:
|
||||
print("-" * 80)
|
||||
opts = parse_prompt(opts)
|
||||
elif additional_prompts:
|
||||
next_prompt = additional_prompts.pop(0)
|
||||
opts.prompt = next_prompt
|
||||
else:
|
||||
opts = None
|
||||
|
||||
if trt:
|
||||
trt_ctx_manager.stop_runtime()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(main)
|
||||
390
flux/to_remove/cli_control.py
Normal file
390
flux/to_remove/cli_control.py
Normal file
@ -0,0 +1,390 @@
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from glob import iglob
|
||||
|
||||
import torch
|
||||
from fire import Fire
|
||||
from transformers import pipeline
|
||||
|
||||
from flux.modules.image_embedders import CannyImageEncoder, DepthImageEncoder
|
||||
from flux.sampling import denoise, get_noise, get_schedule, prepare_control, unpack
|
||||
from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingOptions:
|
||||
prompt: str
|
||||
width: int
|
||||
height: int
|
||||
num_steps: int
|
||||
guidance: float
|
||||
seed: int | None
|
||||
img_cond_path: str
|
||||
lora_scale: float | None
|
||||
|
||||
|
||||
def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
|
||||
user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
|
||||
usage = (
|
||||
"Usage: Either write your prompt directly, leave this field empty "
|
||||
"to repeat the prompt or write a command starting with a slash:\n"
|
||||
"- '/w <width>' will set the width of the generated image\n"
|
||||
"- '/h <height>' will set the height of the generated image\n"
|
||||
"- '/s <seed>' sets the next seed\n"
|
||||
"- '/g <guidance>' sets the guidance (flux-dev only)\n"
|
||||
"- '/n <steps>' sets the number of steps\n"
|
||||
"- '/q' to quit"
|
||||
)
|
||||
|
||||
while (prompt := input(user_question)).startswith("/"):
|
||||
if prompt.startswith("/w"):
|
||||
if prompt.count(" ") != 1:
|
||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
||||
continue
|
||||
_, width = prompt.split()
|
||||
options.width = 16 * (int(width) // 16)
|
||||
print(
|
||||
f"Setting resolution to {options.width} x {options.height} "
|
||||
f"({options.height * options.width / 1e6:.2f}MP)"
|
||||
)
|
||||
elif prompt.startswith("/h"):
|
||||
if prompt.count(" ") != 1:
|
||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
||||
continue
|
||||
_, height = prompt.split()
|
||||
options.height = 16 * (int(height) // 16)
|
||||
print(
|
||||
f"Setting resolution to {options.width} x {options.height} "
|
||||
f"({options.height * options.width / 1e6:.2f}MP)"
|
||||
)
|
||||
elif prompt.startswith("/g"):
|
||||
if prompt.count(" ") != 1:
|
||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
||||
continue
|
||||
_, guidance = prompt.split()
|
||||
options.guidance = float(guidance)
|
||||
print(f"Setting guidance to {options.guidance}")
|
||||
elif prompt.startswith("/s"):
|
||||
if prompt.count(" ") != 1:
|
||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
||||
continue
|
||||
_, seed = prompt.split()
|
||||
options.seed = int(seed)
|
||||
print(f"Setting seed to {options.seed}")
|
||||
elif prompt.startswith("/n"):
|
||||
if prompt.count(" ") != 1:
|
||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
||||
continue
|
||||
_, steps = prompt.split()
|
||||
options.num_steps = int(steps)
|
||||
print(f"Setting number of steps to {options.num_steps}")
|
||||
elif prompt.startswith("/q"):
|
||||
print("Quitting")
|
||||
return None
|
||||
else:
|
||||
if not prompt.startswith("/h"):
|
||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
||||
print(usage)
|
||||
if prompt != "":
|
||||
options.prompt = prompt
|
||||
return options
|
||||
|
||||
|
||||
def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None:
|
||||
if options is None:
|
||||
return None
|
||||
|
||||
user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n"
|
||||
usage = (
|
||||
"Usage: Either write your prompt directly, leave this field empty "
|
||||
"to repeat the conditioning image or write a command starting with a slash:\n"
|
||||
"- '/q' to quit"
|
||||
)
|
||||
|
||||
while True:
|
||||
img_cond_path = input(user_question)
|
||||
|
||||
if img_cond_path.startswith("/"):
|
||||
if img_cond_path.startswith("/q"):
|
||||
print("Quitting")
|
||||
return None
|
||||
else:
|
||||
if not img_cond_path.startswith("/h"):
|
||||
print(f"Got invalid command '{img_cond_path}'\n{usage}")
|
||||
print(usage)
|
||||
continue
|
||||
|
||||
if img_cond_path == "":
|
||||
break
|
||||
|
||||
if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith(
|
||||
(".jpg", ".jpeg", ".png", ".webp")
|
||||
):
|
||||
print(f"File '{img_cond_path}' does not exist or is not a valid image file")
|
||||
continue
|
||||
|
||||
options.img_cond_path = img_cond_path
|
||||
break
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def parse_lora_scale(options: SamplingOptions | None) -> tuple[SamplingOptions | None, bool]:
|
||||
changed = False
|
||||
|
||||
if options is None:
|
||||
return None, changed
|
||||
|
||||
user_question = "Next lora scale (write /h for help, /q to quit and leave empty to repeat):\n"
|
||||
usage = (
|
||||
"Usage: Either write your prompt directly, leave this field empty "
|
||||
"to repeat the lora scale or write a command starting with a slash:\n"
|
||||
"- '/q' to quit"
|
||||
)
|
||||
|
||||
while (prompt := input(user_question)).startswith("/"):
|
||||
if prompt.startswith("/q"):
|
||||
print("Quitting")
|
||||
return None, changed
|
||||
else:
|
||||
if not prompt.startswith("/h"):
|
||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
||||
print(usage)
|
||||
if prompt != "":
|
||||
options.lora_scale = float(prompt)
|
||||
changed = True
|
||||
return options, changed
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main(
|
||||
name: str,
|
||||
width: int = 1024,
|
||||
height: int = 1024,
|
||||
seed: int | None = None,
|
||||
prompt: str = "a robot made out of gold",
|
||||
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
||||
num_steps: int = 50,
|
||||
loop: bool = False,
|
||||
guidance: float | None = None,
|
||||
offload: bool = False,
|
||||
output_dir: str = "output",
|
||||
add_sampling_metadata: bool = True,
|
||||
img_cond_path: str = "assets/robot.webp",
|
||||
lora_scale: float | None = 0.85,
|
||||
trt: bool = False,
|
||||
trt_transformer_precision: str = "bf16",
|
||||
track_usage: bool = False,
|
||||
**kwargs: dict | None,
|
||||
):
|
||||
"""
|
||||
Sample the flux model. Either interactively (set `--loop`) or run for a
|
||||
single image.
|
||||
|
||||
Args:
|
||||
height: height of the sample in pixels (should be a multiple of 16)
|
||||
width: width of the sample in pixels (should be a multiple of 16)
|
||||
seed: Set a seed for sampling
|
||||
output_name: where to save the output image, `{idx}` will be replaced
|
||||
by the index of the sample
|
||||
prompt: Prompt used for sampling
|
||||
device: Pytorch device
|
||||
num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
|
||||
loop: start an interactive session and sample multiple times
|
||||
guidance: guidance value used for guidance distillation
|
||||
add_sampling_metadata: Add the prompt to the image Exif metadata
|
||||
img_cond_path: path to conditioning image (jpeg/png/webp)
|
||||
trt: use TensorRT backend for optimized inference
|
||||
trt_transformer_precision: specify transformer precision for inference
|
||||
track_usage: track usage of the model for licensing purposes
|
||||
"""
|
||||
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
|
||||
|
||||
if "lora" in name:
|
||||
assert not trt, "TRT does not support LORA"
|
||||
assert name in [
|
||||
"flux-dev-canny",
|
||||
"flux-dev-depth",
|
||||
"flux-dev-canny-lora",
|
||||
"flux-dev-depth-lora",
|
||||
], f"Got unknown model name: {name}"
|
||||
|
||||
if guidance is None:
|
||||
if name in ["flux-dev-canny", "flux-dev-canny-lora"]:
|
||||
guidance = 30.0
|
||||
elif name in ["flux-dev-depth", "flux-dev-depth-lora"]:
|
||||
guidance = 10.0
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
if name not in configs:
|
||||
available = ", ".join(configs.keys())
|
||||
raise ValueError(f"Got unknown model name: {name}, chose from {available}")
|
||||
|
||||
torch_device = torch.device(device)
|
||||
|
||||
output_name = os.path.join(output_dir, "img_{idx}.jpg")
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
idx = 0
|
||||
else:
|
||||
fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
|
||||
if len(fns) > 0:
|
||||
idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
|
||||
else:
|
||||
idx = 0
|
||||
|
||||
if name in ["flux-dev-depth", "flux-dev-depth-lora"]:
|
||||
img_embedder = DepthImageEncoder(torch_device)
|
||||
elif name in ["flux-dev-canny", "flux-dev-canny-lora"]:
|
||||
img_embedder = CannyImageEncoder(torch_device)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
if not trt:
|
||||
# init all components
|
||||
t5 = load_t5(torch_device, max_length=512)
|
||||
clip = load_clip(torch_device)
|
||||
model = load_flow_model(name, device="cpu" if offload else torch_device)
|
||||
ae = load_ae(name, device="cpu" if offload else torch_device)
|
||||
else:
|
||||
# lazy import to make install optional
|
||||
from flux.trt.trt_manager import ModuleName, TRTManager
|
||||
|
||||
trt_ctx_manager = TRTManager(
|
||||
trt_transformer_precision=trt_transformer_precision,
|
||||
trt_t5_precision=os.environ.get("TRT_T5_PRECISION", "bf16"),
|
||||
)
|
||||
|
||||
engines = trt_ctx_manager.load_engines(
|
||||
model_name=name,
|
||||
module_names={
|
||||
ModuleName.CLIP,
|
||||
ModuleName.TRANSFORMER,
|
||||
ModuleName.T5,
|
||||
ModuleName.VAE,
|
||||
ModuleName.VAE_ENCODER,
|
||||
},
|
||||
engine_dir=os.environ.get("TRT_ENGINE_DIR", "./engines"),
|
||||
custom_onnx_paths=os.environ.get("CUSTOM_ONNX_PATHS", ""),
|
||||
trt_image_height=height,
|
||||
trt_image_width=width,
|
||||
trt_batch_size=1,
|
||||
trt_static_batch=kwargs.get("static_batch", True),
|
||||
trt_static_shape=kwargs.get("static_shape", True),
|
||||
)
|
||||
|
||||
ae = engines[ModuleName.VAE].to(device="cpu" if offload else torch_device)
|
||||
model = engines[ModuleName.TRANSFORMER].to(device="cpu" if offload else torch_device)
|
||||
clip = engines[ModuleName.CLIP].to(torch_device)
|
||||
t5 = engines[ModuleName.T5].to(device="cpu" if offload else torch_device)
|
||||
|
||||
# set lora scale
|
||||
if "lora" in name and lora_scale is not None:
|
||||
for _, module in model.named_modules():
|
||||
if hasattr(module, "set_scale"):
|
||||
module.set_scale(lora_scale)
|
||||
|
||||
rng = torch.Generator(device="cpu")
|
||||
opts = SamplingOptions(
|
||||
prompt=prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
num_steps=num_steps,
|
||||
guidance=guidance,
|
||||
seed=seed,
|
||||
img_cond_path=img_cond_path,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
if loop:
|
||||
opts = parse_prompt(opts)
|
||||
opts = parse_img_cond_path(opts)
|
||||
if "lora" in name:
|
||||
opts, changed = parse_lora_scale(opts)
|
||||
if changed:
|
||||
# update the lora scale:
|
||||
for _, module in model.named_modules():
|
||||
if hasattr(module, "set_scale"):
|
||||
module.set_scale(opts.lora_scale)
|
||||
|
||||
while opts is not None:
|
||||
if opts.seed is None:
|
||||
opts.seed = rng.seed()
|
||||
print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# prepare input
|
||||
x = get_noise(
|
||||
1,
|
||||
opts.height,
|
||||
opts.width,
|
||||
device=torch_device,
|
||||
dtype=torch.bfloat16,
|
||||
seed=opts.seed,
|
||||
)
|
||||
opts.seed = None
|
||||
if offload:
|
||||
t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch_device)
|
||||
inp = prepare_control(
|
||||
t5,
|
||||
clip,
|
||||
x,
|
||||
prompt=opts.prompt,
|
||||
ae=ae,
|
||||
encoder=img_embedder,
|
||||
img_cond_path=opts.img_cond_path,
|
||||
)
|
||||
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
|
||||
|
||||
# offload TEs and AE to CPU, load model to gpu
|
||||
if offload:
|
||||
t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
model = model.to(torch_device)
|
||||
|
||||
# denoise initial noise
|
||||
x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
|
||||
|
||||
# offload model, load autoencoder to gpu
|
||||
if offload:
|
||||
model.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
ae.decoder.to(x.device)
|
||||
|
||||
# decode latents to pixel space
|
||||
x = unpack(x.float(), opts.height, opts.width)
|
||||
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
|
||||
x = ae.decode(x)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.perf_counter()
|
||||
print(f"Done in {t1 - t0:.1f}s")
|
||||
|
||||
idx = save_image(
|
||||
nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage
|
||||
)
|
||||
|
||||
if loop:
|
||||
print("-" * 80)
|
||||
opts = parse_prompt(opts)
|
||||
opts = parse_img_cond_path(opts)
|
||||
if "lora" in name:
|
||||
opts, changed = parse_lora_scale(opts)
|
||||
if changed:
|
||||
# update the lora scale:
|
||||
for _, module in model.named_modules():
|
||||
if hasattr(module, "set_scale"):
|
||||
module.set_scale(opts.lora_scale)
|
||||
else:
|
||||
opts = None
|
||||
|
||||
if trt:
|
||||
trt_ctx_manager.stop_runtime()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(main)
|
||||
334
flux/to_remove/cli_fill.py
Normal file
334
flux/to_remove/cli_fill.py
Normal file
@ -0,0 +1,334 @@
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from glob import iglob
|
||||
|
||||
import torch
|
||||
from fire import Fire
|
||||
from PIL import Image
|
||||
from transformers import pipeline
|
||||
|
||||
from flux.sampling import denoise, get_noise, get_schedule, prepare_fill, unpack
|
||||
from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingOptions:
|
||||
prompt: str
|
||||
width: int
|
||||
height: int
|
||||
num_steps: int
|
||||
guidance: float
|
||||
seed: int | None
|
||||
img_cond_path: str
|
||||
img_mask_path: str
|
||||
|
||||
|
||||
def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
|
||||
user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
|
||||
usage = (
|
||||
"Usage: Either write your prompt directly, leave this field empty "
|
||||
"to repeat the prompt or write a command starting with a slash:\n"
|
||||
"- '/s <seed>' sets the next seed\n"
|
||||
"- '/g <guidance>' sets the guidance (flux-dev only)\n"
|
||||
"- '/n <steps>' sets the number of steps\n"
|
||||
"- '/q' to quit"
|
||||
)
|
||||
|
||||
while (prompt := input(user_question)).startswith("/"):
|
||||
if prompt.startswith("/g"):
|
||||
if prompt.count(" ") != 1:
|
||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
||||
continue
|
||||
_, guidance = prompt.split()
|
||||
options.guidance = float(guidance)
|
||||
print(f"Setting guidance to {options.guidance}")
|
||||
elif prompt.startswith("/s"):
|
||||
if prompt.count(" ") != 1:
|
||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
||||
continue
|
||||
_, seed = prompt.split()
|
||||
options.seed = int(seed)
|
||||
print(f"Setting seed to {options.seed}")
|
||||
elif prompt.startswith("/n"):
|
||||
if prompt.count(" ") != 1:
|
||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
||||
continue
|
||||
_, steps = prompt.split()
|
||||
options.num_steps = int(steps)
|
||||
print(f"Setting number of steps to {options.num_steps}")
|
||||
elif prompt.startswith("/q"):
|
||||
print("Quitting")
|
||||
return None
|
||||
else:
|
||||
if not prompt.startswith("/h"):
|
||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
||||
print(usage)
|
||||
if prompt != "":
|
||||
options.prompt = prompt
|
||||
return options
|
||||
|
||||
|
||||
def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None:
|
||||
if options is None:
|
||||
return None
|
||||
|
||||
user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n"
|
||||
usage = (
|
||||
"Usage: Either write your prompt directly, leave this field empty "
|
||||
"to repeat the conditioning image or write a command starting with a slash:\n"
|
||||
"- '/q' to quit"
|
||||
)
|
||||
|
||||
while True:
|
||||
img_cond_path = input(user_question)
|
||||
|
||||
if img_cond_path.startswith("/"):
|
||||
if img_cond_path.startswith("/q"):
|
||||
print("Quitting")
|
||||
return None
|
||||
else:
|
||||
if not img_cond_path.startswith("/h"):
|
||||
print(f"Got invalid command '{img_cond_path}'\n{usage}")
|
||||
print(usage)
|
||||
continue
|
||||
|
||||
if img_cond_path == "":
|
||||
break
|
||||
|
||||
if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith(
|
||||
(".jpg", ".jpeg", ".png", ".webp")
|
||||
):
|
||||
print(f"File '{img_cond_path}' does not exist or is not a valid image file")
|
||||
continue
|
||||
else:
|
||||
with Image.open(img_cond_path) as img:
|
||||
width, height = img.size
|
||||
|
||||
if width % 32 != 0 or height % 32 != 0:
|
||||
print(f"Image dimensions must be divisible by 32, got {width}x{height}")
|
||||
continue
|
||||
|
||||
options.img_cond_path = img_cond_path
|
||||
break
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def parse_img_mask_path(options: SamplingOptions | None) -> SamplingOptions | None:
|
||||
if options is None:
|
||||
return None
|
||||
|
||||
user_question = "Next conditioning mask (write /h for help, /q to quit and leave empty to repeat):\n"
|
||||
usage = (
|
||||
"Usage: Either write your prompt directly, leave this field empty "
|
||||
"to repeat the conditioning mask or write a command starting with a slash:\n"
|
||||
"- '/q' to quit"
|
||||
)
|
||||
|
||||
while True:
|
||||
img_mask_path = input(user_question)
|
||||
|
||||
if img_mask_path.startswith("/"):
|
||||
if img_mask_path.startswith("/q"):
|
||||
print("Quitting")
|
||||
return None
|
||||
else:
|
||||
if not img_mask_path.startswith("/h"):
|
||||
print(f"Got invalid command '{img_mask_path}'\n{usage}")
|
||||
print(usage)
|
||||
continue
|
||||
|
||||
if img_mask_path == "":
|
||||
break
|
||||
|
||||
if not os.path.isfile(img_mask_path) or not img_mask_path.lower().endswith(
|
||||
(".jpg", ".jpeg", ".png", ".webp")
|
||||
):
|
||||
print(f"File '{img_mask_path}' does not exist or is not a valid image file")
|
||||
continue
|
||||
else:
|
||||
with Image.open(img_mask_path) as img:
|
||||
width, height = img.size
|
||||
|
||||
if width % 32 != 0 or height % 32 != 0:
|
||||
print(f"Image dimensions must be divisible by 32, got {width}x{height}")
|
||||
continue
|
||||
else:
|
||||
with Image.open(options.img_cond_path) as img_cond:
|
||||
img_cond_width, img_cond_height = img_cond.size
|
||||
|
||||
if width != img_cond_width or height != img_cond_height:
|
||||
print(
|
||||
f"Mask dimensions must match conditioning image, got {width}x{height} and {img_cond_width}x{img_cond_height}"
|
||||
)
|
||||
continue
|
||||
|
||||
options.img_mask_path = img_mask_path
|
||||
break
|
||||
|
||||
return options
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main(
|
||||
seed: int | None = None,
|
||||
prompt: str = "a white paper cup",
|
||||
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
||||
num_steps: int = 50,
|
||||
loop: bool = False,
|
||||
guidance: float = 30.0,
|
||||
offload: bool = False,
|
||||
output_dir: str = "output",
|
||||
add_sampling_metadata: bool = True,
|
||||
img_cond_path: str = "assets/cup.png",
|
||||
img_mask_path: str = "assets/cup_mask.png",
|
||||
track_usage: bool = False,
|
||||
):
|
||||
"""
|
||||
Sample the flux model. Either interactively (set `--loop`) or run for a
|
||||
single image. This demo assumes that the conditioning image and mask have
|
||||
the same shape and that height and width are divisible by 32.
|
||||
|
||||
Args:
|
||||
seed: Set a seed for sampling
|
||||
output_name: where to save the output image, `{idx}` will be replaced
|
||||
by the index of the sample
|
||||
prompt: Prompt used for sampling
|
||||
device: Pytorch device
|
||||
num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
|
||||
loop: start an interactive session and sample multiple times
|
||||
guidance: guidance value used for guidance distillation
|
||||
add_sampling_metadata: Add the prompt to the image Exif metadata
|
||||
img_cond_path: path to conditioning image (jpeg/png/webp)
|
||||
img_mask_path: path to conditioning mask (jpeg/png/webp)
|
||||
track_usage: track usage of the model for licensing purposes
|
||||
"""
|
||||
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
|
||||
|
||||
name = "flux-dev-fill"
|
||||
if name not in configs:
|
||||
available = ", ".join(configs.keys())
|
||||
raise ValueError(f"Got unknown model name: {name}, chose from {available}")
|
||||
|
||||
torch_device = torch.device(device)
|
||||
|
||||
output_name = os.path.join(output_dir, "img_{idx}.jpg")
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
idx = 0
|
||||
else:
|
||||
fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
|
||||
if len(fns) > 0:
|
||||
idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
|
||||
else:
|
||||
idx = 0
|
||||
|
||||
# init all components
|
||||
t5 = load_t5(torch_device, max_length=128)
|
||||
clip = load_clip(torch_device)
|
||||
model = load_flow_model(name, device="cpu" if offload else torch_device)
|
||||
ae = load_ae(name, device="cpu" if offload else torch_device)
|
||||
|
||||
rng = torch.Generator(device="cpu")
|
||||
with Image.open(img_cond_path) as img:
|
||||
width, height = img.size
|
||||
opts = SamplingOptions(
|
||||
prompt=prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
num_steps=num_steps,
|
||||
guidance=guidance,
|
||||
seed=seed,
|
||||
img_cond_path=img_cond_path,
|
||||
img_mask_path=img_mask_path,
|
||||
)
|
||||
|
||||
if loop:
|
||||
opts = parse_prompt(opts)
|
||||
opts = parse_img_cond_path(opts)
|
||||
|
||||
with Image.open(opts.img_cond_path) as img:
|
||||
width, height = img.size
|
||||
opts.height = height
|
||||
opts.width = width
|
||||
|
||||
opts = parse_img_mask_path(opts)
|
||||
|
||||
while opts is not None:
|
||||
if opts.seed is None:
|
||||
opts.seed = rng.seed()
|
||||
print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# prepare input
|
||||
x = get_noise(
|
||||
1,
|
||||
opts.height,
|
||||
opts.width,
|
||||
device=torch_device,
|
||||
dtype=torch.bfloat16,
|
||||
seed=opts.seed,
|
||||
)
|
||||
opts.seed = None
|
||||
if offload:
|
||||
t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch_device)
|
||||
inp = prepare_fill(
|
||||
t5,
|
||||
clip,
|
||||
x,
|
||||
prompt=opts.prompt,
|
||||
ae=ae,
|
||||
img_cond_path=opts.img_cond_path,
|
||||
mask_path=opts.img_mask_path,
|
||||
)
|
||||
|
||||
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
|
||||
|
||||
# offload TEs and AE to CPU, load model to gpu
|
||||
if offload:
|
||||
t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
model = model.to(torch_device)
|
||||
|
||||
# denoise initial noise
|
||||
x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
|
||||
|
||||
# offload model, load autoencoder to gpu
|
||||
if offload:
|
||||
model.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
ae.decoder.to(x.device)
|
||||
|
||||
# decode latents to pixel space
|
||||
x = unpack(x.float(), opts.height, opts.width)
|
||||
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
|
||||
x = ae.decode(x)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.perf_counter()
|
||||
print(f"Done in {t1 - t0:.1f}s")
|
||||
|
||||
idx = save_image(
|
||||
nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage
|
||||
)
|
||||
|
||||
if loop:
|
||||
print("-" * 80)
|
||||
opts = parse_prompt(opts)
|
||||
opts = parse_img_cond_path(opts)
|
||||
|
||||
with Image.open(opts.img_cond_path) as img:
|
||||
width, height = img.size
|
||||
opts.height = height
|
||||
opts.width = width
|
||||
|
||||
opts = parse_img_mask_path(opts)
|
||||
else:
|
||||
opts = None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(main)
|
||||
368
flux/to_remove/cli_kontext.py
Normal file
368
flux/to_remove/cli_kontext.py
Normal file
@ -0,0 +1,368 @@
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from glob import iglob
|
||||
|
||||
import torch
|
||||
from fire import Fire
|
||||
|
||||
from flux.content_filters import PixtralContentFilter
|
||||
from flux.sampling import denoise, get_schedule, prepare_kontext, unpack
|
||||
from flux.util import (
|
||||
aspect_ratio_to_height_width,
|
||||
check_onnx_access_for_trt,
|
||||
load_ae,
|
||||
load_clip,
|
||||
load_flow_model,
|
||||
load_t5,
|
||||
save_image,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingOptions:
|
||||
prompt: str
|
||||
width: int | None
|
||||
height: int | None
|
||||
num_steps: int
|
||||
guidance: float
|
||||
seed: int | None
|
||||
img_cond_path: str
|
||||
|
||||
|
||||
def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
|
||||
user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
|
||||
usage = (
|
||||
"Usage: Either write your prompt directly, leave this field empty "
|
||||
"to repeat the prompt or write a command starting with a slash:\n"
|
||||
"- '/ar <width>:<height>' will set the aspect ratio of the generated image\n"
|
||||
"- '/s <seed>' sets the next seed\n"
|
||||
"- '/g <guidance>' sets the guidance (flux-dev only)\n"
|
||||
"- '/n <steps>' sets the number of steps\n"
|
||||
"- '/q' to quit"
|
||||
)
|
||||
|
||||
while (prompt := input(user_question)).startswith("/"):
|
||||
if prompt.startswith("/ar"):
|
||||
if prompt.count(" ") != 1:
|
||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
||||
continue
|
||||
_, ratio_prompt = prompt.split()
|
||||
if ratio_prompt == "auto":
|
||||
options.width = None
|
||||
options.height = None
|
||||
print("Setting resolution to input image resolution.")
|
||||
else:
|
||||
options.width, options.height = aspect_ratio_to_height_width(ratio_prompt)
|
||||
print(f"Setting resolution to {options.width} x {options.height}.")
|
||||
elif prompt.startswith("/h"):
|
||||
if prompt.count(" ") != 1:
|
||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
||||
continue
|
||||
_, height = prompt.split()
|
||||
if height == "auto":
|
||||
options.height = None
|
||||
else:
|
||||
options.height = 16 * (int(height) // 16)
|
||||
if options.height is not None and options.width is not None:
|
||||
print(
|
||||
f"Setting resolution to {options.width} x {options.height} "
|
||||
f"({options.height * options.width / 1e6:.2f}MP)"
|
||||
)
|
||||
else:
|
||||
print(f"Setting resolution to {options.width} x {options.height}.")
|
||||
elif prompt.startswith("/g"):
|
||||
if prompt.count(" ") != 1:
|
||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
||||
continue
|
||||
_, guidance = prompt.split()
|
||||
options.guidance = float(guidance)
|
||||
print(f"Setting guidance to {options.guidance}")
|
||||
elif prompt.startswith("/s"):
|
||||
if prompt.count(" ") != 1:
|
||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
||||
continue
|
||||
_, seed = prompt.split()
|
||||
options.seed = int(seed)
|
||||
print(f"Setting seed to {options.seed}")
|
||||
elif prompt.startswith("/n"):
|
||||
if prompt.count(" ") != 1:
|
||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
||||
continue
|
||||
_, steps = prompt.split()
|
||||
options.num_steps = int(steps)
|
||||
print(f"Setting number of steps to {options.num_steps}")
|
||||
elif prompt.startswith("/q"):
|
||||
print("Quitting")
|
||||
return None
|
||||
else:
|
||||
if not prompt.startswith("/h"):
|
||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
||||
print(usage)
|
||||
if prompt != "":
|
||||
options.prompt = prompt
|
||||
return options
|
||||
|
||||
|
||||
def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None:
|
||||
if options is None:
|
||||
return None
|
||||
|
||||
user_question = "Next input image (write /h for help, /q to quit and leave empty to repeat):\n"
|
||||
usage = (
|
||||
"Usage: Either write a path to an image directly, leave this field empty "
|
||||
"to repeat the last input image or write a command starting with a slash:\n"
|
||||
"- '/q' to quit\n\n"
|
||||
"The input image will be edited by FLUX.1 Kontext creating a new image based"
|
||||
"on your instruction prompt."
|
||||
)
|
||||
|
||||
while True:
|
||||
img_cond_path = input(user_question)
|
||||
|
||||
if img_cond_path.startswith("/"):
|
||||
if img_cond_path.startswith("/q"):
|
||||
print("Quitting")
|
||||
return None
|
||||
else:
|
||||
if not img_cond_path.startswith("/h"):
|
||||
print(f"Got invalid command '{img_cond_path}'\n{usage}")
|
||||
print(usage)
|
||||
continue
|
||||
|
||||
if img_cond_path == "":
|
||||
break
|
||||
|
||||
if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith(
|
||||
(".jpg", ".jpeg", ".png", ".webp")
|
||||
):
|
||||
print(f"File '{img_cond_path}' does not exist or is not a valid image file")
|
||||
continue
|
||||
|
||||
options.img_cond_path = img_cond_path
|
||||
break
|
||||
|
||||
return options
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main(
|
||||
name: str = "flux-dev-kontext",
|
||||
aspect_ratio: str | None = None,
|
||||
seed: int | None = None,
|
||||
prompt: str = "replace the logo with the text 'Black Forest Labs'",
|
||||
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
||||
num_steps: int = 30,
|
||||
loop: bool = False,
|
||||
guidance: float = 2.5,
|
||||
offload: bool = False,
|
||||
output_dir: str = "output",
|
||||
add_sampling_metadata: bool = True,
|
||||
img_cond_path: str = "assets/cup.png",
|
||||
trt: bool = False,
|
||||
trt_transformer_precision: str = "bf16",
|
||||
track_usage: bool = False,
|
||||
):
|
||||
"""
|
||||
Sample the flux model. Either interactively (set `--loop`) or run for a
|
||||
single image.
|
||||
|
||||
Args:
|
||||
height: height of the sample in pixels (should be a multiple of 16), None
|
||||
defaults to the size of the conditioning
|
||||
width: width of the sample in pixels (should be a multiple of 16), None
|
||||
defaults to the size of the conditioning
|
||||
seed: Set a seed for sampling
|
||||
output_name: where to save the output image, `{idx}` will be replaced
|
||||
by the index of the sample
|
||||
prompt: Prompt used for sampling
|
||||
device: Pytorch device
|
||||
num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
|
||||
loop: start an interactive session and sample multiple times
|
||||
guidance: guidance value used for guidance distillation
|
||||
add_sampling_metadata: Add the prompt to the image Exif metadata
|
||||
img_cond_path: path to conditioning image (jpeg/png/webp)
|
||||
trt: use TensorRT backend for optimized inference
|
||||
track_usage: track usage of the model for licensing purposes
|
||||
"""
|
||||
assert name == "flux-dev-kontext", f"Got unknown model name: {name}"
|
||||
|
||||
torch_device = torch.device(device)
|
||||
|
||||
output_name = os.path.join(output_dir, "img_{idx}.jpg")
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
idx = 0
|
||||
else:
|
||||
fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
|
||||
if len(fns) > 0:
|
||||
idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
|
||||
else:
|
||||
idx = 0
|
||||
|
||||
if aspect_ratio is None:
|
||||
width = None
|
||||
height = None
|
||||
else:
|
||||
width, height = aspect_ratio_to_height_width(aspect_ratio)
|
||||
|
||||
if not trt:
|
||||
t5 = load_t5(torch_device, max_length=512)
|
||||
clip = load_clip(torch_device)
|
||||
model = load_flow_model(name, device="cpu" if offload else torch_device)
|
||||
else:
|
||||
# lazy import to make install optional
|
||||
from flux.trt.trt_manager import ModuleName, TRTManager
|
||||
|
||||
# Check if we need ONNX model access (which requires authentication for FLUX models)
|
||||
onnx_dir = check_onnx_access_for_trt(name, trt_transformer_precision)
|
||||
|
||||
trt_ctx_manager = TRTManager(
|
||||
trt_transformer_precision=trt_transformer_precision,
|
||||
trt_t5_precision=os.environ.get("TRT_T5_PRECISION", "bf16"),
|
||||
)
|
||||
engines = trt_ctx_manager.load_engines(
|
||||
model_name=name,
|
||||
module_names={
|
||||
ModuleName.CLIP,
|
||||
ModuleName.TRANSFORMER,
|
||||
ModuleName.T5,
|
||||
},
|
||||
engine_dir=os.environ.get("TRT_ENGINE_DIR", "./engines"),
|
||||
custom_onnx_paths=onnx_dir or os.environ.get("CUSTOM_ONNX_PATHS", ""),
|
||||
trt_image_height=height,
|
||||
trt_image_width=width,
|
||||
trt_batch_size=1,
|
||||
trt_timing_cache=os.getenv("TRT_TIMING_CACHE_FILE", None),
|
||||
trt_static_batch=False,
|
||||
trt_static_shape=False,
|
||||
)
|
||||
|
||||
model = engines[ModuleName.TRANSFORMER].to(device="cpu" if offload else torch_device)
|
||||
clip = engines[ModuleName.CLIP].to(torch_device)
|
||||
t5 = engines[ModuleName.T5].to(device="cpu" if offload else torch_device)
|
||||
|
||||
ae = load_ae(name, device="cpu" if offload else torch_device)
|
||||
content_filter = PixtralContentFilter(torch.device("cpu"))
|
||||
|
||||
rng = torch.Generator(device="cpu")
|
||||
opts = SamplingOptions(
|
||||
prompt=prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
num_steps=num_steps,
|
||||
guidance=guidance,
|
||||
seed=seed,
|
||||
img_cond_path=img_cond_path,
|
||||
)
|
||||
|
||||
if loop:
|
||||
opts = parse_prompt(opts)
|
||||
opts = parse_img_cond_path(opts)
|
||||
|
||||
while opts is not None:
|
||||
if opts.seed is None:
|
||||
opts.seed = rng.seed()
|
||||
print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
|
||||
t0 = time.perf_counter()
|
||||
|
||||
if content_filter.test_txt(opts.prompt):
|
||||
print("Your prompt has been automatically flagged. Please choose another prompt.")
|
||||
if loop:
|
||||
print("-" * 80)
|
||||
opts = parse_prompt(opts)
|
||||
opts = parse_img_cond_path(opts)
|
||||
else:
|
||||
opts = None
|
||||
continue
|
||||
if content_filter.test_image(opts.img_cond_path):
|
||||
print("Your input image has been automatically flagged. Please choose another image.")
|
||||
if loop:
|
||||
print("-" * 80)
|
||||
opts = parse_prompt(opts)
|
||||
opts = parse_img_cond_path(opts)
|
||||
else:
|
||||
opts = None
|
||||
continue
|
||||
|
||||
if offload:
|
||||
t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch_device)
|
||||
inp, height, width = prepare_kontext(
|
||||
t5=t5,
|
||||
clip=clip,
|
||||
prompt=opts.prompt,
|
||||
ae=ae,
|
||||
img_cond_path=opts.img_cond_path,
|
||||
target_width=opts.width,
|
||||
target_height=opts.height,
|
||||
bs=1,
|
||||
seed=opts.seed,
|
||||
device=torch_device,
|
||||
)
|
||||
from safetensors.torch import save_file
|
||||
|
||||
save_file({k: v.cpu().contiguous() for k, v in inp.items()}, "output/noise.sft")
|
||||
inp.pop("img_cond_orig")
|
||||
opts.seed = None
|
||||
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
|
||||
|
||||
# offload TEs and AE to CPU, load model to gpu
|
||||
if offload:
|
||||
t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
model = model.to(torch_device)
|
||||
|
||||
# denoise initial noise
|
||||
t00 = time.time()
|
||||
x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
|
||||
torch.cuda.synchronize()
|
||||
t01 = time.time()
|
||||
print(f"Denoising took {t01 - t00:.3f}s")
|
||||
|
||||
# offload model, load autoencoder to gpu
|
||||
if offload:
|
||||
model.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
ae.decoder.to(x.device)
|
||||
|
||||
# decode latents to pixel space
|
||||
x = unpack(x.float(), height, width)
|
||||
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
|
||||
ae_dev_t0 = time.perf_counter()
|
||||
x = ae.decode(x)
|
||||
torch.cuda.synchronize()
|
||||
ae_dev_t1 = time.perf_counter()
|
||||
print(f"AE decode took {ae_dev_t1 - ae_dev_t0:.3f}s")
|
||||
|
||||
if content_filter.test_image(x.cpu()):
|
||||
print(
|
||||
"Your output image has been automatically flagged. Choose another prompt/image or try again."
|
||||
)
|
||||
if loop:
|
||||
print("-" * 80)
|
||||
opts = parse_prompt(opts)
|
||||
opts = parse_img_cond_path(opts)
|
||||
else:
|
||||
opts = None
|
||||
continue
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.perf_counter()
|
||||
print(f"Done in {t1 - t0:.1f}s")
|
||||
|
||||
idx = save_image(
|
||||
None, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage
|
||||
)
|
||||
|
||||
if loop:
|
||||
print("-" * 80)
|
||||
opts = parse_prompt(opts)
|
||||
opts = parse_img_cond_path(opts)
|
||||
else:
|
||||
opts = None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(main)
|
||||
290
flux/to_remove/cli_redux.py
Normal file
290
flux/to_remove/cli_redux.py
Normal file
@ -0,0 +1,290 @@
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from glob import iglob
|
||||
|
||||
import torch
|
||||
from fire import Fire
|
||||
from transformers import pipeline
|
||||
|
||||
from flux.modules.image_embedders import ReduxImageEncoder
|
||||
from flux.sampling import denoise, get_noise, get_schedule, prepare_redux, unpack
|
||||
from flux.util import (
|
||||
get_checkpoint_path,
|
||||
load_ae,
|
||||
load_clip,
|
||||
load_flow_model,
|
||||
load_t5,
|
||||
save_image,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingOptions:
|
||||
prompt: str
|
||||
width: int
|
||||
height: int
|
||||
num_steps: int
|
||||
guidance: float
|
||||
seed: int | None
|
||||
img_cond_path: str
|
||||
|
||||
|
||||
def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
|
||||
user_question = "Write /h for help, /q to quit and leave empty to repeat):\n"
|
||||
usage = (
|
||||
"Usage: Leave this field empty to do nothing "
|
||||
"or write a command starting with a slash:\n"
|
||||
"- '/w <width>' will set the width of the generated image\n"
|
||||
"- '/h <height>' will set the height of the generated image\n"
|
||||
"- '/s <seed>' sets the next seed\n"
|
||||
"- '/g <guidance>' sets the guidance (flux-dev only)\n"
|
||||
"- '/n <steps>' sets the number of steps\n"
|
||||
"- '/q' to quit"
|
||||
)
|
||||
|
||||
while (prompt := input(user_question)).startswith("/"):
|
||||
if prompt.startswith("/w"):
|
||||
if prompt.count(" ") != 1:
|
||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
||||
continue
|
||||
_, width = prompt.split()
|
||||
options.width = 16 * (int(width) // 16)
|
||||
print(
|
||||
f"Setting resolution to {options.width} x {options.height} "
|
||||
f"({options.height * options.width / 1e6:.2f}MP)"
|
||||
)
|
||||
elif prompt.startswith("/h"):
|
||||
if prompt.count(" ") != 1:
|
||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
||||
continue
|
||||
_, height = prompt.split()
|
||||
options.height = 16 * (int(height) // 16)
|
||||
print(
|
||||
f"Setting resolution to {options.width} x {options.height} "
|
||||
f"({options.height * options.width / 1e6:.2f}MP)"
|
||||
)
|
||||
elif prompt.startswith("/g"):
|
||||
if prompt.count(" ") != 1:
|
||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
||||
continue
|
||||
_, guidance = prompt.split()
|
||||
options.guidance = float(guidance)
|
||||
print(f"Setting guidance to {options.guidance}")
|
||||
elif prompt.startswith("/s"):
|
||||
if prompt.count(" ") != 1:
|
||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
||||
continue
|
||||
_, seed = prompt.split()
|
||||
options.seed = int(seed)
|
||||
print(f"Setting seed to {options.seed}")
|
||||
elif prompt.startswith("/n"):
|
||||
if prompt.count(" ") != 1:
|
||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
||||
continue
|
||||
_, steps = prompt.split()
|
||||
options.num_steps = int(steps)
|
||||
print(f"Setting number of steps to {options.num_steps}")
|
||||
elif prompt.startswith("/q"):
|
||||
print("Quitting")
|
||||
return None
|
||||
else:
|
||||
if not prompt.startswith("/h"):
|
||||
print(f"Got invalid command '{prompt}'\n{usage}")
|
||||
print(usage)
|
||||
return options
|
||||
|
||||
|
||||
def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None:
|
||||
if options is None:
|
||||
return None
|
||||
|
||||
user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n"
|
||||
usage = (
|
||||
"Usage: Either write your prompt directly, leave this field empty "
|
||||
"to repeat the conditioning image or write a command starting with a slash:\n"
|
||||
"- '/q' to quit"
|
||||
)
|
||||
|
||||
while True:
|
||||
img_cond_path = input(user_question)
|
||||
|
||||
if img_cond_path.startswith("/"):
|
||||
if img_cond_path.startswith("/q"):
|
||||
print("Quitting")
|
||||
return None
|
||||
else:
|
||||
if not img_cond_path.startswith("/h"):
|
||||
print(f"Got invalid command '{img_cond_path}'\n{usage}")
|
||||
print(usage)
|
||||
continue
|
||||
|
||||
if img_cond_path == "":
|
||||
break
|
||||
|
||||
if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith(
|
||||
(".jpg", ".jpeg", ".png", ".webp")
|
||||
):
|
||||
print(f"File '{img_cond_path}' does not exist or is not a valid image file")
|
||||
continue
|
||||
|
||||
options.img_cond_path = img_cond_path
|
||||
break
|
||||
|
||||
return options
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main(
|
||||
name: str = "flux-dev",
|
||||
width: int = 1360,
|
||||
height: int = 768,
|
||||
seed: int | None = None,
|
||||
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
||||
num_steps: int | None = None,
|
||||
loop: bool = False,
|
||||
guidance: float = 2.5,
|
||||
offload: bool = False,
|
||||
output_dir: str = "output",
|
||||
add_sampling_metadata: bool = True,
|
||||
img_cond_path: str = "assets/robot.webp",
|
||||
track_usage: bool = False,
|
||||
):
|
||||
"""
|
||||
Sample the flux model. Either interactively (set `--loop`) or run for a
|
||||
single image.
|
||||
|
||||
Args:
|
||||
name: Name of the base model to use (either 'flux-dev' or 'flux-schnell')
|
||||
height: height of the sample in pixels (should be a multiple of 16)
|
||||
width: width of the sample in pixels (should be a multiple of 16)
|
||||
seed: Set a seed for sampling
|
||||
device: Pytorch device
|
||||
num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
|
||||
loop: start an interactive session and sample multiple times
|
||||
guidance: guidance value used for guidance distillation
|
||||
offload: offload models to CPU when not in use
|
||||
output_dir: where to save the output images
|
||||
add_sampling_metadata: Add the prompt to the image Exif metadata
|
||||
img_cond_path: path to conditioning image (jpeg/png/webp)
|
||||
track_usage: track usage of the model for licensing purposes
|
||||
"""
|
||||
|
||||
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
|
||||
|
||||
if name not in (available := ["flux-dev", "flux-schnell"]):
|
||||
raise ValueError(f"Got unknown model name: {name}, chose from {available}")
|
||||
|
||||
torch_device = torch.device(device)
|
||||
if num_steps is None:
|
||||
num_steps = 4 if name == "flux-schnell" else 50
|
||||
|
||||
output_name = os.path.join(output_dir, "img_{idx}.jpg")
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
idx = 0
|
||||
else:
|
||||
fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
|
||||
if len(fns) > 0:
|
||||
idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
|
||||
else:
|
||||
idx = 0
|
||||
|
||||
# init all components
|
||||
t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512)
|
||||
clip = load_clip(torch_device)
|
||||
model = load_flow_model(name, device="cpu" if offload else torch_device)
|
||||
ae = load_ae(name, device="cpu" if offload else torch_device)
|
||||
|
||||
# Download and initialize the Redux adapter
|
||||
redux_path = str(
|
||||
get_checkpoint_path("black-forest-labs/FLUX.1-Redux-dev", "flux1-redux-dev.safetensors", "FLUX_REDUX")
|
||||
)
|
||||
img_embedder = ReduxImageEncoder(torch_device, redux_path=redux_path)
|
||||
|
||||
rng = torch.Generator(device="cpu")
|
||||
prompt = ""
|
||||
opts = SamplingOptions(
|
||||
prompt=prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
num_steps=num_steps,
|
||||
guidance=guidance,
|
||||
seed=seed,
|
||||
img_cond_path=img_cond_path,
|
||||
)
|
||||
|
||||
if loop:
|
||||
opts = parse_prompt(opts)
|
||||
opts = parse_img_cond_path(opts)
|
||||
|
||||
while opts is not None:
|
||||
if opts.seed is None:
|
||||
opts.seed = rng.seed()
|
||||
print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# prepare input
|
||||
x = get_noise(
|
||||
1,
|
||||
opts.height,
|
||||
opts.width,
|
||||
device=torch_device,
|
||||
dtype=torch.bfloat16,
|
||||
seed=opts.seed,
|
||||
)
|
||||
opts.seed = None
|
||||
if offload:
|
||||
ae = ae.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
t5, clip = t5.to(torch_device), clip.to(torch_device)
|
||||
inp = prepare_redux(
|
||||
t5,
|
||||
clip,
|
||||
x,
|
||||
prompt=opts.prompt,
|
||||
encoder=img_embedder,
|
||||
img_cond_path=opts.img_cond_path,
|
||||
)
|
||||
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
|
||||
|
||||
# offload TEs to CPU, load model to gpu
|
||||
if offload:
|
||||
t5, clip = t5.cpu(), clip.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
model = model.to(torch_device)
|
||||
|
||||
# denoise initial noise
|
||||
x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
|
||||
|
||||
# offload model, load autoencoder to gpu
|
||||
if offload:
|
||||
model.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
ae.decoder.to(x.device)
|
||||
|
||||
# decode latents to pixel space
|
||||
x = unpack(x.float(), opts.height, opts.width)
|
||||
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
|
||||
x = ae.decode(x)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.perf_counter()
|
||||
print(f"Done in {t1 - t0:.1f}s")
|
||||
|
||||
idx = save_image(
|
||||
nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage
|
||||
)
|
||||
|
||||
if loop:
|
||||
print("-" * 80)
|
||||
opts = parse_prompt(opts)
|
||||
opts = parse_img_cond_path(opts)
|
||||
else:
|
||||
opts = None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(main)
|
||||
171
flux/to_remove/content_filters.py
Normal file
171
flux/to_remove/content_filters.py
Normal file
@ -0,0 +1,171 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
from transformers import AutoProcessor, LlavaForConditionalGeneration, pipeline
|
||||
|
||||
PROMPT_IMAGE_INTEGRITY = """
|
||||
Task: Analyze an image to identify potential copyright concerns or depictions of public figures.
|
||||
|
||||
Output: Respond with only "yes" or "no"
|
||||
|
||||
Criteria for "yes":
|
||||
- The image contains a recognizable character from copyrighted media (movies, TV, comics, games, etc.)
|
||||
- The image displays a trademarked logo or brand
|
||||
- The image depicts a recognizable public figure (celebrities, politicians, athletes, influencers, historical figures, etc.)
|
||||
|
||||
Criteria for "no":
|
||||
- All other cases
|
||||
- When you cannot identify the specific copyrighted work or named individual
|
||||
|
||||
Critical Requirements:
|
||||
1. You must be able to name the exact copyrighted work or specific person depicted
|
||||
2. General references to demographics or characteristics are not sufficient
|
||||
3. Base your decision solely on visual content, not interpretation
|
||||
4. Provide only the one-word answer: "yes" or "no"
|
||||
""".strip()
|
||||
|
||||
|
||||
PROMPT_IMAGE_INTEGRITY_FOLLOW_UP = "Does this image have copyright concerns or includes public figures?"
|
||||
|
||||
PROMPT_TEXT_INTEGRITY = """
|
||||
Task: Analyze a text prompt to identify potential copyright concerns or requests to depict living public figures.
|
||||
|
||||
Output: Respond with only "yes" or "no"
|
||||
|
||||
Criteria for "Yes":
|
||||
- The prompt explicitly names a character from copyrighted media (movies, TV, comics, games, etc.)
|
||||
- The prompt explicitly mentions a trademarked logo or brand
|
||||
- The prompt names or describes a specific living public figure (celebrities, politicians, athletes, influencers, etc.)
|
||||
|
||||
Criteria for "No":
|
||||
- All other cases
|
||||
- When you cannot identify the specific copyrighted work or named individual
|
||||
|
||||
Critical Requirements:
|
||||
1. You must be able to name the exact copyrighted work or specific person referenced
|
||||
2. General demographic descriptions or characteristics are not sufficient
|
||||
3. Analyze only the prompt text, not potential image outcomes
|
||||
4. Provide only the one-word answer: "yes" or "no"
|
||||
|
||||
The prompt to check is:
|
||||
-----
|
||||
{prompt}
|
||||
-----
|
||||
|
||||
Does this prompt have copyright concerns or includes public figures?
|
||||
""".strip()
|
||||
|
||||
|
||||
class PixtralContentFilter(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
nsfw_threshold: float = 0.85,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
model_id = "mistral-community/pixtral-12b"
|
||||
self.processor = AutoProcessor.from_pretrained(model_id)
|
||||
self.model = LlavaForConditionalGeneration.from_pretrained(model_id, device_map=device)
|
||||
|
||||
self.yes_token, self.no_token = self.processor.tokenizer.encode(["yes", "no"])
|
||||
|
||||
self.nsfw_classifier = pipeline(
|
||||
"image-classification", model="Falconsai/nsfw_image_detection", device=device
|
||||
)
|
||||
self.nsfw_threshold = nsfw_threshold
|
||||
|
||||
def yes_no_logit_processor(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Sets all tokens but yes/no to the minimum.
|
||||
"""
|
||||
scores_yes_token = scores[:, self.yes_token].clone()
|
||||
scores_no_token = scores[:, self.no_token].clone()
|
||||
scores_min = scores.min()
|
||||
scores[:, :] = scores_min - 1
|
||||
scores[:, self.yes_token] = scores_yes_token
|
||||
scores[:, self.no_token] = scores_no_token
|
||||
return scores
|
||||
|
||||
def test_image(self, image: Image.Image | str | torch.Tensor) -> bool:
|
||||
if isinstance(image, torch.Tensor):
|
||||
image = rearrange(image[0].clamp(-1.0, 1.0), "c h w -> h w c")
|
||||
image = Image.fromarray((127.5 * (image + 1.0)).cpu().byte().numpy())
|
||||
elif isinstance(image, str):
|
||||
image = Image.open(image)
|
||||
|
||||
classification = next(c for c in self.nsfw_classifier(image) if c["label"] == "nsfw")
|
||||
if classification["score"] > self.nsfw_threshold:
|
||||
return True
|
||||
|
||||
# 512^2 pixels are enough for checking
|
||||
w, h = image.size
|
||||
f = (512**2 / (w * h)) ** 0.5
|
||||
image = image.resize((int(f * w), int(f * h)))
|
||||
|
||||
chat = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"content": PROMPT_IMAGE_INTEGRITY,
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"image": image,
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"content": PROMPT_IMAGE_INTEGRITY_FOLLOW_UP,
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
inputs = self.processor.apply_chat_template(
|
||||
chat,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
).to(self.model.device)
|
||||
|
||||
generate_ids = self.model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=1,
|
||||
logits_processor=[self.yes_no_logit_processor],
|
||||
do_sample=False,
|
||||
)
|
||||
return generate_ids[0, -1].item() == self.yes_token
|
||||
|
||||
def test_txt(self, txt: str) -> bool:
|
||||
chat = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"content": PROMPT_TEXT_INTEGRITY.format(prompt=txt),
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
inputs = self.processor.apply_chat_template(
|
||||
chat,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
).to(self.model.device)
|
||||
|
||||
generate_ids = self.model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=1,
|
||||
logits_processor=[self.yes_no_logit_processor],
|
||||
do_sample=False,
|
||||
)
|
||||
return generate_ids[0, -1].item() == self.yes_token
|
||||
702
flux/util.py
Normal file
702
flux/util.py
Normal file
@ -0,0 +1,702 @@
|
||||
import getpass
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from huggingface_hub import hf_hub_download, login
|
||||
from PIL import ExifTags, Image
|
||||
from safetensors.torch import load_file as load_sft
|
||||
|
||||
from flux.model import Flux, FluxLoraWrapper, FluxParams
|
||||
from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
|
||||
from flux.modules.conditioner import HFEmbedder
|
||||
|
||||
CHECKPOINTS_DIR = Path("checkpoints")
|
||||
CHECKPOINTS_DIR.mkdir(exist_ok=True)
|
||||
BFL_API_KEY = os.getenv("BFL_API_KEY")
|
||||
|
||||
os.environ.setdefault("TRT_ENGINE_DIR", str(CHECKPOINTS_DIR / "trt_engines"))
|
||||
(CHECKPOINTS_DIR / "trt_engines").mkdir(exist_ok=True)
|
||||
|
||||
|
||||
def ensure_hf_auth():
|
||||
hf_token = os.environ.get("HF_TOKEN")
|
||||
if hf_token:
|
||||
print("Trying to authenticate to HuggingFace with the HF_TOKEN environment variable.")
|
||||
try:
|
||||
login(token=hf_token)
|
||||
print("Successfully authenticated with HuggingFace using HF_TOKEN")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to authenticate with HF_TOKEN: {e}")
|
||||
|
||||
if os.path.exists(os.path.expanduser("~/.cache/huggingface/token")):
|
||||
print("Already authenticated with HuggingFace")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def prompt_for_hf_auth():
|
||||
try:
|
||||
token = getpass.getpass("HF Token (hidden input): ").strip()
|
||||
if not token:
|
||||
print("No token provided. Aborting.")
|
||||
return False
|
||||
|
||||
login(token=token)
|
||||
print("Successfully authenticated!")
|
||||
return True
|
||||
except KeyboardInterrupt:
|
||||
print("\nAuthentication cancelled by user.")
|
||||
return False
|
||||
except Exception as auth_e:
|
||||
print(f"Authentication failed: {auth_e}")
|
||||
print("Tip: You can also run 'huggingface-cli login' or set HF_TOKEN environment variable")
|
||||
return False
|
||||
|
||||
|
||||
def get_checkpoint_path(repo_id: str, filename: str, env_var: str) -> Path:
|
||||
"""Get the local path for a checkpoint file, downloading if necessary."""
|
||||
# if os.environ.get(env_var) is not None:
|
||||
# local_path = os.environ[env_var]
|
||||
# if os.path.exists(local_path):
|
||||
# return Path(local_path)
|
||||
|
||||
# print(
|
||||
# f"Trying to load model {repo_id}, {filename} from environment "
|
||||
# f"variable {env_var}. But file {local_path} does not exist. "
|
||||
# "Falling back to default location."
|
||||
# )
|
||||
|
||||
# # Create a safe directory name from repo_id
|
||||
# safe_repo_name = repo_id.replace("/", "_")
|
||||
# checkpoint_dir = CHECKPOINTS_DIR / safe_repo_name
|
||||
# checkpoint_dir.mkdir(exist_ok=True)
|
||||
|
||||
# local_path = checkpoint_dir / filename
|
||||
|
||||
local_path = filename
|
||||
from mmgp import offload
|
||||
|
||||
if False:
|
||||
print(f"Downloading {filename} from {repo_id} to {local_path}")
|
||||
try:
|
||||
ensure_hf_auth()
|
||||
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=checkpoint_dir)
|
||||
except Exception as e:
|
||||
if "gated repo" in str(e).lower() or "restricted" in str(e).lower():
|
||||
print(f"\nError: Cannot access {repo_id} -- this is a gated repository.")
|
||||
|
||||
# Try one more time to authenticate
|
||||
if prompt_for_hf_auth():
|
||||
# Retry the download after authentication
|
||||
print("Retrying download...")
|
||||
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=checkpoint_dir)
|
||||
else:
|
||||
print("Authentication failed or cancelled.")
|
||||
print("You can also run 'huggingface-cli login' or set HF_TOKEN environment variable")
|
||||
raise RuntimeError(f"Authentication required for {repo_id}")
|
||||
else:
|
||||
raise e
|
||||
|
||||
return local_path
|
||||
|
||||
|
||||
def download_onnx_models_for_trt(model_name: str, trt_transformer_precision: str = "bf16") -> str | None:
|
||||
"""Download ONNX models for TRT to our checkpoints directory"""
|
||||
onnx_repo_map = {
|
||||
"flux-dev": "black-forest-labs/FLUX.1-dev-onnx",
|
||||
"flux-schnell": "black-forest-labs/FLUX.1-schnell-onnx",
|
||||
"flux-dev-canny": "black-forest-labs/FLUX.1-Canny-dev-onnx",
|
||||
"flux-dev-depth": "black-forest-labs/FLUX.1-Depth-dev-onnx",
|
||||
"flux-dev-redux": "black-forest-labs/FLUX.1-Redux-dev-onnx",
|
||||
"flux-dev-fill": "black-forest-labs/FLUX.1-Fill-dev-onnx",
|
||||
"flux-dev-kontext": "black-forest-labs/FLUX.1-Kontext-dev-onnx",
|
||||
}
|
||||
|
||||
if model_name not in onnx_repo_map:
|
||||
return None # No ONNX repository required for this model
|
||||
|
||||
repo_id = onnx_repo_map[model_name]
|
||||
safe_repo_name = repo_id.replace("/", "_")
|
||||
onnx_dir = CHECKPOINTS_DIR / safe_repo_name
|
||||
|
||||
# Map of module names to their ONNX file paths (using specified precision)
|
||||
onnx_file_map = {
|
||||
"clip": "clip.opt/model.onnx",
|
||||
"transformer": f"transformer.opt/{trt_transformer_precision}/model.onnx",
|
||||
"transformer_data": f"transformer.opt/{trt_transformer_precision}/backbone.onnx_data",
|
||||
"t5": "t5.opt/model.onnx",
|
||||
"t5_data": "t5.opt/backbone.onnx_data",
|
||||
"vae": "vae.opt/model.onnx",
|
||||
}
|
||||
|
||||
# If all files exist locally, return the custom_onnx_paths format
|
||||
if onnx_dir.exists():
|
||||
all_files_exist = True
|
||||
custom_paths = []
|
||||
for module, onnx_file in onnx_file_map.items():
|
||||
if module.endswith("_data"):
|
||||
continue # Skip data files
|
||||
local_path = onnx_dir / onnx_file
|
||||
if not local_path.exists():
|
||||
all_files_exist = False
|
||||
break
|
||||
custom_paths.append(f"{module}:{local_path}")
|
||||
|
||||
if all_files_exist:
|
||||
print(f"ONNX models ready in {onnx_dir}")
|
||||
return ",".join(custom_paths)
|
||||
|
||||
# If not all files exist, download them
|
||||
print(f"Downloading ONNX models from {repo_id} to {onnx_dir}")
|
||||
print(f"Using transformer precision: {trt_transformer_precision}")
|
||||
onnx_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Download all ONNX files
|
||||
for module, onnx_file in onnx_file_map.items():
|
||||
local_path = onnx_dir / onnx_file
|
||||
if local_path.exists():
|
||||
continue # Already downloaded
|
||||
|
||||
# Create parent directories
|
||||
local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
print(f"Downloading {onnx_file}")
|
||||
hf_hub_download(repo_id=repo_id, filename=onnx_file, local_dir=onnx_dir)
|
||||
except Exception as e:
|
||||
if "does not exist" in str(e).lower() or "not found" in str(e).lower():
|
||||
continue
|
||||
elif "gated repo" in str(e).lower() or "restricted" in str(e).lower():
|
||||
print(f"Cannot access {repo_id} - requires license acceptance")
|
||||
print("Please follow these steps:")
|
||||
print(f" 1. Visit: https://huggingface.co/{repo_id}")
|
||||
print(" 2. Log in to your HuggingFace account")
|
||||
print(" 3. Accept the license terms and conditions")
|
||||
print(" 4. Then retry this command")
|
||||
raise RuntimeError(f"License acceptance required for {model_name}")
|
||||
else:
|
||||
# Re-raise other errors
|
||||
raise
|
||||
|
||||
print(f"ONNX models ready in {onnx_dir}")
|
||||
|
||||
# Return the custom_onnx_paths format that TRT expects: "module1:path1,module2:path2"
|
||||
# Note: Only return the actual module paths, not the data file
|
||||
custom_paths = []
|
||||
for module, onnx_file in onnx_file_map.items():
|
||||
if module.endswith("_data"):
|
||||
continue # Skip the data file in the return paths
|
||||
full_path = onnx_dir / onnx_file
|
||||
if full_path.exists():
|
||||
custom_paths.append(f"{module}:{full_path}")
|
||||
|
||||
return ",".join(custom_paths)
|
||||
|
||||
|
||||
def check_onnx_access_for_trt(model_name: str, trt_transformer_precision: str = "bf16") -> str | None:
|
||||
"""Check ONNX access and download models for TRT - returns ONNX directory path"""
|
||||
return download_onnx_models_for_trt(model_name, trt_transformer_precision)
|
||||
|
||||
|
||||
def track_usage_via_api(name: str, n=1) -> None:
|
||||
"""
|
||||
Track usage of licensed models via the BFL API for commercial licensing compliance.
|
||||
|
||||
For more information on licensing BFL's models for commercial use and usage reporting,
|
||||
see the README.md or visit: https://dashboard.bfl.ai/licensing/subscriptions?showInstructions=true
|
||||
"""
|
||||
assert BFL_API_KEY is not None, "BFL_API_KEY is not set"
|
||||
|
||||
model_slug_map = {
|
||||
"flux-dev": "flux-1-dev",
|
||||
"flux-dev-kontext": "flux-1-kontext-dev",
|
||||
"flux-dev-fill": "flux-tools",
|
||||
"flux-dev-depth": "flux-tools",
|
||||
"flux-dev-canny": "flux-tools",
|
||||
"flux-dev-canny-lora": "flux-tools",
|
||||
"flux-dev-depth-lora": "flux-tools",
|
||||
"flux-dev-redux": "flux-tools",
|
||||
}
|
||||
|
||||
if name not in model_slug_map:
|
||||
print(f"Skipping tracking usage for {name}, as it cannot be tracked. Please check the model name.")
|
||||
return
|
||||
|
||||
model_slug = model_slug_map[name]
|
||||
url = f"https://api.bfl.ai/v1/licenses/models/{model_slug}/usage"
|
||||
headers = {"x-key": BFL_API_KEY, "Content-Type": "application/json"}
|
||||
payload = {"number_of_generations": n}
|
||||
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to track usage: {response.status_code} {response.text}")
|
||||
else:
|
||||
print(f"Successfully tracked usage for {name} with {n} generations")
|
||||
|
||||
|
||||
def save_image(
|
||||
nsfw_classifier,
|
||||
name: str,
|
||||
output_name: str,
|
||||
idx: int,
|
||||
x: torch.Tensor,
|
||||
add_sampling_metadata: bool,
|
||||
prompt: str,
|
||||
nsfw_threshold: float = 0.85,
|
||||
track_usage: bool = False,
|
||||
) -> int:
|
||||
fn = output_name.format(idx=idx)
|
||||
print(f"Saving {fn}")
|
||||
# bring into PIL format and save
|
||||
x = x.clamp(-1, 1)
|
||||
x = rearrange(x[0], "c h w -> h w c")
|
||||
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
|
||||
|
||||
if nsfw_classifier is not None:
|
||||
nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0]
|
||||
else:
|
||||
nsfw_score = nsfw_threshold - 1.0
|
||||
|
||||
if nsfw_score < nsfw_threshold:
|
||||
exif_data = Image.Exif()
|
||||
if name in ["flux-dev", "flux-schnell"]:
|
||||
exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
|
||||
else:
|
||||
exif_data[ExifTags.Base.Software] = "AI generated;img2img;flux"
|
||||
exif_data[ExifTags.Base.Make] = "Black Forest Labs"
|
||||
exif_data[ExifTags.Base.Model] = name
|
||||
if add_sampling_metadata:
|
||||
exif_data[ExifTags.Base.ImageDescription] = prompt
|
||||
img.save(fn, exif=exif_data, quality=95, subsampling=0)
|
||||
if track_usage:
|
||||
track_usage_via_api(name, 1)
|
||||
idx += 1
|
||||
else:
|
||||
print("Your generated image may contain NSFW content.")
|
||||
|
||||
return idx
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelSpec:
|
||||
params: FluxParams
|
||||
ae_params: AutoEncoderParams
|
||||
repo_id: str
|
||||
repo_flow: str
|
||||
repo_ae: str
|
||||
lora_repo_id: str | None = None
|
||||
lora_filename: str | None = None
|
||||
|
||||
|
||||
configs = {
|
||||
"flux-dev": ModelSpec(
|
||||
repo_id="",
|
||||
repo_flow="",
|
||||
repo_ae="ckpts/flux_vae.safetensors",
|
||||
params=FluxParams(
|
||||
in_channels=64,
|
||||
out_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=True,
|
||||
),
|
||||
ae_params=AutoEncoderParams(
|
||||
resolution=256,
|
||||
in_channels=3,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
z_channels=16,
|
||||
scale_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
),
|
||||
),
|
||||
"flux-schnell": ModelSpec(
|
||||
repo_id="black-forest-labs/FLUX.1-schnell",
|
||||
repo_flow="",
|
||||
repo_ae="ckpts/flux_vae.safetensors",
|
||||
params=FluxParams(
|
||||
in_channels=64,
|
||||
out_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=False,
|
||||
),
|
||||
ae_params=AutoEncoderParams(
|
||||
resolution=256,
|
||||
in_channels=3,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
z_channels=16,
|
||||
scale_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
),
|
||||
),
|
||||
"flux-dev-canny": ModelSpec(
|
||||
repo_id="black-forest-labs/FLUX.1-Canny-dev",
|
||||
repo_flow="",
|
||||
repo_ae="ckpts/flux_vae.safetensors",
|
||||
params=FluxParams(
|
||||
in_channels=128,
|
||||
out_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=True,
|
||||
),
|
||||
ae_params=AutoEncoderParams(
|
||||
resolution=256,
|
||||
in_channels=3,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
z_channels=16,
|
||||
scale_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
),
|
||||
),
|
||||
"flux-dev-canny-lora": ModelSpec(
|
||||
repo_id="black-forest-labs/FLUX.1-dev",
|
||||
repo_flow="",
|
||||
repo_ae="ckpts/flux_vae.safetensors",
|
||||
lora_repo_id="black-forest-labs/FLUX.1-Canny-dev-lora",
|
||||
lora_filename="flux1-canny-dev-lora.safetensors",
|
||||
params=FluxParams(
|
||||
in_channels=128,
|
||||
out_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=True,
|
||||
),
|
||||
ae_params=AutoEncoderParams(
|
||||
resolution=256,
|
||||
in_channels=3,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
z_channels=16,
|
||||
scale_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
),
|
||||
),
|
||||
"flux-dev-depth": ModelSpec(
|
||||
repo_id="black-forest-labs/FLUX.1-Depth-dev",
|
||||
repo_flow="",
|
||||
repo_ae="ckpts/flux_vae.safetensors",
|
||||
params=FluxParams(
|
||||
in_channels=128,
|
||||
out_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=True,
|
||||
),
|
||||
ae_params=AutoEncoderParams(
|
||||
resolution=256,
|
||||
in_channels=3,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
z_channels=16,
|
||||
scale_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
),
|
||||
),
|
||||
"flux-dev-depth-lora": ModelSpec(
|
||||
repo_id="black-forest-labs/FLUX.1-dev",
|
||||
repo_flow="",
|
||||
repo_ae="ckpts/flux_vae.safetensors",
|
||||
lora_repo_id="black-forest-labs/FLUX.1-Depth-dev-lora",
|
||||
lora_filename="flux1-depth-dev-lora.safetensors",
|
||||
params=FluxParams(
|
||||
in_channels=128,
|
||||
out_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=True,
|
||||
),
|
||||
ae_params=AutoEncoderParams(
|
||||
resolution=256,
|
||||
in_channels=3,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
z_channels=16,
|
||||
scale_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
),
|
||||
),
|
||||
"flux-dev-redux": ModelSpec(
|
||||
repo_id="black-forest-labs/FLUX.1-Redux-dev",
|
||||
repo_flow="",
|
||||
repo_ae="ckpts/flux_vae.safetensors",
|
||||
params=FluxParams(
|
||||
in_channels=64,
|
||||
out_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=True,
|
||||
),
|
||||
ae_params=AutoEncoderParams(
|
||||
resolution=256,
|
||||
in_channels=3,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
z_channels=16,
|
||||
scale_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
),
|
||||
),
|
||||
"flux-dev-fill": ModelSpec(
|
||||
repo_id="black-forest-labs/FLUX.1-Fill-dev",
|
||||
repo_flow="",
|
||||
repo_ae="ckpts/flux_vae.safetensors",
|
||||
params=FluxParams(
|
||||
in_channels=384,
|
||||
out_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=True,
|
||||
),
|
||||
ae_params=AutoEncoderParams(
|
||||
resolution=256,
|
||||
in_channels=3,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
z_channels=16,
|
||||
scale_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
),
|
||||
),
|
||||
"flux-dev-kontext": ModelSpec(
|
||||
repo_id="black-forest-labs/FLUX.1-Kontext-dev",
|
||||
repo_flow="",
|
||||
repo_ae="ckpts/flux_vae.safetensors",
|
||||
params=FluxParams(
|
||||
in_channels=64,
|
||||
out_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=True,
|
||||
),
|
||||
ae_params=AutoEncoderParams(
|
||||
resolution=256,
|
||||
in_channels=3,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
z_channels=16,
|
||||
scale_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
PREFERED_KONTEXT_RESOLUTIONS = [
|
||||
(672, 1568),
|
||||
(688, 1504),
|
||||
(720, 1456),
|
||||
(752, 1392),
|
||||
(800, 1328),
|
||||
(832, 1248),
|
||||
(880, 1184),
|
||||
(944, 1104),
|
||||
(1024, 1024),
|
||||
(1104, 944),
|
||||
(1184, 880),
|
||||
(1248, 832),
|
||||
(1328, 800),
|
||||
(1392, 752),
|
||||
(1456, 720),
|
||||
(1504, 688),
|
||||
(1568, 672),
|
||||
]
|
||||
|
||||
|
||||
def aspect_ratio_to_height_width(aspect_ratio: str, area: int = 1024**2) -> tuple[int, int]:
|
||||
width = float(aspect_ratio.split(":")[0])
|
||||
height = float(aspect_ratio.split(":")[1])
|
||||
ratio = width / height
|
||||
width = round(math.sqrt(area * ratio))
|
||||
height = round(math.sqrt(area / ratio))
|
||||
return 16 * (width // 16), 16 * (height // 16)
|
||||
|
||||
|
||||
def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
|
||||
if len(missing) > 0 and len(unexpected) > 0:
|
||||
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
||||
print("\n" + "-" * 79 + "\n")
|
||||
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
|
||||
elif len(missing) > 0:
|
||||
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
||||
elif len(unexpected) > 0:
|
||||
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
|
||||
|
||||
|
||||
def load_flow_model(name: str, model_filename, device: str | torch.device = "cuda", verbose: bool = True) -> Flux:
|
||||
# Loading Flux
|
||||
config = configs[name]
|
||||
|
||||
ckpt_path = model_filename #config.repo_flow
|
||||
|
||||
with torch.device("meta"):
|
||||
if config.lora_repo_id is not None and config.lora_filename is not None:
|
||||
model = FluxLoraWrapper(params=config.params).to(torch.bfloat16)
|
||||
else:
|
||||
model = Flux(config.params).to(torch.bfloat16)
|
||||
|
||||
# print(f"Loading checkpoint: {ckpt_path}")
|
||||
from mmgp import offload
|
||||
offload.load_model_data(model, model_filename )
|
||||
|
||||
# # load_sft doesn't support torch.device
|
||||
# sd = load_sft(ckpt_path, device=str(device))
|
||||
# sd = optionally_expand_state_dict(model, sd)
|
||||
# missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
|
||||
# if verbose:
|
||||
# print_load_warning(missing, unexpected)
|
||||
|
||||
# if config.lora_repo_id is not None and config.lora_filename is not None:
|
||||
# print("Loading LoRA")
|
||||
# lora_path = str(get_checkpoint_path(config.lora_repo_id, config.lora_filename, "FLUX_LORA"))
|
||||
# lora_sd = load_sft(lora_path, device=str(device))
|
||||
# # loading the lora params + overwriting scale values in the norms
|
||||
# missing, unexpected = model.load_state_dict(lora_sd, strict=False, assign=True)
|
||||
# if verbose:
|
||||
# print_load_warning(missing, unexpected)
|
||||
return model
|
||||
|
||||
|
||||
def load_t5(device: str | torch.device = "cuda", text_encoder_filename = None, max_length: int = 512) -> HFEmbedder:
|
||||
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
|
||||
return HFEmbedder("",text_encoder_filename, max_length=max_length, torch_dtype=torch.bfloat16).to(device)
|
||||
|
||||
|
||||
def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
|
||||
return HFEmbedder("ckpts/clip_vit_large_patch14", "", max_length=77, torch_dtype=torch.bfloat16, is_clip =True).to(device)
|
||||
|
||||
|
||||
def load_ae(name: str, device: str | torch.device = "cuda") -> AutoEncoder:
|
||||
config = configs[name]
|
||||
ckpt_path = str(get_checkpoint_path(config.repo_id, config.repo_ae, "FLUX_AE"))
|
||||
|
||||
# Loading the autoencoder
|
||||
with torch.device("meta"):
|
||||
ae = AutoEncoder(config.ae_params)
|
||||
|
||||
# print(f"Loading AE checkpoint: {ckpt_path}")
|
||||
sd = load_sft(ckpt_path, device=str(device))
|
||||
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
|
||||
print_load_warning(missing, unexpected)
|
||||
return ae
|
||||
|
||||
|
||||
def optionally_expand_state_dict(model: torch.nn.Module, state_dict: dict) -> dict:
|
||||
"""
|
||||
Optionally expand the state dict to match the model's parameters shapes.
|
||||
"""
|
||||
for name, param in model.named_parameters():
|
||||
if name in state_dict:
|
||||
if state_dict[name].shape != param.shape:
|
||||
print(
|
||||
f"Expanding '{name}' with shape {state_dict[name].shape} to model parameter with shape {param.shape}."
|
||||
)
|
||||
# expand with zeros:
|
||||
expanded_state_dict_weight = torch.zeros_like(param, device=state_dict[name].device)
|
||||
slices = tuple(slice(0, dim) for dim in state_dict[name].shape)
|
||||
expanded_state_dict_weight[slices] = state_dict[name]
|
||||
state_dict[name] = expanded_state_dict_weight
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
@ -149,6 +149,7 @@ class LTXV:
|
||||
self,
|
||||
model_filepath: str,
|
||||
text_encoder_filepath: str,
|
||||
model_def,
|
||||
dtype = torch.bfloat16,
|
||||
VAE_dtype = torch.bfloat16,
|
||||
mixed_precision_transformer = False
|
||||
@ -157,8 +158,8 @@ class LTXV:
|
||||
# if dtype == torch.float16:
|
||||
dtype = torch.bfloat16
|
||||
self.mixed_precision_transformer = mixed_precision_transformer
|
||||
self.distilled = any("lora" in name for name in model_filepath)
|
||||
model_filepath = [name for name in model_filepath if not "lora" in name ]
|
||||
self.model_def = model_def
|
||||
self.pipeline_config = model_def["LTXV_config"]
|
||||
# with safe_open(ckpt_path, framework="pt") as f:
|
||||
# metadata = f.metadata()
|
||||
# config_str = metadata.get("config")
|
||||
@ -220,11 +221,11 @@ class LTXV:
|
||||
prompt_enhancer_llm_model = None
|
||||
prompt_enhancer_llm_tokenizer = None
|
||||
|
||||
if prompt_enhancer_image_caption_model != None:
|
||||
pipe["prompt_enhancer_image_caption_model"] = prompt_enhancer_image_caption_model
|
||||
prompt_enhancer_image_caption_model._model_dtype = torch.float
|
||||
# if prompt_enhancer_image_caption_model != None:
|
||||
# pipe["prompt_enhancer_image_caption_model"] = prompt_enhancer_image_caption_model
|
||||
# prompt_enhancer_image_caption_model._model_dtype = torch.float
|
||||
|
||||
pipe["prompt_enhancer_llm_model"] = prompt_enhancer_llm_model
|
||||
# pipe["prompt_enhancer_llm_model"] = prompt_enhancer_llm_model
|
||||
|
||||
# offload.profile(pipe, profile_no=5, extraModelsToQuantize = None, quantizeTransformer = False, budgets = { "prompt_enhancer_llm_model" : 10000, "prompt_enhancer_image_caption_model" : 10000, "vae" : 3000, "*" : 100 }, verboseLevel=2)
|
||||
|
||||
@ -299,14 +300,10 @@ class LTXV:
|
||||
conditioning_media_paths = None
|
||||
conditioning_start_frames = None
|
||||
|
||||
if self.distilled :
|
||||
pipeline_config = "ltx_video/configs/ltxv-13b-0.9.7-distilled.yaml"
|
||||
else:
|
||||
pipeline_config = "ltx_video/configs/ltxv-13b-0.9.7-dev.yaml"
|
||||
# check if pipeline_config is a file
|
||||
if not os.path.isfile(pipeline_config):
|
||||
raise ValueError(f"Pipeline config file {pipeline_config} does not exist")
|
||||
with open(pipeline_config, "r") as f:
|
||||
if not os.path.isfile(self.pipeline_config):
|
||||
raise ValueError(f"Pipeline config file {self.pipeline_config} does not exist")
|
||||
with open(self.pipeline_config, "r") as f:
|
||||
pipeline_config = yaml.safe_load(f)
|
||||
|
||||
|
||||
@ -520,7 +517,7 @@ def get_media_num_frames(media_path: str) -> int:
|
||||
return media_path.shape[1]
|
||||
elif isinstance(media_path, str) and any( media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]):
|
||||
reader = imageio.get_reader(media_path)
|
||||
return min(reader.count_frames(), max_frames)
|
||||
return min(reader.count_frames(), 0) # to do
|
||||
else:
|
||||
raise Exception("video format not supported")
|
||||
|
||||
@ -564,6 +561,3 @@ def load_media_file(
|
||||
raise Exception("video format not supported")
|
||||
return media_tensor
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -106,18 +106,18 @@ class WanAny2V:
|
||||
# config = json.load(f)
|
||||
# from mmgp import safetensors2
|
||||
# sd = safetensors2.torch_load_file(xmodel_filename)
|
||||
|
||||
# model_filename = "c:/temp/fflf/diffusion_pytorch_model-00001-of-00007.safetensors"
|
||||
base_config_file = f"configs/{base_model_type}.json"
|
||||
forcedConfigPath = base_config_file if len(model_filename) > 1 or base_model_type in ["flf2v_720p"] else None
|
||||
forcedConfigPath = base_config_file if len(model_filename) > 1 else None
|
||||
# forcedConfigPath = base_config_file = f"configs/flf2v_720p.json"
|
||||
# model_filename[1] = xmodel_filename
|
||||
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
|
||||
# self.model = offload.load_model_data(self.model, xmodel_filename )
|
||||
# offload.load_model_data(self.model, "c:/temp/Phantom-Wan-1.3B.pth")
|
||||
# self.model.to(torch.bfloat16)
|
||||
# self.model.cpu()
|
||||
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
|
||||
offload.change_dtype(self.model, dtype, True)
|
||||
# offload.save_model(self.model, "multitalkbf16.safetensors", config_file_path=base_config_file, filter_sd=sd)
|
||||
# offload.save_model(self.model, "flf2v_720p.safetensors", config_file_path=base_config_file)
|
||||
# offload.save_model(self.model, "flf2v_quanto_int8_fp16_720p.safetensors", do_quantize= True, config_file_path=base_config_file)
|
||||
# offload.save_model(self.model, "multitalk_quanto_fp16.safetensors", do_quantize= True, config_file_path=base_config_file, filter_sd=sd)
|
||||
|
||||
# offload.save_model(self.model, "wan2.1_selforcing_fp16.safetensors", config_file_path=base_config_file)
|
||||
@ -126,7 +126,7 @@ class WanAny2V:
|
||||
self.model.eval().requires_grad_(False)
|
||||
if save_quantized:
|
||||
from wgp import save_quantized_model
|
||||
save_quantized_model(self.model, model_type, model_filename[1 if base_model_type=="fantasy" else 0], dtype, base_config_file)
|
||||
save_quantized_model(self.model, model_type, model_filename[0], dtype, base_config_file)
|
||||
|
||||
self.sample_neg_prompt = config.sample_neg_prompt
|
||||
|
||||
@ -477,8 +477,8 @@ class WanAny2V:
|
||||
any_end_frame = False
|
||||
if input_frames != None:
|
||||
_ , preframes_count, height, width = input_frames.shape
|
||||
lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2]
|
||||
clip_context = self.clip.visual([input_frames[:, -1:]]) #.to(self.param_dtype)
|
||||
lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2]
|
||||
clip_context = self.clip.visual([input_frames[:, -1:]]) if model_type != "flf2v_720p" else self.clip.visual([input_frames[:, -1:], input_frames[:, -1:]])
|
||||
input_frames = input_frames.to(device=self.device).to(dtype= self.VAE_dtype)
|
||||
enc = torch.concat( [input_frames, torch.zeros( (3, frame_num-preframes_count, height, width),
|
||||
device=self.device, dtype= self.VAE_dtype)],
|
||||
@ -488,7 +488,7 @@ class WanAny2V:
|
||||
preframes_count = 1
|
||||
image_start = TF.to_tensor(image_start)
|
||||
any_end_frame = image_end != None
|
||||
add_frames_for_end_image = any_end_frame and model_type not in ["fun_inp_1.3B", "fun_inp", "i2v_720p"]
|
||||
add_frames_for_end_image = any_end_frame and model_type == "i2v"
|
||||
if any_end_frame:
|
||||
image_end = TF.to_tensor(image_end)
|
||||
if add_frames_for_end_image:
|
||||
@ -517,8 +517,8 @@ class WanAny2V:
|
||||
img_interpolated2 = resize_lanczos(image_end, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device) #, self.dtype
|
||||
image_end = resize_lanczos(image_end, clip_image_size, clip_image_size)
|
||||
image_end = image_end.sub_(0.5).div_(0.5).to(self.device) #, self.dtype
|
||||
if image_end != None and model_type == "flf2v_720p":
|
||||
clip_context = self.clip.visual([image_start[:, None, :, :], image_end[:, None, :, :]])
|
||||
if model_type == "flf2v_720p":
|
||||
clip_context = self.clip.visual([image_start[:, None, :, :], image_end[:, None, :, :] if image_end != None else image_start[:, None, :, :]])
|
||||
else:
|
||||
clip_context = self.clip.visual([image_start[:, None, :, :]])
|
||||
|
||||
@ -753,7 +753,7 @@ class WanAny2V:
|
||||
"context" : [context, context_null, context_null],
|
||||
"audio_scale": [audio_scale, None, None ]
|
||||
}
|
||||
elif multitalk:
|
||||
elif multitalk and audio_proj != None:
|
||||
gen_args = {
|
||||
"x" : [latent_model_input, latent_model_input, latent_model_input],
|
||||
"context" : [context, context_null, context_null],
|
||||
|
||||
@ -184,6 +184,7 @@ def get_full_audio_embeddings(audio_guide1 = None, audio_guide2 = None, combinat
|
||||
|
||||
|
||||
def get_window_audio_embeddings(full_audio_embs, audio_start_idx=0, clip_length = 81, vae_scale = 4, audio_window = 5):
|
||||
if full_audio_embs == None: return None
|
||||
HUMAN_NUMBER = len(full_audio_embs)
|
||||
audio_end_idx = audio_start_idx + clip_length
|
||||
indices = (torch.arange(2 * 2 + 1) - 2) * 1
|
||||
|
||||
Loading…
Reference in New Issue
Block a user