From eb92f0c11cdb8c9015a6bbe10509deedd76a5dd7 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Sun, 13 Jul 2025 04:24:55 +0200 Subject: [PATCH 1/9] flux kontext --- defaults/ReadMe.txt | 13 + {finetunes => defaults}/fantasy.json | 0 {finetunes => defaults}/flf2v_720p.json | 10 +- defaults/flux_dev_kontext.json | 19 + defaults/fun_inp.json | 13 + defaults/fun_inp_1.3B.json | 11 + defaults/hunyuan.json | 12 + defaults/hunyuan_avatar.json | 12 + defaults/hunyuan_custom.json | 12 + defaults/hunyuan_custom_audio.json | 12 + defaults/hunyuan_custom_edit.json | 12 + defaults/hunyuan_i2v.json | 12 + .../hunyuan_t2v_accvideo.json | 0 {finetunes => defaults}/hunyuan_t2v_fast.json | 0 defaults/i2v.json | 13 + defaults/i2v_720p.json | 14 + defaults/i2v_fusionix.json | 10 + defaults/ltxv_13B.json | 14 + defaults/ltxv_distilled.json | 14 + {finetunes => defaults}/moviigen.json | 0 {finetunes => defaults}/multitalk.json | 0 {finetunes => defaults}/multitalk_720p.json | 0 defaults/phantom_1.3B.json | 11 + defaults/phantom_14B.json | 13 + defaults/recam_1.3B.json | 11 + defaults/sky_df_1.3B.json | 11 + defaults/sky_df_14B.json | 13 + defaults/sky_df_720p_14B.json | 14 + defaults/t2v.json | 13 + defaults/t2v_1.3B.json | 11 + {finetunes => defaults}/t2v_fusionix.json | 0 {finetunes => defaults}/t2v_sf.json | 0 defaults/vace_1.3B.json | 11 + {finetunes => defaults}/vace_14B.json | 0 .../vace_14B_fusionix.json | 0 {finetunes => defaults}/vace_14B_sf.json | 0 .../vace_multitalk_14B.json | 0 flux/__init__.py | 13 + flux/__main__.py | 18 + flux/_version.py | 21 + flux/flux_main.py | 112 +++ flux/math.py | 54 ++ flux/model.py | 168 +++++ flux/modules/autoencoder.py | 320 ++++++++ flux/modules/conditioner.py | 38 + flux/modules/image_embedders.py | 99 +++ flux/modules/layers copy.py | 327 ++++++++ flux/modules/layers.py | 328 ++++++++ flux/modules/lora.py | 94 +++ flux/sampling.py | 392 ++++++++++ flux/to_remove/cli.py | 302 ++++++++ flux/to_remove/cli_control.py | 390 ++++++++++ flux/to_remove/cli_fill.py | 334 +++++++++ flux/to_remove/cli_kontext.py | 368 +++++++++ flux/to_remove/cli_redux.py | 290 ++++++++ flux/to_remove/content_filters.py | 171 +++++ flux/util.py | 702 ++++++++++++++++++ ltx_video/ltxv.py | 28 +- wan/any2video.py | 24 +- wan/multitalk/multitalk.py | 1 + wgp.py | 660 ++++++++-------- 61 files changed, 5226 insertions(+), 339 deletions(-) create mode 100644 defaults/ReadMe.txt rename {finetunes => defaults}/fantasy.json (100%) rename {finetunes => defaults}/flf2v_720p.json (67%) create mode 100644 defaults/flux_dev_kontext.json create mode 100644 defaults/fun_inp.json create mode 100644 defaults/fun_inp_1.3B.json create mode 100644 defaults/hunyuan.json create mode 100644 defaults/hunyuan_avatar.json create mode 100644 defaults/hunyuan_custom.json create mode 100644 defaults/hunyuan_custom_audio.json create mode 100644 defaults/hunyuan_custom_edit.json create mode 100644 defaults/hunyuan_i2v.json rename {finetunes => defaults}/hunyuan_t2v_accvideo.json (100%) rename {finetunes => defaults}/hunyuan_t2v_fast.json (100%) create mode 100644 defaults/i2v.json create mode 100644 defaults/i2v_720p.json create mode 100644 defaults/i2v_fusionix.json create mode 100644 defaults/ltxv_13B.json create mode 100644 defaults/ltxv_distilled.json rename {finetunes => defaults}/moviigen.json (100%) rename {finetunes => defaults}/multitalk.json (100%) rename {finetunes => defaults}/multitalk_720p.json (100%) create mode 100644 defaults/phantom_1.3B.json create mode 100644 defaults/phantom_14B.json create mode 100644 defaults/recam_1.3B.json create mode 100644 defaults/sky_df_1.3B.json create mode 100644 defaults/sky_df_14B.json create mode 100644 defaults/sky_df_720p_14B.json create mode 100644 defaults/t2v.json create mode 100644 defaults/t2v_1.3B.json rename {finetunes => defaults}/t2v_fusionix.json (100%) rename {finetunes => defaults}/t2v_sf.json (100%) create mode 100644 defaults/vace_1.3B.json rename {finetunes => defaults}/vace_14B.json (100%) rename {finetunes => defaults}/vace_14B_fusionix.json (100%) rename {finetunes => defaults}/vace_14B_sf.json (100%) rename {finetunes => defaults}/vace_multitalk_14B.json (100%) create mode 100644 flux/__init__.py create mode 100644 flux/__main__.py create mode 100644 flux/_version.py create mode 100644 flux/flux_main.py create mode 100644 flux/math.py create mode 100644 flux/model.py create mode 100644 flux/modules/autoencoder.py create mode 100644 flux/modules/conditioner.py create mode 100644 flux/modules/image_embedders.py create mode 100644 flux/modules/layers copy.py create mode 100644 flux/modules/layers.py create mode 100644 flux/modules/lora.py create mode 100644 flux/sampling.py create mode 100644 flux/to_remove/cli.py create mode 100644 flux/to_remove/cli_control.py create mode 100644 flux/to_remove/cli_fill.py create mode 100644 flux/to_remove/cli_kontext.py create mode 100644 flux/to_remove/cli_redux.py create mode 100644 flux/to_remove/content_filters.py create mode 100644 flux/util.py diff --git a/defaults/ReadMe.txt b/defaults/ReadMe.txt new file mode 100644 index 0000000..c98ee2e --- /dev/null +++ b/defaults/ReadMe.txt @@ -0,0 +1,13 @@ +Please dot not modify any file in this Folder. + +If you want to change a property of a default model, copy the corrresponding model file in the ./finetunes folder and modify the properties you want to change in the new file. +If a property is not in the new file, it will be inherited automatically from the default file that matches the same name file. + +For instance to hide a model: + +{ + "model": + { + "visible": false + } +} diff --git a/finetunes/fantasy.json b/defaults/fantasy.json similarity index 100% rename from finetunes/fantasy.json rename to defaults/fantasy.json diff --git a/finetunes/flf2v_720p.json b/defaults/flf2v_720p.json similarity index 67% rename from finetunes/flf2v_720p.json rename to defaults/flf2v_720p.json index 88b5387..b25c438 100644 --- a/finetunes/flf2v_720p.json +++ b/defaults/flf2v_720p.json @@ -1,14 +1,14 @@ { "model": { - "name": "First Last Frame to Video 720p (FLF2V)14B", + "name": "First Last Frame to Video 720p (FLF2V) 14B", "architecture" : "flf2v_720p", - "visible" : false, + "visible" : true, "description": "The First Last Frame 2 Video model is the official model Image 2 Video model that supports Start and End frames.", "URLs": [ - "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_FLF2V_720p_14B_bf16.safetensors", - "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_FLF2V_720p_14B_quanto_int8.safetensors", - "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_FLF2V_720p_14B_quanto_fp16_int8.safetensors" + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_FLF2V_720p_14B_mbf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_FLF2V_720p_14B_quanto_mbf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_FLF2V_720p_14B_quanto_mfp16_int8.safetensors" ], "auto_quantize": true }, diff --git a/defaults/flux_dev_kontext.json b/defaults/flux_dev_kontext.json new file mode 100644 index 0000000..d8efcd9 --- /dev/null +++ b/defaults/flux_dev_kontext.json @@ -0,0 +1,19 @@ +{ + "model": { + "name": "Flux Dev Kontext 12B", + "architecture": "flux_dev_kontext", + "description": "FLUX.1 Kontext is a 12 billion parameter rectified flow transformer capable of editing images based on instructions stored in the Prompt. Please be aware that the output resolution is modified by Flux Kontext and may not be what you requested.", + "URLs": [ + "c:/temp/kontext/flux1_kontext_dev_bf16.safetensors", + "c:/temp/kontext/flux1_kontext_dev_quanto_bf16_int8.safetensors" + ], + "URLs2": [ + "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_bf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_quanto_bf16_int8.safetensors" + ] + }, + "resolution": "1280x720", + "video_length": "1" +} + + \ No newline at end of file diff --git a/defaults/fun_inp.json b/defaults/fun_inp.json new file mode 100644 index 0000000..65330cd --- /dev/null +++ b/defaults/fun_inp.json @@ -0,0 +1,13 @@ +{ + "model": + { + "name": "Fun InP image2video 14B", + "architecture" : "fun_inp", + "description": "The Fun model is an alternative image 2 video that supports out the box End Image fixing (contrary to the original Wan image 2 video model).", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_Fun_InP_14B_bf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_Fun_InP_14B_quanto_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_Fun_InP_14B_quanto_fp16_int8.safetensors" + ] + } +} diff --git a/defaults/fun_inp_1.3B.json b/defaults/fun_inp_1.3B.json new file mode 100644 index 0000000..9d60e63 --- /dev/null +++ b/defaults/fun_inp_1.3B.json @@ -0,0 +1,11 @@ +{ + "model": + { + "name": "Fun InP image2video 1.3B", + "architecture" : "fun_inp_1.3B", + "description": "The Fun model is an alternative image 2 video that supports out the box End Image fixing (contrary to the original Wan image 2 video model). The 1.3B adds also image 2 to video capability to the 1.3B model.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_Fun_InP_1.3B_bf16.safetensors" + ] + } +} \ No newline at end of file diff --git a/defaults/hunyuan.json b/defaults/hunyuan.json new file mode 100644 index 0000000..5012c02 --- /dev/null +++ b/defaults/hunyuan.json @@ -0,0 +1,12 @@ +{ + "model": + { + "name": "Hunyuan Video text2video 720p 13B", + "architecture" : "hunyuan", + "description": "Probably the best text 2 video model available.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_720_bf16.safetensors.safetensors", + "https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_720_quanto_int8.safetensors" + ] + } +} \ No newline at end of file diff --git a/defaults/hunyuan_avatar.json b/defaults/hunyuan_avatar.json new file mode 100644 index 0000000..d01c318 --- /dev/null +++ b/defaults/hunyuan_avatar.json @@ -0,0 +1,12 @@ +{ + "model": + { + "name": "Hunyuan Video Avatar 720p 13B", + "architecture" : "hunyuan_avatar", + "description": "With the Hunyuan Video Avatar model you can animate a person based on the content of an audio input. Please note that the video generator works by processing 128 frames segment at a time (even if you ask less). The good news is that it will concatenate multiple segments for long video generation (max 3 segments recommended as the quality will get worse).", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_avatar_720_bf16.safetensors", + "https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_avatar_720_quanto_bf16_int8.safetensors" + ] + } +} \ No newline at end of file diff --git a/defaults/hunyuan_custom.json b/defaults/hunyuan_custom.json new file mode 100644 index 0000000..d6217e9 --- /dev/null +++ b/defaults/hunyuan_custom.json @@ -0,0 +1,12 @@ +{ + "model": + { + "name": "Hunyuan Video Custom 720p 13B", + "architecture" : "hunyuan_custom", + "description": "The Hunyuan Video Custom model is probably the best model to transfer people (only people for the moment) as it is quite good to keep their identity. However it is slow as to get good results, you need to generate 720p videos with 30 steps.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_custom_720_bf16.safetensors", + "https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_custom_720_quanto_bf16_int8.safetensors" + ] + } +} \ No newline at end of file diff --git a/defaults/hunyuan_custom_audio.json b/defaults/hunyuan_custom_audio.json new file mode 100644 index 0000000..f5c4d52 --- /dev/null +++ b/defaults/hunyuan_custom_audio.json @@ -0,0 +1,12 @@ +{ + "model": + { + "name": "Hunyuan Video Custom Audio 720p 13B", + "architecture" : "hunyuan_custom_audio", + "description": "The Hunyuan Video Custom Audio model can be used to generate scenes of a person speaking given a Reference Image and a Recorded Voice or Song. The reference image is not a start image and therefore one can represent the person in a different context.The video length can be anything up to 10s. It is also quite good to generate no sound Video based on a person.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_custom_audio_720_bf16.safetensors", + "https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_custom_audio_720_quanto_bf16_int8.safetensors" + ] + } +} \ No newline at end of file diff --git a/defaults/hunyuan_custom_edit.json b/defaults/hunyuan_custom_edit.json new file mode 100644 index 0000000..9cf037e --- /dev/null +++ b/defaults/hunyuan_custom_edit.json @@ -0,0 +1,12 @@ +{ + "model": + { + "name": "Hunyuan Video Custom Edit 720p 13B", + "architecture" : "hunyuan_custom_edit", + "description": "The Hunyuan Video Custom Edit model can be used to do Video inpainting on a person (add accessories or completely replace the person). You will need in any case to define a Video Mask which will indicate which area of the Video should be edited.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_custom_edit_720_bf16.safetensors", + "https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_custom_edit_720_quanto_bf16_int8.safetensors" + ] + } +} \ No newline at end of file diff --git a/defaults/hunyuan_i2v.json b/defaults/hunyuan_i2v.json new file mode 100644 index 0000000..400a6a3 --- /dev/null +++ b/defaults/hunyuan_i2v.json @@ -0,0 +1,12 @@ +{ + "model": + { + "name": "Hunyuan Video image2video 720p 13B", + "architecture" : "hunyuan_i2v", + "description": "A good looking image 2 video model, but not so good in prompt adherence.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_i2v_720_bf16v2.safetensors", + "https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_i2v_720_quanto_int8v2.safetensors" + ] + } +} \ No newline at end of file diff --git a/finetunes/hunyuan_t2v_accvideo.json b/defaults/hunyuan_t2v_accvideo.json similarity index 100% rename from finetunes/hunyuan_t2v_accvideo.json rename to defaults/hunyuan_t2v_accvideo.json diff --git a/finetunes/hunyuan_t2v_fast.json b/defaults/hunyuan_t2v_fast.json similarity index 100% rename from finetunes/hunyuan_t2v_fast.json rename to defaults/hunyuan_t2v_fast.json diff --git a/defaults/i2v.json b/defaults/i2v.json new file mode 100644 index 0000000..33a4d55 --- /dev/null +++ b/defaults/i2v.json @@ -0,0 +1,13 @@ +{ + "model": + { + "name": "Wan2.1 image2video 480p 14B", + "architecture" : "i2v", + "description": "The standard Wan Image 2 Video specialized to generate 480p images. It also offers Start and End Image support (End Image is not supported in the original model but seems to work well)", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_image2video_480p_14B_mbf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_image2video_480p_14B_quanto_mbf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_image2video_480p_14B_quanto_mfp16_int8.safetensors" + ] + } +} \ No newline at end of file diff --git a/defaults/i2v_720p.json b/defaults/i2v_720p.json new file mode 100644 index 0000000..90523de --- /dev/null +++ b/defaults/i2v_720p.json @@ -0,0 +1,14 @@ +{ + "model": + { + "name": "Wan2.1 image2video 720p 14B", + "architecture" : "i2v", + "description": "The standard Wan Image 2 Video specialized to generate 720p images. It also offers Start and End Image support (End Image is not supported in the original model but seems to work well).", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_image2video_720p_14B_mbf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_image2video_720p_14B_quanto_mfp16_int8.safetensors" + ] + }, + "resolution": "1280x720" +} \ No newline at end of file diff --git a/defaults/i2v_fusionix.json b/defaults/i2v_fusionix.json new file mode 100644 index 0000000..ffbb0a1 --- /dev/null +++ b/defaults/i2v_fusionix.json @@ -0,0 +1,10 @@ +{ + "model": + { + "name": "Wan2.1 image2video 480p FusioniX 14B", + "architecture" : "i2v", + "description": "A powerful merged image-to-video model based on the original WAN 2.1 I2V model, enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail.", + "URLs": "i2v", + "loras": ["https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan2.1_I2V_14B_FusionX_LoRA.safetensors"] + } +} \ No newline at end of file diff --git a/defaults/ltxv_13B.json b/defaults/ltxv_13B.json new file mode 100644 index 0000000..7e45e9a --- /dev/null +++ b/defaults/ltxv_13B.json @@ -0,0 +1,14 @@ +{ + "model": + { + "name": "LTX Video 0.9.7 13B", + "architecture" : "ltxv_13B", + "description": "LTX Video is a fast model that can be used to generate long videos (up to 260 frames).It is recommended to keep the number of steps to 30 or you will need to update the file 'ltxv_video/configs/ltxv-13b-0.9.7-dev.yaml'.The LTX Video model expects very long prompts, so don't hesitate to use the Prompt Enhancer.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv_0.9.7_13B_dev_bf16.safetensors", + "https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv_0.9.7_13B_dev_quanto_bf16_int8.safetensors" + ], + "LTXV_config": "ltx_video/configs/ltxv-13b-0.9.7-dev.yaml" + }, + "num_inference_steps": 30 +} diff --git a/defaults/ltxv_distilled.json b/defaults/ltxv_distilled.json new file mode 100644 index 0000000..256ea81 --- /dev/null +++ b/defaults/ltxv_distilled.json @@ -0,0 +1,14 @@ +{ + "model": + { + "name": "LTX Video 0.9.7 Distilled 13B", + "architecture" : "ltxv_13B", + "description": "LTX Video is a fast model that can be used to generate long videos (up to 260 frames).This distilled version is a very fast version and retains a high level of quality. The LTX Video model expects very long prompts, so don't hesitate to use the Prompt Enhancer.", + "URLs": "ltxv_13B", + "loras": ["https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv_0.9.7_13B_distilled_lora128_bf16.safetensors"], + "loras_multipliers": [ 1 ], + "lock_inference_steps": true, + "LTXV_config": "ltx_video/configs/ltxv-13b-0.9.7-distilled.yaml" + }, + "num_inference_steps": 6 +} diff --git a/finetunes/moviigen.json b/defaults/moviigen.json similarity index 100% rename from finetunes/moviigen.json rename to defaults/moviigen.json diff --git a/finetunes/multitalk.json b/defaults/multitalk.json similarity index 100% rename from finetunes/multitalk.json rename to defaults/multitalk.json diff --git a/finetunes/multitalk_720p.json b/defaults/multitalk_720p.json similarity index 100% rename from finetunes/multitalk_720p.json rename to defaults/multitalk_720p.json diff --git a/defaults/phantom_1.3B.json b/defaults/phantom_1.3B.json new file mode 100644 index 0000000..5be31da --- /dev/null +++ b/defaults/phantom_1.3B.json @@ -0,0 +1,11 @@ +{ + "model": + { + "name": "Phantom 1.3B", + "architecture" : "phantom_1.3B", + "description": "The Phantom model is specialized in transferring people or objects of your choice into a generated Video. It produces very nice results when used at 720p.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2_1_phantom_1.3B_mbf16.safetensors" + ] + } +} \ No newline at end of file diff --git a/defaults/phantom_14B.json b/defaults/phantom_14B.json new file mode 100644 index 0000000..e6ec614 --- /dev/null +++ b/defaults/phantom_14B.json @@ -0,0 +1,13 @@ +{ + "model": + { + "name": "Phantom 14B", + "architecture" : "phantom_14B", + "description": "The Phantom model is specialized in transferring people or objects of your choice into a generated Video. It produces very nice results when used at 720p.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_phantom_14B_mbf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_phantom_14B_quanto_mbf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_phantom_14B_quanto_mfp16_int8.safetensors" + ] + } +} \ No newline at end of file diff --git a/defaults/recam_1.3B.json b/defaults/recam_1.3B.json new file mode 100644 index 0000000..e65d1b2 --- /dev/null +++ b/defaults/recam_1.3B.json @@ -0,0 +1,11 @@ +{ + "model": + { + "name": "ReCamMaster 1.3B", + "architecture" : "recam_1.3B", + "description": "The Recam Master in theory should allow you to replay a video by applying a different camera movement. The model supports only video that are at least 81 frames long (any frame beyond will be ignored)", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_recammaster_1.3B_bf16.safetensors" + ] + } +} \ No newline at end of file diff --git a/defaults/sky_df_1.3B.json b/defaults/sky_df_1.3B.json new file mode 100644 index 0000000..61e118d --- /dev/null +++ b/defaults/sky_df_1.3B.json @@ -0,0 +1,11 @@ +{ + "model": + { + "name": "SkyReels2 Diffusion Forcing 1.3B", + "architecture" : "sky_df_1.3B", + "description": "The SkyReels 2 Diffusion Forcing model has been designed to generate very long videos that exceeds the usual 5s limit. You can also use this model to extend any existing video.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors" + ] + } +} \ No newline at end of file diff --git a/defaults/sky_df_14B.json b/defaults/sky_df_14B.json new file mode 100644 index 0000000..e9d7bd5 --- /dev/null +++ b/defaults/sky_df_14B.json @@ -0,0 +1,13 @@ +{ + "model": + { + "name": "SkyReels2 Diffusion Forcing 540p 14B", + "architecture" : "sky_df_14B", + "description": "The SkyReels 2 Diffusion Forcing model has been designed to generate very long videos that exceeds the usual 5s limit. You can also use this model to extend any existing video.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/sky_reels2_diffusion_forcing_14B_bf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/sky_reels2_diffusion_forcing_14B_quanto_fp16_int8.safetensors" + ] + } +} \ No newline at end of file diff --git a/defaults/sky_df_720p_14B.json b/defaults/sky_df_720p_14B.json new file mode 100644 index 0000000..6bae666 --- /dev/null +++ b/defaults/sky_df_720p_14B.json @@ -0,0 +1,14 @@ +{ + "model": + { + "name": "SkyReels2 Diffusion Forcing 720p 14B", + "architecture" : "sky_df_14B", + "description": "The SkyReels 2 Diffusion Forcing model has been designed to generate very long videos that exceeds the usual 5s limit. You can also use this model to extend any existing video.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/sky_reels2_diffusion_forcing_720p_14B_mbf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/sky_reels2_diffusion_forcing_720p_14B_quanto_mbf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/sky_reels2_diffusion_forcing_720p_14B_quanto_mfp16_int8.safetensors" + ] + }, + "resolution": "1280x720" +} \ No newline at end of file diff --git a/defaults/t2v.json b/defaults/t2v.json new file mode 100644 index 0000000..2ab946a --- /dev/null +++ b/defaults/t2v.json @@ -0,0 +1,13 @@ +{ + "model": + { + "name": "Wan2.1 text2video 14B", + "architecture" : "t2v", + "description": "The original Wan Text 2 Video model. Most other models have been built on top of it", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_text2video_14B_mbf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_text2video_14B_quanto_mbf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_text2video_14B_quanto_mfp16_int8.safetensors" + ] + } +} \ No newline at end of file diff --git a/defaults/t2v_1.3B.json b/defaults/t2v_1.3B.json new file mode 100644 index 0000000..859304f --- /dev/null +++ b/defaults/t2v_1.3B.json @@ -0,0 +1,11 @@ +{ + "model": + { + "name": "Wan2.1 text2video 1.3B", + "architecture" : "t2v_1.3B", + "description": "The light version of the original Wan Text 2 Video model. Most other models have been built on top of it", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_text2video_1.3B_bf16.safetensors" + ] + } +} \ No newline at end of file diff --git a/finetunes/t2v_fusionix.json b/defaults/t2v_fusionix.json similarity index 100% rename from finetunes/t2v_fusionix.json rename to defaults/t2v_fusionix.json diff --git a/finetunes/t2v_sf.json b/defaults/t2v_sf.json similarity index 100% rename from finetunes/t2v_sf.json rename to defaults/t2v_sf.json diff --git a/defaults/vace_1.3B.json b/defaults/vace_1.3B.json new file mode 100644 index 0000000..716fbd0 --- /dev/null +++ b/defaults/vace_1.3B.json @@ -0,0 +1,11 @@ +{ + "model": + { + "name": "Vace ControlNet 1.3B", + "architecture" : "vace_1.3B", + "description": "The Vace ControlNet model is a powerful model that allows you to control the content of the generated video based of additional custom data : pose or depth video, images or objects you want to see in the video.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_Vace_1.3B_mbf16.safetensors" + ] + } +} \ No newline at end of file diff --git a/finetunes/vace_14B.json b/defaults/vace_14B.json similarity index 100% rename from finetunes/vace_14B.json rename to defaults/vace_14B.json diff --git a/finetunes/vace_14B_fusionix.json b/defaults/vace_14B_fusionix.json similarity index 100% rename from finetunes/vace_14B_fusionix.json rename to defaults/vace_14B_fusionix.json diff --git a/finetunes/vace_14B_sf.json b/defaults/vace_14B_sf.json similarity index 100% rename from finetunes/vace_14B_sf.json rename to defaults/vace_14B_sf.json diff --git a/finetunes/vace_multitalk_14B.json b/defaults/vace_multitalk_14B.json similarity index 100% rename from finetunes/vace_multitalk_14B.json rename to defaults/vace_multitalk_14B.json diff --git a/flux/__init__.py b/flux/__init__.py new file mode 100644 index 0000000..dddc6a3 --- /dev/null +++ b/flux/__init__.py @@ -0,0 +1,13 @@ +try: + from ._version import ( + version as __version__, # type: ignore + version_tuple, + ) +except ImportError: + __version__ = "unknown (no version information available)" + version_tuple = (0, 0, "unknown", "noinfo") + +from pathlib import Path + +PACKAGE = __package__.replace("_", "-") +PACKAGE_ROOT = Path(__file__).parent diff --git a/flux/__main__.py b/flux/__main__.py new file mode 100644 index 0000000..d365c0f --- /dev/null +++ b/flux/__main__.py @@ -0,0 +1,18 @@ +from fire import Fire + +from .cli import main as cli_main +from .cli_control import main as control_main +from .cli_fill import main as fill_main +from .cli_kontext import main as kontext_main +from .cli_redux import main as redux_main + +if __name__ == "__main__": + Fire( + { + "t2i": cli_main, + "control": control_main, + "fill": fill_main, + "kontext": kontext_main, + "redux": redux_main, + } + ) diff --git a/flux/_version.py b/flux/_version.py new file mode 100644 index 0000000..fdf5bff --- /dev/null +++ b/flux/_version.py @@ -0,0 +1,21 @@ +# file generated by setuptools-scm +# don't change, don't track in version control + +__all__ = ["__version__", "__version_tuple__", "version", "version_tuple"] + +TYPE_CHECKING = False +if TYPE_CHECKING: + from typing import Tuple + from typing import Union + + VERSION_TUPLE = Tuple[Union[int, str], ...] +else: + VERSION_TUPLE = object + +version: str +__version__: str +__version_tuple__: VERSION_TUPLE +version_tuple: VERSION_TUPLE + +__version__ = version = '0.0.post58+g1371b2b' +__version_tuple__ = version_tuple = (0, 0, 'post58', 'g1371b2b') diff --git a/flux/flux_main.py b/flux/flux_main.py new file mode 100644 index 0000000..f4b1994 --- /dev/null +++ b/flux/flux_main.py @@ -0,0 +1,112 @@ +import os +import re +import time +from dataclasses import dataclass +from glob import iglob +from mmgp import offload as offload +import torch +from wan.utils.utils import calculate_new_dimensions +from flux.sampling import denoise, get_schedule, prepare_kontext, unpack +from flux.modules.layers import get_linear_split_map +from flux.util import ( + aspect_ratio_to_height_width, + load_ae, + load_clip, + load_flow_model, + load_t5, + save_image, +) + +class model_factory: + def __init__( + self, + checkpoint_dir, + model_filename = None, + model_type = None, + base_model_type = None, + text_encoder_filename = None, + quantizeTransformer = False, + save_quantized = False, + dtype = torch.bfloat16, + VAE_dtype = torch.float32, + mixed_precision_transformer = False + ): + self.device = torch.device(f"cuda") + self.VAE_dtype = VAE_dtype + self.dtype = dtype + torch_device = "cpu" + + self.t5 = load_t5(torch_device, text_encoder_filename, max_length=512) + self.clip = load_clip(torch_device) + self.name= "flux-dev-kontext" + self.model = load_flow_model(self.name, model_filename[0], torch_device) + + self.vae = load_ae(self.name, device=torch_device) + + # offload.change_dtype(self.model, dtype, True) + if save_quantized: + from wgp import save_quantized_model + save_quantized_model(self.model, model_type, model_filename[0], dtype, None) + + split_linear_modules_map = get_linear_split_map() + self.model.split_linear_modules_map = split_linear_modules_map + offload.split_linear_modules(self.model, split_linear_modules_map ) + + + def generate( + self, + seed: int | None = None, + input_prompt: str = "replace the logo with the text 'Black Forest Labs'", + sampling_steps: int = 20, + input_ref_images = None, + width= 832, + height=480, + guide_scale: float = 2.5, + fit_into_canvas = None, + callback = None, + loras_slists = None, + frame_num = 1, + **bbargs + ): + + if self._interrupt: + return None + + rng = torch.Generator(device="cuda") + if seed is None: + seed = rng.seed() + + if input_ref_images != None and len(input_ref_images) > 0: + image_ref = input_ref_images[0] + w, h = image_ref.size + height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas) + + inp, height, width = prepare_kontext( + t5=self.t5, + clip=self.clip, + prompt=input_prompt, + ae=self.vae, + img_cond=image_ref, + target_width=width, + target_height=height, + bs=frame_num, + seed=seed, + device="cuda", + ) + + inp.pop("img_cond_orig") + timesteps = get_schedule(sampling_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell")) + def unpack_latent(x): + return unpack(x.float(), height, width) + # denoise initial noise + x = denoise(self.model, **inp, timesteps=timesteps, guidance=guide_scale, callback=callback, pipeline=self, loras_slists= loras_slists, unpack_latent = unpack_latent) + if x==None: return None + # decode latents to pixel space + x = unpack_latent(x) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + x = self.vae.decode(x) + + x = x.clamp(-1, 1) + x = x.transpose(0, 1) + return x + diff --git a/flux/math.py b/flux/math.py new file mode 100644 index 0000000..9e8aa59 --- /dev/null +++ b/flux/math.py @@ -0,0 +1,54 @@ +import torch +from einops import rearrange +from torch import Tensor +from wan.modules.attention import pay_attention + + +def attention(qkv_list, pe: Tensor) -> Tensor: + q, k, v = qkv_list + qkv_list.clear() + q_list = [q] + q = None + q = apply_rope_(q_list, pe) + k_list = [k] + k = None + k = apply_rope_(k_list, pe) + qkv_list = [q.transpose(1,2), k.transpose(1,2) ,v.transpose(1,2)] + del q,k, v + x = pay_attention(qkv_list).transpose(1,2) + # x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "B H L D -> B L (H D)") + + return x + + +def rope(pos: Tensor, dim: int, theta: int) -> Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = torch.einsum("...n,d->...nd", pos, omega) + out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) + out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + return out.float() + + +def apply_rope_(q_list, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: + xq= q_list[0] + xqshape = xq.shape + xqdtype= xq.dtype + q_list.clear() + xq = xq.float().reshape(*xqshape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq[..., 0] + xq = freqs_cis[..., 1] * xq[..., 1] + + xq_out.add_(xq) + # xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + + return xq_out.reshape(*xqshape).to(xqdtype) + +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) diff --git a/flux/model.py b/flux/model.py new file mode 100644 index 0000000..1802ae6 --- /dev/null +++ b/flux/model.py @@ -0,0 +1,168 @@ +from dataclasses import dataclass + +import torch +from torch import Tensor, nn + +from flux.modules.layers import ( + DoubleStreamBlock, + EmbedND, + LastLayer, + MLPEmbedder, + SingleStreamBlock, + timestep_embedding, +) +from flux.modules.lora import LinearLora, replace_linear_with_lora + + +@dataclass +class FluxParams: + in_channels: int + out_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + + +class Flux(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, params: FluxParams): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = params.out_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = ( + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + ) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) + for _ in range(params.depth_single_blocks) + ] + ) + + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + + def preprocess_loras(self, model_type, sd): + new_sd = {} + if len(sd) == 0: return sd + + first_key= next(iter(sd)) + if first_key.startswith("transformer."): + src_list = [".attn.to_q.", ".attn.to_k.", ".attn.to_v."] + tgt_list = [".linear1_attn_q.", ".linear1_attn_k.", ".linear1_attn_v."] + for k,v in sd.items(): + k = k.replace("transformer.single_transformer_blocks", "diffusion_model.single_blocks") + k = k.replace("transformer.double_transformer_blocks", "diffusion_model.double_blocks") + for src, tgt in zip(src_list, tgt_list): + k = k.replace(src, tgt) + + new_sd[k] = v + + return new_sd + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor | None = None, + callback= None, + pipeline =None, + + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec += self.guidance_in(timestep_embedding(guidance, 256)) + vec += self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + for block in self.double_blocks: + if callback != None: + callback(-1, None, False, True) + if pipeline._interrupt: + return None + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe) + img = img[:, txt.shape[1] :, ...] + + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + return img + + +class FluxLoraWrapper(Flux): + def __init__( + self, + lora_rank: int = 128, + lora_scale: float = 1.0, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + + self.lora_rank = lora_rank + + replace_linear_with_lora( + self, + max_rank=lora_rank, + scale=lora_scale, + ) + + def set_lora_scale(self, scale: float) -> None: + for module in self.modules(): + if isinstance(module, LinearLora): + module.set_scale(scale=scale) diff --git a/flux/modules/autoencoder.py b/flux/modules/autoencoder.py new file mode 100644 index 0000000..f31b731 --- /dev/null +++ b/flux/modules/autoencoder.py @@ -0,0 +1,320 @@ +from dataclasses import dataclass + +import torch +from einops import rearrange +from torch import Tensor, nn + + +@dataclass +class AutoEncoderParams: + resolution: int + in_channels: int + ch: int + out_ch: int + ch_mult: list[int] + num_res_blocks: int + z_channels: int + scale_factor: float + shift_factor: float + + +def swish(x: Tensor) -> Tensor: + return x * torch.sigmoid(x) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) + + def attention(self, h_: Tensor) -> Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() + k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() + v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() + h_ = nn.functional.scaled_dot_product_attention(q, k, v) + + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x: Tensor) -> Tensor: + return x + self.proj_out(self.attention(x)) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h = x + h = self.norm1(h) + h = swish(h) + h = self.conv1(h) + + h = self.norm2(h) + h = swish(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: Tensor): + pad = (0, 1, 0, 1) + x = nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor): + x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + ch: int, + ch_mult: list[int], + num_res_blocks: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + # downsampling + self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + block_in = self.ch + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor) -> Tensor: + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + in_channels: int, + resolution: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.ffactor = 2 ** (self.num_resolutions - 1) + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z: Tensor) -> Tensor: + # get dtype for proper tracing + upscale_dtype = next(self.up.parameters()).dtype + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # cast to proper dtype + h = h.to(upscale_dtype) + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class DiagonalGaussian(nn.Module): + def __init__(self, sample: bool = True, chunk_dim: int = 1): + super().__init__() + self.sample = sample + self.chunk_dim = chunk_dim + + def forward(self, z: Tensor) -> Tensor: + mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) + if self.sample: + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + else: + return mean + + +class AutoEncoder(nn.Module): + def __init__(self, params: AutoEncoderParams, sample_z: bool = False): + super().__init__() + self.params = params + self.encoder = Encoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.decoder = Decoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + out_ch=params.out_ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.reg = DiagonalGaussian(sample=sample_z) + + self.scale_factor = params.scale_factor + self.shift_factor = params.shift_factor + + def get_VAE_tile_size(*args, **kwargs): + return [] + def encode(self, x: Tensor) -> Tensor: + z = self.reg(self.encoder(x)) + z = self.scale_factor * (z - self.shift_factor) + return z + + def decode(self, z: Tensor) -> Tensor: + z = z / self.scale_factor + self.shift_factor + return self.decoder(z) + + def forward(self, x: Tensor) -> Tensor: + return self.decode(self.encode(x)) diff --git a/flux/modules/conditioner.py b/flux/modules/conditioner.py new file mode 100644 index 0000000..29e3b67 --- /dev/null +++ b/flux/modules/conditioner.py @@ -0,0 +1,38 @@ +from torch import Tensor, nn +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer +import os + +class HFEmbedder(nn.Module): + def __init__(self, version: str, text_encoder_filename, max_length: int, is_clip = False, **hf_kwargs): + super().__init__() + self.is_clip = is_clip + self.max_length = max_length + self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" + + if is_clip: + self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length) + self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs) + else: + from mmgp import offload as offloadobj + self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(os.path.dirname(text_encoder_filename), max_length=max_length) + self.hf_module: T5EncoderModel = offloadobj.fast_load_transformers_model(text_encoder_filename) + + self.hf_module = self.hf_module.eval().requires_grad_(False) + + def forward(self, text: list[str]) -> Tensor: + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=False, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + + outputs = self.hf_module( + input_ids=batch_encoding["input_ids"].to(self.hf_module.device), + attention_mask=None, + output_hidden_states=False, + ) + return outputs[self.output_key].bfloat16() diff --git a/flux/modules/image_embedders.py b/flux/modules/image_embedders.py new file mode 100644 index 0000000..aa26d9b --- /dev/null +++ b/flux/modules/image_embedders.py @@ -0,0 +1,99 @@ +import cv2 +import numpy as np +import torch +from einops import rearrange, repeat +from PIL import Image +from safetensors.torch import load_file as load_sft +from torch import nn +from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel + +from flux.util import print_load_warning + + +class DepthImageEncoder: + depth_model_name = "LiheYoung/depth-anything-large-hf" + + def __init__(self, device): + self.device = device + self.depth_model = AutoModelForDepthEstimation.from_pretrained(self.depth_model_name).to(device) + self.processor = AutoProcessor.from_pretrained(self.depth_model_name) + + def __call__(self, img: torch.Tensor) -> torch.Tensor: + hw = img.shape[-2:] + + img = torch.clamp(img, -1.0, 1.0) + img_byte = ((img + 1.0) * 127.5).byte() + + img = self.processor(img_byte, return_tensors="pt")["pixel_values"] + depth = self.depth_model(img.to(self.device)).predicted_depth + depth = repeat(depth, "b h w -> b 3 h w") + depth = torch.nn.functional.interpolate(depth, hw, mode="bicubic", antialias=True) + + depth = depth / 127.5 - 1.0 + return depth + + +class CannyImageEncoder: + def __init__( + self, + device, + min_t: int = 50, + max_t: int = 200, + ): + self.device = device + self.min_t = min_t + self.max_t = max_t + + def __call__(self, img: torch.Tensor) -> torch.Tensor: + assert img.shape[0] == 1, "Only batch size 1 is supported" + + img = rearrange(img[0], "c h w -> h w c") + img = torch.clamp(img, -1.0, 1.0) + img_np = ((img + 1.0) * 127.5).numpy().astype(np.uint8) + + # Apply Canny edge detection + canny = cv2.Canny(img_np, self.min_t, self.max_t) + + # Convert back to torch tensor and reshape + canny = torch.from_numpy(canny).float() / 127.5 - 1.0 + canny = rearrange(canny, "h w -> 1 1 h w") + canny = repeat(canny, "b 1 ... -> b 3 ...") + return canny.to(self.device) + + +class ReduxImageEncoder(nn.Module): + siglip_model_name = "google/siglip-so400m-patch14-384" + + def __init__( + self, + device, + redux_path: str, + redux_dim: int = 1152, + txt_in_features: int = 4096, + dtype=torch.bfloat16, + ) -> None: + super().__init__() + + self.redux_dim = redux_dim + self.device = device if isinstance(device, torch.device) else torch.device(device) + self.dtype = dtype + + with self.device: + self.redux_up = nn.Linear(redux_dim, txt_in_features * 3, dtype=dtype) + self.redux_down = nn.Linear(txt_in_features * 3, txt_in_features, dtype=dtype) + + sd = load_sft(redux_path, device=str(device)) + missing, unexpected = self.load_state_dict(sd, strict=False, assign=True) + print_load_warning(missing, unexpected) + + self.siglip = SiglipVisionModel.from_pretrained(self.siglip_model_name).to(dtype=dtype) + self.normalize = SiglipImageProcessor.from_pretrained(self.siglip_model_name) + + def __call__(self, x: Image.Image) -> torch.Tensor: + imgs = self.normalize.preprocess(images=[x], do_resize=True, return_tensors="pt", do_convert_rgb=True) + + _encoded_x = self.siglip(**imgs.to(device=self.device, dtype=self.dtype)).last_hidden_state + + projected_x = self.redux_down(nn.functional.silu(self.redux_up(_encoded_x))) + + return projected_x diff --git a/flux/modules/layers copy.py b/flux/modules/layers copy.py new file mode 100644 index 0000000..e032ea3 --- /dev/null +++ b/flux/modules/layers copy.py @@ -0,0 +1,327 @@ +import math +from dataclasses import dataclass + +import torch +from einops import rearrange +from torch import Tensor, nn + +from flux.math import attention, rope + +def get_linear_split_map(): + hidden_size = 3072 + _modules_map = { + "qkv" : {"mapped_modules" : ["q", "k", "v"] , "split_sizes": [hidden_size, hidden_size, hidden_size]}, + "linear1" : {"mapped_modules" : ["linear1_attn_q", "linear1_attn_k", "linear1_attn_v", "linear1_mlp"] , "split_sizes": [hidden_size, hidden_size, hidden_size, 7*hidden_size- 3*hidden_size]} + } + return split_linear_modules_map + + +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: Tensor) -> Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + + return emb.unsqueeze(1) + + +def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + t = time_factor * t + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + t.device + ) + + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(t) + return embedding + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + def forward(self, x: Tensor) -> Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor): + x_dtype = x.dtype + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + return (x * rrms).to(dtype=x_dtype) * self.scale + + +class QKNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.query_norm = RMSNorm(dim) + self.key_norm = RMSNorm(dim) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: + q = self.query_norm(q) + k = self.key_norm(k) + return q.to(v), k.to(v) + + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.norm = QKNorm(head_dim) + self.proj = nn.Linear(dim, dim) + + def forward(self, x: Tensor, pe: Tensor) -> Tensor: + qkv = self.qkv(x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + x = attention(q, k, v, pe=pe) + x = self.proj(x) + return x + + +@dataclass +class ModulationOut: + shift: Tensor + scale: Tensor + gate: Tensor + + +class Modulation(nn.Module): + def __init__(self, dim: int, double: bool): + super().__init__() + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) + + def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: + out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) + + return ( + ModulationOut(*out[:3]), + ModulationOut(*out[3:]) if self.is_double else None, + ) + + +class DoubleStreamBlock(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): + super().__init__() + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.img_mod = Modulation(hidden_size, double=True) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.txt_mod = Modulation(hidden_size, double=True) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: + img_mod1, img_mod2 = self.img_mod(vec) + txt_mod1, txt_mod2 = self.txt_mod(vec) + + # prepare image for attention + img_modulated = self.img_norm1(img) + img_modulated.mul_(1 + img_mod1.scale) + img_modulated.add_(img_mod1.shift) + + shape = (*img_modulated.shape[:2], self.num_heads, int(img_modulated.shape[-1] / self.num_heads) ) + img_q = self.img_attn.q(img_modulated).view(*shape).transpose(1,2) + img_k = self.img_attn.k(img_modulated).view(*shape).transpose(1,2) + img_v = self.img_attn.v(img_modulated).view(*shape).transpose(1,2) + del img_modulated + + # img_qkv = self.img_attn.qkv(img_modulated) + # img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + + img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = self.txt_norm1(txt) + txt_modulated.mul_(1 + txt_mod1.scale) + txt_modulated.add_(txt_mod1.shift) + # txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + + shape = (*txt_modulated.shape[:2], self.num_heads, int(txt_modulated.shape[-1] / self.num_heads) ) + txt_q = self.txt_attn.q(txt_modulated).view(*shape).transpose(1,2) + txt_k = self.txt_attn.k(txt_modulated).view(*shape).transpose(1,2) + txt_v = self.txt_attn.v(txt_modulated).view(*shape).transpose(1,2) + del txt_modulated + + # txt_qkv = self.txt_attn.qkv(txt_modulated) + # txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + + txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + qkv_list = [q, k, v] + del q, k, v + attn = attention(qkv_list, pe=pe) + + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] + + # calculate the img blocks + img.addcmul_(self.img_attn.proj(img_attn), img_mod1.gate) + img.addcmul_(self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift), img_mod2.gate) + + # img = img + img_mod1.gate * self.img_attn.proj(img_attn) + # img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) + + # calculate the txt blocks + txt.addcmul_(self.txt_attn.proj(txt_attn), txt_mod1.gate) + txt.addcmul_(self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift), txt_mod2.gate) + # txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) + # txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) + return img, txt + + +class SingleStreamBlock(nn.Module): + """ + A DiT block with parallel linear layers as described in + https://arxiv.org/abs/2302.05442 and adapted modulation interface. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float | None = None, + ): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + head_dim = hidden_size // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + # proj and mlp_out + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + + self.norm = QKNorm(head_dim) + + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp_act = nn.GELU(approximate="tanh") + self.modulation = Modulation(hidden_size, double=False) + + def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: + mod, _ = self.modulation(vec) + x_mod = self.pre_norm(x) + x_mod.mul_(1 + mod.scale) + x_mod.add_(mod.shift) + + ##### More spagheti VRAM optimizations done by DeepBeepMeep ! + # I am sure you are a nice person and as you copy this code, you will give me proper credits: + # Please link to https://github.com/deepbeepmeep/Wan2GP and @deepbeepmeep on twitter + + # x_mod = (1 + mod.scale) * x + mod.shift + + shape = (*x_mod.shape[:2], self.num_heads, int(x_mod.shape[-1] / self.num_heads) ) + q = self.linear1_attn_q(x_mod).view(*shape).transpose(1,2) + k = self.linear1_attn_k(x_mod).view(*shape).transpose(1,2) + v = self.linear1_attn_v(x_mod).view(*shape).transpose(1,2) + + # shape = (*txt_mod.shape[:2], self.heads_num, int(txt_mod.shape[-1] / self.heads_num) ) + # txt_q = self.linear1_attn_q(txt_mod).view(*shape) + # txt_k = self.linear1_attn_k(txt_mod).view(*shape) + # txt_v = self.linear1_attn_v(txt_mod).view(*shape) + + # qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + + # q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + + # compute attention + qkv_list = [q, k, v] + del q, k, v + attn = attention(qkv_list, pe=pe) + # compute activation in mlp stream, cat again and run second linear layer + + x_mod_shape = x_mod.shape + x_mod = x_mod.view(-1, x_mod.shape[-1]) + chunk_size = int(x_mod_shape[1]/6) + x_chunks = torch.split(x_mod, chunk_size) + attn = attn.view(-1, attn.shape[-1]) + attn_chunks =torch.split(attn, chunk_size) + for x_chunk, attn_chunk in zip(x_chunks, attn_chunks): + mlp_chunk = self.linear1_mlp(x_chunk) + mlp_chunk = self.mlp_act(mlp_chunk) + attn_mlp_chunk = torch.cat((attn_chunk, mlp_chunk), -1) + del attn_chunk, mlp_chunk + x_chunk[...] = self.linear2(attn_mlp_chunk) + del attn_mlp_chunk + x_mod = x_mod.view(x_mod_shape) + x.addcmul_(x_mod, mod.gate) + return x + + # output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + # return x + mod.gate * output + + +class LastLayer(nn.Module): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + def forward(self, x: Tensor, vec: Tensor) -> Tensor: + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x diff --git a/flux/modules/layers.py b/flux/modules/layers.py new file mode 100644 index 0000000..0fbe404 --- /dev/null +++ b/flux/modules/layers.py @@ -0,0 +1,328 @@ +import math +from dataclasses import dataclass + +import torch +from einops import rearrange +from torch import Tensor, nn + +from flux.math import attention, rope + +def get_linear_split_map(): + hidden_size = 3072 + split_linear_modules_map = { + "qkv" : {"mapped_modules" : ["q", "k", "v"] , "split_sizes": [hidden_size, hidden_size, hidden_size]}, + "linear1" : {"mapped_modules" : ["linear1_attn_q", "linear1_attn_k", "linear1_attn_v", "linear1_mlp"] , "split_sizes": [hidden_size, hidden_size, hidden_size, 7*hidden_size- 3*hidden_size]} + } + return split_linear_modules_map + + +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: Tensor) -> Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + + return emb.unsqueeze(1) + + +def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + t = time_factor * t + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + t.device + ) + + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(t) + return embedding + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + def forward(self, x: Tensor) -> Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor): + x_dtype = x.dtype + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + return (x * rrms).to(dtype=x_dtype) * self.scale + + + +class QKNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.query_norm = RMSNorm(dim) + self.key_norm = RMSNorm(dim) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: + if k != None: + return self.key_norm(k).to(v) + else: + return self.query_norm(q).to(v) + # q = self.query_norm(q) + # k = self.key_norm(k) + # return q.to(v), k.to(v) + + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.norm = QKNorm(head_dim) + self.proj = nn.Linear(dim, dim) + + def forward(self, x: Tensor, pe: Tensor) -> Tensor: + raise Exception("not implemented") + +@dataclass +class ModulationOut: + shift: Tensor + scale: Tensor + gate: Tensor + + +def split_mlp(mlp, x, divide = 4): + x_shape = x.shape + x = x.view(-1, x.shape[-1]) + chunk_size = int(x_shape[1]/divide) + x_chunks = torch.split(x, chunk_size) + for i, x_chunk in enumerate(x_chunks): + mlp_chunk = mlp[0](x_chunk) + mlp_chunk = mlp[1](mlp_chunk) + x_chunk[...] = mlp[2](mlp_chunk) + return x.reshape(x_shape) + +class Modulation(nn.Module): + def __init__(self, dim: int, double: bool): + super().__init__() + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) + + def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: + out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) + + return ( + ModulationOut(*out[:3]), + ModulationOut(*out[3:]) if self.is_double else None, + ) + + +class DoubleStreamBlock(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): + super().__init__() + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.img_mod = Modulation(hidden_size, double=True) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.txt_mod = Modulation(hidden_size, double=True) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: + img_mod1, img_mod2 = self.img_mod(vec) + txt_mod1, txt_mod2 = self.txt_mod(vec) + + # prepare image for attention + img_modulated = self.img_norm1(img) + img_modulated.mul_(1 + img_mod1.scale) + img_modulated.add_(img_mod1.shift) + + shape = (*img_modulated.shape[:2], self.num_heads, int(img_modulated.shape[-1] / self.num_heads) ) + img_q = self.img_attn.q(img_modulated).view(*shape).transpose(1,2) + img_k = self.img_attn.k(img_modulated).view(*shape).transpose(1,2) + img_v = self.img_attn.v(img_modulated).view(*shape).transpose(1,2) + del img_modulated + + + img_q= self.img_attn.norm(img_q, None, img_v) + img_k = self.img_attn.norm(None, img_k, img_v) + + # prepare txt for attention + txt_modulated = self.txt_norm1(txt) + txt_modulated.mul_(1 + txt_mod1.scale) + txt_modulated.add_(txt_mod1.shift) + + shape = (*txt_modulated.shape[:2], self.num_heads, int(txt_modulated.shape[-1] / self.num_heads) ) + txt_q = self.txt_attn.q(txt_modulated).view(*shape).transpose(1,2) + txt_k = self.txt_attn.k(txt_modulated).view(*shape).transpose(1,2) + txt_v = self.txt_attn.v(txt_modulated).view(*shape).transpose(1,2) + del txt_modulated + + + txt_q = self.txt_attn.norm(txt_q, None, txt_v) + txt_k = self.txt_attn.norm(None, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + del txt_q, img_q + k = torch.cat((txt_k, img_k), dim=2) + del txt_k, img_k + v = torch.cat((txt_v, img_v), dim=2) + del txt_v, img_v + + qkv_list = [q, k, v] + del q, k, v + attn = attention(qkv_list, pe=pe) + + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] + + # calculate the img blocks + img.addcmul_(self.img_attn.proj(img_attn), img_mod1.gate) + mod_img = self.img_norm2(img) + mod_img.mul_(1 + img_mod2.scale) + mod_img.add_(img_mod2.shift) + mod_img = split_mlp(self.img_mlp, mod_img) + # mod_img = self.img_mlp(mod_img) + img.addcmul_( mod_img, img_mod2.gate) + mod_img = None + + # calculate the txt blocks + txt.addcmul_(self.txt_attn.proj(txt_attn), txt_mod1.gate) + txt.addcmul_(self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift), txt_mod2.gate) + return img, txt + + +class SingleStreamBlock(nn.Module): + """ + A DiT block with parallel linear layers as described in + https://arxiv.org/abs/2302.05442 and adapted modulation interface. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float | None = None, + ): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + head_dim = hidden_size // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + # proj and mlp_out + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + + self.norm = QKNorm(head_dim) + + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp_act = nn.GELU(approximate="tanh") + self.modulation = Modulation(hidden_size, double=False) + + def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: + mod, _ = self.modulation(vec) + x_mod = self.pre_norm(x) + x_mod.mul_(1 + mod.scale) + x_mod.add_(mod.shift) + + ##### More spagheti VRAM optimizations done by DeepBeepMeep ! + # I am sure you are a nice person and as you copy this code, you will give me proper credits: + # Please link to https://github.com/deepbeepmeep/Wan2GP and @deepbeepmeep on twitter + + # x_mod = (1 + mod.scale) * x + mod.shift + + shape = (*x_mod.shape[:2], self.num_heads, int(x_mod.shape[-1] / self.num_heads) ) + q = self.linear1_attn_q(x_mod).view(*shape).transpose(1,2) + k = self.linear1_attn_k(x_mod).view(*shape).transpose(1,2) + v = self.linear1_attn_v(x_mod).view(*shape).transpose(1,2) + + q = self.norm(q, None, v) + k = self.norm(None, k, v) + + # compute attention + qkv_list = [q, k, v] + del q, k, v + attn = attention(qkv_list, pe=pe) + # compute activation in mlp stream, cat again and run second linear layer + + x_mod_shape = x_mod.shape + x_mod = x_mod.view(-1, x_mod.shape[-1]) + chunk_size = int(x_mod_shape[1]/6) + x_chunks = torch.split(x_mod, chunk_size) + attn = attn.view(-1, attn.shape[-1]) + attn_chunks =torch.split(attn, chunk_size) + for x_chunk, attn_chunk in zip(x_chunks, attn_chunks): + mlp_chunk = self.linear1_mlp(x_chunk) + mlp_chunk = self.mlp_act(mlp_chunk) + attn_mlp_chunk = torch.cat((attn_chunk, mlp_chunk), -1) + del attn_chunk, mlp_chunk + x_chunk[...] = self.linear2(attn_mlp_chunk) + del attn_mlp_chunk + x_mod = x_mod.view(x_mod_shape) + x.addcmul_(x_mod, mod.gate) + return x + + +class LastLayer(nn.Module): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + def forward(self, x: Tensor, vec: Tensor) -> Tensor: + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x diff --git a/flux/modules/lora.py b/flux/modules/lora.py new file mode 100644 index 0000000..556027e --- /dev/null +++ b/flux/modules/lora.py @@ -0,0 +1,94 @@ +import torch +from torch import nn + + +def replace_linear_with_lora( + module: nn.Module, + max_rank: int, + scale: float = 1.0, +) -> None: + for name, child in module.named_children(): + if isinstance(child, nn.Linear): + new_lora = LinearLora( + in_features=child.in_features, + out_features=child.out_features, + bias=child.bias, + rank=max_rank, + scale=scale, + dtype=child.weight.dtype, + device=child.weight.device, + ) + + new_lora.weight = child.weight + new_lora.bias = child.bias if child.bias is not None else None + + setattr(module, name, new_lora) + else: + replace_linear_with_lora( + module=child, + max_rank=max_rank, + scale=scale, + ) + + +class LinearLora(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + rank: int, + dtype: torch.dtype, + device: torch.device, + lora_bias: bool = True, + scale: float = 1.0, + *args, + **kwargs, + ) -> None: + super().__init__( + in_features=in_features, + out_features=out_features, + bias=bias is not None, + device=device, + dtype=dtype, + *args, + **kwargs, + ) + + assert isinstance(scale, float), "scale must be a float" + + self.scale = scale + self.rank = rank + self.lora_bias = lora_bias + self.dtype = dtype + self.device = device + + if rank > (new_rank := min(self.out_features, self.in_features)): + self.rank = new_rank + + self.lora_A = nn.Linear( + in_features=in_features, + out_features=self.rank, + bias=False, + dtype=dtype, + device=device, + ) + self.lora_B = nn.Linear( + in_features=self.rank, + out_features=out_features, + bias=self.lora_bias, + dtype=dtype, + device=device, + ) + + def set_scale(self, scale: float) -> None: + assert isinstance(scale, float), "scalar value must be a float" + self.scale = scale + + def forward(self, input: torch.Tensor) -> torch.Tensor: + base_out = super().forward(input) + + _lora_out_B = self.lora_B(self.lora_A(input)) + lora_update = _lora_out_B * self.scale + + return base_out + lora_update diff --git a/flux/sampling.py b/flux/sampling.py new file mode 100644 index 0000000..7581dea --- /dev/null +++ b/flux/sampling.py @@ -0,0 +1,392 @@ +import math +from typing import Callable + +import numpy as np +import torch +from einops import rearrange, repeat +from PIL import Image +from torch import Tensor + +from .model import Flux +from .modules.autoencoder import AutoEncoder +from .modules.conditioner import HFEmbedder +from .modules.image_embedders import CannyImageEncoder, DepthImageEncoder, ReduxImageEncoder +from .util import PREFERED_KONTEXT_RESOLUTIONS +from einops import rearrange, repeat + + +def get_noise( + num_samples: int, + height: int, + width: int, + device: torch.device, + dtype: torch.dtype, + seed: int, +): + return torch.randn( + num_samples, + 16, + # allow for packing + 2 * math.ceil(height / 16), + 2 * math.ceil(width / 16), + dtype=dtype, + generator=torch.Generator(device="cuda").manual_seed(seed), + ).to(device) + + +def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]: + bs, c, h, w = img.shape + if bs == 1 and not isinstance(prompt, str): + bs = len(prompt) + + img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if img.shape[0] == 1 and bs > 1: + img = repeat(img, "1 ... -> bs ...", bs=bs) + + img_ids = torch.zeros(h // 2, w // 2, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + if isinstance(prompt, str): + prompt = [prompt] + txt = t5(prompt) + if txt.shape[0] == 1 and bs > 1: + txt = repeat(txt, "1 ... -> bs ...", bs=bs) + txt_ids = torch.zeros(bs, txt.shape[1], 3) + + vec = clip(prompt) + if vec.shape[0] == 1 and bs > 1: + vec = repeat(vec, "1 ... -> bs ...", bs=bs) + + return { + "img": img, + "img_ids": img_ids.to(img.device), + "txt": txt.to(img.device), + "txt_ids": txt_ids.to(img.device), + "vec": vec.to(img.device), + } + + +def prepare_control( + t5: HFEmbedder, + clip: HFEmbedder, + img: Tensor, + prompt: str | list[str], + ae: AutoEncoder, + encoder: DepthImageEncoder | CannyImageEncoder, + img_cond_path: str, +) -> dict[str, Tensor]: + # load and encode the conditioning image + bs, _, h, w = img.shape + if bs == 1 and not isinstance(prompt, str): + bs = len(prompt) + + img_cond = Image.open(img_cond_path).convert("RGB") + + width = w * 8 + height = h * 8 + img_cond = img_cond.resize((width, height), Image.Resampling.LANCZOS) + img_cond = np.array(img_cond) + img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0 + img_cond = rearrange(img_cond, "h w c -> 1 c h w") + + with torch.no_grad(): + img_cond = encoder(img_cond) + img_cond = ae.encode(img_cond) + + img_cond = img_cond.to(torch.bfloat16) + img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if img_cond.shape[0] == 1 and bs > 1: + img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs) + + return_dict = prepare(t5, clip, img, prompt) + return_dict["img_cond"] = img_cond + return return_dict + + +def prepare_fill( + t5: HFEmbedder, + clip: HFEmbedder, + img: Tensor, + prompt: str | list[str], + ae: AutoEncoder, + img_cond_path: str, + mask_path: str, +) -> dict[str, Tensor]: + # load and encode the conditioning image and the mask + bs, _, _, _ = img.shape + if bs == 1 and not isinstance(prompt, str): + bs = len(prompt) + + img_cond = Image.open(img_cond_path).convert("RGB") + img_cond = np.array(img_cond) + img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0 + img_cond = rearrange(img_cond, "h w c -> 1 c h w") + + mask = Image.open(mask_path).convert("L") + mask = np.array(mask) + mask = torch.from_numpy(mask).float() / 255.0 + mask = rearrange(mask, "h w -> 1 1 h w") + + with torch.no_grad(): + img_cond = img_cond.to(img.device) + mask = mask.to(img.device) + img_cond = img_cond * (1 - mask) + img_cond = ae.encode(img_cond) + mask = mask[:, 0, :, :] + mask = mask.to(torch.bfloat16) + mask = rearrange( + mask, + "b (h ph) (w pw) -> b (ph pw) h w", + ph=8, + pw=8, + ) + mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if mask.shape[0] == 1 and bs > 1: + mask = repeat(mask, "1 ... -> bs ...", bs=bs) + + img_cond = img_cond.to(torch.bfloat16) + img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if img_cond.shape[0] == 1 and bs > 1: + img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs) + + img_cond = torch.cat((img_cond, mask), dim=-1) + + return_dict = prepare(t5, clip, img, prompt) + return_dict["img_cond"] = img_cond.to(img.device) + return return_dict + + +def prepare_redux( + t5: HFEmbedder, + clip: HFEmbedder, + img: Tensor, + prompt: str | list[str], + encoder: ReduxImageEncoder, + img_cond_path: str, +) -> dict[str, Tensor]: + bs, _, h, w = img.shape + if bs == 1 and not isinstance(prompt, str): + bs = len(prompt) + + img_cond = Image.open(img_cond_path).convert("RGB") + with torch.no_grad(): + img_cond = encoder(img_cond) + + img_cond = img_cond.to(torch.bfloat16) + if img_cond.shape[0] == 1 and bs > 1: + img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs) + + img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if img.shape[0] == 1 and bs > 1: + img = repeat(img, "1 ... -> bs ...", bs=bs) + + img_ids = torch.zeros(h // 2, w // 2, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + if isinstance(prompt, str): + prompt = [prompt] + txt = t5(prompt) + txt = torch.cat((txt, img_cond.to(txt)), dim=-2) + if txt.shape[0] == 1 and bs > 1: + txt = repeat(txt, "1 ... -> bs ...", bs=bs) + txt_ids = torch.zeros(bs, txt.shape[1], 3) + + vec = clip(prompt) + if vec.shape[0] == 1 and bs > 1: + vec = repeat(vec, "1 ... -> bs ...", bs=bs) + + return { + "img": img, + "img_ids": img_ids.to(img.device), + "txt": txt.to(img.device), + "txt_ids": txt_ids.to(img.device), + "vec": vec.to(img.device), + } + + +def prepare_kontext( + t5: HFEmbedder, + clip: HFEmbedder, + prompt: str | list[str], + ae: AutoEncoder, + img_cond: str, + seed: int, + device: torch.device, + target_width: int | None = None, + target_height: int | None = None, + bs: int = 1, +) -> tuple[dict[str, Tensor], int, int]: + # load and encode the conditioning image + if bs == 1 and not isinstance(prompt, str): + bs = len(prompt) + + width, height = img_cond.size + aspect_ratio = width / height + + # Kontext is trained on specific resolutions, using one of them is recommended + _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS) + + width = 2 * int(width / 16) + height = 2 * int(height / 16) + + img_cond = img_cond.resize((8 * width, 8 * height), Image.Resampling.LANCZOS) + img_cond = np.array(img_cond) + img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0 + img_cond = rearrange(img_cond, "h w c -> 1 c h w") + img_cond_orig = img_cond.clone() + + with torch.no_grad(): + img_cond = ae.encode(img_cond.to(device)) + + img_cond = img_cond.to(torch.bfloat16) + img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if img_cond.shape[0] == 1 and bs > 1: + img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs) + + # image ids are the same as base image with the first dimension set to 1 + # instead of 0 + img_cond_ids = torch.zeros(height // 2, width // 2, 3) + img_cond_ids[..., 0] = 1 + img_cond_ids[..., 1] = img_cond_ids[..., 1] + torch.arange(height // 2)[:, None] + img_cond_ids[..., 2] = img_cond_ids[..., 2] + torch.arange(width // 2)[None, :] + img_cond_ids = repeat(img_cond_ids, "h w c -> b (h w) c", b=bs) + + if target_width is None: + target_width = 8 * width + if target_height is None: + target_height = 8 * height + + img = get_noise( + bs, + target_height, + target_width, + device=device, + dtype=torch.bfloat16, + seed=seed, + ) + + return_dict = prepare(t5, clip, img, prompt) + return_dict["img_cond_seq"] = img_cond + return_dict["img_cond_seq_ids"] = img_cond_ids.to(device) + return_dict["img_cond_orig"] = img_cond_orig + return return_dict, target_height, target_width + + +def time_shift(mu: float, sigma: float, t: Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function( + x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 +) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # estimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + +def denoise( + model: Flux, + # model input + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + vec: Tensor, + # sampling parameters + timesteps: list[float], + guidance: float = 4.0, + # extra img tokens (channel-wise) + img_cond: Tensor | None = None, + # extra img tokens (sequence-wise) + img_cond_seq: Tensor | None = None, + img_cond_seq_ids: Tensor | None = None, + callback=None, + pipeline=None, + loras_slists=None, + unpack_latent = None, +): + + kwargs = {'pipeline': pipeline, 'callback': callback} + if callback != None: + callback(-1, None, True) + + updated_num_steps= len(timesteps) -1 + if callback != None: + from wgp import update_loras_slists + update_loras_slists(model, loras_slists, updated_num_steps) + callback(-1, None, True, override_num_inference_steps = updated_num_steps) + from mmgp import offload + # this is ignored for schnell + guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])): + offload.set_step_no_for_lora(model, i) + if pipeline._interrupt: + return None + + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + img_input = img + img_input_ids = img_ids + if img_cond is not None: + img_input = torch.cat((img, img_cond), dim=-1) + if img_cond_seq is not None: + assert ( + img_cond_seq_ids is not None + ), "You need to provide either both or neither of the sequence conditioning" + img_input = torch.cat((img_input, img_cond_seq), dim=1) + img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1) + pred = model( + img=img_input, + img_ids=img_input_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + **kwargs + ) + if pred == None: return None + + if img_input_ids is not None: + pred = pred[:, : img.shape[1]] + + img += (t_prev - t_curr) * pred + if callback is not None: + preview = unpack_latent(img).transpose(0,1) + callback(i, preview, False) + + + return img + + +def unpack(x: Tensor, height: int, width: int) -> Tensor: + return rearrange( + x, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=math.ceil(height / 16), + w=math.ceil(width / 16), + ph=2, + pw=2, + ) diff --git a/flux/to_remove/cli.py b/flux/to_remove/cli.py new file mode 100644 index 0000000..ed0b1c8 --- /dev/null +++ b/flux/to_remove/cli.py @@ -0,0 +1,302 @@ +import os +import re +import time +from dataclasses import dataclass +from glob import iglob + +import torch +from fire import Fire +from transformers import pipeline + +from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack +from flux.util import ( + check_onnx_access_for_trt, + configs, + load_ae, + load_clip, + load_flow_model, + load_t5, + save_image, +) + +NSFW_THRESHOLD = 0.85 + + +@dataclass +class SamplingOptions: + prompt: str + width: int + height: int + num_steps: int + guidance: float + seed: int | None + + +def parse_prompt(options: SamplingOptions) -> SamplingOptions | None: + user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n" + usage = ( + "Usage: Either write your prompt directly, leave this field empty " + "to repeat the prompt or write a command starting with a slash:\n" + "- '/w ' will set the width of the generated image\n" + "- '/h ' will set the height of the generated image\n" + "- '/s ' sets the next seed\n" + "- '/g ' sets the guidance (flux-dev only)\n" + "- '/n ' sets the number of steps\n" + "- '/q' to quit" + ) + + while (prompt := input(user_question)).startswith("/"): + if prompt.startswith("/w"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, width = prompt.split() + options.width = 16 * (int(width) // 16) + print( + f"Setting resolution to {options.width} x {options.height} " + f"({options.height * options.width / 1e6:.2f}MP)" + ) + elif prompt.startswith("/h"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, height = prompt.split() + options.height = 16 * (int(height) // 16) + print( + f"Setting resolution to {options.width} x {options.height} " + f"({options.height * options.width / 1e6:.2f}MP)" + ) + elif prompt.startswith("/g"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, guidance = prompt.split() + options.guidance = float(guidance) + print(f"Setting guidance to {options.guidance}") + elif prompt.startswith("/s"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, seed = prompt.split() + options.seed = int(seed) + print(f"Setting seed to {options.seed}") + elif prompt.startswith("/n"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, steps = prompt.split() + options.num_steps = int(steps) + print(f"Setting number of steps to {options.num_steps}") + elif prompt.startswith("/q"): + print("Quitting") + return None + else: + if not prompt.startswith("/h"): + print(f"Got invalid command '{prompt}'\n{usage}") + print(usage) + if prompt != "": + options.prompt = prompt + return options + + +@torch.inference_mode() +def main( + name: str = "flux-schnell", + width: int = 1360, + height: int = 768, + seed: int | None = None, + prompt: str = ( + "a photo of a forest with mist swirling around the tree trunks. The word " + '"FLUX" is painted over it in big, red brush strokes with visible texture' + ), + device: str = "cuda" if torch.cuda.is_available() else "cpu", + num_steps: int | None = None, + loop: bool = False, + guidance: float = 2.5, + offload: bool = False, + output_dir: str = "output", + add_sampling_metadata: bool = True, + trt: bool = False, + trt_transformer_precision: str = "bf16", + track_usage: bool = False, +): + """ + Sample the flux model. Either interactively (set `--loop`) or run for a + single image. + + Args: + name: Name of the model to load + height: height of the sample in pixels (should be a multiple of 16) + width: width of the sample in pixels (should be a multiple of 16) + seed: Set a seed for sampling + output_name: where to save the output image, `{idx}` will be replaced + by the index of the sample + prompt: Prompt used for sampling + device: Pytorch device + num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled) + loop: start an interactive session and sample multiple times + guidance: guidance value used for guidance distillation + add_sampling_metadata: Add the prompt to the image Exif metadata + trt: use TensorRT backend for optimized inference + trt_transformer_precision: specify transformer precision for inference + track_usage: track usage of the model for licensing purposes + """ + + prompt = prompt.split("|") + if len(prompt) == 1: + prompt = prompt[0] + additional_prompts = None + else: + additional_prompts = prompt[1:] + prompt = prompt[0] + + assert not ( + (additional_prompts is not None) and loop + ), "Do not provide additional prompts and set loop to True" + + nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device) + + if name not in configs: + available = ", ".join(configs.keys()) + raise ValueError(f"Got unknown model name: {name}, chose from {available}") + + torch_device = torch.device(device) + if num_steps is None: + num_steps = 4 if name == "flux-schnell" else 50 + + # allow for packing and conversion to latent space + height = 16 * (height // 16) + width = 16 * (width // 16) + + output_name = os.path.join(output_dir, "img_{idx}.jpg") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + idx = 0 + else: + fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)] + if len(fns) > 0: + idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 + else: + idx = 0 + + if not trt: + t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512) + clip = load_clip(torch_device) + model = load_flow_model(name, device="cpu" if offload else torch_device) + ae = load_ae(name, device="cpu" if offload else torch_device) + else: + # lazy import to make install optional + from flux.trt.trt_manager import ModuleName, TRTManager + + # Check if we need ONNX model access (which requires authentication for FLUX models) + onnx_dir = check_onnx_access_for_trt(name, trt_transformer_precision) + + trt_ctx_manager = TRTManager( + trt_transformer_precision=trt_transformer_precision, + trt_t5_precision=os.getenv("TRT_T5_PRECISION", "bf16"), + ) + engines = trt_ctx_manager.load_engines( + model_name=name, + module_names={ + ModuleName.CLIP, + ModuleName.TRANSFORMER, + ModuleName.T5, + ModuleName.VAE, + }, + engine_dir=os.environ.get("TRT_ENGINE_DIR", "./engines"), + custom_onnx_paths=onnx_dir or os.environ.get("CUSTOM_ONNX_PATHS", ""), + trt_image_height=height, + trt_image_width=width, + trt_batch_size=1, + trt_timing_cache=os.getenv("TRT_TIMING_CACHE_FILE", None), + trt_static_batch=False, + trt_static_shape=False, + ) + + ae = engines[ModuleName.VAE].to(device="cpu" if offload else torch_device) + model = engines[ModuleName.TRANSFORMER].to(device="cpu" if offload else torch_device) + clip = engines[ModuleName.CLIP].to(torch_device) + t5 = engines[ModuleName.T5].to(device="cpu" if offload else torch_device) + + rng = torch.Generator(device="cpu") + opts = SamplingOptions( + prompt=prompt, + width=width, + height=height, + num_steps=num_steps, + guidance=guidance, + seed=seed, + ) + + if loop: + opts = parse_prompt(opts) + + while opts is not None: + if opts.seed is None: + opts.seed = rng.seed() + print(f"Generating with seed {opts.seed}:\n{opts.prompt}") + t0 = time.perf_counter() + + # prepare input + x = get_noise( + 1, + opts.height, + opts.width, + device=torch_device, + dtype=torch.bfloat16, + seed=opts.seed, + ) + opts.seed = None + if offload: + ae = ae.cpu() + torch.cuda.empty_cache() + t5, clip = t5.to(torch_device), clip.to(torch_device) + inp = prepare(t5, clip, x, prompt=opts.prompt) + timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")) + + # offload TEs to CPU, load model to gpu + if offload: + t5, clip = t5.cpu(), clip.cpu() + torch.cuda.empty_cache() + model = model.to(torch_device) + + # denoise initial noise + x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance) + + # offload model, load autoencoder to gpu + if offload: + model.cpu() + torch.cuda.empty_cache() + ae.decoder.to(x.device) + + # decode latents to pixel space + x = unpack(x.float(), opts.height, opts.width) + with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): + x = ae.decode(x) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + t1 = time.perf_counter() + + fn = output_name.format(idx=idx) + print(f"Done in {t1 - t0:.1f}s. Saving {fn}") + + idx = save_image( + nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage + ) + + if loop: + print("-" * 80) + opts = parse_prompt(opts) + elif additional_prompts: + next_prompt = additional_prompts.pop(0) + opts.prompt = next_prompt + else: + opts = None + + if trt: + trt_ctx_manager.stop_runtime() + + +if __name__ == "__main__": + Fire(main) diff --git a/flux/to_remove/cli_control.py b/flux/to_remove/cli_control.py new file mode 100644 index 0000000..73a6943 --- /dev/null +++ b/flux/to_remove/cli_control.py @@ -0,0 +1,390 @@ +import os +import re +import time +from dataclasses import dataclass +from glob import iglob + +import torch +from fire import Fire +from transformers import pipeline + +from flux.modules.image_embedders import CannyImageEncoder, DepthImageEncoder +from flux.sampling import denoise, get_noise, get_schedule, prepare_control, unpack +from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image + + +@dataclass +class SamplingOptions: + prompt: str + width: int + height: int + num_steps: int + guidance: float + seed: int | None + img_cond_path: str + lora_scale: float | None + + +def parse_prompt(options: SamplingOptions) -> SamplingOptions | None: + user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n" + usage = ( + "Usage: Either write your prompt directly, leave this field empty " + "to repeat the prompt or write a command starting with a slash:\n" + "- '/w ' will set the width of the generated image\n" + "- '/h ' will set the height of the generated image\n" + "- '/s ' sets the next seed\n" + "- '/g ' sets the guidance (flux-dev only)\n" + "- '/n ' sets the number of steps\n" + "- '/q' to quit" + ) + + while (prompt := input(user_question)).startswith("/"): + if prompt.startswith("/w"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, width = prompt.split() + options.width = 16 * (int(width) // 16) + print( + f"Setting resolution to {options.width} x {options.height} " + f"({options.height * options.width / 1e6:.2f}MP)" + ) + elif prompt.startswith("/h"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, height = prompt.split() + options.height = 16 * (int(height) // 16) + print( + f"Setting resolution to {options.width} x {options.height} " + f"({options.height * options.width / 1e6:.2f}MP)" + ) + elif prompt.startswith("/g"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, guidance = prompt.split() + options.guidance = float(guidance) + print(f"Setting guidance to {options.guidance}") + elif prompt.startswith("/s"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, seed = prompt.split() + options.seed = int(seed) + print(f"Setting seed to {options.seed}") + elif prompt.startswith("/n"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, steps = prompt.split() + options.num_steps = int(steps) + print(f"Setting number of steps to {options.num_steps}") + elif prompt.startswith("/q"): + print("Quitting") + return None + else: + if not prompt.startswith("/h"): + print(f"Got invalid command '{prompt}'\n{usage}") + print(usage) + if prompt != "": + options.prompt = prompt + return options + + +def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None: + if options is None: + return None + + user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n" + usage = ( + "Usage: Either write your prompt directly, leave this field empty " + "to repeat the conditioning image or write a command starting with a slash:\n" + "- '/q' to quit" + ) + + while True: + img_cond_path = input(user_question) + + if img_cond_path.startswith("/"): + if img_cond_path.startswith("/q"): + print("Quitting") + return None + else: + if not img_cond_path.startswith("/h"): + print(f"Got invalid command '{img_cond_path}'\n{usage}") + print(usage) + continue + + if img_cond_path == "": + break + + if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith( + (".jpg", ".jpeg", ".png", ".webp") + ): + print(f"File '{img_cond_path}' does not exist or is not a valid image file") + continue + + options.img_cond_path = img_cond_path + break + + return options + + +def parse_lora_scale(options: SamplingOptions | None) -> tuple[SamplingOptions | None, bool]: + changed = False + + if options is None: + return None, changed + + user_question = "Next lora scale (write /h for help, /q to quit and leave empty to repeat):\n" + usage = ( + "Usage: Either write your prompt directly, leave this field empty " + "to repeat the lora scale or write a command starting with a slash:\n" + "- '/q' to quit" + ) + + while (prompt := input(user_question)).startswith("/"): + if prompt.startswith("/q"): + print("Quitting") + return None, changed + else: + if not prompt.startswith("/h"): + print(f"Got invalid command '{prompt}'\n{usage}") + print(usage) + if prompt != "": + options.lora_scale = float(prompt) + changed = True + return options, changed + + +@torch.inference_mode() +def main( + name: str, + width: int = 1024, + height: int = 1024, + seed: int | None = None, + prompt: str = "a robot made out of gold", + device: str = "cuda" if torch.cuda.is_available() else "cpu", + num_steps: int = 50, + loop: bool = False, + guidance: float | None = None, + offload: bool = False, + output_dir: str = "output", + add_sampling_metadata: bool = True, + img_cond_path: str = "assets/robot.webp", + lora_scale: float | None = 0.85, + trt: bool = False, + trt_transformer_precision: str = "bf16", + track_usage: bool = False, + **kwargs: dict | None, +): + """ + Sample the flux model. Either interactively (set `--loop`) or run for a + single image. + + Args: + height: height of the sample in pixels (should be a multiple of 16) + width: width of the sample in pixels (should be a multiple of 16) + seed: Set a seed for sampling + output_name: where to save the output image, `{idx}` will be replaced + by the index of the sample + prompt: Prompt used for sampling + device: Pytorch device + num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled) + loop: start an interactive session and sample multiple times + guidance: guidance value used for guidance distillation + add_sampling_metadata: Add the prompt to the image Exif metadata + img_cond_path: path to conditioning image (jpeg/png/webp) + trt: use TensorRT backend for optimized inference + trt_transformer_precision: specify transformer precision for inference + track_usage: track usage of the model for licensing purposes + """ + nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device) + + if "lora" in name: + assert not trt, "TRT does not support LORA" + assert name in [ + "flux-dev-canny", + "flux-dev-depth", + "flux-dev-canny-lora", + "flux-dev-depth-lora", + ], f"Got unknown model name: {name}" + + if guidance is None: + if name in ["flux-dev-canny", "flux-dev-canny-lora"]: + guidance = 30.0 + elif name in ["flux-dev-depth", "flux-dev-depth-lora"]: + guidance = 10.0 + else: + raise NotImplementedError() + + if name not in configs: + available = ", ".join(configs.keys()) + raise ValueError(f"Got unknown model name: {name}, chose from {available}") + + torch_device = torch.device(device) + + output_name = os.path.join(output_dir, "img_{idx}.jpg") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + idx = 0 + else: + fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)] + if len(fns) > 0: + idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 + else: + idx = 0 + + if name in ["flux-dev-depth", "flux-dev-depth-lora"]: + img_embedder = DepthImageEncoder(torch_device) + elif name in ["flux-dev-canny", "flux-dev-canny-lora"]: + img_embedder = CannyImageEncoder(torch_device) + else: + raise NotImplementedError() + + if not trt: + # init all components + t5 = load_t5(torch_device, max_length=512) + clip = load_clip(torch_device) + model = load_flow_model(name, device="cpu" if offload else torch_device) + ae = load_ae(name, device="cpu" if offload else torch_device) + else: + # lazy import to make install optional + from flux.trt.trt_manager import ModuleName, TRTManager + + trt_ctx_manager = TRTManager( + trt_transformer_precision=trt_transformer_precision, + trt_t5_precision=os.environ.get("TRT_T5_PRECISION", "bf16"), + ) + + engines = trt_ctx_manager.load_engines( + model_name=name, + module_names={ + ModuleName.CLIP, + ModuleName.TRANSFORMER, + ModuleName.T5, + ModuleName.VAE, + ModuleName.VAE_ENCODER, + }, + engine_dir=os.environ.get("TRT_ENGINE_DIR", "./engines"), + custom_onnx_paths=os.environ.get("CUSTOM_ONNX_PATHS", ""), + trt_image_height=height, + trt_image_width=width, + trt_batch_size=1, + trt_static_batch=kwargs.get("static_batch", True), + trt_static_shape=kwargs.get("static_shape", True), + ) + + ae = engines[ModuleName.VAE].to(device="cpu" if offload else torch_device) + model = engines[ModuleName.TRANSFORMER].to(device="cpu" if offload else torch_device) + clip = engines[ModuleName.CLIP].to(torch_device) + t5 = engines[ModuleName.T5].to(device="cpu" if offload else torch_device) + + # set lora scale + if "lora" in name and lora_scale is not None: + for _, module in model.named_modules(): + if hasattr(module, "set_scale"): + module.set_scale(lora_scale) + + rng = torch.Generator(device="cpu") + opts = SamplingOptions( + prompt=prompt, + width=width, + height=height, + num_steps=num_steps, + guidance=guidance, + seed=seed, + img_cond_path=img_cond_path, + lora_scale=lora_scale, + ) + + if loop: + opts = parse_prompt(opts) + opts = parse_img_cond_path(opts) + if "lora" in name: + opts, changed = parse_lora_scale(opts) + if changed: + # update the lora scale: + for _, module in model.named_modules(): + if hasattr(module, "set_scale"): + module.set_scale(opts.lora_scale) + + while opts is not None: + if opts.seed is None: + opts.seed = rng.seed() + print(f"Generating with seed {opts.seed}:\n{opts.prompt}") + t0 = time.perf_counter() + + # prepare input + x = get_noise( + 1, + opts.height, + opts.width, + device=torch_device, + dtype=torch.bfloat16, + seed=opts.seed, + ) + opts.seed = None + if offload: + t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch_device) + inp = prepare_control( + t5, + clip, + x, + prompt=opts.prompt, + ae=ae, + encoder=img_embedder, + img_cond_path=opts.img_cond_path, + ) + timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")) + + # offload TEs and AE to CPU, load model to gpu + if offload: + t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu() + torch.cuda.empty_cache() + model = model.to(torch_device) + + # denoise initial noise + x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance) + + # offload model, load autoencoder to gpu + if offload: + model.cpu() + torch.cuda.empty_cache() + ae.decoder.to(x.device) + + # decode latents to pixel space + x = unpack(x.float(), opts.height, opts.width) + with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): + x = ae.decode(x) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + t1 = time.perf_counter() + print(f"Done in {t1 - t0:.1f}s") + + idx = save_image( + nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage + ) + + if loop: + print("-" * 80) + opts = parse_prompt(opts) + opts = parse_img_cond_path(opts) + if "lora" in name: + opts, changed = parse_lora_scale(opts) + if changed: + # update the lora scale: + for _, module in model.named_modules(): + if hasattr(module, "set_scale"): + module.set_scale(opts.lora_scale) + else: + opts = None + + if trt: + trt_ctx_manager.stop_runtime() + + +if __name__ == "__main__": + Fire(main) diff --git a/flux/to_remove/cli_fill.py b/flux/to_remove/cli_fill.py new file mode 100644 index 0000000..ab78c50 --- /dev/null +++ b/flux/to_remove/cli_fill.py @@ -0,0 +1,334 @@ +import os +import re +import time +from dataclasses import dataclass +from glob import iglob + +import torch +from fire import Fire +from PIL import Image +from transformers import pipeline + +from flux.sampling import denoise, get_noise, get_schedule, prepare_fill, unpack +from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image + + +@dataclass +class SamplingOptions: + prompt: str + width: int + height: int + num_steps: int + guidance: float + seed: int | None + img_cond_path: str + img_mask_path: str + + +def parse_prompt(options: SamplingOptions) -> SamplingOptions | None: + user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n" + usage = ( + "Usage: Either write your prompt directly, leave this field empty " + "to repeat the prompt or write a command starting with a slash:\n" + "- '/s ' sets the next seed\n" + "- '/g ' sets the guidance (flux-dev only)\n" + "- '/n ' sets the number of steps\n" + "- '/q' to quit" + ) + + while (prompt := input(user_question)).startswith("/"): + if prompt.startswith("/g"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, guidance = prompt.split() + options.guidance = float(guidance) + print(f"Setting guidance to {options.guidance}") + elif prompt.startswith("/s"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, seed = prompt.split() + options.seed = int(seed) + print(f"Setting seed to {options.seed}") + elif prompt.startswith("/n"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, steps = prompt.split() + options.num_steps = int(steps) + print(f"Setting number of steps to {options.num_steps}") + elif prompt.startswith("/q"): + print("Quitting") + return None + else: + if not prompt.startswith("/h"): + print(f"Got invalid command '{prompt}'\n{usage}") + print(usage) + if prompt != "": + options.prompt = prompt + return options + + +def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None: + if options is None: + return None + + user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n" + usage = ( + "Usage: Either write your prompt directly, leave this field empty " + "to repeat the conditioning image or write a command starting with a slash:\n" + "- '/q' to quit" + ) + + while True: + img_cond_path = input(user_question) + + if img_cond_path.startswith("/"): + if img_cond_path.startswith("/q"): + print("Quitting") + return None + else: + if not img_cond_path.startswith("/h"): + print(f"Got invalid command '{img_cond_path}'\n{usage}") + print(usage) + continue + + if img_cond_path == "": + break + + if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith( + (".jpg", ".jpeg", ".png", ".webp") + ): + print(f"File '{img_cond_path}' does not exist or is not a valid image file") + continue + else: + with Image.open(img_cond_path) as img: + width, height = img.size + + if width % 32 != 0 or height % 32 != 0: + print(f"Image dimensions must be divisible by 32, got {width}x{height}") + continue + + options.img_cond_path = img_cond_path + break + + return options + + +def parse_img_mask_path(options: SamplingOptions | None) -> SamplingOptions | None: + if options is None: + return None + + user_question = "Next conditioning mask (write /h for help, /q to quit and leave empty to repeat):\n" + usage = ( + "Usage: Either write your prompt directly, leave this field empty " + "to repeat the conditioning mask or write a command starting with a slash:\n" + "- '/q' to quit" + ) + + while True: + img_mask_path = input(user_question) + + if img_mask_path.startswith("/"): + if img_mask_path.startswith("/q"): + print("Quitting") + return None + else: + if not img_mask_path.startswith("/h"): + print(f"Got invalid command '{img_mask_path}'\n{usage}") + print(usage) + continue + + if img_mask_path == "": + break + + if not os.path.isfile(img_mask_path) or not img_mask_path.lower().endswith( + (".jpg", ".jpeg", ".png", ".webp") + ): + print(f"File '{img_mask_path}' does not exist or is not a valid image file") + continue + else: + with Image.open(img_mask_path) as img: + width, height = img.size + + if width % 32 != 0 or height % 32 != 0: + print(f"Image dimensions must be divisible by 32, got {width}x{height}") + continue + else: + with Image.open(options.img_cond_path) as img_cond: + img_cond_width, img_cond_height = img_cond.size + + if width != img_cond_width or height != img_cond_height: + print( + f"Mask dimensions must match conditioning image, got {width}x{height} and {img_cond_width}x{img_cond_height}" + ) + continue + + options.img_mask_path = img_mask_path + break + + return options + + +@torch.inference_mode() +def main( + seed: int | None = None, + prompt: str = "a white paper cup", + device: str = "cuda" if torch.cuda.is_available() else "cpu", + num_steps: int = 50, + loop: bool = False, + guidance: float = 30.0, + offload: bool = False, + output_dir: str = "output", + add_sampling_metadata: bool = True, + img_cond_path: str = "assets/cup.png", + img_mask_path: str = "assets/cup_mask.png", + track_usage: bool = False, +): + """ + Sample the flux model. Either interactively (set `--loop`) or run for a + single image. This demo assumes that the conditioning image and mask have + the same shape and that height and width are divisible by 32. + + Args: + seed: Set a seed for sampling + output_name: where to save the output image, `{idx}` will be replaced + by the index of the sample + prompt: Prompt used for sampling + device: Pytorch device + num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled) + loop: start an interactive session and sample multiple times + guidance: guidance value used for guidance distillation + add_sampling_metadata: Add the prompt to the image Exif metadata + img_cond_path: path to conditioning image (jpeg/png/webp) + img_mask_path: path to conditioning mask (jpeg/png/webp) + track_usage: track usage of the model for licensing purposes + """ + nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device) + + name = "flux-dev-fill" + if name not in configs: + available = ", ".join(configs.keys()) + raise ValueError(f"Got unknown model name: {name}, chose from {available}") + + torch_device = torch.device(device) + + output_name = os.path.join(output_dir, "img_{idx}.jpg") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + idx = 0 + else: + fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)] + if len(fns) > 0: + idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 + else: + idx = 0 + + # init all components + t5 = load_t5(torch_device, max_length=128) + clip = load_clip(torch_device) + model = load_flow_model(name, device="cpu" if offload else torch_device) + ae = load_ae(name, device="cpu" if offload else torch_device) + + rng = torch.Generator(device="cpu") + with Image.open(img_cond_path) as img: + width, height = img.size + opts = SamplingOptions( + prompt=prompt, + width=width, + height=height, + num_steps=num_steps, + guidance=guidance, + seed=seed, + img_cond_path=img_cond_path, + img_mask_path=img_mask_path, + ) + + if loop: + opts = parse_prompt(opts) + opts = parse_img_cond_path(opts) + + with Image.open(opts.img_cond_path) as img: + width, height = img.size + opts.height = height + opts.width = width + + opts = parse_img_mask_path(opts) + + while opts is not None: + if opts.seed is None: + opts.seed = rng.seed() + print(f"Generating with seed {opts.seed}:\n{opts.prompt}") + t0 = time.perf_counter() + + # prepare input + x = get_noise( + 1, + opts.height, + opts.width, + device=torch_device, + dtype=torch.bfloat16, + seed=opts.seed, + ) + opts.seed = None + if offload: + t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch_device) + inp = prepare_fill( + t5, + clip, + x, + prompt=opts.prompt, + ae=ae, + img_cond_path=opts.img_cond_path, + mask_path=opts.img_mask_path, + ) + + timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")) + + # offload TEs and AE to CPU, load model to gpu + if offload: + t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu() + torch.cuda.empty_cache() + model = model.to(torch_device) + + # denoise initial noise + x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance) + + # offload model, load autoencoder to gpu + if offload: + model.cpu() + torch.cuda.empty_cache() + ae.decoder.to(x.device) + + # decode latents to pixel space + x = unpack(x.float(), opts.height, opts.width) + with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): + x = ae.decode(x) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + t1 = time.perf_counter() + print(f"Done in {t1 - t0:.1f}s") + + idx = save_image( + nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage + ) + + if loop: + print("-" * 80) + opts = parse_prompt(opts) + opts = parse_img_cond_path(opts) + + with Image.open(opts.img_cond_path) as img: + width, height = img.size + opts.height = height + opts.width = width + + opts = parse_img_mask_path(opts) + else: + opts = None + + +if __name__ == "__main__": + Fire(main) diff --git a/flux/to_remove/cli_kontext.py b/flux/to_remove/cli_kontext.py new file mode 100644 index 0000000..17ad6a1 --- /dev/null +++ b/flux/to_remove/cli_kontext.py @@ -0,0 +1,368 @@ +import os +import re +import time +from dataclasses import dataclass +from glob import iglob + +import torch +from fire import Fire + +from flux.content_filters import PixtralContentFilter +from flux.sampling import denoise, get_schedule, prepare_kontext, unpack +from flux.util import ( + aspect_ratio_to_height_width, + check_onnx_access_for_trt, + load_ae, + load_clip, + load_flow_model, + load_t5, + save_image, +) + + +@dataclass +class SamplingOptions: + prompt: str + width: int | None + height: int | None + num_steps: int + guidance: float + seed: int | None + img_cond_path: str + + +def parse_prompt(options: SamplingOptions) -> SamplingOptions | None: + user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n" + usage = ( + "Usage: Either write your prompt directly, leave this field empty " + "to repeat the prompt or write a command starting with a slash:\n" + "- '/ar :' will set the aspect ratio of the generated image\n" + "- '/s ' sets the next seed\n" + "- '/g ' sets the guidance (flux-dev only)\n" + "- '/n ' sets the number of steps\n" + "- '/q' to quit" + ) + + while (prompt := input(user_question)).startswith("/"): + if prompt.startswith("/ar"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, ratio_prompt = prompt.split() + if ratio_prompt == "auto": + options.width = None + options.height = None + print("Setting resolution to input image resolution.") + else: + options.width, options.height = aspect_ratio_to_height_width(ratio_prompt) + print(f"Setting resolution to {options.width} x {options.height}.") + elif prompt.startswith("/h"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, height = prompt.split() + if height == "auto": + options.height = None + else: + options.height = 16 * (int(height) // 16) + if options.height is not None and options.width is not None: + print( + f"Setting resolution to {options.width} x {options.height} " + f"({options.height * options.width / 1e6:.2f}MP)" + ) + else: + print(f"Setting resolution to {options.width} x {options.height}.") + elif prompt.startswith("/g"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, guidance = prompt.split() + options.guidance = float(guidance) + print(f"Setting guidance to {options.guidance}") + elif prompt.startswith("/s"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, seed = prompt.split() + options.seed = int(seed) + print(f"Setting seed to {options.seed}") + elif prompt.startswith("/n"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, steps = prompt.split() + options.num_steps = int(steps) + print(f"Setting number of steps to {options.num_steps}") + elif prompt.startswith("/q"): + print("Quitting") + return None + else: + if not prompt.startswith("/h"): + print(f"Got invalid command '{prompt}'\n{usage}") + print(usage) + if prompt != "": + options.prompt = prompt + return options + + +def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None: + if options is None: + return None + + user_question = "Next input image (write /h for help, /q to quit and leave empty to repeat):\n" + usage = ( + "Usage: Either write a path to an image directly, leave this field empty " + "to repeat the last input image or write a command starting with a slash:\n" + "- '/q' to quit\n\n" + "The input image will be edited by FLUX.1 Kontext creating a new image based" + "on your instruction prompt." + ) + + while True: + img_cond_path = input(user_question) + + if img_cond_path.startswith("/"): + if img_cond_path.startswith("/q"): + print("Quitting") + return None + else: + if not img_cond_path.startswith("/h"): + print(f"Got invalid command '{img_cond_path}'\n{usage}") + print(usage) + continue + + if img_cond_path == "": + break + + if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith( + (".jpg", ".jpeg", ".png", ".webp") + ): + print(f"File '{img_cond_path}' does not exist or is not a valid image file") + continue + + options.img_cond_path = img_cond_path + break + + return options + + +@torch.inference_mode() +def main( + name: str = "flux-dev-kontext", + aspect_ratio: str | None = None, + seed: int | None = None, + prompt: str = "replace the logo with the text 'Black Forest Labs'", + device: str = "cuda" if torch.cuda.is_available() else "cpu", + num_steps: int = 30, + loop: bool = False, + guidance: float = 2.5, + offload: bool = False, + output_dir: str = "output", + add_sampling_metadata: bool = True, + img_cond_path: str = "assets/cup.png", + trt: bool = False, + trt_transformer_precision: str = "bf16", + track_usage: bool = False, +): + """ + Sample the flux model. Either interactively (set `--loop`) or run for a + single image. + + Args: + height: height of the sample in pixels (should be a multiple of 16), None + defaults to the size of the conditioning + width: width of the sample in pixels (should be a multiple of 16), None + defaults to the size of the conditioning + seed: Set a seed for sampling + output_name: where to save the output image, `{idx}` will be replaced + by the index of the sample + prompt: Prompt used for sampling + device: Pytorch device + num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled) + loop: start an interactive session and sample multiple times + guidance: guidance value used for guidance distillation + add_sampling_metadata: Add the prompt to the image Exif metadata + img_cond_path: path to conditioning image (jpeg/png/webp) + trt: use TensorRT backend for optimized inference + track_usage: track usage of the model for licensing purposes + """ + assert name == "flux-dev-kontext", f"Got unknown model name: {name}" + + torch_device = torch.device(device) + + output_name = os.path.join(output_dir, "img_{idx}.jpg") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + idx = 0 + else: + fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)] + if len(fns) > 0: + idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 + else: + idx = 0 + + if aspect_ratio is None: + width = None + height = None + else: + width, height = aspect_ratio_to_height_width(aspect_ratio) + + if not trt: + t5 = load_t5(torch_device, max_length=512) + clip = load_clip(torch_device) + model = load_flow_model(name, device="cpu" if offload else torch_device) + else: + # lazy import to make install optional + from flux.trt.trt_manager import ModuleName, TRTManager + + # Check if we need ONNX model access (which requires authentication for FLUX models) + onnx_dir = check_onnx_access_for_trt(name, trt_transformer_precision) + + trt_ctx_manager = TRTManager( + trt_transformer_precision=trt_transformer_precision, + trt_t5_precision=os.environ.get("TRT_T5_PRECISION", "bf16"), + ) + engines = trt_ctx_manager.load_engines( + model_name=name, + module_names={ + ModuleName.CLIP, + ModuleName.TRANSFORMER, + ModuleName.T5, + }, + engine_dir=os.environ.get("TRT_ENGINE_DIR", "./engines"), + custom_onnx_paths=onnx_dir or os.environ.get("CUSTOM_ONNX_PATHS", ""), + trt_image_height=height, + trt_image_width=width, + trt_batch_size=1, + trt_timing_cache=os.getenv("TRT_TIMING_CACHE_FILE", None), + trt_static_batch=False, + trt_static_shape=False, + ) + + model = engines[ModuleName.TRANSFORMER].to(device="cpu" if offload else torch_device) + clip = engines[ModuleName.CLIP].to(torch_device) + t5 = engines[ModuleName.T5].to(device="cpu" if offload else torch_device) + + ae = load_ae(name, device="cpu" if offload else torch_device) + content_filter = PixtralContentFilter(torch.device("cpu")) + + rng = torch.Generator(device="cpu") + opts = SamplingOptions( + prompt=prompt, + width=width, + height=height, + num_steps=num_steps, + guidance=guidance, + seed=seed, + img_cond_path=img_cond_path, + ) + + if loop: + opts = parse_prompt(opts) + opts = parse_img_cond_path(opts) + + while opts is not None: + if opts.seed is None: + opts.seed = rng.seed() + print(f"Generating with seed {opts.seed}:\n{opts.prompt}") + t0 = time.perf_counter() + + if content_filter.test_txt(opts.prompt): + print("Your prompt has been automatically flagged. Please choose another prompt.") + if loop: + print("-" * 80) + opts = parse_prompt(opts) + opts = parse_img_cond_path(opts) + else: + opts = None + continue + if content_filter.test_image(opts.img_cond_path): + print("Your input image has been automatically flagged. Please choose another image.") + if loop: + print("-" * 80) + opts = parse_prompt(opts) + opts = parse_img_cond_path(opts) + else: + opts = None + continue + + if offload: + t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch_device) + inp, height, width = prepare_kontext( + t5=t5, + clip=clip, + prompt=opts.prompt, + ae=ae, + img_cond_path=opts.img_cond_path, + target_width=opts.width, + target_height=opts.height, + bs=1, + seed=opts.seed, + device=torch_device, + ) + from safetensors.torch import save_file + + save_file({k: v.cpu().contiguous() for k, v in inp.items()}, "output/noise.sft") + inp.pop("img_cond_orig") + opts.seed = None + timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")) + + # offload TEs and AE to CPU, load model to gpu + if offload: + t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu() + torch.cuda.empty_cache() + model = model.to(torch_device) + + # denoise initial noise + t00 = time.time() + x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance) + torch.cuda.synchronize() + t01 = time.time() + print(f"Denoising took {t01 - t00:.3f}s") + + # offload model, load autoencoder to gpu + if offload: + model.cpu() + torch.cuda.empty_cache() + ae.decoder.to(x.device) + + # decode latents to pixel space + x = unpack(x.float(), height, width) + with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): + ae_dev_t0 = time.perf_counter() + x = ae.decode(x) + torch.cuda.synchronize() + ae_dev_t1 = time.perf_counter() + print(f"AE decode took {ae_dev_t1 - ae_dev_t0:.3f}s") + + if content_filter.test_image(x.cpu()): + print( + "Your output image has been automatically flagged. Choose another prompt/image or try again." + ) + if loop: + print("-" * 80) + opts = parse_prompt(opts) + opts = parse_img_cond_path(opts) + else: + opts = None + continue + + if torch.cuda.is_available(): + torch.cuda.synchronize() + t1 = time.perf_counter() + print(f"Done in {t1 - t0:.1f}s") + + idx = save_image( + None, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage + ) + + if loop: + print("-" * 80) + opts = parse_prompt(opts) + opts = parse_img_cond_path(opts) + else: + opts = None + + +if __name__ == "__main__": + Fire(main) diff --git a/flux/to_remove/cli_redux.py b/flux/to_remove/cli_redux.py new file mode 100644 index 0000000..71e59e1 --- /dev/null +++ b/flux/to_remove/cli_redux.py @@ -0,0 +1,290 @@ +import os +import re +import time +from dataclasses import dataclass +from glob import iglob + +import torch +from fire import Fire +from transformers import pipeline + +from flux.modules.image_embedders import ReduxImageEncoder +from flux.sampling import denoise, get_noise, get_schedule, prepare_redux, unpack +from flux.util import ( + get_checkpoint_path, + load_ae, + load_clip, + load_flow_model, + load_t5, + save_image, +) + + +@dataclass +class SamplingOptions: + prompt: str + width: int + height: int + num_steps: int + guidance: float + seed: int | None + img_cond_path: str + + +def parse_prompt(options: SamplingOptions) -> SamplingOptions | None: + user_question = "Write /h for help, /q to quit and leave empty to repeat):\n" + usage = ( + "Usage: Leave this field empty to do nothing " + "or write a command starting with a slash:\n" + "- '/w ' will set the width of the generated image\n" + "- '/h ' will set the height of the generated image\n" + "- '/s ' sets the next seed\n" + "- '/g ' sets the guidance (flux-dev only)\n" + "- '/n ' sets the number of steps\n" + "- '/q' to quit" + ) + + while (prompt := input(user_question)).startswith("/"): + if prompt.startswith("/w"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, width = prompt.split() + options.width = 16 * (int(width) // 16) + print( + f"Setting resolution to {options.width} x {options.height} " + f"({options.height * options.width / 1e6:.2f}MP)" + ) + elif prompt.startswith("/h"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, height = prompt.split() + options.height = 16 * (int(height) // 16) + print( + f"Setting resolution to {options.width} x {options.height} " + f"({options.height * options.width / 1e6:.2f}MP)" + ) + elif prompt.startswith("/g"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, guidance = prompt.split() + options.guidance = float(guidance) + print(f"Setting guidance to {options.guidance}") + elif prompt.startswith("/s"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, seed = prompt.split() + options.seed = int(seed) + print(f"Setting seed to {options.seed}") + elif prompt.startswith("/n"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, steps = prompt.split() + options.num_steps = int(steps) + print(f"Setting number of steps to {options.num_steps}") + elif prompt.startswith("/q"): + print("Quitting") + return None + else: + if not prompt.startswith("/h"): + print(f"Got invalid command '{prompt}'\n{usage}") + print(usage) + return options + + +def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None: + if options is None: + return None + + user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n" + usage = ( + "Usage: Either write your prompt directly, leave this field empty " + "to repeat the conditioning image or write a command starting with a slash:\n" + "- '/q' to quit" + ) + + while True: + img_cond_path = input(user_question) + + if img_cond_path.startswith("/"): + if img_cond_path.startswith("/q"): + print("Quitting") + return None + else: + if not img_cond_path.startswith("/h"): + print(f"Got invalid command '{img_cond_path}'\n{usage}") + print(usage) + continue + + if img_cond_path == "": + break + + if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith( + (".jpg", ".jpeg", ".png", ".webp") + ): + print(f"File '{img_cond_path}' does not exist or is not a valid image file") + continue + + options.img_cond_path = img_cond_path + break + + return options + + +@torch.inference_mode() +def main( + name: str = "flux-dev", + width: int = 1360, + height: int = 768, + seed: int | None = None, + device: str = "cuda" if torch.cuda.is_available() else "cpu", + num_steps: int | None = None, + loop: bool = False, + guidance: float = 2.5, + offload: bool = False, + output_dir: str = "output", + add_sampling_metadata: bool = True, + img_cond_path: str = "assets/robot.webp", + track_usage: bool = False, +): + """ + Sample the flux model. Either interactively (set `--loop`) or run for a + single image. + + Args: + name: Name of the base model to use (either 'flux-dev' or 'flux-schnell') + height: height of the sample in pixels (should be a multiple of 16) + width: width of the sample in pixels (should be a multiple of 16) + seed: Set a seed for sampling + device: Pytorch device + num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled) + loop: start an interactive session and sample multiple times + guidance: guidance value used for guidance distillation + offload: offload models to CPU when not in use + output_dir: where to save the output images + add_sampling_metadata: Add the prompt to the image Exif metadata + img_cond_path: path to conditioning image (jpeg/png/webp) + track_usage: track usage of the model for licensing purposes + """ + + nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device) + + if name not in (available := ["flux-dev", "flux-schnell"]): + raise ValueError(f"Got unknown model name: {name}, chose from {available}") + + torch_device = torch.device(device) + if num_steps is None: + num_steps = 4 if name == "flux-schnell" else 50 + + output_name = os.path.join(output_dir, "img_{idx}.jpg") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + idx = 0 + else: + fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)] + if len(fns) > 0: + idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 + else: + idx = 0 + + # init all components + t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512) + clip = load_clip(torch_device) + model = load_flow_model(name, device="cpu" if offload else torch_device) + ae = load_ae(name, device="cpu" if offload else torch_device) + + # Download and initialize the Redux adapter + redux_path = str( + get_checkpoint_path("black-forest-labs/FLUX.1-Redux-dev", "flux1-redux-dev.safetensors", "FLUX_REDUX") + ) + img_embedder = ReduxImageEncoder(torch_device, redux_path=redux_path) + + rng = torch.Generator(device="cpu") + prompt = "" + opts = SamplingOptions( + prompt=prompt, + width=width, + height=height, + num_steps=num_steps, + guidance=guidance, + seed=seed, + img_cond_path=img_cond_path, + ) + + if loop: + opts = parse_prompt(opts) + opts = parse_img_cond_path(opts) + + while opts is not None: + if opts.seed is None: + opts.seed = rng.seed() + print(f"Generating with seed {opts.seed}:\n{opts.prompt}") + t0 = time.perf_counter() + + # prepare input + x = get_noise( + 1, + opts.height, + opts.width, + device=torch_device, + dtype=torch.bfloat16, + seed=opts.seed, + ) + opts.seed = None + if offload: + ae = ae.cpu() + torch.cuda.empty_cache() + t5, clip = t5.to(torch_device), clip.to(torch_device) + inp = prepare_redux( + t5, + clip, + x, + prompt=opts.prompt, + encoder=img_embedder, + img_cond_path=opts.img_cond_path, + ) + timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")) + + # offload TEs to CPU, load model to gpu + if offload: + t5, clip = t5.cpu(), clip.cpu() + torch.cuda.empty_cache() + model = model.to(torch_device) + + # denoise initial noise + x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance) + + # offload model, load autoencoder to gpu + if offload: + model.cpu() + torch.cuda.empty_cache() + ae.decoder.to(x.device) + + # decode latents to pixel space + x = unpack(x.float(), opts.height, opts.width) + with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): + x = ae.decode(x) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + t1 = time.perf_counter() + print(f"Done in {t1 - t0:.1f}s") + + idx = save_image( + nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage + ) + + if loop: + print("-" * 80) + opts = parse_prompt(opts) + opts = parse_img_cond_path(opts) + else: + opts = None + + +if __name__ == "__main__": + Fire(main) diff --git a/flux/to_remove/content_filters.py b/flux/to_remove/content_filters.py new file mode 100644 index 0000000..8de89ed --- /dev/null +++ b/flux/to_remove/content_filters.py @@ -0,0 +1,171 @@ +import torch +from einops import rearrange +from PIL import Image +from transformers import AutoProcessor, LlavaForConditionalGeneration, pipeline + +PROMPT_IMAGE_INTEGRITY = """ +Task: Analyze an image to identify potential copyright concerns or depictions of public figures. + +Output: Respond with only "yes" or "no" + +Criteria for "yes": +- The image contains a recognizable character from copyrighted media (movies, TV, comics, games, etc.) +- The image displays a trademarked logo or brand +- The image depicts a recognizable public figure (celebrities, politicians, athletes, influencers, historical figures, etc.) + +Criteria for "no": +- All other cases +- When you cannot identify the specific copyrighted work or named individual + +Critical Requirements: +1. You must be able to name the exact copyrighted work or specific person depicted +2. General references to demographics or characteristics are not sufficient +3. Base your decision solely on visual content, not interpretation +4. Provide only the one-word answer: "yes" or "no" +""".strip() + + +PROMPT_IMAGE_INTEGRITY_FOLLOW_UP = "Does this image have copyright concerns or includes public figures?" + +PROMPT_TEXT_INTEGRITY = """ +Task: Analyze a text prompt to identify potential copyright concerns or requests to depict living public figures. + +Output: Respond with only "yes" or "no" + +Criteria for "Yes": +- The prompt explicitly names a character from copyrighted media (movies, TV, comics, games, etc.) +- The prompt explicitly mentions a trademarked logo or brand +- The prompt names or describes a specific living public figure (celebrities, politicians, athletes, influencers, etc.) + +Criteria for "No": +- All other cases +- When you cannot identify the specific copyrighted work or named individual + +Critical Requirements: +1. You must be able to name the exact copyrighted work or specific person referenced +2. General demographic descriptions or characteristics are not sufficient +3. Analyze only the prompt text, not potential image outcomes +4. Provide only the one-word answer: "yes" or "no" + +The prompt to check is: +----- +{prompt} +----- + +Does this prompt have copyright concerns or includes public figures? +""".strip() + + +class PixtralContentFilter(torch.nn.Module): + def __init__( + self, + device: torch.device = torch.device("cpu"), + nsfw_threshold: float = 0.85, + ): + super().__init__() + + model_id = "mistral-community/pixtral-12b" + self.processor = AutoProcessor.from_pretrained(model_id) + self.model = LlavaForConditionalGeneration.from_pretrained(model_id, device_map=device) + + self.yes_token, self.no_token = self.processor.tokenizer.encode(["yes", "no"]) + + self.nsfw_classifier = pipeline( + "image-classification", model="Falconsai/nsfw_image_detection", device=device + ) + self.nsfw_threshold = nsfw_threshold + + def yes_no_logit_processor( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor + ) -> torch.FloatTensor: + """ + Sets all tokens but yes/no to the minimum. + """ + scores_yes_token = scores[:, self.yes_token].clone() + scores_no_token = scores[:, self.no_token].clone() + scores_min = scores.min() + scores[:, :] = scores_min - 1 + scores[:, self.yes_token] = scores_yes_token + scores[:, self.no_token] = scores_no_token + return scores + + def test_image(self, image: Image.Image | str | torch.Tensor) -> bool: + if isinstance(image, torch.Tensor): + image = rearrange(image[0].clamp(-1.0, 1.0), "c h w -> h w c") + image = Image.fromarray((127.5 * (image + 1.0)).cpu().byte().numpy()) + elif isinstance(image, str): + image = Image.open(image) + + classification = next(c for c in self.nsfw_classifier(image) if c["label"] == "nsfw") + if classification["score"] > self.nsfw_threshold: + return True + + # 512^2 pixels are enough for checking + w, h = image.size + f = (512**2 / (w * h)) ** 0.5 + image = image.resize((int(f * w), int(f * h))) + + chat = [ + { + "role": "user", + "content": [ + { + "type": "text", + "content": PROMPT_IMAGE_INTEGRITY, + }, + { + "type": "image", + "image": image, + }, + { + "type": "text", + "content": PROMPT_IMAGE_INTEGRITY_FOLLOW_UP, + }, + ], + } + ] + + inputs = self.processor.apply_chat_template( + chat, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ).to(self.model.device) + + generate_ids = self.model.generate( + **inputs, + max_new_tokens=1, + logits_processor=[self.yes_no_logit_processor], + do_sample=False, + ) + return generate_ids[0, -1].item() == self.yes_token + + def test_txt(self, txt: str) -> bool: + chat = [ + { + "role": "user", + "content": [ + { + "type": "text", + "content": PROMPT_TEXT_INTEGRITY.format(prompt=txt), + }, + ], + } + ] + + inputs = self.processor.apply_chat_template( + chat, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ).to(self.model.device) + + generate_ids = self.model.generate( + **inputs, + max_new_tokens=1, + logits_processor=[self.yes_no_logit_processor], + do_sample=False, + ) + return generate_ids[0, -1].item() == self.yes_token diff --git a/flux/util.py b/flux/util.py new file mode 100644 index 0000000..9b477a0 --- /dev/null +++ b/flux/util.py @@ -0,0 +1,702 @@ +import getpass +import math +import os +from dataclasses import dataclass +from pathlib import Path + +import requests +import torch +from einops import rearrange +from huggingface_hub import hf_hub_download, login +from PIL import ExifTags, Image +from safetensors.torch import load_file as load_sft + +from flux.model import Flux, FluxLoraWrapper, FluxParams +from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams +from flux.modules.conditioner import HFEmbedder + +CHECKPOINTS_DIR = Path("checkpoints") +CHECKPOINTS_DIR.mkdir(exist_ok=True) +BFL_API_KEY = os.getenv("BFL_API_KEY") + +os.environ.setdefault("TRT_ENGINE_DIR", str(CHECKPOINTS_DIR / "trt_engines")) +(CHECKPOINTS_DIR / "trt_engines").mkdir(exist_ok=True) + + +def ensure_hf_auth(): + hf_token = os.environ.get("HF_TOKEN") + if hf_token: + print("Trying to authenticate to HuggingFace with the HF_TOKEN environment variable.") + try: + login(token=hf_token) + print("Successfully authenticated with HuggingFace using HF_TOKEN") + return True + except Exception as e: + print(f"Warning: Failed to authenticate with HF_TOKEN: {e}") + + if os.path.exists(os.path.expanduser("~/.cache/huggingface/token")): + print("Already authenticated with HuggingFace") + return True + + return False + + +def prompt_for_hf_auth(): + try: + token = getpass.getpass("HF Token (hidden input): ").strip() + if not token: + print("No token provided. Aborting.") + return False + + login(token=token) + print("Successfully authenticated!") + return True + except KeyboardInterrupt: + print("\nAuthentication cancelled by user.") + return False + except Exception as auth_e: + print(f"Authentication failed: {auth_e}") + print("Tip: You can also run 'huggingface-cli login' or set HF_TOKEN environment variable") + return False + + +def get_checkpoint_path(repo_id: str, filename: str, env_var: str) -> Path: + """Get the local path for a checkpoint file, downloading if necessary.""" + # if os.environ.get(env_var) is not None: + # local_path = os.environ[env_var] + # if os.path.exists(local_path): + # return Path(local_path) + + # print( + # f"Trying to load model {repo_id}, {filename} from environment " + # f"variable {env_var}. But file {local_path} does not exist. " + # "Falling back to default location." + # ) + + # # Create a safe directory name from repo_id + # safe_repo_name = repo_id.replace("/", "_") + # checkpoint_dir = CHECKPOINTS_DIR / safe_repo_name + # checkpoint_dir.mkdir(exist_ok=True) + + # local_path = checkpoint_dir / filename + + local_path = filename + from mmgp import offload + + if False: + print(f"Downloading {filename} from {repo_id} to {local_path}") + try: + ensure_hf_auth() + hf_hub_download(repo_id=repo_id, filename=filename, local_dir=checkpoint_dir) + except Exception as e: + if "gated repo" in str(e).lower() or "restricted" in str(e).lower(): + print(f"\nError: Cannot access {repo_id} -- this is a gated repository.") + + # Try one more time to authenticate + if prompt_for_hf_auth(): + # Retry the download after authentication + print("Retrying download...") + hf_hub_download(repo_id=repo_id, filename=filename, local_dir=checkpoint_dir) + else: + print("Authentication failed or cancelled.") + print("You can also run 'huggingface-cli login' or set HF_TOKEN environment variable") + raise RuntimeError(f"Authentication required for {repo_id}") + else: + raise e + + return local_path + + +def download_onnx_models_for_trt(model_name: str, trt_transformer_precision: str = "bf16") -> str | None: + """Download ONNX models for TRT to our checkpoints directory""" + onnx_repo_map = { + "flux-dev": "black-forest-labs/FLUX.1-dev-onnx", + "flux-schnell": "black-forest-labs/FLUX.1-schnell-onnx", + "flux-dev-canny": "black-forest-labs/FLUX.1-Canny-dev-onnx", + "flux-dev-depth": "black-forest-labs/FLUX.1-Depth-dev-onnx", + "flux-dev-redux": "black-forest-labs/FLUX.1-Redux-dev-onnx", + "flux-dev-fill": "black-forest-labs/FLUX.1-Fill-dev-onnx", + "flux-dev-kontext": "black-forest-labs/FLUX.1-Kontext-dev-onnx", + } + + if model_name not in onnx_repo_map: + return None # No ONNX repository required for this model + + repo_id = onnx_repo_map[model_name] + safe_repo_name = repo_id.replace("/", "_") + onnx_dir = CHECKPOINTS_DIR / safe_repo_name + + # Map of module names to their ONNX file paths (using specified precision) + onnx_file_map = { + "clip": "clip.opt/model.onnx", + "transformer": f"transformer.opt/{trt_transformer_precision}/model.onnx", + "transformer_data": f"transformer.opt/{trt_transformer_precision}/backbone.onnx_data", + "t5": "t5.opt/model.onnx", + "t5_data": "t5.opt/backbone.onnx_data", + "vae": "vae.opt/model.onnx", + } + + # If all files exist locally, return the custom_onnx_paths format + if onnx_dir.exists(): + all_files_exist = True + custom_paths = [] + for module, onnx_file in onnx_file_map.items(): + if module.endswith("_data"): + continue # Skip data files + local_path = onnx_dir / onnx_file + if not local_path.exists(): + all_files_exist = False + break + custom_paths.append(f"{module}:{local_path}") + + if all_files_exist: + print(f"ONNX models ready in {onnx_dir}") + return ",".join(custom_paths) + + # If not all files exist, download them + print(f"Downloading ONNX models from {repo_id} to {onnx_dir}") + print(f"Using transformer precision: {trt_transformer_precision}") + onnx_dir.mkdir(exist_ok=True) + + # Download all ONNX files + for module, onnx_file in onnx_file_map.items(): + local_path = onnx_dir / onnx_file + if local_path.exists(): + continue # Already downloaded + + # Create parent directories + local_path.parent.mkdir(parents=True, exist_ok=True) + + try: + print(f"Downloading {onnx_file}") + hf_hub_download(repo_id=repo_id, filename=onnx_file, local_dir=onnx_dir) + except Exception as e: + if "does not exist" in str(e).lower() or "not found" in str(e).lower(): + continue + elif "gated repo" in str(e).lower() or "restricted" in str(e).lower(): + print(f"Cannot access {repo_id} - requires license acceptance") + print("Please follow these steps:") + print(f" 1. Visit: https://huggingface.co/{repo_id}") + print(" 2. Log in to your HuggingFace account") + print(" 3. Accept the license terms and conditions") + print(" 4. Then retry this command") + raise RuntimeError(f"License acceptance required for {model_name}") + else: + # Re-raise other errors + raise + + print(f"ONNX models ready in {onnx_dir}") + + # Return the custom_onnx_paths format that TRT expects: "module1:path1,module2:path2" + # Note: Only return the actual module paths, not the data file + custom_paths = [] + for module, onnx_file in onnx_file_map.items(): + if module.endswith("_data"): + continue # Skip the data file in the return paths + full_path = onnx_dir / onnx_file + if full_path.exists(): + custom_paths.append(f"{module}:{full_path}") + + return ",".join(custom_paths) + + +def check_onnx_access_for_trt(model_name: str, trt_transformer_precision: str = "bf16") -> str | None: + """Check ONNX access and download models for TRT - returns ONNX directory path""" + return download_onnx_models_for_trt(model_name, trt_transformer_precision) + + +def track_usage_via_api(name: str, n=1) -> None: + """ + Track usage of licensed models via the BFL API for commercial licensing compliance. + + For more information on licensing BFL's models for commercial use and usage reporting, + see the README.md or visit: https://dashboard.bfl.ai/licensing/subscriptions?showInstructions=true + """ + assert BFL_API_KEY is not None, "BFL_API_KEY is not set" + + model_slug_map = { + "flux-dev": "flux-1-dev", + "flux-dev-kontext": "flux-1-kontext-dev", + "flux-dev-fill": "flux-tools", + "flux-dev-depth": "flux-tools", + "flux-dev-canny": "flux-tools", + "flux-dev-canny-lora": "flux-tools", + "flux-dev-depth-lora": "flux-tools", + "flux-dev-redux": "flux-tools", + } + + if name not in model_slug_map: + print(f"Skipping tracking usage for {name}, as it cannot be tracked. Please check the model name.") + return + + model_slug = model_slug_map[name] + url = f"https://api.bfl.ai/v1/licenses/models/{model_slug}/usage" + headers = {"x-key": BFL_API_KEY, "Content-Type": "application/json"} + payload = {"number_of_generations": n} + + response = requests.post(url, headers=headers, json=payload) + if response.status_code != 200: + raise Exception(f"Failed to track usage: {response.status_code} {response.text}") + else: + print(f"Successfully tracked usage for {name} with {n} generations") + + +def save_image( + nsfw_classifier, + name: str, + output_name: str, + idx: int, + x: torch.Tensor, + add_sampling_metadata: bool, + prompt: str, + nsfw_threshold: float = 0.85, + track_usage: bool = False, +) -> int: + fn = output_name.format(idx=idx) + print(f"Saving {fn}") + # bring into PIL format and save + x = x.clamp(-1, 1) + x = rearrange(x[0], "c h w -> h w c") + img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) + + if nsfw_classifier is not None: + nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0] + else: + nsfw_score = nsfw_threshold - 1.0 + + if nsfw_score < nsfw_threshold: + exif_data = Image.Exif() + if name in ["flux-dev", "flux-schnell"]: + exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux" + else: + exif_data[ExifTags.Base.Software] = "AI generated;img2img;flux" + exif_data[ExifTags.Base.Make] = "Black Forest Labs" + exif_data[ExifTags.Base.Model] = name + if add_sampling_metadata: + exif_data[ExifTags.Base.ImageDescription] = prompt + img.save(fn, exif=exif_data, quality=95, subsampling=0) + if track_usage: + track_usage_via_api(name, 1) + idx += 1 + else: + print("Your generated image may contain NSFW content.") + + return idx + + +@dataclass +class ModelSpec: + params: FluxParams + ae_params: AutoEncoderParams + repo_id: str + repo_flow: str + repo_ae: str + lora_repo_id: str | None = None + lora_filename: str | None = None + + +configs = { + "flux-dev": ModelSpec( + repo_id="", + repo_flow="", + repo_ae="ckpts/flux_vae.safetensors", + params=FluxParams( + in_channels=64, + out_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-schnell": ModelSpec( + repo_id="black-forest-labs/FLUX.1-schnell", + repo_flow="", + repo_ae="ckpts/flux_vae.safetensors", + params=FluxParams( + in_channels=64, + out_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=False, + ), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-dev-canny": ModelSpec( + repo_id="black-forest-labs/FLUX.1-Canny-dev", + repo_flow="", + repo_ae="ckpts/flux_vae.safetensors", + params=FluxParams( + in_channels=128, + out_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-dev-canny-lora": ModelSpec( + repo_id="black-forest-labs/FLUX.1-dev", + repo_flow="", + repo_ae="ckpts/flux_vae.safetensors", + lora_repo_id="black-forest-labs/FLUX.1-Canny-dev-lora", + lora_filename="flux1-canny-dev-lora.safetensors", + params=FluxParams( + in_channels=128, + out_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-dev-depth": ModelSpec( + repo_id="black-forest-labs/FLUX.1-Depth-dev", + repo_flow="", + repo_ae="ckpts/flux_vae.safetensors", + params=FluxParams( + in_channels=128, + out_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-dev-depth-lora": ModelSpec( + repo_id="black-forest-labs/FLUX.1-dev", + repo_flow="", + repo_ae="ckpts/flux_vae.safetensors", + lora_repo_id="black-forest-labs/FLUX.1-Depth-dev-lora", + lora_filename="flux1-depth-dev-lora.safetensors", + params=FluxParams( + in_channels=128, + out_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-dev-redux": ModelSpec( + repo_id="black-forest-labs/FLUX.1-Redux-dev", + repo_flow="", + repo_ae="ckpts/flux_vae.safetensors", + params=FluxParams( + in_channels=64, + out_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-dev-fill": ModelSpec( + repo_id="black-forest-labs/FLUX.1-Fill-dev", + repo_flow="", + repo_ae="ckpts/flux_vae.safetensors", + params=FluxParams( + in_channels=384, + out_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-dev-kontext": ModelSpec( + repo_id="black-forest-labs/FLUX.1-Kontext-dev", + repo_flow="", + repo_ae="ckpts/flux_vae.safetensors", + params=FluxParams( + in_channels=64, + out_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), +} + + +PREFERED_KONTEXT_RESOLUTIONS = [ + (672, 1568), + (688, 1504), + (720, 1456), + (752, 1392), + (800, 1328), + (832, 1248), + (880, 1184), + (944, 1104), + (1024, 1024), + (1104, 944), + (1184, 880), + (1248, 832), + (1328, 800), + (1392, 752), + (1456, 720), + (1504, 688), + (1568, 672), +] + + +def aspect_ratio_to_height_width(aspect_ratio: str, area: int = 1024**2) -> tuple[int, int]: + width = float(aspect_ratio.split(":")[0]) + height = float(aspect_ratio.split(":")[1]) + ratio = width / height + width = round(math.sqrt(area * ratio)) + height = round(math.sqrt(area / ratio)) + return 16 * (width // 16), 16 * (height // 16) + + +def print_load_warning(missing: list[str], unexpected: list[str]) -> None: + if len(missing) > 0 and len(unexpected) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + print("\n" + "-" * 79 + "\n") + print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + elif len(missing) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + elif len(unexpected) > 0: + print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + + +def load_flow_model(name: str, model_filename, device: str | torch.device = "cuda", verbose: bool = True) -> Flux: + # Loading Flux + config = configs[name] + + ckpt_path = model_filename #config.repo_flow + + with torch.device("meta"): + if config.lora_repo_id is not None and config.lora_filename is not None: + model = FluxLoraWrapper(params=config.params).to(torch.bfloat16) + else: + model = Flux(config.params).to(torch.bfloat16) + + # print(f"Loading checkpoint: {ckpt_path}") + from mmgp import offload + offload.load_model_data(model, model_filename ) + + # # load_sft doesn't support torch.device + # sd = load_sft(ckpt_path, device=str(device)) + # sd = optionally_expand_state_dict(model, sd) + # missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) + # if verbose: + # print_load_warning(missing, unexpected) + + # if config.lora_repo_id is not None and config.lora_filename is not None: + # print("Loading LoRA") + # lora_path = str(get_checkpoint_path(config.lora_repo_id, config.lora_filename, "FLUX_LORA")) + # lora_sd = load_sft(lora_path, device=str(device)) + # # loading the lora params + overwriting scale values in the norms + # missing, unexpected = model.load_state_dict(lora_sd, strict=False, assign=True) + # if verbose: + # print_load_warning(missing, unexpected) + return model + + +def load_t5(device: str | torch.device = "cuda", text_encoder_filename = None, max_length: int = 512) -> HFEmbedder: + # max length 64, 128, 256 and 512 should work (if your sequence is short enough) + return HFEmbedder("",text_encoder_filename, max_length=max_length, torch_dtype=torch.bfloat16).to(device) + + +def load_clip(device: str | torch.device = "cuda") -> HFEmbedder: + return HFEmbedder("ckpts/clip_vit_large_patch14", "", max_length=77, torch_dtype=torch.bfloat16, is_clip =True).to(device) + + +def load_ae(name: str, device: str | torch.device = "cuda") -> AutoEncoder: + config = configs[name] + ckpt_path = str(get_checkpoint_path(config.repo_id, config.repo_ae, "FLUX_AE")) + + # Loading the autoencoder + with torch.device("meta"): + ae = AutoEncoder(config.ae_params) + + # print(f"Loading AE checkpoint: {ckpt_path}") + sd = load_sft(ckpt_path, device=str(device)) + missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) + print_load_warning(missing, unexpected) + return ae + + +def optionally_expand_state_dict(model: torch.nn.Module, state_dict: dict) -> dict: + """ + Optionally expand the state dict to match the model's parameters shapes. + """ + for name, param in model.named_parameters(): + if name in state_dict: + if state_dict[name].shape != param.shape: + print( + f"Expanding '{name}' with shape {state_dict[name].shape} to model parameter with shape {param.shape}." + ) + # expand with zeros: + expanded_state_dict_weight = torch.zeros_like(param, device=state_dict[name].device) + slices = tuple(slice(0, dim) for dim in state_dict[name].shape) + expanded_state_dict_weight[slices] = state_dict[name] + state_dict[name] = expanded_state_dict_weight + + return state_dict + + diff --git a/ltx_video/ltxv.py b/ltx_video/ltxv.py index 7b1c6fe..6b43c38 100644 --- a/ltx_video/ltxv.py +++ b/ltx_video/ltxv.py @@ -149,6 +149,7 @@ class LTXV: self, model_filepath: str, text_encoder_filepath: str, + model_def, dtype = torch.bfloat16, VAE_dtype = torch.bfloat16, mixed_precision_transformer = False @@ -157,8 +158,8 @@ class LTXV: # if dtype == torch.float16: dtype = torch.bfloat16 self.mixed_precision_transformer = mixed_precision_transformer - self.distilled = any("lora" in name for name in model_filepath) - model_filepath = [name for name in model_filepath if not "lora" in name ] + self.model_def = model_def + self.pipeline_config = model_def["LTXV_config"] # with safe_open(ckpt_path, framework="pt") as f: # metadata = f.metadata() # config_str = metadata.get("config") @@ -220,11 +221,11 @@ class LTXV: prompt_enhancer_llm_model = None prompt_enhancer_llm_tokenizer = None - if prompt_enhancer_image_caption_model != None: - pipe["prompt_enhancer_image_caption_model"] = prompt_enhancer_image_caption_model - prompt_enhancer_image_caption_model._model_dtype = torch.float + # if prompt_enhancer_image_caption_model != None: + # pipe["prompt_enhancer_image_caption_model"] = prompt_enhancer_image_caption_model + # prompt_enhancer_image_caption_model._model_dtype = torch.float - pipe["prompt_enhancer_llm_model"] = prompt_enhancer_llm_model + # pipe["prompt_enhancer_llm_model"] = prompt_enhancer_llm_model # offload.profile(pipe, profile_no=5, extraModelsToQuantize = None, quantizeTransformer = False, budgets = { "prompt_enhancer_llm_model" : 10000, "prompt_enhancer_image_caption_model" : 10000, "vae" : 3000, "*" : 100 }, verboseLevel=2) @@ -299,14 +300,10 @@ class LTXV: conditioning_media_paths = None conditioning_start_frames = None - if self.distilled : - pipeline_config = "ltx_video/configs/ltxv-13b-0.9.7-distilled.yaml" - else: - pipeline_config = "ltx_video/configs/ltxv-13b-0.9.7-dev.yaml" # check if pipeline_config is a file - if not os.path.isfile(pipeline_config): - raise ValueError(f"Pipeline config file {pipeline_config} does not exist") - with open(pipeline_config, "r") as f: + if not os.path.isfile(self.pipeline_config): + raise ValueError(f"Pipeline config file {self.pipeline_config} does not exist") + with open(self.pipeline_config, "r") as f: pipeline_config = yaml.safe_load(f) @@ -520,7 +517,7 @@ def get_media_num_frames(media_path: str) -> int: return media_path.shape[1] elif isinstance(media_path, str) and any( media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]): reader = imageio.get_reader(media_path) - return min(reader.count_frames(), max_frames) + return min(reader.count_frames(), 0) # to do else: raise Exception("video format not supported") @@ -564,6 +561,3 @@ def load_media_file( raise Exception("video format not supported") return media_tensor - -if __name__ == "__main__": - main() diff --git a/wan/any2video.py b/wan/any2video.py index a2901d4..91f7258 100644 --- a/wan/any2video.py +++ b/wan/any2video.py @@ -106,18 +106,18 @@ class WanAny2V: # config = json.load(f) # from mmgp import safetensors2 # sd = safetensors2.torch_load_file(xmodel_filename) - + # model_filename = "c:/temp/fflf/diffusion_pytorch_model-00001-of-00007.safetensors" base_config_file = f"configs/{base_model_type}.json" - forcedConfigPath = base_config_file if len(model_filename) > 1 or base_model_type in ["flf2v_720p"] else None + forcedConfigPath = base_config_file if len(model_filename) > 1 else None + # forcedConfigPath = base_config_file = f"configs/flf2v_720p.json" # model_filename[1] = xmodel_filename self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) # self.model = offload.load_model_data(self.model, xmodel_filename ) # offload.load_model_data(self.model, "c:/temp/Phantom-Wan-1.3B.pth") - # self.model.to(torch.bfloat16) - # self.model.cpu() self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype) offload.change_dtype(self.model, dtype, True) - # offload.save_model(self.model, "multitalkbf16.safetensors", config_file_path=base_config_file, filter_sd=sd) + # offload.save_model(self.model, "flf2v_720p.safetensors", config_file_path=base_config_file) + # offload.save_model(self.model, "flf2v_quanto_int8_fp16_720p.safetensors", do_quantize= True, config_file_path=base_config_file) # offload.save_model(self.model, "multitalk_quanto_fp16.safetensors", do_quantize= True, config_file_path=base_config_file, filter_sd=sd) # offload.save_model(self.model, "wan2.1_selforcing_fp16.safetensors", config_file_path=base_config_file) @@ -126,7 +126,7 @@ class WanAny2V: self.model.eval().requires_grad_(False) if save_quantized: from wgp import save_quantized_model - save_quantized_model(self.model, model_type, model_filename[1 if base_model_type=="fantasy" else 0], dtype, base_config_file) + save_quantized_model(self.model, model_type, model_filename[0], dtype, base_config_file) self.sample_neg_prompt = config.sample_neg_prompt @@ -477,8 +477,8 @@ class WanAny2V: any_end_frame = False if input_frames != None: _ , preframes_count, height, width = input_frames.shape - lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2] - clip_context = self.clip.visual([input_frames[:, -1:]]) #.to(self.param_dtype) + lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2] + clip_context = self.clip.visual([input_frames[:, -1:]]) if model_type != "flf2v_720p" else self.clip.visual([input_frames[:, -1:], input_frames[:, -1:]]) input_frames = input_frames.to(device=self.device).to(dtype= self.VAE_dtype) enc = torch.concat( [input_frames, torch.zeros( (3, frame_num-preframes_count, height, width), device=self.device, dtype= self.VAE_dtype)], @@ -488,7 +488,7 @@ class WanAny2V: preframes_count = 1 image_start = TF.to_tensor(image_start) any_end_frame = image_end != None - add_frames_for_end_image = any_end_frame and model_type not in ["fun_inp_1.3B", "fun_inp", "i2v_720p"] + add_frames_for_end_image = any_end_frame and model_type == "i2v" if any_end_frame: image_end = TF.to_tensor(image_end) if add_frames_for_end_image: @@ -517,8 +517,8 @@ class WanAny2V: img_interpolated2 = resize_lanczos(image_end, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device) #, self.dtype image_end = resize_lanczos(image_end, clip_image_size, clip_image_size) image_end = image_end.sub_(0.5).div_(0.5).to(self.device) #, self.dtype - if image_end != None and model_type == "flf2v_720p": - clip_context = self.clip.visual([image_start[:, None, :, :], image_end[:, None, :, :]]) + if model_type == "flf2v_720p": + clip_context = self.clip.visual([image_start[:, None, :, :], image_end[:, None, :, :] if image_end != None else image_start[:, None, :, :]]) else: clip_context = self.clip.visual([image_start[:, None, :, :]]) @@ -753,7 +753,7 @@ class WanAny2V: "context" : [context, context_null, context_null], "audio_scale": [audio_scale, None, None ] } - elif multitalk: + elif multitalk and audio_proj != None: gen_args = { "x" : [latent_model_input, latent_model_input, latent_model_input], "context" : [context, context_null, context_null], diff --git a/wan/multitalk/multitalk.py b/wan/multitalk/multitalk.py index e429371..3945682 100644 --- a/wan/multitalk/multitalk.py +++ b/wan/multitalk/multitalk.py @@ -184,6 +184,7 @@ def get_full_audio_embeddings(audio_guide1 = None, audio_guide2 = None, combinat def get_window_audio_embeddings(full_audio_embs, audio_start_idx=0, clip_length = 81, vae_scale = 4, audio_window = 5): + if full_audio_embs == None: return None HUMAN_NUMBER = len(full_audio_embs) audio_end_idx = audio_start_idx + clip_length indices = (torch.arange(2 * 2 + 1) - 2) * 1 diff --git a/wgp.py b/wgp.py index ee8b669..c9fc10a 100644 --- a/wgp.py +++ b/wgp.py @@ -262,7 +262,7 @@ def process_prompt_and_add_tasks(state, model_choice): MMAudio_setting = inputs["MMAudio_setting"] if skip_steps_cache_type == "mag": - if model_type in ["sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"]: + if model_type in ["sky_df_1.3B", "sky_df_14B"]: gr.Info("Mag Cache is not supported with Diffusion Forcing") return if num_inference_steps > 50: @@ -1294,6 +1294,12 @@ def _parse_args(): help="Path to a directory that contains LTX Videos Loras" ) + parser.add_argument( + "--lora-dir-flux", + type=str, + default="loras_flux", + help="Path to a directory that contains flux images Loras" + ) parser.add_argument( @@ -1534,6 +1540,8 @@ def get_lora_dir(model_type): return root_lora_dir elif model_family == "ltxv": return args.lora_dir_ltxv + elif model_family == "flux": + return args.lora_dir_flux elif model_family =="hunyuan": if i2v: return args.lora_dir_hunyuan_i2v @@ -1622,26 +1630,13 @@ for path in ["wan2.1_Vace_1.3B_preview_bf16.safetensors", "sky_reels2_diffusion "sky_reels2_diffusion_forcing_720p_14B_quanto_int8.safetensors", "sky_reels2_diffusion_forcing_720p_14B_quanto_fp16_int8.safetensors", "wan2.1_image2video_480p_14B_bf16.safetensors", "wan2.1_image2video_480p_14B_quanto_int8.safetensors", "wan2.1_image2video_720p_14B_quanto_int8.safetensors", "wan2.1_image2video_720p_14B_quanto_fp16_int8.safetensors", "wan2.1_image2video_720p_14B_bf16.safetensors", "wan2.1_text2video_14B_bf16.safetensors", "wan2.1_text2video_14B_quanto_int8.safetensors", -"wan2.1_Vace_14B_mbf16.safetensors", "wan2.1_Vace_14B_quanto_mbf16_int8.safetensors" +"wan2.1_Vace_14B_mbf16.safetensors", "wan2.1_Vace_14B_quanto_mbf16_int8.safetensors", "wan2.1_FLF2V_720p_14B_quanto_int8.safetensors", "wan2.1_FLF2V_720p_14B_bf16.safetensors" ]: if Path(os.path.join("ckpts" , path)).is_file(): print(f"Removing old version of model '{path}'. A new version of this model will be downloaded next time you use it.") os.remove( os.path.join("ckpts" , path)) -finetunes = {} - -wan_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_mbf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_text2video_14B_quanto_mfp16_int8.safetensors", - "ckpts/wan2.1_recammaster_1.3B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_bf16.safetensors", - "ckpts/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_quanto_fp16_int8.safetensors", "ckpts/sky_reels2_diffusion_forcing_720p_14B_mbf16.safetensors", - "ckpts/sky_reels2_diffusion_forcing_720p_14B_quanto_mbf16_int8.safetensors", "ckpts/sky_reels2_diffusion_forcing_720p_14B_quanto_mfp16_int8.safetensors", - "ckpts/wan2.1_Vace_1.3B_mbf16.safetensors", - "ckpts/wan2_1_phantom_1.3B_mbf16.safetensors", "ckpts/wan2.1_phantom_14B_mbf16.safetensors", "ckpts/wan2.1_phantom_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_phantom_14B_quanto_mfp16_int8.safetensors", - ] -wan_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_mbf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_mfp16_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_mbf16.safetensors", - "ckpts/wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_mfp16_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", - "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_14B_quanto_fp16_int8.safetensors", - ] -ltxv_choices= ["ckpts/ltxv_0.9.7_13B_dev_bf16.safetensors", "ckpts/ltxv_0.9.7_13B_dev_quanto_bf16_int8.safetensors", "ckpts/ltxv_0.9.7_13B_distilled_lora128_bf16.safetensors"] +models_def = {} modules_files = { "vace_14B" : ["ckpts/wan2.1_Vace_14B_module_mbf16.safetensors", "ckpts/wan2.1_Vace_14B_module_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_Vace_14B_module_quanto_mfp16_int8.safetensors"], @@ -1649,31 +1644,15 @@ modules_files = { "multitalk": ["ckpts/wan2.1_multitalk_14B_mbf16.safetensors", "ckpts/wan2.1_multitalk_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_multitalk_14B_quanto_mfp16_int8.safetensors"] } +# unused +base_types = ["multitalk", "fantasy", "vace_14B", "vace_multitalk_14B", + "t2v_1.3B", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B", + "recam_1.3B", "sky_df_1.3B", "sky_df_14B", + "i2v", "flf2v_720p", "fun_inp_1.3B", "fun_inp", "ltxv_13B", + "hunyuan", "hunyuan_i2v", "hunyuan_custom", "hunyuan_custom_audio", "hunyuan_custom_edit", "hunyuan_avatar", + ] -hunyuan_choices= ["ckpts/hunyuan_video_720_bf16.safetensors", "ckpts/hunyuan_video_720_quanto_int8.safetensors", "ckpts/hunyuan_video_i2v_720_bf16v2.safetensors", "ckpts/hunyuan_video_i2v_720_quanto_int8v2.safetensors", - "ckpts/hunyuan_video_custom_720_bf16.safetensors", "ckpts/hunyuan_video_custom_720_quanto_bf16_int8.safetensors", - "ckpts/hunyuan_video_custom_audio_720_bf16.safetensors", "ckpts/hunyuan_video_custom_audio_720_quanto_bf16_int8.safetensors", - "ckpts/hunyuan_video_custom_edit_720_bf16.safetensors", "ckpts/hunyuan_video_custom_edit_720_quanto_bf16_int8.safetensors", - "ckpts/hunyuan_video_avatar_720_bf16.safetensors", "ckpts/hunyuan_video_avatar_720_quanto_bf16_int8.safetensors", - ] - -transformer_choices = wan_choices_t2v + wan_choices_i2v + ltxv_choices + hunyuan_choices -def get_dependent_models(model_type, quantization, dtype_policy ): - # if model_type == "fantasy": - # dependent_model_type = "i2v_720p" - if model_type == "ltxv_13B_distilled": - dependent_model_type = "ltxv_13B" - # elif model_type == "vace_14B": - # dependent_model_type = "t2v" - else: - return [], [] - return [get_model_filename(dependent_model_type, quantization, dtype_policy)], [dependent_model_type] - -abstract_model_types = ["multitalk", "fantasy", "vace_14B", "vace_multitalk_14B", "flf2v_720p"] -model_types = [ "t2v_1.3B", "t2v", "i2v", "i2v_720p", "vace_1.3B", "phantom_1.3B", "phantom_14B", - "fun_inp_1.3B", "fun_inp", "recam_1.3B", "sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "ltxv_13B", "ltxv_13B_distilled", - "hunyuan", "hunyuan_i2v", "hunyuan_custom", "hunyuan_custom_audio", "hunyuan_custom_edit", "hunyuan_avatar"] - +# only needed for imported old settings files model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B", "i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B", "vace_14B": "Vace_14B", "recam_1.3B": "recammaster_1.3B", "sky_df_1.3B" : "sky_reels2_diffusion_forcing_1.3B", "sky_df_14B" : "sky_reels2_diffusion_forcing_14B", @@ -1683,12 +1662,12 @@ model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", " "hunyuan_avatar" : "hunyuan_video_avatar" } def get_base_model_type(model_type): - finetune_def = get_model_finetune_def(model_type) - if finetune_def == None: + model_def = get_model_def(model_type) + if model_def == None: return model_type if model_type in model_types else None # return model_type else: - return finetune_def["architecture"] + return model_def["architecture"] def are_model_types_compatible(imported_model_type, current_model_type): imported_base_model_type = get_base_model_type(imported_model_type) @@ -1697,29 +1676,26 @@ def are_model_types_compatible(imported_model_type, current_model_type): return True eqv_map = { - "i2v_720p" : "i2v", "flf2v_720p" : "i2v", "t2v_1.3B" : "t2v", "sky_df_1.3B" : "sky_df_14B", - "sky_df_720p_14B" : "sky_df_14B", } if imported_base_model_type in eqv_map: imported_base_model_type = eqv_map[imported_base_model_type] comp_map = { "vace_14B" : [ "vace_multitalk_14B"], "t2v" : [ "vace_14B", "vace_1.3B" "vace_multitalk_14B", "t2v_1.3B", "phantom_1.3B","phantom_14B"], - "i2v" : [ "fantasy", "multitalk", "i2v_720p","flf2v_720p" ], - "ltxv_13B_distilled": ["ltxv_13B"], + "i2v" : [ "fantasy", "multitalk", "flf2v_720p" ], "fantasy": ["multitalk"], - "sky_df_14B": ["sky_df_1.3B", "sky_df_720p_14B"], + "sky_df_14B": ["sky_df_1.3B"], "hunyuan_custom": ["hunyuan_custom_edit", "hunyuan_custom_audio"], } comp_list= comp_map.get(imported_base_model_type, None) if comp_list == None: return False return curent_base_model_type in comp_list -def get_model_finetune_def(model_type): - return finetunes.get(model_type, None ) +def get_model_def(model_type): + return models_def.get(model_type, None ) @@ -1743,7 +1719,7 @@ def get_model_family(model_type): def test_class_i2v(model_type): model_type = get_base_model_type(model_type) - return model_type in ["i2v", "i2v_720p", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "hunyuan_i2v", "multitalk" ] + return model_type in ["i2v", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "hunyuan_i2v", "multitalk" ] def test_vace_module(model_type): model_type = get_base_model_type(model_type) @@ -1751,13 +1727,13 @@ def test_vace_module(model_type): def test_any_sliding_window(model_type): model_type = get_base_model_type(model_type) - return test_vace_module(model_type) or model_type in ["sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "ltxv_13B", "ltxv_13B_distilled", "multitalk", "t2v", "fantasy"] or test_class_i2v(model_type) + return test_vace_module(model_type) or model_type in ["sky_df_1.3B", "sky_df_14B", "ltxv_13B", "multitalk", "t2v", "fantasy"] or test_class_i2v(model_type) def get_model_min_frames_and_step(model_type): model_type = get_base_model_type(model_type) - if model_type in ["sky_df_14B", "sky_df_720p_14B"]: + if model_type in ["sky_df_14B"]: return 17, 20 - elif model_type in ["ltxv_13B", "ltxv_13B_distilled"]: + elif model_type in ["ltxv_13B"]: return 17, 8 elif test_vace_module(model_type): return 17, 4 @@ -1768,11 +1744,11 @@ def get_model_fps(model_type): model_type = get_base_model_type(model_type) if model_type in ["hunyuan_avatar", "hunyuan_custom_audio", "multitalk", "vace_multitalk_14B"]: fps = 25 - elif model_type in ["sky_df_14B", "sky_df_720p_14B", "hunyuan", "hunyuan_i2v", "hunyuan_custom_edit", "hunyuan_custom"]: + elif model_type in ["sky_df_14B", "hunyuan", "hunyuan_i2v", "hunyuan_custom_edit", "hunyuan_custom"]: fps = 24 elif model_type in ["fantasy"]: fps = 23 - elif model_type in ["ltxv_13B", "ltxv_13B_distilled"]: + elif model_type in ["ltxv_13B"]: fps = 30 else: fps = 16 @@ -1790,96 +1766,30 @@ def get_computed_fps(force_fps, base_model_type , video_guide, video_source ): return fps def get_model_name(model_type, description_container = [""]): - finetune_def = get_model_finetune_def(model_type) - if finetune_def != None: - model_name = finetune_def["name"] - description = finetune_def["description"] - description_container[0] = description - return model_name - model_filename = get_model_filename(model_type) - if "Fun" in model_filename: - model_name = "Fun InP image2video" - model_name += " 14B" if "14B" in model_filename else " 1.3B" - description = "The Fun model is an alternative image 2 video that supports out the box End Image fixing (contrary to the original Wan image 2 video model). The 1.3B adds also image 2 to video capability to the 1.3B model." - elif "Vace" in model_filename: - model_name = "Vace ControlNet" - model_name += " 14B" if "14B" in model_filename else " 1.3B" - description = "The Vace ControlNet model is a powerful model that allows you to control the content of the generated video based of additional custom data : pose or depth video, images or objects you want to see in the video." - elif "image" in model_filename: - model_name = "Wan2.1 image2video" - model_name += " 720p" if "720p" in model_filename else " 480p" - model_name += " 14B" - if "720p" in model_filename: - description = "The standard Wan Image 2 Video specialized to generate 720p images. It also offers Start and End Image support (End Image is not supported in the original model but seems to work well)" - else: - description = "The standard Wan Image 2 Video specialized to generate 480p images. It also offers Start and End Image support (End Image is not supported in the original model but seems to work well)" - elif "recam" in model_filename: - model_name = "ReCamMaster" - model_name += " 14B" if "14B" in model_filename else " 1.3B" - description = "The Recam Master in theory should allow you to replay a video by applying a different camera movement. The model supports only video that are at least 81 frames long (any frame beyond will be ignored)" - elif "sky_reels2_diffusion_forcing" in model_filename: - model_name = "SkyReels2 Diffusion Forcing" - if "720p" in model_filename : - model_name += " 720p" - elif not "1.3B" in model_filename : - model_name += " 540p" - model_name += " 14B" if "14B" in model_filename else " 1.3B" - description = "The SkyReels 2 Diffusion Forcing model has been designed to generate very long videos that exceeds the usual 5s limit. You can also use this model to extend any existing video." - elif "phantom" in model_filename: - model_name = "Phantom" - if "14B" in model_filename: - model_name += " 14B" - description = "The Phantom model is specialized to transfer people or objects of your choice into a generated Video. It seems to produce better results if you keep the original background of the Image Referendes." - else: - model_name += " 1.3B" - description = "The Phantom model is specialized to transfer people or objects of your choice into a generated Video. It produces very nice results when used at 720p." - elif "ltxv_0.9.7_13B_dev" in model_filename: - model_name = "LTX Video 0.9.7 13B" - description = "LTX Video is a fast model that can be used to generate long videos (up to 260 frames).It is recommended to keep the number of steps to 30 or you will need to update the file 'ltxv_video/configs/ltxv-13b-0.9.7-dev.yaml'.The LTX Video model expects very long prompts, so don't hesitate to use the Prompt Enhancer." - elif "ltxv_0.9.7_13B_distilled" in model_filename: - model_name = "LTX Video 0.9.7 Distilled 13B" - description = "LTX Video is a fast model that can be used to generate long videos (up to 260 frames).This distilled version is a very fast version and retains a high level of quality. The LTX Video model expects very long prompts, so don't hesitate to use the Prompt Enhancer." - elif "hunyuan_video_720" in model_filename: - model_name = "Hunyuan Video text2video 720p 13B" - description = "Probably the best text 2 video model available." - elif "hunyuan_video_i2v" in model_filename: - model_name = "Hunyuan Video image2video 720p 13B" - description = "A good looking image 2 video model, but not so good in prompt adherence." - elif "hunyuan_video_custom" in model_filename: - if "audio" in model_filename: - model_name = "Hunyuan Video Custom Audio 720p 13B" - description = "The Hunyuan Video Custom Audio model can be used to generate scenes of a person speaking given a Reference Image and a Recorded Voice or Song. The reference image is not a start image and therefore one can represent the person in a different context.The video length can be anything up to 10s. It is also quite good to generate no sound Video based on a person." - elif "edit" in model_filename: - model_name = "Hunyuan Video Custom Edit 720p 13B" - description = "The Hunyuan Video Custom Edit model can be used to do Video inpainting on a person (add accessories or completely replace the person). You will need in any case to define a Video Mask which will indicate which area of the Video should be edited." - else: - model_name = "Hunyuan Video Custom 720p 13B" - description = "The Hunyuan Video Custom model is probably the best model to transfer people (only people for the momment) as it is quite good to keep their identity. However it is slow as to get good results, you need to generate 720p videos with 30 steps." - elif "hunyuan_video_avatar" in model_filename: - model_name = "Hunyuan Video Avatar 720p 13B" - description = "With the Hunyuan Video Avatar model you can animate a person based on the content of an audio input. Please note that the video generator works by processing 128 frames segment at a time (even if you ask less). The good news is that it will concatenate multiple segments for long video generation (max 3 segments recommended as the quality will get worse)." - else: - model_name = "Wan2.1 text2video" - model_name += " 14B" if "14B" in model_filename else " 1.3B" - description = "The original Wan Text 2 Video model. Most other models have been built on top of it" + model_def = get_model_def(model_type) + if model_def == None: raise Exception(f"Unknown model {model_type}") + model_name = model_def["name"] + description = model_def["description"] description_container[0] = description return model_name def get_model_record(model_name): return f"WanGP v{WanGP_version} by DeepBeepMeep - " + model_name -def get_finetune_URLs(model_type, stack= []): - finetune_def = finetunes.get(model_type, None) - if finetune_def != None: - URLs = finetune_def["URLs"] - if isinstance(URLs, str): - if len(stack) > 10: raise Exception(f"Circular Reference in Model URLs dependencies: {stack}") - return get_finetune_URLs(URLs, stack = stack + [URLs] ) +def get_model_recursive_prop(model_type, prop = "URLs", return_list = False, stack= []): + model_def = models_def.get(model_type, None) + if model_def != None: + prop_value = model_def.get(prop, None) + if prop_value == None: + return [] + if isinstance(prop_value, str): + if len(stack) > 10: raise Exception(f"Circular Reference in Model {prop} dependencies: {stack}") + return get_model_recursive_prop(prop_value, prop = prop, stack = stack + [prop_value] ) else: - return URLs + return prop_value else: if model_type in model_types: - return model_type + return [] if return_list else model_type else: raise Exception(f"Unknown model type '{model_type}'") @@ -1889,17 +1799,14 @@ def get_model_filename(model_type, quantization ="int8", dtype_policy = "", is_m choices = modules_files.get(model_type, None) if choices == None: raise Exception(f"Invalid Module Id '{model_type}'") else: - finetune_def = finetunes.get(model_type, None) - if finetune_def != None: - URLs = finetune_def["URLs"] - if isinstance(URLs, str): - if len(stack) > 10: raise Exception(f"Circular Reference in Model URLs dependencies: {stack}") - return get_model_filename(URLs, quantization=quantization, dtype_policy=dtype_policy, stack = stack + [URLs]) - else: - choices = [ ("ckpts/" + os.path.basename(path) if path.startswith("http") else path) for path in URLs ] + model_def = models_def.get(model_type, None) + if model_def == None: raise Exception(f"Unknown model type {model_type}") + URLs = model_def["URLs"] + if isinstance(URLs, str): + if len(stack) > 10: raise Exception(f"Circular Reference in Model URLs dependencies: {stack}") + return get_model_filename(URLs, quantization=quantization, dtype_policy=dtype_policy, stack = stack + [URLs]) else: - signature = model_signatures[model_type] - choices = [ name for name in transformer_choices if signature in name] + choices = [ ("ckpts/" + os.path.basename(path) if path.startswith("http") else path) for path in URLs ] if len(quantization) == 0: quantization = "bf16" @@ -1954,8 +1861,8 @@ def fix_settings(model_type, ui_defaults): if image_prompt_type != None : if not isinstance(image_prompt_type, str): image_prompt_type = "S" if image_prompt_type == 0 else "SE" - if model_type == "flf2v_720p" and not "E" in image_prompt_type: - image_prompt_type = "SE" + # if model_type == "flf2v_720p" and not "E" in image_prompt_type: + # image_prompt_type = "SE" if video_settings_version <= 2: image_prompt_type = image_prompt_type.replace("G","") ui_defaults["image_prompt_type"] = image_prompt_type @@ -1968,7 +1875,7 @@ def fix_settings(model_type, ui_defaults): audio_prompt_type = ui_defaults.get("audio_prompt_type", None) if video_settings_version < 2.2: - if not model_type in ["vace_1.3B","vace_14B", "sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "ltxv_13B", "ltxv_13B_distilled"]: + if not model_type in ["vace_1.3B","vace_14B", "sky_df_1.3B", "sky_df_14B", "ltxv_13B"]: for p in ["sliding_window_size", "sliding_window_overlap", "sliding_window_overlap_noise", "sliding_window_discard_last_frames"]: if p in ui_defaults: del ui_defaults[p] @@ -1979,7 +1886,7 @@ def fix_settings(model_type, ui_defaults): video_prompt_type = ui_defaults.get("video_prompt_type", "") - if model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar", "phantom_14B", "phantom_1.3B"]: + if model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar", "phantom_14B", "phantom_1.3B", "flux_dev_kontext"]: if not "I" in video_prompt_type: # workaround for settings corruption video_prompt_type += "I" if model_type in ["hunyuan"]: @@ -2022,9 +1929,9 @@ def get_default_settings(model_type): i2v = test_class_i2v(model_type) defaults_filename = get_settings_file_name(model_type) if not Path(defaults_filename).is_file(): - finetune_def = get_model_finetune_def(model_type) - if finetune_def != None: - ui_defaults = finetune_def["settings"] + model_def = get_model_def(model_type) + if model_def != None: + ui_defaults = model_def["settings"] if len(ui_defaults.get("prompt","")) == 0: ui_defaults["prompt"]= get_default_prompt(i2v) else: @@ -2067,7 +1974,11 @@ def get_default_settings(model_type): "guidance_scale": 7.0, }) - elif model_type in ["sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"]: + elif model_type in ["flux_dev_kontext"]: + ui_defaults.update({ + "video_prompt_type": "I", + }) + elif model_type in ["sky_df_1.3B", "sky_df_14B"]: ui_defaults.update({ "guidance_scale": 6.0, "flow_shift": 8, @@ -2142,33 +2053,50 @@ def get_default_settings(model_type): return ui_defaults -finetunes_paths = glob.glob( os.path.join("finetunes", "*.json") ) -finetunes_paths.sort() -for file_path in finetunes_paths: - finetune_id = os.path.basename(file_path)[:-5] +def set_default_model_def(model_def, model_type): + if model_type == "flux_dev_kontext": + model_def.update({"image_outputs": True}) + + +models_def_paths = glob.glob( os.path.join("defaults", "*.json") ) + glob.glob( os.path.join("finetunes", "*.json") ) +models_def_paths.sort() +for file_path in models_def_paths: + model_type = os.path.basename(file_path)[:-5] with open(file_path, "r", encoding="utf-8") as f: try: json_def = json.load(f) except Exception as e: - raise Exception(f"Error while parsing Finetune Definition File '{file_path}': {str(e)}") - finetune_def = json_def["model"] - del json_def["model"] - finetune_def["settings"] = json_def - finetunes[finetune_id] = finetune_def + raise Exception(f"Error while parsing Model Definition File '{file_path}': {str(e)}") + model_def = json_def["model"] + model_def["path"] = file_path + del json_def["model"] + settings = json_def + existing_model_def = models_def.get(model_type, None) + if existing_model_def is not None: + existing_settings = models_def["settings"] + existing_settings.update(settings) + existing_model_def.update(model_def) + else: + models_def[model_type] = model_def + set_default_model_def(model_def, model_type) + model_def["settings"] = settings -model_types += [model_type for model_type, finetune in finetunes.items() if finetune.get("visible", True)] -displayed_model_types= model_types +model_types = models_def.keys() +displayed_model_types= [] +for model_type in model_types: + model_def = get_model_def(model_type) + if not (model_def != None and model_def.get("visible", True) == False): + displayed_model_types.append(model_type) -# model_types += [model_type for model_type in abstract_model_types if model_type not in finetunes] transformer_types = server_config.get("transformer_types", []) transformer_type = server_config.get("last_model_type", None) advanced = server_config.get("last_advanced_choice", False) if args.advanced: advanced = True -if transformer_type != None and not transformer_type in model_types and not transformer_type in finetunes: transformer_type = None +if transformer_type != None and not transformer_type in model_types and not transformer_type in models_def: transformer_type = None if transformer_type == None: - transformer_type = transformer_types[0] if len(transformer_types) > 0 else model_types[0] + transformer_type = transformer_types[0] if len(transformer_types) > 0 else "t2v" transformer_quantization =server_config.get("transformer_quantization", "int8") @@ -2222,18 +2150,11 @@ if args.compile: #args.fastest or compile="transformer" lock_ui_compile = True -#attention_mode="sage" -#attention_mode="sage2" -#attention_mode="flash" -#attention_mode="sdpa" -#attention_mode="xformers" -# compile = "transformer" - def save_quantized_model(model, model_type, model_filename, dtype, config_file): if "quanto" in model_filename: return - finetune_def = get_model_finetune_def(model_type) - if finetune_def == None: return - URLs= finetune_def["URLs"] + model_def = get_model_def(model_type) + if model_def == None: return + URLs= model_def["URLs"] if isinstance(URLs, str): print("Unable to create a quantized model for a finetune that references external files") return @@ -2258,16 +2179,13 @@ def save_quantized_model(model, model_type, model_filename, dtype, config_file) print(f"New quantized file '{model_filename}' had been created for finetune Id '{model_type}'.") if not model_filename in URLs: URLs.append(model_filename) - finetune_def = finetune_def.copy() - if "settings" in finetune_def: - saved_def = typing.OrderedDict() - saved_def["model"] = finetune_def - saved_def.update(finetune_def["settings"]) - del finetune_def["settings"] - finetune_file = os.path.join("finetunes" , model_type + ".json") - with open(finetune_file, "w", encoding="utf-8") as writer: - writer.write(json.dumps(saved_def, indent=4)) - print(f"The '{finetune_file}' definition file has been automatically updated with the local path to the new quantized model.") + finetune_file = os.path.join(os.path.dirname(model_def["path"]) , model_type + ".json") + with open(finetune_file, 'r', encoding='utf-8') as reader: + saved_finetune_def = json.load(reader) + saved_finetune_def["model"]["URLs"] = URLs + with open(finetune_file, "w", encoding="utf-8") as writer: + writer.write(json.dumps(saved_finetune_def, indent=4)) + print(f"The '{finetune_file}' definition file has been automatically updated with the local path to the new quantized model.") def get_loras_preprocessor(transformer, model_type): preprocessor = getattr(transformer, "preprocess_loras", None) @@ -2279,15 +2197,6 @@ def get_loras_preprocessor(transformer, model_type): return preprocessor_wrapper - -# def get_model_manager(model_family): -# if model_family == "wan": -# return None -# elif model_family == "ltxv": -# from ltxv import model_def -# return model_def -# else: -# raise Exception("model family not supported") def get_wan_text_encoder_filename(text_encoder_quantization): text_encoder_filename = "ckpts/umt5-xxl/models_t5_umt5-xxl-enc-bf16.safetensors" @@ -2389,10 +2298,10 @@ def download_models(model_filename, model_type): urlretrieve(url,filename, create_progress_hook(filename)) model_family = get_model_family(model_type) - finetune_def = get_model_finetune_def(model_type) - if finetune_def != None and not model_type in modules_files: + model_def = get_model_def(model_type) + if model_def != None and not model_type in modules_files: if not os.path.isfile(model_filename ): - URLs = get_finetune_URLs(model_type) + URLs = get_model_recursive_prop(model_type, "URLs") if not isinstance(URLs, str): # dont download anything right now if a base type is referenced as the download will occur just after use_url = model_filename for url in URLs: @@ -2400,7 +2309,7 @@ def download_models(model_filename, model_type): use_url = url break if not url.startswith("http"): - raise Exception(f"Model '{model_filename}' was not found locally and no URL was provided to download it. Please add an URL in the finetune definition file.") + raise Exception(f"Model '{model_filename}' was not found locally and no URL was provided to download it. Please add an URL in the model definition file.") try: download_file(use_url, model_filename) except Exception as e: @@ -2408,11 +2317,14 @@ def download_models(model_filename, model_type): raise Exception(f"URL '{use_url}' is invalid for Model '{model_filename}' : {str(e)}'") model_filename = None - for url in finetune_def.get("preload_URLs", []): + preload_URLs = get_model_recursive_prop(model_type, "preload_URLs", return_list= True) + model_loras = get_model_recursive_prop(model_type, "loras", return_list= True) + + for url in preload_URLs + model_loras: filename = "ckpts/" + url.split("/")[-1] if not os.path.isfile(filename ): if not url.startswith("http"): - raise Exception(f"File '{filename}' to preload was not found locally and no URL was provided to download it. Please add an URL in the finetune definition file.") + raise Exception(f"File '{filename}' to preload was not found locally and no URL was provided to download it. Please add an URL in the model definition file.") try: download_file(url, filename) except Exception as e: @@ -2420,21 +2332,21 @@ def download_models(model_filename, model_type): raise Exception(f"Preload URL '{url}' is invalid: {str(e)}'") if model_family == "wan": text_encoder_filename = get_wan_text_encoder_filename(text_encoder_quantization) - model_def = { + model_files = { "repoId" : "DeepBeepMeep/Wan2.1", "sourceFolderList" : ["xlm-roberta-large", "umt5-xxl", "" ], "fileList" : [ [ "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "sentencepiece.bpe.model", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"], ["special_tokens_map.json", "spiece.model", "tokenizer.json", "tokenizer_config.json"] + computeList(text_encoder_filename) , ["Wan2.1_VAE.safetensors", "fantasy_proj_model.safetensors" ] + computeList(model_filename) ] } elif model_family == "ltxv": text_encoder_filename = get_ltxv_text_encoder_filename(text_encoder_quantization) - model_def = { + model_files = { "repoId" : "DeepBeepMeep/LTX_Video", "sourceFolderList" : ["T5_xxl_1.1", "" ], "fileList" : [ ["added_tokens.json", "special_tokens_map.json", "spiece.model", "tokenizer_config.json"] + computeList(text_encoder_filename), ["ltxv_0.9.7_VAE.safetensors", "ltxv_0.9.7_spatial_upscaler.safetensors", "ltxv_scheduler.json"] + computeList(model_filename) ] } elif model_family == "hunyuan": text_encoder_filename = get_hunyuan_text_encoder_filename(text_encoder_quantization) - model_def = { + model_files = { "repoId" : "DeepBeepMeep/HunyuanVideo", "sourceFolderList" : [ "llava-llama-3-8b", "clip_vit_large_patch14", "whisper-tiny" , "det_align", "" ], "fileList" :[ ["config.json", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json", "preprocessor_config.json"] + computeList(text_encoder_filename) , @@ -2444,12 +2356,31 @@ def download_models(model_filename, model_type): [ "hunyuan_video_720_quanto_int8_map.json", "hunyuan_video_custom_VAE_fp32.safetensors", "hunyuan_video_custom_VAE_config.json", "hunyuan_video_VAE_fp32.safetensors", "hunyuan_video_VAE_config.json" , "hunyuan_video_720_quanto_int8_map.json" ] + computeList(model_filename) ] } + elif model_family == "flux": + text_encoder_filename = get_ltxv_text_encoder_filename(text_encoder_quantization) + model_files = [ + { + "repoId" : "DeepBeepMeep/Flux", + "sourceFolderList" : [""], + "fileList" : [ ["flux_vae.safetensors"] ] + }, + { + "repoId" : "DeepBeepMeep/LTX_Video", + "sourceFolderList" : ["T5_xxl_1.1"], + "fileList" : [ ["added_tokens.json", "special_tokens_map.json", "spiece.model", "tokenizer_config.json"] + computeList(text_encoder_filename) ] + }, + { + "repoId" : "DeepBeepMeep/HunyuanVideo", + "sourceFolderList" : [ "clip_vit_large_patch14", ], + "fileList" :[ + ["config.json", "merges.txt", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json", "vocab.json"], + ] + } + ] - # else: - # model_manager = get_model_manager(model_family) - # model_def = model_manager.get_files_def(model_filename, text_encoder_quantization) - - process_files_def(**model_def) + if not isinstance(model_files, list): model_files = [model_files] + for one_repo in model_files: + process_files_def(**one_repo) offload.default_verboseLevel = verbose_level @@ -2543,13 +2474,13 @@ def setup_loras(model_type, transformer, lora_dir, lora_preselected_preset, spl return loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset -def load_wan_model(model_filename, model_type, base_model_type, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False): +def load_wan_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False): if test_class_i2v(base_model_type): cfg = WAN_CONFIGS['i2v-14B'] else: cfg = WAN_CONFIGS['t2v-14B'] # cfg = WAN_CONFIGS['t2v-1.3B'] - if base_model_type in ("sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"): + if base_model_type in ("sky_df_1.3B", "sky_df_14B"): model_factory = wan.DTT2V else: model_factory = wan.WanAny2V @@ -2573,12 +2504,13 @@ def load_wan_model(model_filename, model_type, base_model_type, quantizeTransfor pipe["text_encoder_2"] = wan_model.clip.model return wan_model, pipe -def load_ltxv_model(model_filename, model_type, base_model_type, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): +def load_ltxv_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): from ltx_video.ltxv import LTXV ltxv_model = LTXV( model_filepath = model_filename, text_encoder_filepath = get_ltxv_text_encoder_filename(text_encoder_quantization), + model_def = model_def, dtype = dtype, # quantizeTransformer = quantizeTransformer, VAE_dtype = VAE_dtype, @@ -2590,7 +2522,28 @@ def load_ltxv_model(model_filename, model_type, base_model_type, quantizeTransfo return ltxv_model, pipe -def load_hunyuan_model(model_filename, model_type = None, base_model_type = None, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): + +def load_flux_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): + from flux.flux_main import model_factory + + flux_model = model_factory( + checkpoint_dir="ckpts", + model_filename=model_filename, + model_type = model_type, + base_model_type=base_model_type, + text_encoder_filename= get_ltxv_text_encoder_filename(text_encoder_quantization), + quantizeTransformer = quantizeTransformer, + dtype = dtype, + VAE_dtype = VAE_dtype, + mixed_precision_transformer = mixed_precision_transformer, + save_quantized = save_quantized + ) + + pipe = { "transformer": flux_model.model, "vae" : flux_model.vae, "text_encoder" : flux_model.clip, "text_encoder_2" : flux_model.t5} + + return flux_model, pipe + +def load_hunyuan_model(model_filename, model_type = None, base_model_type = None, model_def = None, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): from hyvideo.hunyuan import HunyuanVideoSampler hunyuan_model = HunyuanVideoSampler.from_pretrained( @@ -2634,21 +2587,20 @@ def get_transformer_model(model): def load_models(model_type): - global transformer_type, transformer_loras_filenames + global transformer_type base_model_type = get_base_model_type(model_type) - finetune_def = get_model_finetune_def(model_type) + model_def = get_model_def(model_type) preload =int(args.preload) - save_quantized = args.save_quantized and finetune_def != None + save_quantized = args.save_quantized and model_def != None model_filename = get_model_filename(model_type=model_type, quantization= "" if save_quantized else transformer_quantization, dtype_policy = transformer_dtype_policy) - modules = finetune_def.get("modules", []) if finetune_def != None else [] + modules = get_model_recursive_prop(model_type, "modules", return_list= True) if save_quantized and "quanto" in model_filename: save_quantized = False print("Need to provide a non quantized model to create a quantized model to be saved") if save_quantized and len(modules) > 0: - _, model_types_no_module = dependent_models_types = get_dependent_models(base_model_type, transformer_quantization, transformer_dtype_policy) print(f"Unable to create a finetune quantized model as some modules are declared in the finetune definition. If your finetune includes already the module weights you can remove the 'modules' entry and try again. If not you will need also to change temporarly the model 'architecture' to an architecture that wont require the modules part ('{model_types_no_module[0] if len(model_types_no_module)>0 else ''}' ?) to quantize and then add back the original 'modules' and 'architecture' entries.") save_quantized = False - quantizeTransformer = not save_quantized and finetune_def !=None and transformer_quantization in ("int8", "fp8") and finetune_def.get("auto_quantize", False) and not "quanto" in model_filename + quantizeTransformer = not save_quantized and model_def !=None and transformer_quantization in ("int8", "fp8") and model_def.get("auto_quantize", False) and not "quanto" in model_filename if quantizeTransformer and len(modules) > 0: print(f"Autoquantize is not yet supported if some modules are declared") quantizeTransformer = False @@ -2660,22 +2612,16 @@ def load_models(model_type): perc_reserved_mem_max = args.perc_reserved_mem_max if preload == 0: preload = server_config.get("preload_in_VRAM", 0) - new_transformer_loras_filenames = None - dependent_models, dependent_models_types = get_dependent_models(model_type, quantization= transformer_quantization, dtype_policy = transformer_dtype) - new_transformer_loras_filenames = [model_filename] if "_lora" in model_filename else None - - model_file_list = dependent_models + [model_filename] - model_type_list = dependent_models_types + [model_type] + model_file_list = [model_filename] + model_type_list = [model_type] new_transformer_filename = model_file_list[-1] for module_type in modules: model_file_list.append(get_model_filename(module_type, transformer_quantization, transformer_dtype, is_module= True)) - model_type_list.append(module_type) for filename, file_model_type in zip(model_file_list, model_type_list): download_models(filename, file_model_type) VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float mixed_precision_transformer = server_config.get("mixed_precision","0") == "1" - transformer_loras_filenames = None transformer_type = None for i, filename in enumerate(model_file_list): if i==0: @@ -2684,11 +2630,13 @@ def load_models(model_type): print(f"Loading Module '{filename}' ...") if model_family == "wan" : - wan_model, pipe = load_wan_model(model_file_list, model_type, base_model_type, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized) + wan_model, pipe = load_wan_model(model_file_list, model_type, base_model_type, model_def, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized) elif model_family == "ltxv": - wan_model, pipe = load_ltxv_model(model_file_list, model_type, base_model_type, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized) + wan_model, pipe = load_ltxv_model(model_file_list, model_type, base_model_type, model_def, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized) + elif model_family == "flux": + wan_model, pipe = load_flux_model(model_file_list, model_type, base_model_type, model_def, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized) elif model_family == "hunyuan": - wan_model, pipe = load_hunyuan_model(model_file_list, model_type, base_model_type, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized) + wan_model, pipe = load_hunyuan_model(model_file_list, model_type, base_model_type, model_def, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized) else: raise Exception(f"Model '{new_transformer_filename}' not supported.") wan_model._model_file_name = new_transformer_filename @@ -2720,7 +2668,6 @@ def load_models(model_type): offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = False, loras = "transformer", coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = transformer_dtype, **kwargs) if len(args.gpu) > 0: torch.set_default_device(args.gpu) - transformer_loras_filenames = new_transformer_loras_filenames transformer_type = model_type return wan_model, offloadobj, pipe["transformer"] @@ -3054,7 +3001,7 @@ def finalize_generation(state): return gr.Gallery(selected_index=choice), gr.Button(interactive= True), gr.Button(visible= True), gr.Button(visible= False), gr.Column(visible= False), gr.HTML(visible= False, value="") def get_default_video_info(): - return "Please Select a Video" + return "Please Select an Video / Image" def get_file_list(state, input_file_list): @@ -3093,14 +3040,23 @@ def select_video(state, input_file_list, event_data: gr.EventData): if len(file_list) > 0: configs = file_settings_list[choice] file_name = file_list[choice] - fps, width, height, frames_count = get_video_info(file_name) - nb_audio_tracks = extract_audio_tracks(file_name,query_only = True) values = [ os.path.basename(file_name)] labels = [ "File Name"] misc_values= [] misc_labels = [] pp_values= [] pp_labels = [] + extension = os.path.splitext(file_name)[-1] + if not extension in [".mp4"]: + img = Image.open(file_name) + width, height = img.size + configs = None + is_image = True + nb_audio_tracks = 0 + else: + fps, width, height, frames_count = get_video_info(file_name) + is_image = False + nb_audio_tracks = extract_audio_tracks(file_name,query_only = True) if configs != None: video_model_name = configs.get("type", "Unknown model") if "-" in video_model_name: video_model_name = video_model_name[video_model_name.find("-")+2:] @@ -3128,8 +3084,12 @@ def select_video(state, input_file_list, event_data: gr.EventData): labels += misc_labels video_creation_date = str(get_file_creation_date(file_name)) if "." in video_creation_date: video_creation_date = video_creation_date[:video_creation_date.rfind(".")] - values += [f"{width}x{height}", f"{frames_count} frames (duration={frames_count/fps:.1f} s, fps={round(fps)})"] - labels += ["Resolution", "Frames"] + if is_image: + values += [f"{width}x{height}"] + labels += ["Resolution"] + else: + values += [f"{width}x{height}", f"{frames_count} frames (duration={frames_count/fps:.1f} s, fps={round(fps)})"] + labels += ["Resolution", "Frames"] if nb_audio_tracks > 0: values +=[nb_audio_tracks] labels +=["Nb Audio Tracks"] @@ -3246,8 +3206,7 @@ def select_video(state, input_file_list, event_data: gr.EventData): else: html = get_default_video_info() visible= len(file_list) > 0 - return choice, html, gr.update(visible=visible), gr.update(visible=visible) - + return choice, html, gr.update(visible=visible and not is_image) , gr.update(visible=visible and is_image), gr.update(visible=visible and is_image) def expand_slist(slist, num_inference_steps ): new_slist= [] inc = len(slist) / num_inference_steps @@ -3834,6 +3793,14 @@ def edit_video( cleanup_temp_audio_files(audio_tracks) clear_status(state) +def get_transformer_loras(model_type): + model_def = get_model_def(model_type) + transformer_loras_filenames = get_model_recursive_prop(model_type, "loras", return_list=True) + transformer_loras_filenames = [ "ckpts/" + os.path.basename(filename) for filename in transformer_loras_filenames] + transformer_loras_multipliers = get_model_recursive_prop(model_type, "loras_multipliers", return_list=True) + [1.] * len(transformer_loras_filenames) + transformer_loras_multipliers = transformer_loras_multipliers[:len(transformer_loras_filenames)] + return transformer_loras_filenames, transformer_loras_multipliers + def generate_video( task, send_cmd, @@ -3914,7 +3881,9 @@ def generate_video( file_list = gen["file_list"] file_settings_list = gen["file_settings_list"] - prompt_no = gen["prompt_no"] + + model_def = get_model_def(model_type) + is_image = model_def.get("image_outputs", False) fit_canvas = server_config.get("fit_canvas", 0) @@ -3967,7 +3936,8 @@ def generate_video( loras = state["loras"] loras_slists = [] - if len(loras) > 0 or transformer_loras_filenames != None: + transformer_loras_filenames, transformer_loras_multipliers = get_transformer_loras(model_type) + if len(loras) > 0 or len(transformer_loras_filenames) > 0 : def is_float(element: any) -> bool: if element is None: return False @@ -4005,11 +3975,12 @@ def generate_video( lora_dir = get_lora_dir(model_type) loras_selected = [ os.path.join(lora_dir, lora) for lora in activated_loras] - pinnedLora = profile !=5 and transformer_loras_filenames == None #False # # # + pinnedLora = profile !=5 # and transformer_loras_filenames == None False # # # split_linear_modules_map = getattr(trans,"split_linear_modules_map", None) if transformer_loras_filenames != None: - loras_selected += transformer_loras_filenames - loras_list_mult_choices_nums.append(1.) + loras_selected = transformer_loras_filenames + loras_selected + loras_list_mult_choices_nums = transformer_loras_multipliers + loras_list_mult_choices_nums + loras_slists = transformer_loras_multipliers + loras_slists offload.load_loras_into_model(trans, loras_selected, loras_list_mult_choices_nums, activate_all_loras=True, preprocess_sd=get_loras_preprocessor(trans, base_model_type), pinnedLora=pinnedLora, split_linear_modules_map = split_linear_modules_map) errors = trans._loras_errors if len(errors) > 0: @@ -4039,6 +4010,7 @@ def generate_video( hunyuan_avatar = "hunyuan_video_avatar" in model_filename fantasy = base_model_type in ["fantasy"] multitalk = base_model_type in ["multitalk", "vace_multitalk_14B"] + flux_dev_kontext = base_model_type in ["flux_dev_kontext"] if "L" in image_prompt_type: if len(file_list)>0: @@ -4103,8 +4075,7 @@ def generate_video( if trans.enable_cache == "mag": trans.magcache_thresh = 0 trans.magcache_K = 2 - finetune_def = get_model_finetune_def(model_type) - def_mag_ratios = finetune_def.get("magcache_ratios", None) if finetune_def != None else None + def_mag_ratios = model_def.get("magcache_ratios", None) if model_def != None else None if def_mag_ratios != None: trans.def_mag_ratios = def_mag_ratios elif get_model_family(model_type) == "wan": @@ -4120,8 +4091,8 @@ def generate_video( elif trans.enable_cache == "tea": trans.rel_l1_thresh = 0 - finetune_def = get_model_finetune_def(model_type) - def_tea_coefficients = finetune_def.get("teacache_coefficients", None) if finetune_def != None else None + model_def = get_model_def(model_type) + def_tea_coefficients = model_def.get("teacache_coefficients", None) if model_def != None else None if def_tea_coefficients != None: trans.coefficients = def_tea_coefficients elif get_model_family(model_type) == "wan": @@ -4272,7 +4243,7 @@ def generate_video( abort = gen.get("abort", False) while not abort: - enable_RIFLEx = RIFLEx_setting == 0 and current_video_length > (6* get_model_fps(base_model_type)) or RIFLEx_setting == 1 + enable_RIFLEx = RIFLEx_setting == 0 and current_video_length > (6* get_model_fps(base_model_type)+1) or RIFLEx_setting == 1 if sliding_window: prompt = prompts[window_no] if window_no < len(prompts) else prompts[-1] new_extra_windows = gen.get("extra_windows",0) @@ -4308,7 +4279,7 @@ def generate_video( audio_proj_split = get_window_audio_embeddings(audio_proj_full, audio_start_idx= window_start_frame, clip_length = current_video_length) if i2v and window_no > 1: src_video = pre_video_guide - if hunyuan_custom or hunyuan_avatar: + if hunyuan_custom or hunyuan_avatar or flux_dev_kontext: src_ref_images = image_refs elif phantom: src_ref_images = image_refs.copy() if image_refs != None else None @@ -4455,7 +4426,7 @@ def generate_video( input_video= pre_video_guide if diffusion_forcing or ltxv or hunyuan_custom_edit else source_video, denoising_strength=denoising_strength, target_camera= target_camera, - frame_num=(current_video_length // latent_size)* latent_size + 1, + frame_num=current_video_length if is_image else (current_video_length // latent_size)* latent_size + 1, height = height, width = width, fit_into_canvas = fit_canvas == 1, @@ -4608,14 +4579,27 @@ def generate_video( time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss") save_prompt = original_prompts[0] + from wan.utils.utils import truncate_for_filesystem + extension = "jpg" if is_image else "mp4" + if os.name == 'nt': - file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(truncate_for_filesystem(save_prompt,50)).strip()}.mp4" + file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(truncate_for_filesystem(save_prompt,50)).strip()}.{extension}" else: - file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(truncate_for_filesystem(save_prompt,100)).strip()}.mp4" + file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(truncate_for_filesystem(save_prompt,100)).strip()}.{extension}" video_path = os.path.join(save_path, file_name) any_mmaudio = MMAudio_setting != 0 and server_config.get("mmaudio_enabled", 0) != 0 and sample.shape[1] >=fps - if len(control_audio_tracks) > 0 or source_audio != None or any_mmaudio or merged_audio_data is not None: + + if is_image: + sample = sample.permute(1,2,3,0) #c f h w -> f h w c + new_video_path = [] + for no, img in enumerate(sample): + img = Image.fromarray((127.5 * (img + 1.0)).cpu().byte().numpy()) + img_path = os.path.splitext(video_path)[0] + ("" if no==0 else f"_{no}") + ".jpg" + new_video_path.append(img_path) + img.save(img_path) + video_path= new_video_path + elif len(control_audio_tracks) > 0 or source_audio != None or any_mmaudio or merged_audio_data is not None: save_path_tmp = video_path[:-4] + "_tmp.mp4" cache_video( tensor=sample[None], save_file=save_path_tmp, fps=output_fps, nrow=1, normalize=True, value_range=(-1, 1)) if len(control_audio_tracks) > 0: @@ -4653,21 +4637,26 @@ def generate_video( if prompt_enhancer_image_caption_model != None and prompt_enhancer !=None and len(prompt_enhancer)>0: configs["enhanced_prompt"] = "\n".join(prompts) configs["generation_time"] = round(end_time-start_time) + if is_image: configs["is_image"] = True metadata_choice = server_config.get("metadata_type","metadata") - if metadata_choice == "json": - with open(video_path.replace('.mp4', '.json'), 'w') as f: - json.dump(configs, f, indent=4) - elif metadata_choice == "metadata": - from mutagen.mp4 import MP4 - file = MP4(video_path) - file.tags['©cmt'] = [json.dumps(configs)] - file.save() - - print(f"New video saved to Path: "+video_path) - with lock: - file_list.append(video_path) - file_settings_list.append(configs) - + video_path = [video_path] if not isinstance(video_path, list) else video_path + for no, path in enumerate(video_path): + if metadata_choice == "json": + with open(path.replace(f'.{extension}', '.json'), 'w') as f: + json.dump(configs, f, indent=4) + elif metadata_choice == "metadata" and not is_image: + from mutagen.mp4 import MP4 + file = MP4(path) + file.tags['©cmt'] = [json.dumps(configs)] + file.save() + if is_image: + print(f"New image saved to Path: "+ path) + else: + print(f"New video saved to Path: "+ path) + with lock: + file_list.append(path) + file_settings_list.append(configs if no > 0 else configs.copy()) + # Play notification sound for single video try: if server_config.get("notification_sound_enabled", 1): @@ -4698,7 +4687,7 @@ def prepare_generate_video(state): def generate_preview(latents): import einops - + # thanks Comfyui for the rgb factors model_family = get_model_family(transformer_type) if model_family == "wan": latent_channels = 16 @@ -4727,6 +4716,29 @@ def generate_preview(latents): latent_rgb_factors_bias = [-0.1835, -0.0868, -0.3360] # latent_rgb_factors_bias = [0.0259, -0.0192, -0.0761] + elif model_family =="flux": + scale_factor = 0.3611 + shift_factor = 0.1159 + latent_rgb_factors =[ + [-0.0346, 0.0244, 0.0681], + [ 0.0034, 0.0210, 0.0687], + [ 0.0275, -0.0668, -0.0433], + [-0.0174, 0.0160, 0.0617], + [ 0.0859, 0.0721, 0.0329], + [ 0.0004, 0.0383, 0.0115], + [ 0.0405, 0.0861, 0.0915], + [-0.0236, -0.0185, -0.0259], + [-0.0245, 0.0250, 0.1180], + [ 0.1008, 0.0755, -0.0421], + [-0.0515, 0.0201, 0.0011], + [ 0.0428, -0.0012, -0.0036], + [ 0.0817, 0.0765, 0.0749], + [-0.1264, -0.0522, -0.1103], + [-0.0280, -0.0881, -0.0499], + [-0.1262, -0.0982, -0.0778] + ] + latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851] + elif model_family == "ltxv": latent_channels = 128 latent_dimensions = 3 @@ -5527,7 +5539,7 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None base_model_type = get_base_model_type(model_type) if model_type != base_model_type: inputs["base_model_type"] = base_model_type - diffusion_forcing = base_model_type in ["sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"] + diffusion_forcing = base_model_type in ["sky_df_1.3B", "sky_df_14B"] vace = test_vace_module(base_model_type) if target == "settings": @@ -5608,21 +5620,34 @@ def init_generate(state, input_file_list, last_choice): def video_to_control_video(state, input_file_list, choice): file_list, file_settings_list = get_file_list(state, input_file_list) if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list): return gr.update() - gr.Info("Select Video was copied to Control Video input") + gr.Info("Selected Video was copied to Control Video input") return file_list[choice] def video_to_source_video(state, input_file_list, choice): file_list, file_settings_list = get_file_list(state, input_file_list) if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list): return gr.update() - gr.Info("Select Video was copied to Source Video input") + gr.Info("Selected Video was copied to Source Video input") return file_list[choice] - + +def image_to_ref_image(state, input_file_list, choice, target, target_name): + file_list, file_settings_list = get_file_list(state, input_file_list) + if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list): return gr.update() + gr.Info(f"Selected Image was copied to {target_name}") + if target == None: + target =[] + target.append( file_list[choice]) + return target + + def apply_post_processing(state, input_file_list, choice, PP_temporal_upsampling, PP_spatial_upsampling, PP_MMAudio_setting, PP_MMAudio_prompt, PP_MMAudio_neg_prompt, PP_MMAudio_seed, PP_repeat_generation): gen = get_gen_info(state) file_list, file_settings_list = get_file_list(state, input_file_list) if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list) : return gr.update(), gr.update() + if not file_list[choice].endswith(".mp4"): + gr.Info("Post processing is only available with Videos") + return gr.update(), gr.update() overrides = { "temporal_upsampling":PP_temporal_upsampling, "spatial_upsampling":PP_spatial_upsampling, @@ -5639,6 +5664,7 @@ def apply_post_processing(state, input_file_list, choice, PP_temporal_upsampling in_progress = gen.get("in_progress", False) return "edit", get_unique_id() if not in_progress else gr.update(), get_unique_id() if in_progress else gr.update() + def eject_video_from_gallery(state, input_file_list, choice): gen = get_gen_info(state) file_list, file_settings_list = get_file_list(state, input_file_list) @@ -5754,7 +5780,7 @@ def get_settings_from_file(state, file_path, allow_json, merge_with_defaults, sw configs = json.load(f) except: pass - else: + elif file_path.endswith(".mp4"): from mutagen.mp4 import MP4 try: file = MP4(file_path) @@ -5778,7 +5804,7 @@ def get_settings_from_file(state, file_path, allow_json, merge_with_defaults, sw model_type = get_model_type(model_filename) if model_type == None: model_type = current_model_type - elif not model_type in model_types and not model_type in finetune_def: + elif not model_type in model_types: model_type = current_model_type fix_settings(model_type, configs) if switch_type_if_compatible and are_model_types_compatible(model_type,current_model_type): @@ -6248,6 +6274,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non gen = dict() gen["queue"] = [] state_dict["gen"] = gen + model_def = get_model_def(model_type) + if model_def == None: model_def = {} base_model_type = get_base_model_type(model_type) model_filename = get_model_filename( base_model_type ) preset_to_load = lora_preselected_preset if lora_preset_model == model_type else "" @@ -6333,7 +6361,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non flf2v = base_model_type == "flf2v_720p" diffusion_forcing = "diffusion_forcing" in model_filename ltxv = "ltxv" in model_filename - ltxv_distilled = "ltxv" in model_filename and "distilled" in model_filename + lock_inference_steps = model_def.get("lock_inference_steps", False) recammaster = "recam" in model_filename vace = test_vace_module(base_model_type) phantom = "phantom" in model_filename @@ -6345,11 +6373,19 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non hunyuan_video_custom_audio = hunyuan_video_custom and "audio" in model_filename hunyuan_video_custom_edit = hunyuan_video_custom and "edit" in model_filename hunyuan_video_avatar = "hunyuan_video_avatar" in model_filename + flux_dev_kontext = base_model_type in ["flux_dev_kontext"] + image_outputs = model_def.get("image_outputs", False) sliding_window_enabled = test_any_sliding_window(model_type) multi_prompts_gen_type_value = ui_defaults.get("multi_prompts_gen_type_value",0) prompt_label, wizard_prompt_label = get_prompt_labels(multi_prompts_gen_type_value) any_video_source = True fps = get_model_fps(base_model_type) + image_prompt_type_value = "" + video_prompt_type_value = "" + any_start_image = False + any_end_image = False + any_reference_image = False + with gr.Column(visible= test_class_i2v(model_type) or diffusion_forcing or ltxv or recammaster or vace) as image_prompt_column: if vace: image_prompt_type_value= ui_defaults.get("image_prompt_type","") @@ -6417,11 +6453,13 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non keep_frames_video_source = gr.Text(visible=False) else: if test_class_i2v(model_type): - image_prompt_type_value= ui_defaults.get("image_prompt_type","SE" if flf2v else "S" ) - image_prompt_type_choices = [] if flf2v else [("Use only a Start Image", "S")] + # image_prompt_type_value= ui_defaults.get("image_prompt_type","SE" if flf2v else "S" ) + image_prompt_type_value= ui_defaults.get("image_prompt_type","S" ) + image_prompt_type_choices = [("Use only a Start Image", "S")] image_prompt_type_choices += [("Use both a Start and an End Image", "SE")] image_prompt_type = gr.Radio( image_prompt_type_choices, value =image_prompt_type_value, label="Location", show_label= False, visible= not hunyuan_i2v, scale= 3) - + any_start_image = True + any_end_image = True image_start = gr.Gallery(preview= True, label="Images as starting points for new videos", type ="pil", #file_types= "image", columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value) @@ -6438,7 +6476,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non keep_frames_video_source = gr.Text(visible=False) any_video_source = False - with gr.Column(visible= vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or hunyuan_video_custom_edit or t2v) as video_prompt_column: + with gr.Column(visible= vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or hunyuan_video_custom_edit or t2v or flux_dev_kontext) as video_prompt_column: video_prompt_type_value= ui_defaults.get("video_prompt_type","") video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False) any_control_video = True @@ -6489,7 +6527,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non # video_prompt_video_guide_trigger = gr.Text(visible=False, value="") if t2v: - video_prompt_type_video_mask = gr.Dropdown(value = "", visible = False) + video_prompt_type_video_mask = gr.Dropdown(value = "", choices = [""], visible = False) elif hunyuan_video_custom_edit: video_prompt_type_video_mask = gr.Dropdown( choices=[ @@ -6520,7 +6558,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non label="Area Processed", scale = 2 ) if t2v: - video_prompt_type_image_refs = gr.Dropdown(value="", visible =False) + video_prompt_type_image_refs = gr.Dropdown(value="", label="Ref Image", choices=[""], visible =False) elif vace: video_prompt_type_image_refs = gr.Dropdown( choices=[ @@ -6561,7 +6599,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non video_mask = gr.Video(label= "Video Mask Area (for Inpainting, white = Control Area, black = Unchanged)", visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value , value= ui_defaults.get("video_mask", None)) mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value ) - + any_reference_image = vace or phantom or hunyuan_video_custom or hunyuan_video_avatar image_refs = gr.Gallery(preview= True, label ="Start Image" if hunyuan_video_avatar else "Reference Images", type ="pil", show_label= True, columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in video_prompt_type_value, @@ -6576,7 +6614,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non # ("Keep it for first Image (landscape) and remove it for other Images (objects / people)", 2), ], value=ui_defaults.get("remove_background_images_ref",1), - label="Automatic Removal of Background of People or Objects (Only)", scale = 3, visible= "I" in video_prompt_type_value and not hunyuan_video_avatar + label="Automatic Removal of Background of People or Objects (Only)", scale = 3, visible= "I" in video_prompt_type_value and not hunyuan_video_avatar and not flux_dev_kontext ) any_audio_voices_support = any_audio_track(base_model_type) @@ -6655,8 +6693,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non label= label ) with gr.Row(): - if recammaster: - video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s), locked", interactive= False) + if image_outputs: + video_length = gr.Slider(1, 16, value=ui_defaults.get("video_length", 1), step=1, label="Number of Images to Generate", visible = flux_dev_kontext) + elif recammaster: + video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s), locked", interactive= False, visible = True) else: min_frames, frames_step = get_model_min_frames_and_step(base_model_type) @@ -6664,8 +6704,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non "video_length", 81 if get_model_family(base_model_type)=="wan" else 97), step=frames_step, label=f"Number of frames ({fps} = 1s)", interactive= True) - with gr.Row(visible = not ltxv_distilled) as inference_steps_row: - num_inference_steps = gr.Slider(1, 100, value=ui_defaults.get("num_inference_steps",30), step=1, label="Number of Inference Steps") + with gr.Row(visible = not lock_inference_steps) as inference_steps_row: + num_inference_steps = gr.Slider(1, 100, value=ui_defaults.get("num_inference_steps",30), step=1, label="Number of Inference Steps", visible = True) @@ -6679,7 +6719,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance_scale",5), step=0.5, label="Guidance (CFG)", visible=not (hunyuan_t2v or hunyuan_i2v)) audio_guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("audio_guidance_scale", 5 if fantasy else 4), step=0.5, label="Audio Guidance", visible=fantasy or multitalk) embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=(hunyuan_t2v or hunyuan_i2v)) - flow_shift = gr.Slider(1.0, 25.0, value=ui_defaults.get("flow_shift",3), step=0.1, label="Shift Scale") + flow_shift = gr.Slider(1.0, 25.0, value=ui_defaults.get("flow_shift",3), step=0.1, label="Shift Scale", visible = not image_outputs) with gr.Row(visible = get_model_family(model_type) == "wan" and not diffusion_forcing ) as sample_solver_row: sample_solver = gr.Dropdown( value=ui_defaults.get("sample_solver",""), choices=[ @@ -6693,7 +6733,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non with gr.Row(visible = vace): control_net_weight = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight",1), step=0.1, label="Control Net Weight #1", visible=vace) control_net_weight2 = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight2",1), step=0.1, label="Control Net Weight #2", visible=vace) - negative_prompt = gr.Textbox(label="Negative Prompt (ignored if no Guidance that is if CFG = 1)", value=ui_defaults.get("negative_prompt", "") ) + negative_prompt = gr.Textbox(label="Negative Prompt (ignored if no Guidance that is if CFG = 1)", value=ui_defaults.get("negative_prompt", ""), visible = vace or t2v or test_class_i2v(model_type) ) with gr.Column(visible = vace or t2v or test_class_i2v(model_type)) as NAG_col: gr.Markdown("NAG enforces Negative Prompt even if no Guidance is set (CFG = 1), set NAG Scale to > 1 to enable it") with gr.Row(): @@ -6701,7 +6741,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non NAG_tau = gr.Slider(1.0, 5.0, value=ui_defaults.get("NAG_tau",3.5), step=0.1, label="NAG Tau", visible = True) NAG_alpha = gr.Slider(1.0, 2.0, value=ui_defaults.get("NAG_alpha",.5), step=0.1, label="NAG Alpha", visible = True) with gr.Row(): - repeat_generation = gr.Slider(1, 25.0, value=ui_defaults.get("repeat_generation",1), step=1, label="Num. of Generated Videos per Prompt") + repeat_generation = gr.Slider(1, 25.0, value=ui_defaults.get("repeat_generation",1), step=1, label="Num. of Generated Videos per Prompt", visible = not image_outputs) multi_images_gen_type = gr.Dropdown( value=ui_defaults.get("multi_images_gen_type",0), choices=[ ("Generate every combination of images and texts", 0), @@ -6720,7 +6760,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non label="Activated Loras" ) loras_multipliers = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by Space chars or CR, lines that start with # are ignored", value=launch_multis_str) - with gr.Tab("Steps Skipping", visible = not ltxv) as speed_tab: + with gr.Tab("Steps Skipping", visible = not (ltxv or image_outputs)) as speed_tab: with gr.Column(): gr.Markdown("Tea Cache and Mag Cache accelerate the Video Generation by skipping intelligently some steps, the more steps are skipped the lower the quality of the video.") gr.Markdown("Steps Skipping consumes also VRAM. It is recommended not to skip at least the first 10% steps.") @@ -6763,7 +6803,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ("Rife x4 frames/s", "rife4"), ], value=temporal_upsampling, - visible=True, + visible=not image_outputs, scale = 1, label="Temporal Upsampling", elem_classes= element_class @@ -6785,7 +6825,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non return temporal_upsampling, spatial_upsampling temporal_upsampling, spatial_upsampling = gen_upsampling_dropdowns(ui_defaults.get("temporal_upsampling", ""), ui_defaults.get("spatial_upsampling", "")) - with gr.Tab("MMAudio", visible = server_config.get("mmaudio_enabled", 0) != 0 and not any_audio_track(base_model_type)) as mmaudio_tab: + with gr.Tab("MMAudio", visible = server_config.get("mmaudio_enabled", 0) != 0 and not any_audio_track(base_model_type) and not image_outputs) as mmaudio_tab: with gr.Column(): gr.Markdown("Add a soundtrack based on the content of the Generated Video") def gen_mmaudio_dropdowns(MMAudio_setting, MMAudio_prompt, MMAudio_neg_prompt, MMAudio_seed = None, element_class = None, max_height = None ): @@ -6812,7 +6852,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non MMAudio_setting, MMAudio_prompt, MMAudio_neg_prompt, _ = gen_mmaudio_dropdowns(ui_defaults.get("MMAudio_setting", 0), ui_defaults.get("MMAudio_prompt", ""), ui_defaults.get("MMAudio_neg_prompt", "")) - with gr.Tab("Quality", visible = not ltxv) as quality_tab: + with gr.Tab("Quality", visible = not (ltxv or flux_dev_kontext)) as quality_tab: with gr.Column(visible = not (hunyuan_i2v or hunyuan_t2v or hunyuan_video_custom or hunyuan_video_avatar) ) as skip_layer_guidance_row: gr.Markdown("Skip Layer Guidance (improves video quality, requires guidance > 1)") with gr.Row(): @@ -6903,7 +6943,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non label="Text Prompts separated by a Carriage Return" ) - with gr.Tab("Misc.") as misc_tab: + with gr.Tab("Misc.", visible = not image_outputs) as misc_tab: with gr.Column(visible = not (recammaster or ltxv or diffusion_forcing)) as RIFLEx_setting_col: gr.Markdown("With Riflex you can generate videos longer than 5s which is the default duration of videos used to train the model") RIFLEx_setting = gr.Dropdown( @@ -6974,13 +7014,19 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non with gr.Accordion("Video Info and Late Post Processing", open=False) as video_info_accordion: with gr.Tabs() as video_info_tabs: with gr.Tab("Information", id="video_info"): + default_visibility = {} if update_form else {"visible" : False} video_info = gr.HTML(visible=True, min_height=100, value=get_default_video_info()) - with gr.Row(visible= False) as video_buttons_row: + with gr.Row(**default_visibility) as video_buttons_row: video_info_extract_settings_btn = gr.Button("Extract Settings", size ="sm") video_info_to_control_video_btn = gr.Button("To Control Video", size ="sm", visible = any_control_video ) video_info_to_video_source_btn = gr.Button("To Video Source", size ="sm", visible = any_video_source) video_info_eject_video_btn = gr.Button("Eject Video", size ="sm") - with gr.Tab("Post Processing", id= "post_processing") as video_postprocessing_tab: + with gr.Row(**default_visibility) as image_buttons_row: + video_info_to_start_image_btn = gr.Button("To Start Image", size ="sm", visible = any_start_image ) + video_info_to_end_image_btn = gr.Button("To End Image", size ="sm", visible = any_end_image) + video_info_to_reference_image_btn = gr.Button("To Reference Image", size ="sm", visible = any_reference_image) + video_info_eject_image_btn = gr.Button("Eject Image", size ="sm") + with gr.Tab("Post Processing", id= "post_processing", visible = True) as video_postprocessing_tab: with gr.Group(elem_classes= "postprocess"): with gr.Column(): PP_temporal_upsampling, PP_spatial_upsampling = gen_upsampling_dropdowns("", "", element_class ="postprocess") @@ -6990,10 +7036,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non PP_repeat_generation = gr.Slider(1, 25.0, value=1, step=1, label="Number of Sample Videos to Generate") video_info_postprocessing_btn = gr.Button("Apply Postprocessing", size ="sm", visible=True) - with gr.Tab("Add Videos", id= "video_add"): + with gr.Tab("Add Videos / Images", id= "video_add"): files_to_load = gr.Files(label= "Files to Load in Gallery", height=120) with gr.Row(): - video_info_add_videos_btn = gr.Button("Add Videos", size ="sm") + video_info_add_videos_btn = gr.Button("Add Videos / Images", size ="sm") if not update_form: @@ -7046,6 +7092,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non video_prompt_type_video_guide, video_prompt_type_video_mask, video_prompt_type_image_refs, apg_col, audio_prompt_type_sources, audio_prompt_type_remux_row, video_guide_outpainting_col,video_guide_outpainting_top, video_guide_outpainting_bottom, video_guide_outpainting_left, video_guide_outpainting_right, video_guide_outpainting_checkbox, video_guide_outpainting_row, show_advanced, video_info_to_control_video_btn, video_info_to_video_source_btn, sample_solver_row, + video_buttons_row, image_buttons_row, video_postprocessing_tab, video_info_to_start_image_btn, video_info_to_end_image_btn, video_info_to_reference_image_btn, NAG_col] # presets_column, if update_form: locals_dict = locals() @@ -7072,7 +7119,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name]).then( fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars]) queue_df.select( fn=handle_celll_selection, inputs=state, outputs=[queue_df, modal_image_display, modal_container]) - output.select(select_video, [state, output], outputs=[last_choice, video_info, video_buttons_row, video_postprocessing_tab] ) + output.select(select_video, [state, output], outputs=[last_choice, video_info, video_buttons_row, image_buttons_row, video_postprocessing_tab] ) preview_trigger.change(refresh_preview, inputs= [state], outputs= [preview]) def refresh_status_async(state, progress=gr.Progress()): @@ -7128,9 +7175,12 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ).then( fn=use_video_settings, inputs =[state, output, last_choice] , outputs= [model_choice, refresh_form_trigger]) video_info_add_videos_btn.click(fn=add_videos_to_gallery, inputs =[state, output, last_choice, files_to_load], outputs = [output, files_to_load, video_info_tabs] ) - video_info_eject_video_btn.click(fn=eject_video_from_gallery, inputs =[state, output, last_choice], outputs = [output, video_info, video_buttons_row] ) + gr.on(triggers=[video_info_eject_video_btn.click, video_info_eject_image_btn.click], fn=eject_video_from_gallery, inputs =[state, output, last_choice], outputs = [output, video_info, video_buttons_row] ) video_info_to_control_video_btn.click(fn=video_to_control_video, inputs =[state, output, last_choice], outputs = [video_guide] ) video_info_to_video_source_btn.click(fn=video_to_source_video, inputs =[state, output, last_choice], outputs = [video_source] ) + video_info_to_start_image_btn.click(fn=image_to_ref_image, inputs =[state, output, last_choice, image_start, gr.State("Start Image")], outputs = [image_start] ) + video_info_to_end_image_btn.click(fn=image_to_ref_image, inputs =[state, output, last_choice, image_end, gr.State("End Image")], outputs = [image_end] ) + video_info_to_reference_image_btn.click(fn=image_to_ref_image, inputs =[state, output, last_choice, image_refs, gr.State("Ref Image")], outputs = [image_refs] ) video_info_postprocessing_btn.click(fn=apply_post_processing, inputs =[state, output, last_choice, PP_temporal_upsampling, PP_spatial_upsampling, PP_MMAudio_setting, PP_MMAudio_prompt, PP_MMAudio_neg_prompt, PP_MMAudio_seed, PP_repeat_generation], outputs = [mode, generate_trigger, add_to_queue_trigger ] ) save_lset_btn.click(validate_save_lset, inputs=[state, lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop]) delete_lset_btn.click(validate_delete_lset, inputs=[state, lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ]) @@ -7681,7 +7731,7 @@ def generate_info_tab(): def get_sorted_dropdown(dropdown_types): - families_order = {"wan":0, "ltxv":1, "hunyuan":2 } + families_order = {"wan":0, "ltxv":1, "hunyuan":2, "flux": 3 } dropdown_classes = [ families_order[get_model_family(type)] for type in dropdown_types] dropdown_names = [ get_model_name(type) for type in dropdown_types] From 37f41804a68b71a17097f8046df4fbf88883657b Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Sun, 13 Jul 2025 04:29:46 +0200 Subject: [PATCH 2/9] loras flux --- loras_flux/readme.txt | 1 + 1 file changed, 1 insertion(+) create mode 100644 loras_flux/readme.txt diff --git a/loras_flux/readme.txt b/loras_flux/readme.txt new file mode 100644 index 0000000..2424800 --- /dev/null +++ b/loras_flux/readme.txt @@ -0,0 +1 @@ +flux loras go here \ No newline at end of file From 64c59c15d9afdc0e8f339ce4140a7b5d773bab5f Mon Sep 17 00:00:00 2001 From: deepbeepmeep Date: Tue, 15 Jul 2025 22:26:56 +0200 Subject: [PATCH 3/9] Flux Kontext and more --- README.md | 20 + defaults/flux_dev_kontext.json | 9 +- defaults/t2i.json | 13 + defaults/vace_14B_fusionix.json | 2 +- defaults/vace_14B_fusionix_t2i.json | 16 + docs/FINETUNES.md | 71 ++-- docs/LORAS.md | 66 ++-- docs/MODELS.md | 26 +- finetunes/put your finetunes here.txt | 0 flux/flux_main.py | 4 +- postprocessing/film_grain.py | 21 + preprocessing/matanyone/app.py | 85 ++++- wan/any2video.py | 106 +++--- wan/diffusion_forcing.py | 50 ++- wan/modules/attention.py | 6 +- wan/modules/model.py | 38 +- wan/multitalk/attention.py | 4 +- wan/multitalk/multitalk.py | 34 +- wan/multitalk/multitalk_utils.py | 2 +- wan/utils/utils.py | 26 +- wgp.py | 527 ++++++++++++++++---------- 21 files changed, 734 insertions(+), 392 deletions(-) create mode 100644 defaults/t2i.json create mode 100644 defaults/vace_14B_fusionix_t2i.json create mode 100644 finetunes/put your finetunes here.txt create mode 100644 postprocessing/film_grain.py diff --git a/README.md b/README.md index 85c68ac..004b7c3 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,26 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep ## 🔥 Latest Updates +### July 15 2025: WanGP v7.0 is an AI Powered Photoshop +This release turns the Wan models into Image Generators. This goes way more than allowing to generate a video made of single frame : +- Multiple Images generated at the same time so that you can choose the one you like best.It is Highly VRAM optimized so that you can generate for instance 4 720p Images at the same time with less than 10 GB +- With the *image2image* the original text2video WanGP becomes an image upsampler / restorer +- *Vace image2image* comes out of the box with image outpainting, person / object replacement, ... +- You can use in one click a newly Image generated as Start Image or Reference Image for a Video generation + +And to complete the full suite of AI Image Generators, Ladies and Gentlemen please welcome for the first time in WanGP : **Flux Kontext**.\ +As a reminder Flux Kontext is an image editor : give it an image and a prompt and it will do the change for you.\ +This highly optimized version of Flux Kontext will make you feel that you have been cheated all this time as WanGP Flux Kontext requires only 8 GB of VRAM to generate 4 images at the same time with no need for quantization. + +WanGP v7 comes with *Image2image* vanilla and *Vace FusinoniX*. However you can build your own finetune where you will combine a text2video or Vace model with any combination of Loras. + +Also in the news: +- You can now enter the *Bbox* for each speaker in *Multitalk* to precisely locate who is speaking. And to save some headaches the *Image Mask generator* will give you the *Bbox* coordinates of an area you have selected. +- *Film Grain* post processing to add a vintage look at your video +- *First Last Frame to Video* model should work much better now as I have discovered rencently its implementation was not complete +- More power for the finetuners, you can now embed Loras directly in the finetune definition. You can also override the default models (titles, visibility, ...) with your own finetunes. Check the doc that has been updated. + + ### July 10 2025: WanGP v6.7, is NAG a game changer ? you tell me Maybe you knew that already but most *Loras accelerators* we use today (Causvid, FusioniX) don't use *Guidance* at all (that it is *CFG* is set to 1). This helps to get much faster generations but the downside is that *Negative Prompts* are completely ignored (including the default ones set by the models). **NAG** (https://github.com/ChenDarYen/Normalized-Attention-Guidance) aims to solve that by injecting the *Negative Prompt* during the *attention* processing phase. diff --git a/defaults/flux_dev_kontext.json b/defaults/flux_dev_kontext.json index d8efcd9..14006b1 100644 --- a/defaults/flux_dev_kontext.json +++ b/defaults/flux_dev_kontext.json @@ -2,18 +2,15 @@ "model": { "name": "Flux Dev Kontext 12B", "architecture": "flux_dev_kontext", - "description": "FLUX.1 Kontext is a 12 billion parameter rectified flow transformer capable of editing images based on instructions stored in the Prompt. Please be aware that the output resolution is modified by Flux Kontext and may not be what you requested.", + "description": "FLUX.1 Kontext is a 12 billion parameter rectified flow transformer capable of editing images based on instructions stored in the Prompt. Please be aware that Flux Kontext is picky on the resolution of the input image the output dimensions may not match the dimensions of the input image.", "URLs": [ - "c:/temp/kontext/flux1_kontext_dev_bf16.safetensors", - "c:/temp/kontext/flux1_kontext_dev_quanto_bf16_int8.safetensors" - ], - "URLs2": [ "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_bf16.safetensors", "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_quanto_bf16_int8.safetensors" ] }, + "prompt": "add a hat", "resolution": "1280x720", - "video_length": "1" + "video_length": 1 } \ No newline at end of file diff --git a/defaults/t2i.json b/defaults/t2i.json new file mode 100644 index 0000000..f49f426 --- /dev/null +++ b/defaults/t2i.json @@ -0,0 +1,13 @@ +{ + "model": { + "name": "Wan2.1 text2image 14B", + "architecture": "t2v", + "description": "The original Wan Text 2 Video model configured to generate an image instead of a video.", + "image_outputs": true, + "URLs": "t2v" + }, + "video_length": 1, + "resolution": "1280x720" +} + + \ No newline at end of file diff --git a/defaults/vace_14B_fusionix.json b/defaults/vace_14B_fusionix.json index 99b07d1..44c048c 100644 --- a/defaults/vace_14B_fusionix.json +++ b/defaults/vace_14B_fusionix.json @@ -15,7 +15,7 @@ "seed": -1, "num_inference_steps": 10, "guidance_scale": 1, - "flow_shift": 5, + "flow_shift": 2, "embedded_guidance_scale": 6, "repeat_generation": 1, "multi_images_gen_type": 0, diff --git a/defaults/vace_14B_fusionix_t2i.json b/defaults/vace_14B_fusionix_t2i.json new file mode 100644 index 0000000..75fbf42 --- /dev/null +++ b/defaults/vace_14B_fusionix_t2i.json @@ -0,0 +1,16 @@ +{ + "model": { + "name": "Vace FusioniX image2image 14B", + "architecture": "vace_14B", + "modules": [ + "vace_14B" + ], + "image_outputs": true, + "description": "Vace control model enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail.", + "URLs": "t2v_fusionix" + }, + "resolution": "1280x720", + "guidance_scale": 1, + "num_inference_steps": 10, + "video_length": 1 +} \ No newline at end of file diff --git a/docs/FINETUNES.md b/docs/FINETUNES.md index 28823f4..1c9ee6b 100644 --- a/docs/FINETUNES.md +++ b/docs/FINETUNES.md @@ -2,22 +2,30 @@ A Finetuned model is model that shares the same architecture of one specific model but has derived weights from this model. Some finetuned models have been created by combining multiple finetuned models. -As there are potentially an infinite number of finetunes, specific finetuned models are not known by default by WanGP, however you can create a finetuned model definition that will tell WanGP about the existence of this finetuned model and WanGP will do as usual all the work for you: autodownload the model and build the user interface. +As there are potentially an infinite number of finetunes, specific finetuned models are not known by default by WanGP. However you can create a finetuned model definition that will tell WanGP about the existence of this finetuned model and WanGP will do as usual all the work for you: autodownload the model and build the user interface. + +WanGP finetune system can be also used to tweak default models : for instance you can add on top of an existing model some loras that will be always applied transparently. Finetune models definitions are light json files that can be easily shared. You can find some of them on the WanGP *discord* server https://discord.gg/g7efUW9jGV +All the finetunes definitions files should be stored in the *finetunes/* subfolder. + Finetuned models have been tested so far with Wan2.1 text2video, Wan2.1 image2video, Hunyuan Video text2video. There isn't currently any support for LTX Video finetunes. -## Create a new Finetune Model Definition -All the finetune models definitions are json files stored in the **finetunes** sub folder. All the corresponding finetune model weights will be stored in the *ckpts* subfolder and will sit next to the base models. -WanGP comes with a few prebuilt finetune models that you can use as starting points and to get an idea of the structure of the definition file. + +## Create a new Finetune Model Definition +All the finetune models definitions are json files stored in the **finetunes/** sub folder. All the corresponding finetune model weights when they are downloaded will be stored in the *ckpts/* subfolder and will sit next to the base models. + +All the models used by WanGP are also described using the finetunes json format and can be found in the **defaults/** subfolder. Please don’t modify any file in the **defaults/** folder. + +However you can use these files as starting points for new definition files and to get an idea of the structure of a definition file. If you want to change how a base model is handled (title, default settings, path to model weights, …) you may override any property of the default finetunes definition file by creating a new file in the finetunes folder with the same name. Everything will happen as if the two models will be merged property by property with a higher priority given to the finetunes model definition. A definition is built from a *settings file* that can contains all the default parameters for a video generation. On top of this file a subtree named **model** contains all the information regarding the finetune (URLs to download model, corresponding base model id, ...). You can obtain a settings file in several ways: - In the subfolder **settings**, get the json file that corresponds to the base model of your finetune (see the next section for the list of ids of base models) -- From the user interface, go to the base model and click **export settings** +- From the user interface, select the base model for which you want to create a finetune and click **export settings** Here are steps: 1) Create a *settings file* @@ -26,45 +34,60 @@ Here are steps: 4) Restart WanGP ## Architecture Models Ids -A finetune is derived from a base model and will inherit all the user interface and corresponding model capabilities, here are Architecture Ids: -- *t2v*: Wan 2.1 Video text 2 -- *i2v*: Wan 2.1 Video image 2 480p -- *i2v_720p*: Wan 2.1 Video image 2 720p +A finetune is derived from a base model and will inherit all the user interface and corresponding model capabilities, here are some Architecture Ids: +- *t2v*: Wan 2.1 Video text 2 video +- *i2v*: Wan 2.1 Video image 2 video 480p and 720p - *vace_14B*: Wan 2.1 Vace 14B - *hunyuan*: Hunyuan Video text 2 video - *hunyuan_i2v*: Hunyuan Video image 2 video +Any file name in the defaults subfolder (without the json extension) corresponds to an architecture id. + +Please note that weights of some architectures correspond to a combination of weight of a one architecture which are completed by the weights of one more or modules. + +A module is a set a weights that are insufficient to be model by itself but that can be added to an existing model to extend its capabilities. + +For instance if one adds a module *vace_14B* on top of a model with architecture *t2v* one gets get a model with the *vace_14B* architecture. Here *vace_14B* stands for both an architecture name and a module name. The module system allows you to reuse shared weights between models. + + ## The Model Subtree - *name* : name of the finetune used to select - *architecture* : architecture Id of the base model of the finetune (see previous section) - *description*: description of the finetune that will appear at the top - *URLs*: URLs of all the finetune versions (quantized / non quantized). WanGP will pick the version that is the closest to the user preferences. You will need to follow a naming convention to help WanGP identify the content of each version (see next section). Right now WanGP supports only 8 bits quantized model that have been quantized using **quanto**. WanGP offers a command switch to build easily such a quantized model (see below). *URLs* can contain also paths to local file to allow testing. -- *modules*: this a list of modules to be combined with the models referenced by the URLs. A module is a model extension that is merged with a model to expand its capabilities. So far the only module supported is Vace 14B (its id is *vace_14B*). For instance the full Vace model is the fusion of a Wan text 2 video and the Vace module. +- *modules*: this a list of modules to be combined with the models referenced by the URLs. A module is a model extension that is merged with a model to expand its capabilities. Supported models so far are : *vace_14B* and *multitalk*. For instance the full Vace model is the fusion of a Wan text 2 video and the Vace module. - *preload_URLs* : URLs of files to download no matter what (used to load quantization maps for instance) +-*loras* : URLs of Loras that will applied before any other Lora specified by the user. These loras will be quite often Loras accelerator. For instance if you specified here the FusioniX Lora you will be able to reduce the number of generation steps to -*loras_multipliers* : a list of float numbers that defines the weight of each Lora mentioned above. - *auto_quantize*: if set to True and no quantized model URL is provided, WanGP will perform on the fly quantization if the user expects a quantized model +-*visible* : by default assumed to be true. If set to false the model will no longer be visible. This can be useful if you create a finetune to override a default model and hide it. +-*image_outputs* : turn any model that generates a video into a model that generates images. In fact it will adapt the user interface for image generation and ask the model to generate a video with a single frame. + +In order to favor reusability the properties of *URLs*, *modules*, *loras* and *preload_URLs* can contain instead of a list of URLs a single text which corresponds to the id of a finetune or default model to reuse. + +For example let’s say you have defined a *t2v_fusionix.json* file which contains the URLs to download the finetune. In the *vace_fusionix.json* you can write « URLs » : « fusionix » to reuse automatically the URLS already defined in the correspond file. Example of **model** subtree ``` - "model": - { - "name": "Wan text2video FusioniX 14B", - "architecture" : "t2v", - "description": "A powerful merged text-to-video model based on the original WAN 2.1 T2V model, enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail. multiple open-source models and LoRAs to boost temporal quality, expressiveness, and motion realism.", - "URLs": [ - "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_fp16.safetensors", - "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_fp16_int8.safetensors", - "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_bf16_int8.safetensors" - ], + "model": + { + "name": "Wan text2video FusioniX 14B", + "architecture" : "t2v", + "description": "A powerful merged text-to-video model based on the original WAN 2.1 T2V model, enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail. multiple open-source models and LoRAs to boost temporal quality, expressiveness, and motion realism.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_fp16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_fp16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_bf16_int8.safetensors" + ], "preload_URLs": [ ], - "auto_quantize": true - }, + "auto_quantize": true + }, ``` ## Finetune Model Naming Convention If a model is not quantized, it is assumed to be mostly 16 bits (with maybe a few 32 bits weights), so *bf16* or *fp16* should appear somewhere in the name. If you need examples just look at the **ckpts** subfolder, the naming convention for the base models is the same. -If a model is quantized the term *quanto* should also be included since WanGP supports for the moment only *quanto* quantized model, most specically you should replace *fp16* by *quanto_fp16_int8* or *bf6* by *quanto_bf16_int8*. +If a model is quantized the term *quanto* should also be included since WanGP supports for the moment only *quanto* quantized model, most specically you should replace *fp16* by *quanto_fp16_int8* or *bf6* by *quanto_bf16_int8*. Please note it is important than *bf16", "fp16* and *quanto* are all in lower cases letters. @@ -82,4 +105,4 @@ If you launch the app with the *--save-quantized* switch, WanGP will create a qu You need to create a quantized model specifically for *bf16* or *fp16* as they can not converted on the fly. However there is no need for a non quantized model as they can be converted on the fly while being loaded. -Wan models supports both *fp16* and *bf16* data types albeit *fp16* delivers in theory better quality. On the contrary Hunyuan and LTXV supports only *bf16*. \ No newline at end of file +Wan models supports both *fp16* and *bf16* data types albeit *fp16* delivers in theory better quality. On the contrary Hunyuan and LTXV supports only *bf16*. diff --git a/docs/LORAS.md b/docs/LORAS.md index be53458..0b2d034 100644 --- a/docs/LORAS.md +++ b/docs/LORAS.md @@ -6,18 +6,19 @@ Loras (Low-Rank Adaptations) allow you to customize video generation models by a Loras are organized in different folders based on the model they're designed for: -### Text-to-Video Models +### Wan Text-to-Video Models - `loras/` - General t2v loras - `loras/1.3B/` - Loras specifically for 1.3B models - `loras/14B/` - Loras specifically for 14B models -### Image-to-Video Models +### Wan Image-to-Video Models - `loras_i2v/` - Image-to-video loras ### Other Models - `loras_hunyuan/` - Hunyuan Video t2v loras - `loras_hunyuan_i2v/` - Hunyuan Video i2v loras - `loras_ltxv/` - LTX Video loras +- `loras_flux/` - Flux loras ## Custom Lora Directory @@ -64,7 +65,7 @@ For dynamic effects over generation steps, use comma-separated values: ## Lora Presets -Presets are combinations of loras with predefined multipliers and prompts. +Lora Presets are combinations of loras with predefined multipliers and prompts. ### Creating Presets 1. Configure your loras and multipliers @@ -95,17 +96,37 @@ WanGP supports multiple lora formats: - **Replicate** format - **Standard PyTorch** (.pt, .pth) -## Safe-Forcing lightx2v Lora (Video Generation Accelerator) -Safeforcing Lora has been created by Kijai from the Safe-Forcing lightx2v distilled Wan model and can generate videos with only 2 steps and offers also a 2x speed improvement since it doesnt require classifier free guidance. It works on both t2v and i2v models +## Loras Accelerators +Most Loras are used to apply a specific style or to alter the content of the output of the generated video. +However some Loras have been designed to tranform a model into a distilled model which requires fewer steps to generate a video. + +You will find most *Loras Accelerators* here: +https://huggingface.co/DeepBeepMeep/Wan2.1/tree/main/loras_accelerators ### Setup Instructions -1. Download the Lora: - ``` - https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32.safetensors - ``` -2. Place in your `loras/` directory +1. Download the Lora +2. Place it in your `loras/` directory if it is a t2v lora or in the `loras_i2v/` directory if it isa i2v lora +## FusioniX (or FusionX) Lora +If you need just one Lora accelerator use this one. It is a combination of multiple Loras acelerators (including Causvid below) and style loras. It will not only accelerate the video generation but it will also improve the quality. There are two versions of this lora whether you use it for t2v or i2v + +### Usage +1. Select a Wan t2v model (e.g., Wan 2.1 text2video 13B or Vace 13B) +2. Enable Advanced Mode +3. In Advanced Generation Tab: + - Set Guidance Scale = 1 + - Set Shift Scale = 2 +4. In Advanced Lora Tab: + - Select CausVid Lora + - Set multiplier to 1 +5. Set generation steps from 8-10 +6. Generate! + +## Safe-Forcing lightx2v Lora (Video Generation Accelerator) +Safeforcing Lora has been created by Kijai from the Safe-Forcing lightx2v distilled Wan model and can generate videos with only 2 steps and offers also a 2x speed improvement since it doesnt require classifier free guidance. It works on both t2v and i2v models +You will find it under the name of *Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32.safetensors* + ### Usage 1. Select a Wan t2v or i2v model (e.g., Wan 2.1 text2video 13B or Vace 13B) 2. Enable Advanced Mode @@ -118,17 +139,10 @@ Safeforcing Lora has been created by Kijai from the Safe-Forcing lightx2v distil 5. Set generation steps to 2-8 6. Generate! + ## CausVid Lora (Video Generation Accelerator) - CausVid is a distilled Wan model that generates videos in 4-12 steps with 2x speed improvement. -### Setup Instructions -1. Download the CausVid Lora: - ``` - https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_CausVid_14B_T2V_lora_rank32.safetensors - ``` -2. Place in your `loras/` directory - ### Usage 1. Select a Wan t2v model (e.g., Wan 2.1 text2video 13B or Vace 13B) 2. Enable Advanced Mode @@ -149,25 +163,10 @@ CausVid is a distilled Wan model that generates videos in 4-12 steps with 2x spe *Note: Lower steps = lower quality (especially motion)* - ## AccVid Lora (Video Generation Accelerator) AccVid is a distilled Wan model that generates videos with a 2x speed improvement since classifier free guidance is no longer needed (that is cfg = 1). -### Setup Instructions -1. Download the AccVid Lora: - -- for t2v models: - ``` - https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_AccVid_T2V_14B_lora_rank32_fp16.safetensors - ``` - -- for i2v models: - ``` - https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_AccVid_I2V_480P_14B_lora_rank32_fp16.safetensors - ``` - -2. Place in your `loras/` directory or `loras_i2v/` directory ### Usage 1. Select a Wan t2v model (e.g., Wan 2.1 text2video 13B or Vace 13B) or Wan i2v model @@ -268,6 +267,7 @@ In the video, a man is presented. The man is in a city and looks at his watch. --lora-dir-hunyuan path # Path to Hunyuan t2v loras --lora-dir-hunyuan-i2v path # Path to Hunyuan i2v loras --lora-dir-ltxv path # Path to LTX Video loras +--lora-dir-flux path # Path to Flux loras --lora-preset preset # Load preset on startup --check-loras # Filter incompatible loras ``` \ No newline at end of file diff --git a/docs/MODELS.md b/docs/MODELS.md index c8187be..720cb73 100644 --- a/docs/MODELS.md +++ b/docs/MODELS.md @@ -2,6 +2,8 @@ WanGP supports multiple video generation models, each optimized for different use cases and hardware configurations. +Most models can combined with Loras Accelerators (check the Lora guide) to accelerate the generation of a video x2 or x3 with little quality loss + ## Wan 2.1 Text2Video Models Please note that that the term *Text2Video* refers to the underlying Wan architecture but as it has been greatly improved overtime many derived Text2Video models can now generate videos using images. @@ -65,6 +67,12 @@ Please note that that the term *Text2Video* refers to the underlying Wan archite ## Wan 2.1 Specialized Models +#### Multitalk +- **Type**: Multi Talking head animation +- **Input**: Voice track + image +- **Works on**: People +- **Use case**: Lip-sync and voice-driven animation for up to two people + #### FantasySpeaking - **Type**: Talking head animation - **Input**: Voice track + image @@ -82,7 +90,7 @@ Please note that that the term *Text2Video* refers to the underlying Wan archite - **Requirements**: 81+ frame input videos, 15+ denoising steps - **Use case**: View same scene from different angles -#### Sky Reels v2 +#### Sky Reels v2 Diffusion - **Type**: Diffusion Forcing model - **Specialty**: "Infinite length" videos - **Features**: High quality continuous generation @@ -107,22 +115,6 @@ Please note that that the term *Text2Video* refers to the underlying Wan archite
-## Wan Special Loras -### Safe-Forcing lightx2v Lora -- **Type**: Distilled model (Lora implementation) -- **Speed**: 4-8 steps generation, 2x faster (no classifier free guidance) -- **Compatible**: Works with t2v and i2v Wan 14B models -- **Setup**: Requires Safe-Forcing lightx2v Lora (see [LORAS.md](LORAS.md)) - - -### Causvid Lora -- **Type**: Distilled model (Lora implementation) -- **Speed**: 4-12 steps generation, 2x faster (no classifier free guidance) -- **Compatible**: Works with Wan 14B models -- **Setup**: Requires CausVid Lora (see [LORAS.md](LORAS.md)) - - -
## Hunyuan Video Models diff --git a/finetunes/put your finetunes here.txt b/finetunes/put your finetunes here.txt new file mode 100644 index 0000000..e69de29 diff --git a/flux/flux_main.py b/flux/flux_main.py index f4b1994..b782cc9 100644 --- a/flux/flux_main.py +++ b/flux/flux_main.py @@ -65,7 +65,7 @@ class model_factory: fit_into_canvas = None, callback = None, loras_slists = None, - frame_num = 1, + batch_size = 1, **bbargs ): @@ -89,7 +89,7 @@ class model_factory: img_cond=image_ref, target_width=width, target_height=height, - bs=frame_num, + bs=batch_size, seed=seed, device="cuda", ) diff --git a/postprocessing/film_grain.py b/postprocessing/film_grain.py new file mode 100644 index 0000000..a38b43a --- /dev/null +++ b/postprocessing/film_grain.py @@ -0,0 +1,21 @@ +# Thanks to https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/film_grain.py +import torch + +def add_film_grain(images: torch.Tensor, grain_intensity: float = 0, saturation: float = 0.5): + device = images.device + + images = images.permute(1, 2 ,3 ,0) + images.add_(1.).div_(2.) + grain = torch.randn_like(images, device=device) + grain[:, :, :, 0] *= 2 + grain[:, :, :, 2] *= 3 + grain = grain * saturation + grain[:, :, :, 1].unsqueeze(3).repeat( + 1, 1, 1, 3 + ) * (1 - saturation) + + # Blend the grain with the image + noised_images = images + grain_intensity * grain + noised_images.clamp_(0, 1) + noised_images.sub_(.5).mul_(2.) + noised_images = noised_images.permute(3, 0, 1 ,2) + return noised_images diff --git a/preprocessing/matanyone/app.py b/preprocessing/matanyone/app.py index a6146d8..c8ea190 100644 --- a/preprocessing/matanyone/app.py +++ b/preprocessing/matanyone/app.py @@ -65,6 +65,7 @@ def get_frames_from_image(image_input, image_state): Return [[0:nearest_frame], [nearest_frame:], nearest_frame] """ + load_sam() user_name = time.time() frames = [image_input] * 2 # hardcode: mimic a video with 2 frames @@ -89,7 +90,7 @@ def get_frames_from_image(image_input, image_state): gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True), gr.update(visible=True),\ gr.update(visible=True), gr.update(visible=True), \ - gr.update(visible=True), gr.update(visible=False), \ + gr.update(visible=True), gr.update(value="", visible=True), gr.update(visible=False), \ gr.update(visible=False), gr.update(visible=True), \ gr.update(visible=True) @@ -103,6 +104,8 @@ def get_frames_from_video(video_input, video_state): [[0:nearest_frame], [nearest_frame:], nearest_frame] """ + load_sam() + while model == None: time.sleep(1) @@ -273,6 +276,20 @@ def save_video(frames, output_path, fps): return output_path +def mask_to_xyxy_box(mask): + rows, cols = np.where(mask == 255) + xmin = min(cols) + xmax = max(cols) + 1 + ymin = min(rows) + ymax = max(rows) + 1 + xmin = max(xmin, 0) + ymin = max(ymin, 0) + xmax = min(xmax, mask.shape[1]) + ymax = min(ymax, mask.shape[0]) + box = [xmin, ymin, xmax, ymax] + box = [int(x) for x in box] + return box + # image matting def image_matting(video_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size, refine_iter): matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg) @@ -320,9 +337,17 @@ def image_matting(video_state, interactive_state, mask_dropdown, erode_kernel_si foreground = output_frames foreground_output = Image.fromarray(foreground[-1]) - alpha_output = Image.fromarray(alpha[-1][:,:,0]) - - return foreground_output, gr.update(visible=True) + alpha_output = alpha[-1][:,:,0] + frame_temp = alpha_output.copy() + alpha_output[frame_temp > 127] = 0 + alpha_output[frame_temp <= 127] = 255 + bbox_info = mask_to_xyxy_box(alpha_output) + h = alpha_output.shape[0] + w = alpha_output.shape[1] + bbox_info = [str(int(bbox_info[0]/ w * 100 )), str(int(bbox_info[1]/ h * 100 )), str(int(bbox_info[2]/ w * 100 )), str(int(bbox_info[3]/ h * 100 )) ] + bbox_info = ":".join(bbox_info) + alpha_output = Image.fromarray(alpha_output) + return foreground_output, alpha_output, bbox_info, gr.update(visible=True), gr.update(visible=True) # video matting def video_matting(video_state, end_slider, matting_type, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size): @@ -469,6 +494,13 @@ def restart(): gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ gr.update(visible=False), gr.update(visible=False, choices=[], value=[]), "", gr.update(visible=False) +def load_sam(): + global model_loaded + global model + global matanyone_model + model.samcontroler.sam_controler.model.to(arg_device) + matanyone_model.to(arg_device) + def load_unload_models(selected): global model_loaded global model @@ -476,8 +508,7 @@ def load_unload_models(selected): if selected: # print("Matanyone Tab Selected") if model_loaded: - model.samcontroler.sam_controler.model.to(arg_device) - matanyone_model.to(arg_device) + load_sam() else: # args, defined in track_anything.py sam_checkpoint_url_dict = { @@ -522,12 +553,16 @@ def export_to_vace_video_input(foreground_video_output): def export_image(image_refs, image_output): gr.Info("Masked Image transferred to Current Video") - # return "MV#" + str(time.time()), foreground_video_output, alpha_video_output if image_refs == None: image_refs =[] image_refs.append( image_output) return image_refs +def export_image_mask(image_input, image_mask): + gr.Info("Input Image & Mask transferred to Current Video") + return Image.fromarray(image_input), image_mask + + def export_to_current_video_engine(model_type, foreground_video_output, alpha_video_output): gr.Info("Original Video and Full Mask have been transferred") # return "MV#" + str(time.time()), foreground_video_output, alpha_video_output @@ -543,7 +578,7 @@ def teleport_to_video_tab(tab_state): return gr.Tabs(selected="video_gen") -def display(tabs, tab_state, model_choice, vace_video_input, vace_video_mask, vace_image_refs): +def display(tabs, tab_state, model_choice, vace_video_input, vace_image_input, vace_video_mask, vace_image_mask, vace_image_refs): # my_tab.select(fn=load_unload_models, inputs=[], outputs=[]) media_url = "https://github.com/pq-yang/MatAnyone/releases/download/media/" @@ -677,7 +712,7 @@ def display(tabs, tab_state, model_choice, vace_video_input, vace_video_mask, va foreground_output_button = gr.Button(value="Black & White Video Output", visible=False, elem_classes="new_button") with gr.Column(scale=2): alpha_video_output = gr.Video(label="B & W Mask Video Output", visible=False, elem_classes="video") - alpha_output_button = gr.Button(value="Alpha Mask Output", visible=False, elem_classes="new_button") + export_image_mask_btn = gr.Button(value="Alpha Mask Output", visible=False, elem_classes="new_button") with gr.Row(): with gr.Row(visible= False): export_to_vace_video_14B_btn = gr.Button("Export to current Video Input Video For Inpainting", visible= False) @@ -696,7 +731,7 @@ def display(tabs, tab_state, model_choice, vace_video_input, vace_video_mask, va ], outputs=[video_state, video_info, template_frame, image_selection_slider, end_selection_slider, track_pause_number_slider, point_prompt, matting_type, clear_button_click, add_mask_button, matting_button, template_frame, - foreground_video_output, alpha_video_output, foreground_output_button, alpha_output_button, mask_dropdown, step2_title] + foreground_video_output, alpha_video_output, foreground_output_button, export_image_mask_btn, mask_dropdown, step2_title] ) # second step: select images from slider @@ -755,7 +790,7 @@ def display(tabs, tab_state, model_choice, vace_video_input, vace_video_mask, va foreground_video_output, alpha_video_output, template_frame, image_selection_slider, end_selection_slider, track_pause_number_slider,point_prompt, export_to_vace_video_14B_btn, export_to_current_video_engine_btn, matting_type, clear_button_click, - add_mask_button, matting_button, template_frame, foreground_video_output, alpha_video_output, remove_mask_button, foreground_output_button, alpha_output_button, mask_dropdown, video_info, step2_title + add_mask_button, matting_button, template_frame, foreground_video_output, alpha_video_output, remove_mask_button, foreground_output_button, export_image_mask_btn, mask_dropdown, video_info, step2_title ], queue=False, show_progress=False) @@ -770,7 +805,7 @@ def display(tabs, tab_state, model_choice, vace_video_input, vace_video_mask, va foreground_video_output, alpha_video_output, template_frame, image_selection_slider , end_selection_slider, track_pause_number_slider,point_prompt, export_to_vace_video_14B_btn, export_to_current_video_engine_btn, matting_type, clear_button_click, - add_mask_button, matting_button, template_frame, foreground_video_output, alpha_video_output, remove_mask_button, foreground_output_button, alpha_output_button, mask_dropdown, video_info, step2_title + add_mask_button, matting_button, template_frame, foreground_video_output, alpha_video_output, remove_mask_button, foreground_output_button, export_image_mask_btn, mask_dropdown, video_info, step2_title ], queue=False, show_progress=False) @@ -872,15 +907,19 @@ def display(tabs, tab_state, model_choice, vace_video_input, vace_video_mask, va # output image with gr.Row(equal_height=True): foreground_image_output = gr.Image(type="pil", label="Foreground Output", visible=False, elem_classes="image") + alpha_image_output = gr.Image(type="pil", label="Mask", visible=False, elem_classes="image") + with gr.Row(equal_height=True): + bbox_info = gr.Text(label ="Mask BBox Info (Left:Top:Right:Bottom)", interactive= False) with gr.Row(): - with gr.Row(): - export_image_btn = gr.Button(value="Add to current Reference Images", visible=False, elem_classes="new_button") - with gr.Column(scale=2, visible= False): - alpha_image_output = gr.Image(type="pil", label="Alpha Output", visible=False, elem_classes="image") - alpha_output_button = gr.Button(value="Alpha Mask Output", visible=False, elem_classes="new_button") + # with gr.Row(): + export_image_btn = gr.Button(value="Add to current Reference Images", visible=False, elem_classes="new_button") + # with gr.Column(scale=2, visible= True): + export_image_mask_btn = gr.Button(value="Set to Control Image & Mask", visible=False, elem_classes="new_button") export_image_btn.click( fn=export_image, inputs= [vace_image_refs, foreground_image_output], outputs= [vace_image_refs]).then( #video_prompt_video_guide_trigger, fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs]) + export_image_mask_btn.click( fn=export_image_mask, inputs= [image_input, alpha_image_output], outputs= [vace_image_input, vace_image_mask]).then( #video_prompt_video_guide_trigger, + fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs]) # first step: get the image information extract_frames_button.click( @@ -890,9 +929,17 @@ def display(tabs, tab_state, model_choice, vace_video_input, vace_video_mask, va ], outputs=[image_state, image_info, template_frame, image_selection_slider, track_pause_number_slider,point_prompt, clear_button_click, add_mask_button, matting_button, template_frame, - foreground_image_output, alpha_image_output, export_image_btn, alpha_output_button, mask_dropdown, step2_title] + foreground_image_output, alpha_image_output, bbox_info, export_image_btn, export_image_mask_btn, mask_dropdown, step2_title] ) + # points clear + clear_button_click.click( + fn = clear_click, + inputs = [image_state, click_state,], + outputs = [template_frame,click_state], + ) + + # second step: select images from slider image_selection_slider.release(fn=select_image_template, inputs=[image_selection_slider, image_state, interactive_state], @@ -925,7 +972,7 @@ def display(tabs, tab_state, model_choice, vace_video_input, vace_video_mask, va matting_button.click( fn=image_matting, inputs=[image_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size, image_selection_slider], - outputs=[foreground_image_output, export_image_btn] + outputs=[foreground_image_output, alpha_image_output,bbox_info, export_image_btn, export_image_mask_btn] ) diff --git a/wan/any2video.py b/wan/any2video.py index 91f7258..ba81250 100644 --- a/wan/any2video.py +++ b/wan/any2video.py @@ -61,6 +61,7 @@ class WanAny2V: checkpoint_dir, model_filename = None, model_type = None, + model_def = None, base_model_type = None, text_encoder_filename = None, quantizeTransformer = False, @@ -75,7 +76,8 @@ class WanAny2V: self.dtype = dtype self.num_train_timesteps = config.num_train_timesteps self.param_dtype = config.param_dtype - + self.model_def = model_def + self.image_outputs = model_def.get("image_outputs", False) self.text_encoder = T5EncoderModel( text_len=config.text_len, dtype=config.t5_dtype, @@ -106,7 +108,7 @@ class WanAny2V: # config = json.load(f) # from mmgp import safetensors2 # sd = safetensors2.torch_load_file(xmodel_filename) - # model_filename = "c:/temp/fflf/diffusion_pytorch_model-00001-of-00007.safetensors" + # model_filename = "c:/temp/flf/diffusion_pytorch_model-00001-of-00007.safetensors" base_config_file = f"configs/{base_model_type}.json" forcedConfigPath = base_config_file if len(model_filename) > 1 else None # forcedConfigPath = base_config_file = f"configs/flf2v_720p.json" @@ -208,7 +210,7 @@ class WanAny2V: if refs is not None: length = len(refs) - mask_pad = torch.zeros_like(mask[:, :length, :, :]) + mask_pad = torch.zeros(mask.shape[0], length, *mask.shape[-2:], dtype=mask.dtype, device=mask.device) mask = torch.cat((mask_pad, mask), dim=1) result_masks.append(mask) return result_masks @@ -327,20 +329,6 @@ class WanAny2V: self.background_mask = [ item if item != None else self.background_mask[0] for item in self.background_mask ] # deplicate background mask with double control net since first controlnet image ref modifed by ref return src_video, src_mask, src_ref_images - def decode_latent(self, zs, ref_images=None, tile_size= 0 ): - if ref_images is None: - ref_images = [None] * len(zs) - # else: - # assert len(zs) == len(ref_images) - - trimed_zs = [] - for z, refs in zip(zs, ref_images): - if refs is not None: - z = z[:, len(refs):, :, :] - trimed_zs.append(z) - - return self.vae.decode(trimed_zs, tile_size= tile_size) - def get_vae_latents(self, ref_images, device, tile_size= 0): ref_vae_latents = [] for ref_image in ref_images: @@ -366,6 +354,7 @@ class WanAny2V: height = 720, fit_into_canvas = True, frame_num=81, + batch_size = 1, shift=5.0, sample_solver='unipc', sampling_steps=50, @@ -397,6 +386,7 @@ class WanAny2V: NAG_alpha = 0.5, offloadobj = None, apg_switch = False, + speakers_bboxes = None, **bbargs ): @@ -554,8 +544,8 @@ class WanAny2V: overlapped_latents_frames_num = int(1 + (preframes_count-1) // 4) if overlapped_latents != None: # disabled because looks worse - if False and overlapped_latents_frames_num > 1: lat_y[:, 1:overlapped_latents_frames_num] = overlapped_latents[:, 1:] - extended_overlapped_latents = lat_y[:, :overlapped_latents_frames_num].clone() + if False and overlapped_latents_frames_num > 1: lat_y[:, :, 1:overlapped_latents_frames_num] = overlapped_latents[:, 1:] + extended_overlapped_latents = lat_y[:, :overlapped_latents_frames_num].clone().unsqueeze(0) y = torch.concat([msk, lat_y]) lat_y = None kwargs.update({'clip_fea': clip_context, 'y': y}) @@ -586,7 +576,7 @@ class WanAny2V: overlapped_frames_num = (overlapped_latents_frames_num-1) * 4 + 1 else: overlapped_latents_frames_num = overlapped_frames_num = 0 - if len(keep_frames_parsed) == 0 or (overlapped_frames_num + len(keep_frames_parsed)) == input_frames.shape[1] and all(keep_frames_parsed) : keep_frames_parsed = [] + if len(keep_frames_parsed) == 0 or self.image_outputs or (overlapped_frames_num + len(keep_frames_parsed)) == input_frames.shape[1] and all(keep_frames_parsed) : keep_frames_parsed = [] injection_denoising_step = int(sampling_steps * (1. - denoising_strength) ) latent_keep_frames = [] if source_latents.shape[1] < lat_frames or len(keep_frames_parsed) > 0: @@ -609,6 +599,7 @@ class WanAny2V: input_ref_images = self.get_vae_latents(input_ref_images, self.device) input_ref_images_neg = torch.zeros_like(input_ref_images) ref_images_count = input_ref_images.shape[1] if input_ref_images != None else 0 + trim_frames = input_ref_images.shape[1] # Vace if vace : @@ -633,8 +624,8 @@ class WanAny2V: context_scale = context_scale if context_scale != None else [1.0] * len(z) kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale, "ref_images_count": ref_images_count }) if overlapped_latents != None : - overlapped_latents_size = overlapped_latents.shape[1] - extended_overlapped_latents = z[0][0:16, 0:overlapped_latents_size + ref_images_count].clone() + overlapped_latents_size = overlapped_latents.shape[2] + extended_overlapped_latents = z[0][:16, :overlapped_latents_size + ref_images_count].clone().unsqueeze(0) target_shape = list(z0[0].shape) target_shape[0] = int(target_shape[0] / 2) @@ -649,7 +640,7 @@ class WanAny2V: from wan.multitalk.multitalk import get_target_masks audio_proj = [audio.to(self.dtype) for audio in audio_proj] human_no = len(audio_proj[0]) - token_ref_target_masks = get_target_masks(human_no, lat_h, lat_w, height, width, face_scale = 0.05, bbox = None).to(self.dtype) if human_no > 1 else None + token_ref_target_masks = get_target_masks(human_no, lat_h, lat_w, height, width, face_scale = 0.05, bbox = speakers_bboxes).to(self.dtype) if human_no > 1 else None if fantasy and audio_proj != None: kwargs.update({ "audio_proj": audio_proj.to(self.dtype), "audio_context_lens": audio_context_lens, }) @@ -658,8 +649,8 @@ class WanAny2V: if self._interrupt: return None + expand_shape = [batch_size] + [-1] * len(target_shape) # Ropes - batch_size = 1 if target_camera != None: shape = list(target_shape[1:]) shape[0] *= 2 @@ -698,14 +689,14 @@ class WanAny2V: if sample_scheduler != None: scheduler_kwargs = {} if isinstance(sample_scheduler, FlowMatchScheduler) else {"generator": seed_g} - - latents = torch.randn( *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) + # b, c, lat_f, lat_h, lat_w + latents = torch.randn(batch_size, *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) if apg_switch != 0: apg_momentum = -0.75 apg_norm_threshold = 55 text_momentumbuffer = MomentumBuffer(apg_momentum) audio_momentumbuffer = MomentumBuffer(apg_momentum) - + # self.image_outputs = False # denoising for i, t in enumerate(tqdm(timesteps)): offload.set_step_no_for_lora(self.model, i) @@ -715,36 +706,36 @@ class WanAny2V: if denoising_strength < 1 and input_frames != None and i <= injection_denoising_step: sigma = t / 1000 - noise = torch.randn( *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) + noise = torch.randn(batch_size, *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) if inject_from_start: new_latents = latents.clone() - new_latents[:, :source_latents.shape[1] ] = noise[:, :source_latents.shape[1] ] * sigma + (1 - sigma) * source_latents + new_latents[:,:, :source_latents.shape[1] ] = noise[:, :, :source_latents.shape[1] ] * sigma + (1 - sigma) * source_latents.unsqueeze(0) for latent_no, keep_latent in enumerate(latent_keep_frames): if not keep_latent: - new_latents[:, latent_no:latent_no+1 ] = latents[:, latent_no:latent_no+1] + new_latents[:, :, latent_no:latent_no+1 ] = latents[:, :, latent_no:latent_no+1] latents = new_latents new_latents = None else: - latents = noise * sigma + (1 - sigma) * source_latents + latents = noise * sigma + (1 - sigma) * source_latents.unsqueeze(0) noise = None if extended_overlapped_latents != None: latent_noise_factor = t / 1000 - latents[:, 0:extended_overlapped_latents.shape[1]] = extended_overlapped_latents * (1.0 - latent_noise_factor) + torch.randn_like(extended_overlapped_latents ) * latent_noise_factor + latents[:, :, :extended_overlapped_latents.shape[2]] = extended_overlapped_latents * (1.0 - latent_noise_factor) + torch.randn_like(extended_overlapped_latents ) * latent_noise_factor if vace: overlap_noise_factor = overlap_noise / 1000 for zz in z: - zz[0:16, ref_images_count:extended_overlapped_latents.shape[1] ] = extended_overlapped_latents[:, ref_images_count:] * (1.0 - overlap_noise_factor) + torch.randn_like(extended_overlapped_latents[:, ref_images_count:] ) * overlap_noise_factor + zz[0:16, ref_images_count:extended_overlapped_latents.shape[2] ] = extended_overlapped_latents[0, :, ref_images_count:] * (1.0 - overlap_noise_factor) + torch.randn_like(extended_overlapped_latents[0, :, ref_images_count:] ) * overlap_noise_factor if target_camera != None: - latent_model_input = torch.cat([latents, source_latents], dim=1) + latent_model_input = torch.cat([latents, source_latents.unsqueeze(0).expand(*expand_shape)], dim=2) # !!!! else: latent_model_input = latents if phantom: gen_args = { - "x" : ([ torch.cat([latent_model_input[:,:-ref_images_count], input_ref_images], dim=1) ] * 2 + - [ torch.cat([latent_model_input[:,:-ref_images_count], input_ref_images_neg], dim=1)]), + "x" : ([ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images.unsqueeze(0).expand(*expand_shape)], dim=2) ] * 2 + + [ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images_neg.unsqueeze(0).expand(*expand_shape)], dim=2)]), "context": [context, context_null, context_null] , } elif fantasy: @@ -832,38 +823,41 @@ class WanAny2V: if sample_solver == "euler": dt = timesteps[i] if i == len(timesteps)-1 else (timesteps[i] - timesteps[i + 1]) dt = dt / self.num_timesteps - latents = latents - noise_pred * dt[:, None, None, None] + latents = latents - noise_pred * dt[:, None, None, None, None] else: - temp_x0 = sample_scheduler.step( - noise_pred[:, :target_shape[1]].unsqueeze(0), + latents = sample_scheduler.step( + noise_pred[:, :, :target_shape[1]], t, - latents.unsqueeze(0), + latents, **scheduler_kwargs)[0] - latents = temp_x0.squeeze(0) - del temp_x0 if callback is not None: - callback(i, latents, False) + latents_preview = latents + if vace and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ] + if trim_frames > 0: latents_preview= latents_preview[:, :,:-trim_frames] + if len(latents_preview) > 1: latents_preview = latents_preview.transpose(0,2) + callback(i, latents_preview[0], False) + latents_preview = None - x0 = [latents] + if vace and ref_images_count > 0: latents = latents[:, :, ref_images_count:] + if trim_frames > 0: latents= latents[:, :,:-trim_frames] + if return_latent_slice != None: + latent_slice = latents[:, :, return_latent_slice].clone() + + x0 =latents.unbind(dim=0) if chipmunk: self.model.release_chipmunk() # need to add it at every exit when in prod - if return_latent_slice != None: - latent_slice = latents[:, return_latent_slice].clone() - if vace: - # vace post processing - videos = self.decode_latent(x0, input_ref_images, VAE_tile_size) - else: - if phantom and input_ref_images != None: - trim_frames = input_ref_images.shape[1] - if trim_frames > 0: x0 = [x0_[:,:-trim_frames] for x0_ in x0] - videos = self.vae.decode(x0, VAE_tile_size) + videos = self.vae.decode(x0, VAE_tile_size) + if self.image_outputs: + videos = torch.cat(videos, dim=1) if len(videos) > 1 else videos[0] + else: + videos = videos[0] # return only first video if return_latent_slice != None: - return { "x" : videos[0], "latent_slice" : latent_slice } - return videos[0] + return { "x" : videos, "latent_slice" : latent_slice } + return videos def adapt_vace_model(self): model = self.model diff --git a/wan/diffusion_forcing.py b/wan/diffusion_forcing.py index ee168ec..9b5918a 100644 --- a/wan/diffusion_forcing.py +++ b/wan/diffusion_forcing.py @@ -31,6 +31,7 @@ class DTT2V: rank=0, model_filename = None, model_type = None, + model_def = None, base_model_type = None, save_quantized = False, text_encoder_filename = None, @@ -53,6 +54,8 @@ class DTT2V: checkpoint_path=text_encoder_filename, tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), shard_fn= None) + self.model_def = model_def + self.image_outputs = model_def.get("image_outputs", False) self.vae_stride = config.vae_stride self.patch_size = config.patch_size @@ -202,6 +205,7 @@ class DTT2V: width: int = 832, fit_into_canvas = True, frame_num: int = 97, + batch_size = 1, sampling_steps: int = 50, shift: float = 1.0, guide_scale: float = 5.0, @@ -224,8 +228,9 @@ class DTT2V: generator = torch.Generator(device=self.device) generator.manual_seed(seed) self._guidance_scale = guide_scale - frame_num = max(17, frame_num) # must match causal_block_size for value of 5 - frame_num = int( round( (frame_num - 17) / 20)* 20 + 17 ) + if frame_num > 1: + frame_num = max(17, frame_num) # must match causal_block_size for value of 5 + frame_num = int( round( (frame_num - 17) / 20)* 20 + 17 ) if ar_step == 0: causal_block_size = 1 @@ -244,7 +249,7 @@ class DTT2V: image_start = np.array(image_start.resize((width, height))).transpose(2, 0, 1) - latent_length = (frame_num - 1) // 4 + 1 + latent_length = (frame_num - 1) // 4 + 1 latent_height = height // 8 latent_width = width // 8 @@ -297,12 +302,12 @@ class DTT2V: prefix_video = prefix_video[:, : predix_video_latent_length] base_num_frames_iter = latent_length - latent_shape = [16, base_num_frames_iter, latent_height, latent_width] + latent_shape = [batch_size, 16, base_num_frames_iter, latent_height, latent_width] latents = self.prepare_latents( latent_shape, dtype=torch.float32, device=self.device, generator=generator ) if prefix_video is not None: - latents[:, :predix_video_latent_length] = prefix_video.to(torch.float32) + latents[:, :, :predix_video_latent_length] = prefix_video.to(torch.float32) step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( base_num_frames_iter, init_timesteps, @@ -340,7 +345,7 @@ class DTT2V: else: self.model.enable_cache = None from mmgp import offload - freqs = get_rotary_pos_embed(latents.shape[1 :], enable_RIFLEx= False) + freqs = get_rotary_pos_embed(latents.shape[2 :], enable_RIFLEx= False) kwrags = { "freqs" :freqs, "fps" : fps_embeds, @@ -358,15 +363,15 @@ class DTT2V: update_mask_i = step_update_mask[i] valid_interval_start, valid_interval_end = valid_interval[i] timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() - latent_model_input = latents[:, valid_interval_start:valid_interval_end, :, :].clone() + latent_model_input = latents[:, :, valid_interval_start:valid_interval_end, :, :].clone() if overlap_noise > 0 and valid_interval_start < predix_video_latent_length: noise_factor = 0.001 * overlap_noise timestep_for_noised_condition = overlap_noise - latent_model_input[:, valid_interval_start:predix_video_latent_length] = ( - latent_model_input[:, valid_interval_start:predix_video_latent_length] + latent_model_input[:, :, valid_interval_start:predix_video_latent_length] = ( + latent_model_input[:, :, valid_interval_start:predix_video_latent_length] * (1.0 - noise_factor) + torch.randn_like( - latent_model_input[:, valid_interval_start:predix_video_latent_length] + latent_model_input[:, :, valid_interval_start:predix_video_latent_length] ) * noise_factor ) @@ -417,18 +422,27 @@ class DTT2V: del noise_pred_cond, noise_pred_uncond for idx in range(valid_interval_start, valid_interval_end): if update_mask_i[idx].item(): - latents[:, idx] = sample_schedulers[idx].step( - noise_pred[:, idx - valid_interval_start], + latents[:, :, idx] = sample_schedulers[idx].step( + noise_pred[:, :, idx - valid_interval_start], timestep_i[idx], - latents[:, idx], + latents[:, :, idx], return_dict=False, generator=generator, )[0] sample_schedulers_counter[idx] += 1 if callback is not None: - callback(i, latents.squeeze(0), False) + latents_preview = latents + if len(latents_preview) > 1: latents_preview = latents_preview.transpose(0,2) + callback(i, latents_preview[0], False) + latents_preview = None - x0 = latents.unsqueeze(0) - videos = [self.vae.decode(x0, tile_size= VAE_tile_size)[0]] - output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w - return output_video + x0 =latents.unbind(dim=0) + + videos = self.vae.decode(x0, VAE_tile_size) + + if self.image_outputs: + videos = torch.cat(videos, dim=1) if len(videos) > 1 else videos[0] + else: + videos = videos[0] # return only first video + + return videos diff --git a/wan/modules/attention.py b/wan/modules/attention.py index 41a934b..a95332d 100644 --- a/wan/modules/attention.py +++ b/wan/modules/attention.py @@ -185,7 +185,7 @@ def pay_attention( q,k,v = qkv_list qkv_list.clear() out_dtype = q.dtype - if q.dtype == torch.bfloat16 and not bfloat16_supported: + if q.dtype == torch.bfloat16 and not bfloat16_supported: q = q.to(torch.float16) k = k.to(torch.float16) v = v.to(torch.float16) @@ -194,7 +194,9 @@ def pay_attention( q = q.to(v.dtype) k = k.to(v.dtype) - + batch = len(q) + if len(k) != batch: k = k.expand(batch, -1, -1, -1) + if len(v) != batch: v = v.expand(batch, -1, -1, -1) if attn == "chipmunk": from src.chipmunk.modules import SparseDiffMlp, SparseDiffAttn from src.chipmunk.util import LayerCounter, GLOBAL_CONFIG diff --git a/wan/modules/model.py b/wan/modules/model.py index b7a2670..7d6357d 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -33,9 +33,10 @@ def sinusoidal_embedding_1d(dim, position): def reshape_latent(latent, latent_frames): - if latent_frames == latent.shape[0]: - return latent - return latent.reshape(latent_frames, -1, latent.shape[-1] ) + return latent.reshape(latent.shape[0], latent_frames, -1, latent.shape[-1] ) + +def restore_latent_shape(latent): + return latent.reshape(latent.shape[0], -1, latent.shape[-1] ) def identify_k( b: float, d: int, N: int): @@ -493,7 +494,7 @@ class WanAttentionBlock(nn.Module): x_mod = reshape_latent(x_mod , latent_frames) x_mod *= 1 + e[1] x_mod += e[0] - x_mod = reshape_latent(x_mod , 1) + x_mod = restore_latent_shape(x_mod) if cam_emb != None: cam_emb = self.cam_encoder(cam_emb) cam_emb = cam_emb.repeat(1, 2, 1) @@ -510,7 +511,7 @@ class WanAttentionBlock(nn.Module): x, y = reshape_latent(x , latent_frames), reshape_latent(y , latent_frames) x.addcmul_(y, e[2]) - x, y = reshape_latent(x , 1), reshape_latent(y , 1) + x, y = restore_latent_shape(x), restore_latent_shape(y) del y y = self.norm3(x) y = y.to(attention_dtype) @@ -542,7 +543,7 @@ class WanAttentionBlock(nn.Module): y = reshape_latent(y , latent_frames) y *= 1 + e[4] y += e[3] - y = reshape_latent(y , 1) + y = restore_latent_shape(y) y = y.to(attention_dtype) ffn = self.ffn[0] @@ -562,7 +563,7 @@ class WanAttentionBlock(nn.Module): y = y.to(dtype) x, y = reshape_latent(x , latent_frames), reshape_latent(y , latent_frames) x.addcmul_(y, e[5]) - x, y = reshape_latent(x , 1), reshape_latent(y , 1) + x, y = restore_latent_shape(x), restore_latent_shape(y) if hints_processed is not None: for hint, scale in zip(hints_processed, context_scale): @@ -669,6 +670,8 @@ class VaceWanAttentionBlock(WanAttentionBlock): hints[0] = None if self.block_id == 0: c = self.before_proj(c) + bz = x.shape[0] + if bz > c.shape[0]: c = c.repeat(bz, 1, 1 ) c += x c = super().forward(c, **kwargs) c_skip = self.after_proj(c) @@ -707,7 +710,7 @@ class Head(nn.Module): x = reshape_latent(x , latent_frames) x *= (1 + e[1]) x += e[0] - x = reshape_latent(x , 1) + x = restore_latent_shape(x) x= x.to(self.head.weight.dtype) x = self.head(x) return x @@ -1162,11 +1165,15 @@ class WanModel(ModelMixin, ConfigMixin): x_list[i] = x_list[0].clone() last_x_idx = i else: - # image source + # image source + bz = len(x) if y is not None: - x = torch.cat([x, y], dim=0) + y = y.unsqueeze(0) + if bz > 1: y = y.expand(bz, -1, -1, -1, -1) + x = torch.cat([x, y], dim=1) # embeddings - x = self.patch_embedding(x.unsqueeze(0)).to(modulation_dtype) + # x = self.patch_embedding(x.unsqueeze(0)).to(modulation_dtype) + x = self.patch_embedding(x).to(modulation_dtype) grid_sizes = x.shape[2:] if chipmunk: x = x.unsqueeze(-1) @@ -1204,7 +1211,7 @@ class WanModel(ModelMixin, ConfigMixin): ) # b, dim e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype) - if self.inject_sample_info: + if self.inject_sample_info and fps!=None: fps = torch.tensor(fps, dtype=torch.long, device=device) fps_emb = self.fps_embedding(fps).to(e.dtype) @@ -1402,7 +1409,7 @@ class WanModel(ModelMixin, ConfigMixin): x_list[i] = self.unpatchify(x, grid_sizes) del x - return [x[0].float() for x in x_list] + return [x.float() for x in x_list] def unpatchify(self, x, grid_sizes): r""" @@ -1427,7 +1434,10 @@ class WanModel(ModelMixin, ConfigMixin): u = torch.einsum('fhwpqrc->cfphqwr', u) u = u.reshape(c, *[i * j for i, j in zip(grid_sizes, self.patch_size)]) out.append(u) - return out + if len(x) == 1: + return out[0].unsqueeze(0) + else: + return torch.stack(out, 0) def init_weights(self): r""" diff --git a/wan/multitalk/attention.py b/wan/multitalk/attention.py index 44c1ca0..12fb317 100644 --- a/wan/multitalk/attention.py +++ b/wan/multitalk/attention.py @@ -333,7 +333,7 @@ class SingleStreamMutiAttention(SingleStreamAttention): human1 = normalize_and_scale(x_ref_attn_map[0], (human1_min_value, human1_max_value), (self.rope_h1[0], self.rope_h1[1])) human2 = normalize_and_scale(x_ref_attn_map[1], (human2_min_value, human2_max_value), (self.rope_h2[0], self.rope_h2[1])) - back = torch.full((x_ref_attn_map.size(1),), self.rope_bak, dtype=human1.dtype).to(human1.device) + back = torch.full((x_ref_attn_map.size(1),), self.rope_bak, dtype=human1.dtype, device=human1.device) max_indices = x_ref_attn_map.argmax(dim=0) normalized_map = torch.stack([human1, human2, back], dim=1) normalized_pos = normalized_map[range(x_ref_attn_map.size(1)), max_indices] # N @@ -351,7 +351,7 @@ class SingleStreamMutiAttention(SingleStreamAttention): if self.qk_norm: encoder_k = self.add_k_norm(encoder_k) - per_frame = torch.zeros(N_a, dtype=encoder_k.dtype).to(encoder_k.device) + per_frame = torch.zeros(N_a, dtype=encoder_k.dtype, device=encoder_k.device) per_frame[:per_frame.size(0)//2] = (self.rope_h1[0] + self.rope_h1[1]) / 2 per_frame[per_frame.size(0)//2:] = (self.rope_h2[0] + self.rope_h2[1]) / 2 encoder_pos = torch.concat([per_frame]*N_t, dim=0) diff --git a/wan/multitalk/multitalk.py b/wan/multitalk/multitalk.py index 3945682..038efdf 100644 --- a/wan/multitalk/multitalk.py +++ b/wan/multitalk/multitalk.py @@ -272,6 +272,34 @@ def timestep_transform( new_t = new_t * num_timesteps return new_t +def parse_speakers_locations(speakers_locations): + bbox = {} + if speakers_locations is None or len(speakers_locations) == 0: + return None, "" + speakers = speakers_locations.split(" ") + if len(speakers) !=2: + error= "Two speakers locations should be defined" + return "", error + + for i, speaker in enumerate(speakers): + location = speaker.strip().split(":") + if len(location) not in (2,4): + error = f"Invalid Speaker Location '{location}'. A Speaker Location should be defined in the format Left:Right or usuing a BBox Left:Top:Right:Bottom" + return "", error + try: + good = False + location_float = [ float(val) for val in location] + good = all( 0 <= val <= 100 for val in location_float) + except: + pass + if not good: + error = f"Invalid Speaker Location '{location}'. Each number should be between 0 and 100." + return "", error + if len(location_float) == 2: + location_float = [location_float[0], 0, location_float[1], 100] + bbox[f"human{i}"] = location_float + return bbox, "" + # construct human mask def get_target_masks(HUMAN_NUMBER, lat_h, lat_w, src_h, src_w, face_scale = 0.05, bbox = None): @@ -286,7 +314,9 @@ def get_target_masks(HUMAN_NUMBER, lat_h, lat_w, src_h, src_w, face_scale = 0.05 assert len(bbox) == HUMAN_NUMBER, f"The number of target bbox should be the same with cond_audio" background_mask = torch.zeros([src_h, src_w]) for _, person_bbox in bbox.items(): - x_min, y_min, x_max, y_max = person_bbox + y_min, x_min, y_max, x_max = person_bbox + x_min, y_min, x_max, y_max = max(x_min,5), max(y_min, 5), min(x_max,95), min(y_max,95) + x_min, y_min, x_max, y_max = int(src_h * x_min / 100), int(src_w * y_min / 100), int(src_h * x_max / 100), int(src_w * y_max / 100) human_mask = torch.zeros([src_h, src_w]) human_mask[int(x_min):int(x_max), int(y_min):int(y_max)] = 1 background_mask += human_mask @@ -306,7 +336,7 @@ def get_target_masks(HUMAN_NUMBER, lat_h, lat_w, src_h, src_w, face_scale = 0.05 human_masks = [human_mask1, human_mask2] background_mask = torch.where(background_mask > 0, torch.tensor(0), torch.tensor(1)) human_masks.append(background_mask) - + # toto = Image.fromarray(human_masks[2].mul_(255).unsqueeze(-1).repeat(1,1,3).to(torch.uint8).cpu().numpy()) ref_target_masks = torch.stack(human_masks, dim=0) #.to(self.device) # resize and centercrop for ref_target_masks # ref_target_masks = resize_and_centercrop(ref_target_masks, (target_h, target_w)) diff --git a/wan/multitalk/multitalk_utils.py b/wan/multitalk/multitalk_utils.py index 4054361..4722dae 100644 --- a/wan/multitalk/multitalk_utils.py +++ b/wan/multitalk/multitalk_utils.py @@ -128,7 +128,7 @@ def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, spli _, seq_lens, heads, _ = visual_q.shape class_num, _ = ref_target_masks.shape - x_ref_attn_maps = torch.zeros(class_num, seq_lens).to(visual_q.device).to(visual_q.dtype) + x_ref_attn_maps = torch.zeros(class_num, seq_lens, dtype=visual_q.device, device=visual_q.dtype) split_chunk = heads // split_num diff --git a/wan/utils/utils.py b/wan/utils/utils.py index 6ffcc29..fe0b6bf 100644 --- a/wan/utils/utils.py +++ b/wan/utils/utils.py @@ -5,7 +5,8 @@ import os import os.path as osp import torchvision.transforms.functional as TF import torch.nn.functional as F - +import cv2 +import tempfile import imageio import torch import decord @@ -101,6 +102,29 @@ def get_video_frame(file_name, frame_no): img = Image.fromarray(frame.numpy().astype(np.uint8)) return img +def convert_image_to_video(image): + if image is None: + return None + + # Convert PIL/numpy image to OpenCV format if needed + if isinstance(image, np.ndarray): + # Gradio images are typically RGB, OpenCV expects BGR + img_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + else: + # Handle PIL Image + img_array = np.array(image) + img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR) + + height, width = img_bgr.shape[:2] + + # Create temporary video file (auto-cleaned by Gradio) + with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as temp_video: + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + out = cv2.VideoWriter(temp_video.name, fourcc, 30.0, (width, height)) + out.write(img_bgr) + out.release() + return temp_video.name + def resize_lanczos(img, h, w): img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) img = img.resize((w,h), resample=Image.Resampling.LANCZOS) diff --git a/wgp.py b/wgp.py index c9fc10a..8e5335b 100644 --- a/wgp.py +++ b/wgp.py @@ -16,7 +16,7 @@ import json import wan from wan.utils import notification_sound from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES, VACE_SIZE_CONFIGS -from wan.utils.utils import cache_video, convert_tensor_to_image, save_image, get_video_info, get_file_creation_date +from wan.utils.utils import cache_video, convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video from wan.utils.utils import extract_audio_tracks, combine_video_with_audio_tracks, cleanup_temp_audio_files from wan.modules.attention import get_attention_modes, get_supported_attention_modes @@ -50,7 +50,7 @@ AUTOSAVE_FILENAME = "queue.zip" PROMPT_VARS_MAX = 10 target_mmgp_version = "3.5.1" -WanGP_version = "6.7" +WanGP_version = "7.0" settings_version = 2.22 max_source_video_frames = 1000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -171,6 +171,8 @@ def process_prompt_and_add_tasks(state, model_choice): gr.Warning("Internal state error: Could not retrieve inputs for the model.") queue = gen.get("queue", []) return get_queue_table(queue) + model_def = get_model_def(model_type) + image_outputs = model_def.get("image_outputs", False) model_type = get_base_model_type(model_type) inputs["model_filename"] = model_filename @@ -182,7 +184,7 @@ def process_prompt_and_add_tasks(state, model_choice): if frames_count > max_source_video_frames: gr.Info(f"Post processing is not supported on videos longer than {max_source_video_frames} frames. Output Video will be truncated") # return - for k in ["image_start", "image_end", "image_refs", "video_guide", "audio_guide", "audio_guide2", "video_mask"]: + for k in ["image_start", "image_end", "image_refs", "video_guide", "audio_guide", "audio_guide2", "video_mask", "image_mask"]: inputs[k] = None inputs.update(edit_overrides) del gen["edit_video_source"], gen["edit_overrides"] @@ -193,6 +195,13 @@ def process_prompt_and_add_tasks(state, model_choice): if len(spatial_upsampling) >0: prompt += ["Spatial Upsampling"] temporal_upsampling = inputs.get("temporal_upsampling","") if len(temporal_upsampling) >0: prompt += ["Temporal Upsampling"] + if image_outputs and len(temporal_upsampling) > 0: + gr.Info("Temporal Upsampling can not be used with an Image") + return + film_grain_intensity = inputs.get("film_grain_intensity",0) + film_grain_saturation = inputs.get("film_grain_saturation",0.5) + # if film_grain_intensity >0: prompt += [f"Film Grain: intensity={film_grain_intensity}, saturation={film_grain_saturation}"] + if film_grain_intensity >0: prompt += ["Film Grain"] MMAudio_setting = inputs.get("MMAudio_setting",0) seed = inputs.get("seed",None) repeat_generation= inputs.get("repeat_generation",1) @@ -201,7 +210,7 @@ def process_prompt_and_add_tasks(state, model_choice): return if MMAudio_setting !=0: prompt += ["MMAudio"] if len(prompt) == 0: - gr.Info("You must choose at leat one Post Processing Method") + gr.Info("You must choose at least one Post Processing Method") return inputs["prompt"] = ", ".join(prompt) add_video_task(**inputs) @@ -247,7 +256,10 @@ def process_prompt_and_add_tasks(state, model_choice): audio_guide = inputs["audio_guide"] audio_guide2 = inputs["audio_guide2"] video_guide = inputs["video_guide"] + image_guide = inputs["image_guide"] video_mask = inputs["video_mask"] + image_mask = inputs["image_mask"] + speakers_locations = inputs["speakers_locations"] video_source = inputs["video_source"] frames_positions = inputs["frames_positions"] keep_frames_video_guide= inputs["keep_frames_video_guide"] @@ -269,6 +281,13 @@ def process_prompt_and_add_tasks(state, model_choice): gr.Info("Mag Cache maximum number of steps is 50") return + if "B" in audio_prompt_type or "X" in audio_prompt_type: + from wan.multitalk.multitalk import parse_speakers_locations + speakers_bboxes, error = parse_speakers_locations(speakers_locations) + if len(error) > 0: + gr.Info(error) + return + if MMAudio_setting != 0 and server_config.get("mmaudio_enabled", 0) != 0 and video_length <16: #should depend on the architecture gr.Info("MMAudio can generate an Audio track only if the Video is at least 1s long") if "F" in video_prompt_type: @@ -314,12 +333,16 @@ def process_prompt_and_add_tasks(state, model_choice): audio_guide2 = None if model_type in ["vace_multitalk_14B"] and ("B" in audio_prompt_type or "X" in audio_prompt_type): - if not "I" in video_prompt_type: - gr.Info("To get good results with Multitalk and two people speaking, it is recommended to set a Reference Frame that contains the two people one on each side ") + if not "I" in video_prompt_type and not not "V" in video_prompt_type: + gr.Info("To get good results with Multitalk and two people speaking, it is recommended to set a Reference Frame or a Control Video (potentially truncated) that contains the two people one on each side") - if "R" in audio_prompt_type and len(filter_letters(image_prompt_type, "VLG")) > 0 : - gr.Info("Remuxing is not yet supported if there is a video source") - audio_prompt_type= replace("R" ,"") + if len(filter_letters(image_prompt_type, "VL")) > 0 : + if "R" in audio_prompt_type: + gr.Info("Remuxing is not yet supported if there is a video source") + audio_prompt_type= audio_prompt_type.replace("R" ,"") + if "A" in audio_prompt_type: + gr.Info("Creating an Audio track is not yet supported if there is a video source") + return if model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_audio", "hunyuan_avatar"]: if image_refs == None : @@ -342,17 +365,26 @@ def process_prompt_and_add_tasks(state, model_choice): image_refs = None if "V" in video_prompt_type: - if video_guide == None: - gr.Info("You must provide a Control Video") + if video_guide is None and image_guide is None: + if image_outputs: + gr.Info("You must provide a Control Image") + else: + gr.Info("You must provide a Control Video") return if "A" in video_prompt_type and not "U" in video_prompt_type: - if video_mask == None: - gr.Info("You must provide a Video Mask") + if video_mask is None and image_mask is None: + if image_outputs: + gr.Info("You must provide a Image Mask") + else: + gr.Info("You must provide a Video Mask") return else: video_mask = None + image_mask = None - if not "G" in video_prompt_type: + if "G" in video_prompt_type: + gr.Info(f"With Denoising Strength {denoising_strength:.1f}, denoising will start a Step no {int(num_inference_steps * (1. - denoising_strength))} ") + else: denoising_strength = 1.0 _, error = parse_keep_frames_video_guide(keep_frames_video_guide, video_length) @@ -361,7 +393,9 @@ def process_prompt_and_add_tasks(state, model_choice): return else: video_guide = None + image_guide = None video_mask = None + image_mask = None keep_frames_video_guide = "" denoising_strength = 1.0 @@ -416,10 +450,6 @@ def process_prompt_and_add_tasks(state, model_choice): if "hunyuan_custom_custom_edit" in model_filename: - if video_guide == None: - gr.Info("You must provide a Control Video") - return - if len(keep_frames_video_guide) > 0: gr.Info("Filtering Frames with this model is not supported") return @@ -440,7 +470,9 @@ def process_prompt_and_add_tasks(state, model_choice): "audio_guide": audio_guide, "audio_guide2": audio_guide2, "video_guide": video_guide, + "image_guide": image_guide, "video_mask": video_mask, + "image_mask": image_mask, "video_source": video_source, "frames_positions": frames_positions, "keep_frames_video_source": keep_frames_video_source, @@ -517,15 +549,15 @@ def process_prompt_and_add_tasks(state, model_choice): return update_queue_data(queue) def get_preview_images(inputs): - inputs_to_query = ["image_start", "image_end", "video_source", "video_guide", "video_mask", "image_refs" ] - labels = ["Start Image", "End Image", "Video Source", "Video Guide", "Video Mask", "Image Reference"] + inputs_to_query = ["image_start", "image_end", "video_source", "video_guide", "image_guide", "video_mask", "image_mask", "image_refs" ] + labels = ["Start Image", "End Image", "Video Source", "Video Guide", "Image Guide", "Video Mask", "Image Mask", "Image Reference"] start_image_data = None start_image_labels = [] end_image_data = None end_image_labels = [] for label, name in zip(labels,inputs_to_query): image= inputs.get(name, None) - if image != None: + if image is not None: image= [image] if not isinstance(image, list) else image.copy() if start_image_data == None: start_image_data = image @@ -645,7 +677,7 @@ def save_queue_action(state): params_copy = task.get('params', {}).copy() task_id_s = task.get('id', f"task_{task_index}") - image_keys = ["image_start", "image_end", "image_refs"] + image_keys = ["image_start", "image_end", "image_refs", "image_guide", "image_mask"] video_keys = ["video_guide", "video_mask", "video_source", "audio_guide", "audio_guide2"] for key in image_keys: @@ -821,7 +853,7 @@ def load_queue_action(filepath, state, evt:gr.EventData): max_id_in_file = max(max_id_in_file, task_id_loaded) params['state'] = state - image_keys = ["image_start", "image_end", "image_refs"] + image_keys = ["image_start", "image_end", "image_refs", "image_guide", "image_mask"] video_keys = ["video_guide", "video_mask", "video_source", "audio_guide", "audio_guide2"] loaded_pil_images = {} @@ -1041,7 +1073,7 @@ def autosave_queue(): params_copy = task.get('params', {}).copy() task_id_s = task.get('id', f"task_{task_index}") - image_keys = ["image_start", "image_end", "image_refs"] + image_keys = ["image_start", "image_end", "image_refs", "image_guide", "image_mask"] video_keys = ["video_guide", "video_mask", "video_source", "audio_guide", "audio_guide2"] for key in image_keys: @@ -1929,110 +1961,112 @@ def get_default_settings(model_type): i2v = test_class_i2v(model_type) defaults_filename = get_settings_file_name(model_type) if not Path(defaults_filename).is_file(): + ui_defaults = { + "prompt": get_default_prompt(i2v), + "resolution": "1280x720" if "720" in model_type else "832x480", + "video_length": 81, + "num_inference_steps": 30, + "seed": -1, + "repeat_generation": 1, + "multi_images_gen_type": 0, + "guidance_scale": 5.0, + "embedded_guidance_scale" : 6.0, + "flow_shift": 7.0 if not "720" in model_type and i2v else 5.0, + "negative_prompt": "", + "activated_loras": [], + "loras_multipliers": "", + "skip_steps_multiplier": 1.5, + "skip_steps_start_step_perc": 20, + "RIFLEx_setting": 0, + "slg_switch": 0, + "slg_layers": [9], + "slg_start_perc": 10, + "slg_end_perc": 90 + } + if model_type in ["fantasy"]: + ui_defaults["audio_guidance_scale"] = 5.0 + elif model_type in ["multitalk"]: + ui_defaults.update({ + "guidance_scale": 5.0, + "flow_shift": 7, # 11 for 720p + "audio_guidance_scale": 4, + "sliding_window_discard_last_frames" : 4, + "sample_solver" : "euler", + "adaptive_switch" : 1, + }) + + elif model_type in ["hunyuan","hunyuan_i2v"]: + ui_defaults.update({ + "guidance_scale": 7.0, + }) + + elif model_type in ["flux_dev_kontext"]: + ui_defaults.update({ + "video_prompt_type": "I", + }) + elif model_type in ["sky_df_1.3B", "sky_df_14B"]: + ui_defaults.update({ + "guidance_scale": 6.0, + "flow_shift": 8, + "sliding_window_discard_last_frames" : 0, + "resolution": "1280x720" if "720" in model_type else "960x544", + "sliding_window_size" : 121 if "720" in model_type else 97, + "RIFLEx_setting": 2, + "guidance_scale": 6, + "flow_shift": 8, + }) + + + elif model_type in ["phantom_1.3B", "phantom_14B"]: + ui_defaults.update({ + "guidance_scale": 7.5, + "flow_shift": 5, + "remove_background_images_ref": 0, + "video_prompt_type": "I", + # "resolution": "1280x720" + }) + + elif model_type in ["hunyuan_custom"]: + ui_defaults.update({ + "guidance_scale": 7.5, + "flow_shift": 13, + "resolution": "1280x720", + "video_prompt_type": "I", + }) + elif model_type in ["hunyuan_custom_audio"]: + ui_defaults.update({ + "guidance_scale": 7.5, + "flow_shift": 13, + "video_prompt_type": "I", + }) + elif model_type in ["hunyuan_custom_edit"]: + ui_defaults.update({ + "guidance_scale": 7.5, + "flow_shift": 13, + "video_prompt_type": "MVAI", + "sliding_window_size": 129, + }) + elif model_type in ["hunyuan_avatar"]: + ui_defaults.update({ + "guidance_scale": 7.5, + "flow_shift": 5, + "skip_steps_start_step_perc": 25, + "video_length": 129, + "video_prompt_type": "I", + }) + elif model_type in ["vace_14B", "vace_multitalk_14B"]: + ui_defaults.update({ + "sliding_window_discard_last_frames": 0, + }) + + model_def = get_model_def(model_type) if model_def != None: - ui_defaults = model_def["settings"] - if len(ui_defaults.get("prompt","")) == 0: - ui_defaults["prompt"]= get_default_prompt(i2v) - else: - ui_defaults = { - "prompt": get_default_prompt(i2v), - "resolution": "1280x720" if "720" in model_type else "832x480", - "video_length": 81, - "num_inference_steps": 30, - "seed": -1, - "repeat_generation": 1, - "multi_images_gen_type": 0, - "guidance_scale": 5.0, - "embedded_guidance_scale" : 6.0, - "flow_shift": 7.0 if not "720" in model_type and i2v else 5.0, - "negative_prompt": "", - "activated_loras": [], - "loras_multipliers": "", - "skip_steps_multiplier": 1.5, - "skip_steps_start_step_perc": 20, - "RIFLEx_setting": 0, - "slg_switch": 0, - "slg_layers": [9], - "slg_start_perc": 10, - "slg_end_perc": 90 - } - if model_type in ["fantasy"]: - ui_defaults["audio_guidance_scale"] = 5.0 - elif model_type in ["multitalk"]: - ui_defaults.update({ - "guidance_scale": 5.0, - "flow_shift": 7, # 11 for 720p - "audio_guidance_scale": 4, - "sliding_window_discard_last_frames" : 4, - "sample_solver" : "euler", - "adaptive_switch" : 1, - }) + ui_defaults_update = model_def["settings"] + ui_defaults.update(ui_defaults_update) - elif model_type in ["hunyuan","hunyuan_i2v"]: - ui_defaults.update({ - "guidance_scale": 7.0, - }) - - elif model_type in ["flux_dev_kontext"]: - ui_defaults.update({ - "video_prompt_type": "I", - }) - elif model_type in ["sky_df_1.3B", "sky_df_14B"]: - ui_defaults.update({ - "guidance_scale": 6.0, - "flow_shift": 8, - "sliding_window_discard_last_frames" : 0, - "resolution": "1280x720" if "720" in model_type else "960x544", - "sliding_window_size" : 121 if "720" in model_type else 97, - "RIFLEx_setting": 2, - "guidance_scale": 6, - "flow_shift": 8, - }) - - - elif model_type in ["phantom_1.3B", "phantom_14B"]: - ui_defaults.update({ - "guidance_scale": 7.5, - "flow_shift": 5, - "remove_background_images_ref": 0, - "video_prompt_type": "I", - # "resolution": "1280x720" - }) - - elif model_type in ["hunyuan_custom"]: - ui_defaults.update({ - "guidance_scale": 7.5, - "flow_shift": 13, - "resolution": "1280x720", - "video_prompt_type": "I", - }) - elif model_type in ["hunyuan_custom_audio"]: - ui_defaults.update({ - "guidance_scale": 7.5, - "flow_shift": 13, - "video_prompt_type": "I", - }) - elif model_type in ["hunyuan_custom_edit"]: - ui_defaults.update({ - "guidance_scale": 7.5, - "flow_shift": 13, - "video_prompt_type": "MVAI", - "sliding_window_size": 129, - }) - elif model_type in ["hunyuan_avatar"]: - ui_defaults.update({ - "guidance_scale": 7.5, - "flow_shift": 5, - "skip_steps_start_step_perc": 25, - "video_length": 129, - "video_prompt_type": "I", - }) - elif model_type in ["vace_14B", "vace_multitalk_14B"]: - ui_defaults.update({ - "sliding_window_discard_last_frames": 0, - }) - + if len(ui_defaults.get("prompt","")) == 0: + ui_defaults["prompt"]= get_default_prompt(i2v) with open(defaults_filename, "w", encoding="utf-8") as f: json.dump(ui_defaults, f, indent=4) @@ -2489,7 +2523,8 @@ def load_wan_model(model_filename, model_type, base_model_type, model_def, quant config=cfg, checkpoint_dir="ckpts", model_filename=model_filename, - model_type = model_type, + model_type = model_type, + model_def = model_def, base_model_type=base_model_type, text_encoder_filename= get_wan_text_encoder_filename(text_encoder_quantization), quantizeTransformer = quantizeTransformer, @@ -2598,7 +2633,7 @@ def load_models(model_type): save_quantized = False print("Need to provide a non quantized model to create a quantized model to be saved") if save_quantized and len(modules) > 0: - print(f"Unable to create a finetune quantized model as some modules are declared in the finetune definition. If your finetune includes already the module weights you can remove the 'modules' entry and try again. If not you will need also to change temporarly the model 'architecture' to an architecture that wont require the modules part ('{model_types_no_module[0] if len(model_types_no_module)>0 else ''}' ?) to quantize and then add back the original 'modules' and 'architecture' entries.") + print(f"Unable to create a finetune quantized model as some modules are declared in the finetune definition. If your finetune includes already the module weights you can remove the 'modules' entry and try again. If not you will need also to change temporarly the model 'architecture' to an architecture that wont require the modules part ({modules}) to quantize and then add back the original 'modules' and 'architecture' entries.") save_quantized = False quantizeTransformer = not save_quantized and model_def !=None and transformer_quantization in ("int8", "fp8") and model_def.get("auto_quantize", False) and not "quanto" in model_filename if quantizeTransformer and len(modules) > 0: @@ -2931,8 +2966,10 @@ def refresh_gallery(state): #, msg prompt = task["prompt"] params = task["params"] model_type = params["model_type"] - model_type = get_base_model_type(model_type) - onemorewindow_visible = test_any_sliding_window(model_type) + base_model_type = get_base_model_type(model_type) + model_def = get_model_def(model_type) + is_image = model_def.get("image_outputs", False) + onemorewindow_visible = test_any_sliding_window(base_model_type) and not is_image enhanced = False if prompt.startswith("!enhanced!\n"): enhanced = True @@ -3047,7 +3084,7 @@ def select_video(state, input_file_list, event_data: gr.EventData): pp_values= [] pp_labels = [] extension = os.path.splitext(file_name)[-1] - if not extension in [".mp4"]: + if not has_video_file_extension(file_name): img = Image.open(file_name) width, height = img.size configs = None @@ -3064,6 +3101,8 @@ def select_video(state, input_file_list, event_data: gr.EventData): misc_labels += ["Model"] video_temporal_upsampling = configs.get("temporal_upsampling", "") video_spatial_upsampling = configs.get("spatial_upsampling", "") + video_film_grain_intensity = configs.get("film_grain_intensity", 0) + video_film_grain_saturation = configs.get("film_grain_saturation", 0.5) video_MMAudio_setting = configs.get("MMAudio_setting", 0) video_MMAudio_prompt = configs.get("MMAudio_prompt", "") video_MMAudio_neg_prompt = configs.get("MMAudio_neg_prompt", "") @@ -3074,6 +3113,9 @@ def select_video(state, input_file_list, event_data: gr.EventData): if len(video_temporal_upsampling) > 0: pp_values += [ video_temporal_upsampling ] pp_labels += [ "Upsampling" ] + if video_film_grain_intensity > 0: + pp_values += [ f"Intensity={video_film_grain_intensity}, Saturation={video_film_grain_saturation}" ] + pp_labels += [ "Film Grain" ] if video_MMAudio_setting != 0: pp_values += [ f'Prompt="{video_MMAudio_prompt}", Neg Prompt="{video_MMAudio_neg_prompt}", Seed={video_MMAudio_seed}' ] pp_labels += [ "MMAudio" ] @@ -3206,7 +3248,7 @@ def select_video(state, input_file_list, event_data: gr.EventData): else: html = get_default_video_info() visible= len(file_list) > 0 - return choice, html, gr.update(visible=visible and not is_image) , gr.update(visible=visible and is_image), gr.update(visible=visible and is_image) + return choice, html, gr.update(visible=visible and not is_image) , gr.update(visible=visible and is_image), gr.update(visible=visible and not is_image) def expand_slist(slist, num_inference_steps ): new_slist= [] inc = len(slist) / num_inference_steps @@ -3674,6 +3716,8 @@ def edit_video( seed, temporal_upsampling, spatial_upsampling, + film_grain_intensity, + film_grain_saturation, MMAudio_setting, MMAudio_prompt, MMAudio_neg_prompt, @@ -3694,6 +3738,7 @@ def edit_video( if configs == None: configs = { "type" : get_model_record("Post Processing") } has_already_audio = False + audio_tracks = [] if MMAudio_setting == 0: audio_tracks = extract_audio_tracks(video_source) has_already_audio = len(audio_tracks) > 0 @@ -3711,8 +3756,8 @@ def edit_video( frames_count = min(frames_count, 1000) sample = None - if len(temporal_upsampling) > 0 or len(spatial_upsampling) > 0: - send_cmd("progress", [0, get_latest_status(state,"Upsampling")]) + if len(temporal_upsampling) > 0 or len(spatial_upsampling) > 0 or film_grain_intensity > 0: + send_cmd("progress", [0, get_latest_status(state,"Upsampling" if len(temporal_upsampling) > 0 or len(spatial_upsampling) > 0 else "Adding Film Grain" )]) sample = get_resampled_video(video_source, 0, max_source_video_frames, fps) sample = sample.float().div_(127.5).sub_(1.).permute(-1,0,1,2) frames_count = sample.shape[1] @@ -3728,6 +3773,12 @@ def edit_video( sample = perform_spatial_upsampling(sample, spatial_upsampling ) configs["spatial_upsampling"] = spatial_upsampling + if film_grain_intensity > 0: + from postprocessing.film_grain import add_film_grain + sample = add_film_grain(sample, film_grain_intensity, film_grain_saturation) + configs["film_grain_intensity"] = film_grain_intensity + configs["film_grain_saturation"] = film_grain_saturation + any_mmaudio = MMAudio_setting != 0 and server_config.get("mmaudio_enabled", 0) != 0 and frames_count >=output_fps if any_mmaudio: download_mmaudio() @@ -3834,16 +3885,19 @@ def generate_video( image_refs, frames_positions, video_guide, + image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting, video_mask, + image_mask, control_net_weight, control_net_weight2, mask_expand, audio_guide, audio_guide2, audio_prompt_type, + speakers_locations, sliding_window_size, sliding_window_overlap, sliding_window_overlap_noise, @@ -3851,6 +3905,8 @@ def generate_video( remove_background_images_ref, temporal_upsampling, spatial_upsampling, + film_grain_intensity, + film_grain_saturation, MMAudio_setting, MMAudio_prompt, MMAudio_neg_prompt, @@ -3871,11 +3927,17 @@ def generate_video( model_filename, mode, ): + + def remove_temp_filenames(temp_filenames_list): + for temp_filename in temp_filenames_list: + if temp_filename!= None and os.path.isfile(temp_filename): + os.remove(temp_filename) + global wan_model, offloadobj, reload_needed, save_path gen = get_gen_info(state) torch.set_grad_enabled(False) - if mode == "edit": - edit_video(send_cmd, state, video_source, seed, temporal_upsampling, spatial_upsampling, MMAudio_setting, MMAudio_prompt, MMAudio_neg_prompt, repeat_generation) + if mode == "edit": + edit_video(send_cmd, state, video_source, seed, temporal_upsampling, spatial_upsampling, film_grain_intensity, film_grain_saturation, MMAudio_setting, MMAudio_prompt, MMAudio_neg_prompt, repeat_generation) return with lock: file_list = gen["file_list"] @@ -3884,6 +3946,23 @@ def generate_video( model_def = get_model_def(model_type) is_image = model_def.get("image_outputs", False) + if is_image: + batch_size = video_length + video_length = 1 + else: + batch_size = 1 + temp_filenames_list = [] + + if image_guide is not None and isinstance(image_guide, Image.Image): + video_guide = convert_image_to_video(image_guide) + temp_filenames_list.append(video_guide) + image_guide = None + + if image_mask is not None and isinstance(image_mask, Image.Image): + video_mask = convert_image_to_video(image_mask) + temp_filenames_list.append(video_mask) + image_mask = None + fit_canvas = server_config.get("fit_canvas", 0) @@ -3926,7 +4005,6 @@ def generate_video( trans = get_transformer_model(wan_model) audio_sampling_rate = 16000 - temp_filename = None base_model_type = get_base_model_type(model_type) prompts = prompt.split("\n") @@ -4012,6 +4090,11 @@ def generate_video( multitalk = base_model_type in ["multitalk", "vace_multitalk_14B"] flux_dev_kontext = base_model_type in ["flux_dev_kontext"] + if "B" in audio_prompt_type or "X" in audio_prompt_type: + from wan.multitalk.multitalk import parse_speakers_locations + speakers_bboxes, error = parse_speakers_locations(speakers_locations) + else: + speakers_bboxes = None if "L" in image_prompt_type: if len(file_list)>0: video_source = file_list[-1] @@ -4268,7 +4351,7 @@ def generate_video( window_start_frame = guide_start_frame - (reuse_frames if window_no > 1 else source_video_overlap_frames_count) if reuse_frames > 0: return_latent_slice = slice(-(reuse_frames - 1 + discard_last_frames ) // latent_size - 1, None if discard_last_frames == 0 else -(discard_last_frames // latent_size) ) - refresh_preview = {} + refresh_preview = {"image_guide" : None, "image_mask" : None} if fantasy: window_latent_start_frame = (window_start_frame ) // latent_size window_latent_size= (current_video_length - 1) // latent_size + 1 @@ -4426,7 +4509,8 @@ def generate_video( input_video= pre_video_guide if diffusion_forcing or ltxv or hunyuan_custom_edit else source_video, denoising_strength=denoising_strength, target_camera= target_camera, - frame_num=current_video_length if is_image else (current_video_length // latent_size)* latent_size + 1, + frame_num= (current_video_length // latent_size)* latent_size + 1, + batch_size = batch_size, height = height, width = width, fit_into_canvas = fit_canvas == 1, @@ -4469,11 +4553,13 @@ def generate_video( NAG_scale = NAG_scale, NAG_tau = NAG_tau, NAG_alpha = NAG_alpha, + speakers_bboxes =speakers_bboxes, offloadobj = offloadobj, ) except Exception as e: - if temp_filename!= None and os.path.isfile(temp_filename): - os.remove(temp_filename) + if len(control_audio_tracks) > 0: + cleanup_temp_audio_files(control_audio_tracks) + remove_temp_filenames(temp_filenames_list) offloadobj.unload_all() offload.unload_loras_from_model(trans) # if compile: @@ -4569,7 +4655,9 @@ def generate_video( if len(spatial_upsampling) > 0: sample = perform_spatial_upsampling(sample, spatial_upsampling ) - + if film_grain_intensity> 0: + from postprocessing.film_grain import add_film_grain + sample = add_film_grain(sample, film_grain_intensity, film_grain_saturation) if sliding_window : if frames_already_processed == None: frames_already_processed = sample @@ -4675,8 +4763,8 @@ def generate_video( offload.unload_loras_from_model(trans) if len(control_audio_tracks) > 0: cleanup_temp_audio_files(control_audio_tracks) - if temp_filename!= None and os.path.isfile(temp_filename): - os.remove(temp_filename) + + remove_temp_filenames(temp_filenames_list) def prepare_generate_video(state): @@ -5529,7 +5617,7 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None if "lset_name" in inputs: inputs.pop("lset_name") - unsaved_params = ["image_start", "image_end", "image_refs", "video_guide", "video_source", "video_mask", "audio_guide", "audio_guide2"] + unsaved_params = ["image_start", "image_end", "image_refs", "video_guide", "image_guide", "video_source", "video_mask", "image_mask", "audio_guide", "audio_guide2"] for k in unsaved_params: inputs.pop(k) if model_filename == None: model_filename = state["model_filename"] @@ -5629,28 +5717,36 @@ def video_to_source_video(state, input_file_list, choice): gr.Info("Selected Video was copied to Source Video input") return file_list[choice] -def image_to_ref_image(state, input_file_list, choice, target, target_name): +def image_to_ref_image_add(state, input_file_list, choice, target, target_name): file_list, file_settings_list = get_file_list(state, input_file_list) if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list): return gr.update() - gr.Info(f"Selected Image was copied to {target_name}") + gr.Info(f"Selected Image was added to {target_name}") if target == None: target =[] target.append( file_list[choice]) return target +def image_to_ref_image_set(state, input_file_list, choice, target, target_name): + file_list, file_settings_list = get_file_list(state, input_file_list) + if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list): return gr.update() + gr.Info(f"Selected Image was copied to {target_name}") + return file_list[choice] -def apply_post_processing(state, input_file_list, choice, PP_temporal_upsampling, PP_spatial_upsampling, PP_MMAudio_setting, PP_MMAudio_prompt, PP_MMAudio_neg_prompt, PP_MMAudio_seed, PP_repeat_generation): + +def apply_post_processing(state, input_file_list, choice, PP_temporal_upsampling, PP_spatial_upsampling, PP_film_grain_intensity, PP_film_grain_saturation, PP_MMAudio_setting, PP_MMAudio_prompt, PP_MMAudio_neg_prompt, PP_MMAudio_seed, PP_repeat_generation): gen = get_gen_info(state) file_list, file_settings_list = get_file_list(state, input_file_list) if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list) : - return gr.update(), gr.update() + return gr.update(), gr.update(), gr.update() if not file_list[choice].endswith(".mp4"): gr.Info("Post processing is only available with Videos") - return gr.update(), gr.update() + return gr.update(), gr.update(), gr.update() overrides = { "temporal_upsampling":PP_temporal_upsampling, "spatial_upsampling":PP_spatial_upsampling, + "film_grain_intensity": PP_film_grain_intensity, + "film_grain_saturation": PP_film_grain_saturation, "MMAudio_setting" : PP_MMAudio_setting, "MMAudio_prompt" : PP_MMAudio_prompt, "MMAudio_neg_prompt": PP_MMAudio_neg_prompt, @@ -5682,6 +5778,14 @@ def eject_video_from_gallery(state, input_file_list, choice): choice = min(choice, len(file_list)) return gr.Gallery(value = file_list, selected_index= choice), gr.update() if len(file_list) >0 else get_default_video_info(), gr.Row(visible= len(file_list) > 0) +def has_video_file_extension(filename): + extension = os.path.splitext(filename)[-1] + return extension in [".mp4"] + +def has_image_file_extension(filename): + extension = os.path.splitext(filename)[-1] + return extension in [".jpeg", ".jpg", ".png", ".bmp", ".tiff"] + def add_videos_to_gallery(state, input_file_list, choice, files_to_load): gen = get_gen_info(state) if files_to_load == None: @@ -5693,10 +5797,15 @@ def add_videos_to_gallery(state, input_file_list, choice, files_to_load): for file_path in files_to_load: file_settings, _ = get_settings_from_file(state, file_path, False, False, False) if file_settings == None: + fps = 0 try: - fps, width, height, frames_count = get_video_info(file_path) + if has_video_file_extension(file_path): + fps, width, height, frames_count = get_video_info(file_path) + elif has_image_file_extension(file_path): + width, height = Image.open(file_path).size + fps = 1 except: - fps = 0 + pass if fps == 0: invalid_files_count += 1 continue @@ -5878,15 +5987,18 @@ def save_inputs( image_refs, frames_positions, video_guide, + image_guide, keep_frames_video_guide, denoising_strength, video_mask, + image_mask, control_net_weight, control_net_weight2, mask_expand, audio_guide, audio_guide2, audio_prompt_type, + speakers_locations, sliding_window_size, sliding_window_overlap, sliding_window_overlap_noise, @@ -5894,6 +6006,8 @@ def save_inputs( remove_background_images_ref, temporal_upsampling, spatial_upsampling, + film_grain_intensity, + film_grain_saturation, MMAudio_setting, MMAudio_prompt, MMAudio_neg_prompt, @@ -6097,7 +6211,7 @@ def refresh_audio_prompt_type_remux(state, audio_prompt_type, remux): def refresh_audio_prompt_type_sources(state, audio_prompt_type, audio_prompt_type_sources): audio_prompt_type = del_in_sequence(audio_prompt_type, "XCPAB") audio_prompt_type = add_to_sequence(audio_prompt_type, audio_prompt_type_sources) - return audio_prompt_type, gr.update(visible = "A" in audio_prompt_type), gr.update(visible = "B" in audio_prompt_type) + return audio_prompt_type, gr.update(visible = "A" in audio_prompt_type), gr.update(visible = "B" in audio_prompt_type), gr.update(visible = ("B" in audio_prompt_type or "X" in audio_prompt_type)) def refresh_image_prompt_type(state, image_prompt_type): any_video_source = len(filter_letters(image_prompt_type, "VLG"))>0 @@ -6110,19 +6224,26 @@ def refresh_video_prompt_type_image_refs(state, video_prompt_type, video_prompt_ vace= test_vace_module(state["model_type"]) return video_prompt_type, gr.update(visible = visible),gr.update(visible = visible), gr.update(visible = visible and "F" in video_prompt_type_image_refs), gr.update(visible= ("F" in video_prompt_type_image_refs or "K" in video_prompt_type_image_refs or "V" in video_prompt_type) and vace ) -def refresh_video_prompt_type_video_mask(video_prompt_type, video_prompt_type_video_mask): +def refresh_video_prompt_type_video_mask(state, video_prompt_type, video_prompt_type_video_mask): video_prompt_type = del_in_sequence(video_prompt_type, "XYZWNA") video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_mask) visible= "A" in video_prompt_type - return video_prompt_type, gr.update(visible= visible), gr.update(visible= visible ) + model_type = state["model_type"] + model_def = get_model_def(model_type) + image_outputs = model_def.get("image_outputs", False) + return video_prompt_type, gr.update(visible= visible and not image_outputs), gr.update(visible= visible and image_outputs), gr.update(visible= visible ) def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide): video_prompt_type = del_in_sequence(video_prompt_type, "PDSLCMGUV") video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide) visible = "V" in video_prompt_type mask_visible = visible and "A" in video_prompt_type and not "U" in video_prompt_type + model_type = state["model_type"] + model_def = get_model_def(model_type) + image_outputs = model_def.get("image_outputs", False) + vace= test_vace_module(state["model_type"]) - return video_prompt_type, gr.update(visible = visible), gr.update(visible = visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type) and vace), gr.update(visible= visible and not "U" in video_prompt_type), gr.update(visible= mask_visible), gr.update(visible= mask_visible) + return video_prompt_type, gr.update(visible = visible and not image_outputs), gr.update(visible = visible and image_outputs), gr.update(visible = visible and not image_outputs), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type) and vace), gr.update(visible= visible and not "U" in video_prompt_type), gr.update(visible= mask_visible and not image_outputs), gr.update(visible= mask_visible and image_outputs), gr.update(visible= mask_visible) # def refresh_video_prompt_video_guide_trigger(state, video_prompt_type, video_prompt_type_video_guide): # video_prompt_type_video_guide = video_prompt_type_video_guide.split("#")[0] @@ -6223,6 +6344,8 @@ def get_resolution_choices(current_resolution_choice): if resolution_choices == None: resolution_choices=[ # 1080p + ("1920x1088 (21:9, 1080p)", "1920x1088"), + ("1088x1920 (9:21, 1080p)", "1088x1920"), ("1920x832 (21:9, 1080p)", "1920x832"), ("832x1920 (9:21, 1080p)", "832x1920"), # 720p @@ -6390,7 +6513,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non if vace: image_prompt_type_value= ui_defaults.get("image_prompt_type","") image_prompt_type_value = "" if image_prompt_type_value == "S" else image_prompt_type_value - image_prompt_type = gr.Radio( [("New Video", ""),("Continue Video File", "V"),("Continue Last Video", "L")], value =image_prompt_type_value, label="Source Video", show_label= False, visible= True , scale= 3) + image_prompt_type = gr.Radio( [("New Video", ""),("Continue Video File", "V"),("Continue Last Video", "L")], value =image_prompt_type_value, label="Source Video", show_label= False, visible= not image_outputs , scale= 3) image_start = gr.Gallery(visible = False) image_end = gr.Gallery(visible = False) @@ -6480,12 +6603,13 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non video_prompt_type_value= ui_defaults.get("video_prompt_type","") video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False) any_control_video = True + any_control_image = image_outputs with gr.Row(): if t2v: video_prompt_type_video_guide = gr.Dropdown( choices=[ ("Use Text Prompt Only", ""), - ("Video to Video guided by Text Prompt", "GUV"), + ("Image to Image guided by Text Prompt" if image_outputs else "Video to Video guided by Text Prompt", "GUV"), ], value=filter_letters(video_prompt_type_value, "GUV"), label="Video to Video", scale = 2, show_label= False, visible= True @@ -6493,8 +6617,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non elif vace: video_prompt_type_video_guide = gr.Dropdown( choices=[ - ("No Control Video", ""), - ("Keep Control Video Unchanged", "UV"), + ("No Control Image" if image_outputs else "No Control Video", ""), + ("Keep Control Image Unchanged" if image_outputs else "Keep Control Video Unchanged", "UV"), ("Transfer Human Motion", "PV"), ("Transfer Depth", "DV"), ("Transfer Shapes", "SV"), @@ -6510,19 +6634,20 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ("Transfer Shapes & Flow", "SLV"), ], value=filter_letters(video_prompt_type_value, "PDSLCMGUV"), - label="Control Video Process", scale = 2, visible= True, show_label= True, + label="Control Image Process" if image_outputs else "Control Video Process", scale = 2, visible= True, show_label= True, ) elif hunyuan_video_custom_edit: video_prompt_type_video_guide = gr.Dropdown( choices=[ - ("Inpaint Control Video", "MV"), + ("Inpaint Control Image" if image_outputs else "Inpaint Control Video", "MV"), ("Transfer Human Motion", "PMV"), ], value=filter_letters(video_prompt_type_value, "PDSLCMUV"), - label="Video to Video", scale = 3, visible= True, show_label= True, + label="Image to Image" if image_outputs else "Video to Video", scale = 3, visible= True, show_label= True, ) else: any_control_video = False + any_control_image = False video_prompt_type_video_guide = gr.Dropdown(visible= False) # video_prompt_video_guide_trigger = gr.Text(visible=False, value="") @@ -6578,16 +6703,17 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non visible = False, label="Start / Reference Images", scale = 2 ) - - video_guide = gr.Video(label= "Control Video", visible= "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None),) + image_guide = gr.Image(label= "Control Image", type ="pil", visible= image_outputs and "V" in video_prompt_type_value, value= ui_defaults.get("image_guide", None)) + video_guide = gr.Video(label= "Control Video", visible= (not image_outputs) and "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None)) + denoising_strength = gr.Slider(0, 1, value= ui_defaults.get("denoising_strength" ,0.5), step=0.01, label="Denoising Strength (the Lower the Closer to the Control Video)", visible = "G" in video_prompt_type_value, show_reset_button= False) - keep_frames_video_guide = gr.Text(value=ui_defaults.get("keep_frames_video_guide","") , visible= "V" in video_prompt_type_value, scale = 2, label= "Frames to keep in Control Video (empty=All, 1=first, a:b for a range, space to separate values)" ) #, -1=last + keep_frames_video_guide = gr.Text(value=ui_defaults.get("keep_frames_video_guide","") , visible= (not image_outputs) and "V" in video_prompt_type_value, scale = 2, label= "Frames to keep in Control Video (empty=All, 1=first, a:b for a range, space to separate values)" ) #, -1=last with gr.Column(visible= ("V" in video_prompt_type_value or "K" in video_prompt_type_value or "F" in video_prompt_type_value) and vace) as video_guide_outpainting_col: video_guide_outpainting_value = ui_defaults.get("video_guide_outpainting","#") video_guide_outpainting = gr.Text(value=video_guide_outpainting_value , visible= False) with gr.Group(): - video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Spatial Outpainting on Control Video, Background or Injected Reference Frames", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") ) + video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Spatial Outpainting on Control Video, Landscape or Injected Reference Frames", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") ) with gr.Row(visible = not video_guide_outpainting_value.startswith("#")) as video_guide_outpainting_row: video_guide_outpainting_value = video_guide_outpainting_value[1:] if video_guide_outpainting_value.startswith("#") else video_guide_outpainting_value video_guide_outpainting_list = [0] * 4 if len(video_guide_outpainting_value) == 0 else [int(v) for v in video_guide_outpainting_value.split(" ")] @@ -6595,8 +6721,9 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non video_guide_outpainting_bottom = gr.Slider(0, 100, value= video_guide_outpainting_list[1], step=5, label="Bottom %", show_reset_button= False) video_guide_outpainting_left = gr.Slider(0, 100, value= video_guide_outpainting_list[2], step=5, label="Left %", show_reset_button= False) video_guide_outpainting_right = gr.Slider(0, 100, value= video_guide_outpainting_list[3], step=5, label="Right %", show_reset_button= False) - - video_mask = gr.Video(label= "Video Mask Area (for Inpainting, white = Control Area, black = Unchanged)", visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value , value= ui_defaults.get("video_mask", None)) + any_image_mask = image_outputs and vace + image_mask = gr.Image(label= "Image Mask Area (for Inpainting, white = Control Area, black = Unchanged)", type ="pil", visible= image_outputs and "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value , value= ui_defaults.get("image_mask", None)) + video_mask = gr.Video(label= "Video Mask Area (for Inpainting, white = Control Area, black = Unchanged)", visible= (not image_outputs) and "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value , value= ui_defaults.get("video_mask", None)) mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value ) any_reference_image = vace or phantom or hunyuan_video_custom or hunyuan_video_avatar @@ -6630,7 +6757,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ("Two speakers, Speakers Audio sources are assumed to be played in Parallel", "PAB"), ], value= filter_letters(audio_prompt_type_value, "XCPAB"), - label="Voices: if there are multiple People the first is assumed to be to the Left and the second one to the Right", scale = 3, visible = multitalk + label="Voices", scale = 3, visible = multitalk ) else: audio_prompt_type_sources = gr.Dropdown( choices= [""], value = "", visible=False) @@ -6638,6 +6765,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non with gr.Row(visible = any_audio_voices_support) as audio_guide_row: audio_guide = gr.Audio(value= ui_defaults.get("audio_guide", None), type="filepath", label="Voice to follow", show_download_button= True, visible= any_audio_voices_support and "A" in audio_prompt_type_value ) audio_guide2 = gr.Audio(value= ui_defaults.get("audio_guide2", None), type="filepath", label="Voice to follow #2", show_download_button= True, visible= any_audio_voices_support and "B" in audio_prompt_type_value ) + with gr.Row(visible = any_audio_voices_support and ("B" in audio_prompt_type_value or "X" in audio_prompt_type_value) ) as speakers_locations_row: + speakers_locations = gr.Text( ui_defaults.get("speakers_locations", "0:45 55:100"), label="Speakers Locations separated by a Space. Each Location = Left:Right or a BBox Left:Top:Right:Bottom", visible= True) advanced_prompt = advanced_ui prompt_vars=[] @@ -6694,7 +6823,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ) with gr.Row(): if image_outputs: - video_length = gr.Slider(1, 16, value=ui_defaults.get("video_length", 1), step=1, label="Number of Images to Generate", visible = flux_dev_kontext) + video_length = gr.Slider(1, 16, value=ui_defaults.get("video_length", 1), step=1, label="Number of Images to Generate", visible = True) elif recammaster: video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s), locked", interactive= False, visible = True) else: @@ -6702,7 +6831,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non video_length = gr.Slider(min_frames, 737 if test_any_sliding_window(base_model_type) else 337, value=ui_defaults.get( "video_length", 81 if get_model_family(base_model_type)=="wan" else 97), - step=frames_step, label=f"Number of frames ({fps} = 1s)", interactive= True) + step=frames_step, label=f"Number of frames ({fps} = 1s)", visible = True, interactive= True) with gr.Row(visible = not lock_inference_steps) as inference_steps_row: num_inference_steps = gr.Slider(1, 100, value=ui_defaults.get("num_inference_steps",30), step=1, label="Number of Inference Steps", visible = True) @@ -6790,12 +6919,12 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ) skip_steps_start_step_perc = gr.Slider(0, 100, value=ui_defaults.get("skip_steps_start_step_perc",0), step=1, label="Skip Steps starting moment in % of generation") - with gr.Tab("Upsampling"): + with gr.Tab("Post Processing"): with gr.Column(): gr.Markdown("Upsampling - postprocessing that may improve fluidity and the size of the video") - def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , element_class= None, max_height= None): + def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grain_intensity, film_grain_saturation, element_class= None, max_height= None): temporal_upsampling = gr.Dropdown( choices=[ ("Disabled", ""), @@ -6803,7 +6932,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ("Rife x4 frames/s", "rife4"), ], value=temporal_upsampling, - visible=not image_outputs, + visible=True, scale = 1, label="Temporal Upsampling", elem_classes= element_class @@ -6822,8 +6951,13 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non elem_classes= element_class # max_height = max_height ) - return temporal_upsampling, spatial_upsampling - temporal_upsampling, spatial_upsampling = gen_upsampling_dropdowns(ui_defaults.get("temporal_upsampling", ""), ui_defaults.get("spatial_upsampling", "")) + + with gr.Row(): + film_grain_intensity = gr.Slider(0, 1, value=film_grain_intensity, step=0.01, label="Film Grain Intensity (0 = disabled)") + film_grain_saturation = gr.Slider(0.0, 1, value=film_grain_saturation, step=0.01, label="Film Grain Saturation") + + return temporal_upsampling, spatial_upsampling, film_grain_intensity, film_grain_saturation + temporal_upsampling, spatial_upsampling, film_grain_intensity, film_grain_saturation = gen_upsampling_dropdowns(ui_defaults.get("temporal_upsampling", ""), ui_defaults.get("spatial_upsampling", ""), ui_defaults.get("film_grain_intensity", 0), ui_defaults.get("film_grain_saturation", 0.5)) with gr.Tab("MMAudio", visible = server_config.get("mmaudio_enabled", 0) != 0 and not any_audio_track(base_model_type) and not image_outputs) as mmaudio_tab: with gr.Column(): @@ -7024,18 +7158,21 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non with gr.Row(**default_visibility) as image_buttons_row: video_info_to_start_image_btn = gr.Button("To Start Image", size ="sm", visible = any_start_image ) video_info_to_end_image_btn = gr.Button("To End Image", size ="sm", visible = any_end_image) + video_info_to_image_guide_btn = gr.Button("To Control Image", size ="sm", visible = any_control_image ) + video_info_to_image_mask_btn = gr.Button("To Mask Image", size ="sm", visible = any_image_mask) video_info_to_reference_image_btn = gr.Button("To Reference Image", size ="sm", visible = any_reference_image) video_info_eject_image_btn = gr.Button("Eject Image", size ="sm") with gr.Tab("Post Processing", id= "post_processing", visible = True) as video_postprocessing_tab: with gr.Group(elem_classes= "postprocess"): with gr.Column(): - PP_temporal_upsampling, PP_spatial_upsampling = gen_upsampling_dropdowns("", "", element_class ="postprocess") - with gr.Column() as PP_MMAudio_col: + PP_temporal_upsampling, PP_spatial_upsampling, PP_film_grain_intensity, PP_film_grain_saturation = gen_upsampling_dropdowns("", "", 0, 0.5, element_class ="postprocess") + with gr.Column(visible = server_config.get("mmaudio_enabled", 0) == 1) as PP_MMAudio_col: PP_MMAudio_setting, PP_MMAudio_prompt, PP_MMAudio_neg_prompt, _ = gen_mmaudio_dropdowns( 0, "" , "", None, element_class ="postprocess" ) PP_MMAudio_seed = gr.Slider(-1, 999999999, value=-1, step=1, label="Seed (-1 for random)") PP_repeat_generation = gr.Slider(1, 25.0, value=1, step=1, label="Number of Sample Videos to Generate") - - video_info_postprocessing_btn = gr.Button("Apply Postprocessing", size ="sm", visible=True) + with gr.Row(): + video_info_postprocessing_btn = gr.Button("Apply Postprocessing", size ="sm", visible=True) + video_info_eject_video2_btn = gr.Button("Eject Video", size ="sm", visible=True) with gr.Tab("Add Videos / Images", id= "video_add"): files_to_load = gr.Files(label= "Files to Load in Gallery", height=120) with gr.Row(): @@ -7092,8 +7229,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non video_prompt_type_video_guide, video_prompt_type_video_mask, video_prompt_type_image_refs, apg_col, audio_prompt_type_sources, audio_prompt_type_remux_row, video_guide_outpainting_col,video_guide_outpainting_top, video_guide_outpainting_bottom, video_guide_outpainting_left, video_guide_outpainting_right, video_guide_outpainting_checkbox, video_guide_outpainting_row, show_advanced, video_info_to_control_video_btn, video_info_to_video_source_btn, sample_solver_row, - video_buttons_row, image_buttons_row, video_postprocessing_tab, video_info_to_start_image_btn, video_info_to_end_image_btn, video_info_to_reference_image_btn, - NAG_col] # presets_column, + video_buttons_row, image_buttons_row, video_postprocessing_tab, video_info_to_start_image_btn, video_info_to_end_image_btn, video_info_to_reference_image_btn, video_info_to_image_guide_btn, video_info_to_image_mask_btn, + NAG_col, speakers_locations_row] # presets_column, if update_form: locals_dict = locals() gen_inputs = [state_dict if k=="state" else locals_dict[k] for k in inputs_names] + [state_dict] + extra_inputs @@ -7104,12 +7241,12 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non last_choice = gr.Number(value =-1, interactive= False, visible= False) audio_prompt_type_remux.change(fn=refresh_audio_prompt_type_remux, inputs=[state, audio_prompt_type, audio_prompt_type_remux], outputs=[audio_prompt_type]) - audio_prompt_type_sources.change(fn=refresh_audio_prompt_type_sources, inputs=[state, audio_prompt_type, audio_prompt_type_sources], outputs=[audio_prompt_type, audio_guide, audio_guide2]) + audio_prompt_type_sources.change(fn=refresh_audio_prompt_type_sources, inputs=[state, audio_prompt_type, audio_prompt_type_sources], outputs=[audio_prompt_type, audio_guide, audio_guide2, speakers_locations_row]) image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start, image_end, video_source, keep_frames_video_source] ) # video_prompt_video_guide_trigger.change(fn=refresh_video_prompt_video_guide_trigger, inputs=[state, video_prompt_type, video_prompt_video_guide_trigger], outputs=[video_prompt_type, video_prompt_type_video_guide, video_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, mask_expand]) video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_images_ref, frames_positions, video_guide_outpainting_col]) - video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, video_prompt_type, video_prompt_type_video_guide], outputs = [video_prompt_type, video_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, mask_expand]) - video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [video_prompt_type, video_prompt_type_video_mask], outputs = [video_prompt_type, video_mask, mask_expand]) + video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, video_prompt_type, video_prompt_type_video_guide], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, mask_expand]) + video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [state, video_prompt_type, video_prompt_type_video_mask], outputs = [video_prompt_type, video_mask, image_mask, mask_expand]) multi_prompts_gen_type.select(fn=refresh_prompt_labels, inputs=multi_prompts_gen_type, outputs=[prompt, wizard_prompt]) video_guide_outpainting_top.input(fn=update_video_guide_outpainting, inputs=[video_guide_outpainting, video_guide_outpainting_top, gr.State(0)], outputs = [video_guide_outpainting], trigger_mode="multiple" ) video_guide_outpainting_bottom.input(fn=update_video_guide_outpainting, inputs=[video_guide_outpainting, video_guide_outpainting_bottom,gr.State(1)], outputs = [video_guide_outpainting], trigger_mode="multiple" ) @@ -7119,7 +7256,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name]).then( fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars]) queue_df.select( fn=handle_celll_selection, inputs=state, outputs=[queue_df, modal_image_display, modal_container]) - output.select(select_video, [state, output], outputs=[last_choice, video_info, video_buttons_row, image_buttons_row, video_postprocessing_tab] ) + output.select(select_video, [state, output], outputs=[last_choice, video_info, video_buttons_row, image_buttons_row, video_postprocessing_tab], trigger_mode="multiple") preview_trigger.change(refresh_preview, inputs= [state], outputs= [preview]) def refresh_status_async(state, progress=gr.Progress()): @@ -7175,13 +7312,15 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ).then( fn=use_video_settings, inputs =[state, output, last_choice] , outputs= [model_choice, refresh_form_trigger]) video_info_add_videos_btn.click(fn=add_videos_to_gallery, inputs =[state, output, last_choice, files_to_load], outputs = [output, files_to_load, video_info_tabs] ) - gr.on(triggers=[video_info_eject_video_btn.click, video_info_eject_image_btn.click], fn=eject_video_from_gallery, inputs =[state, output, last_choice], outputs = [output, video_info, video_buttons_row] ) + gr.on(triggers=[video_info_eject_video_btn.click, video_info_eject_video2_btn.click, video_info_eject_image_btn.click], fn=eject_video_from_gallery, inputs =[state, output, last_choice], outputs = [output, video_info, video_buttons_row] ) video_info_to_control_video_btn.click(fn=video_to_control_video, inputs =[state, output, last_choice], outputs = [video_guide] ) video_info_to_video_source_btn.click(fn=video_to_source_video, inputs =[state, output, last_choice], outputs = [video_source] ) - video_info_to_start_image_btn.click(fn=image_to_ref_image, inputs =[state, output, last_choice, image_start, gr.State("Start Image")], outputs = [image_start] ) - video_info_to_end_image_btn.click(fn=image_to_ref_image, inputs =[state, output, last_choice, image_end, gr.State("End Image")], outputs = [image_end] ) - video_info_to_reference_image_btn.click(fn=image_to_ref_image, inputs =[state, output, last_choice, image_refs, gr.State("Ref Image")], outputs = [image_refs] ) - video_info_postprocessing_btn.click(fn=apply_post_processing, inputs =[state, output, last_choice, PP_temporal_upsampling, PP_spatial_upsampling, PP_MMAudio_setting, PP_MMAudio_prompt, PP_MMAudio_neg_prompt, PP_MMAudio_seed, PP_repeat_generation], outputs = [mode, generate_trigger, add_to_queue_trigger ] ) + video_info_to_start_image_btn.click(fn=image_to_ref_image_add, inputs =[state, output, last_choice, image_start, gr.State("Start Image")], outputs = [image_start] ) + video_info_to_end_image_btn.click(fn=image_to_ref_image_add, inputs =[state, output, last_choice, image_end, gr.State("End Image")], outputs = [image_end] ) + video_info_to_image_guide_btn.click(fn=image_to_ref_image_set, inputs =[state, output, last_choice, image_guide, gr.State("Control Image")], outputs = [image_guide] ) + video_info_to_image_mask_btn.click(fn=image_to_ref_image_set, inputs =[state, output, last_choice, image_mask, gr.State("Image Mask")], outputs = [image_mask] ) + video_info_to_reference_image_btn.click(fn=image_to_ref_image_add, inputs =[state, output, last_choice, image_refs, gr.State("Ref Image")], outputs = [image_refs] ) + video_info_postprocessing_btn.click(fn=apply_post_processing, inputs =[state, output, last_choice, PP_temporal_upsampling, PP_spatial_upsampling, PP_film_grain_intensity, PP_film_grain_saturation, PP_MMAudio_setting, PP_MMAudio_prompt, PP_MMAudio_neg_prompt, PP_MMAudio_seed, PP_repeat_generation], outputs = [mode, generate_trigger, add_to_queue_trigger ] ) save_lset_btn.click(validate_save_lset, inputs=[state, lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop]) delete_lset_btn.click(validate_delete_lset, inputs=[state, lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ]) confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then( @@ -7405,7 +7544,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ) return ( state, loras_choices, lset_name, state, - video_guide, video_mask, image_refs, prompt_enhancer_row, mmaudio_tab, PP_MMAudio_col + video_guide, image_guide, video_mask, image_mask, image_refs, prompt_enhancer_row, mmaudio_tab, PP_MMAudio_col ) @@ -8220,12 +8359,12 @@ def create_ui(): header = gr.Markdown(generate_header(transformer_type, compile, attention_mode), visible= True) with gr.Row(): ( state, loras_choices, lset_name, state, - video_guide, video_mask, image_refs, prompt_enhancer_row, mmaudio_tab, PP_MMAudio_col + video_guide, image_guide, video_mask, image_mask, image_refs, prompt_enhancer_row, mmaudio_tab, PP_MMAudio_col ) = generate_video_tab(model_choice=model_choice, header=header, main = main) with gr.Tab("Guides", id="info") as info_tab: generate_info_tab() with gr.Tab("Video Mask Creator", id="video_mask_creator") as video_mask_creator: - matanyone_app.display(main_tabs, tab_state, model_choice, video_guide, video_mask, image_refs) + matanyone_app.display(main_tabs, tab_state, model_choice, video_guide, image_guide, video_mask, image_mask, image_refs) if not args.lock_config: with gr.Tab("Downloads", id="downloads") as downloads_tab: generate_download_tab(lset_name, loras_choices, state) From afbf94e44f65f9b4e1370b62eada2977d791ed51 Mon Sep 17 00:00:00 2001 From: deepbeepmeep Date: Tue, 15 Jul 2025 23:43:02 +0200 Subject: [PATCH 4/9] fix ? --- wan/any2video.py | 2 +- wan/diffusion_forcing.py | 2 +- wan/utils/utils.py | 8 +++++++- wgp.py | 5 +++-- 4 files changed, 12 insertions(+), 5 deletions(-) diff --git a/wan/any2video.py b/wan/any2video.py index ba81250..67b1564 100644 --- a/wan/any2video.py +++ b/wan/any2video.py @@ -683,7 +683,7 @@ class WanAny2V: # init denoising updated_num_steps= len(timesteps) if callback != None: - from wgp import update_loras_slists + from wan.utils.utils import update_loras_slists update_loras_slists(self.model, loras_slists, updated_num_steps) callback(-1, None, True, override_num_inference_steps = updated_num_steps) diff --git a/wan/diffusion_forcing.py b/wan/diffusion_forcing.py index 9b5918a..85343fd 100644 --- a/wan/diffusion_forcing.py +++ b/wan/diffusion_forcing.py @@ -19,7 +19,7 @@ from wan.utils.utils import calculate_new_dimensions from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps) from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler -from wgp import update_loras_slists +from wan.utils.utils import update_loras_slists class DTT2V: diff --git a/wan/utils/utils.py b/wan/utils/utils.py index fe0b6bf..9e61fcb 100644 --- a/wan/utils/utils.py +++ b/wan/utils/utils.py @@ -33,7 +33,13 @@ def seed_everything(seed: int): torch.cuda.manual_seed(seed) if torch.backends.mps.is_available(): torch.mps.manual_seed(seed) - + +def update_loras_slists(trans, slists, num_inference_steps ): + from mmgp import offload + slists = [ expand_slist(slist, num_inference_steps ) if isinstance(slist, list) else slist for slist in slists ] + nos = [str(l) for l in range(len(slists))] + offload.activate_loras(trans, nos, slists ) + def resample(video_fps, video_frames_count, max_target_frames_count, target_fps, start_target_frame ): import math diff --git a/wgp.py b/wgp.py index 8e5335b..7888417 100644 --- a/wgp.py +++ b/wgp.py @@ -2107,8 +2107,9 @@ for file_path in models_def_paths: settings = json_def existing_model_def = models_def.get(model_type, None) if existing_model_def is not None: - existing_settings = models_def["settings"] - existing_settings.update(settings) + existing_settings = models_def.get("settings", None) + if existing_settings != None: + existing_settings.update(settings) existing_model_def.update(model_def) else: models_def[model_type] = model_def From bda410f367b452a5d23f34206c34e5af29d5d7d3 Mon Sep 17 00:00:00 2001 From: deepbeepmeep Date: Wed, 16 Jul 2025 00:47:23 +0200 Subject: [PATCH 5/9] added better missing model error handling --- wan/utils/utils.py | 9 +++++++++ wgp.py | 31 ++++++++++++++++++------------- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/wan/utils/utils.py b/wan/utils/utils.py index 9e61fcb..cbd34e9 100644 --- a/wan/utils/utils.py +++ b/wan/utils/utils.py @@ -34,6 +34,15 @@ def seed_everything(seed: int): if torch.backends.mps.is_available(): torch.mps.manual_seed(seed) +def expand_slist(slist, num_inference_steps ): + new_slist= [] + inc = len(slist) / num_inference_steps + pos = 0 + for i in range(num_inference_steps): + new_slist.append(slist[ int(pos)]) + pos += inc + return new_slist + def update_loras_slists(trans, slists, num_inference_steps ): from mmgp import offload slists = [ expand_slist(slist, num_inference_steps ) if isinstance(slist, list) else slist for slist in slists ] diff --git a/wgp.py b/wgp.py index 7888417..ab12f50 100644 --- a/wgp.py +++ b/wgp.py @@ -16,6 +16,7 @@ import json import wan from wan.utils import notification_sound from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES, VACE_SIZE_CONFIGS +from wan.utils.utils import expand_slist, update_loras_slists from wan.utils.utils import cache_video, convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video from wan.utils.utils import extract_audio_tracks, combine_video_with_audio_tracks, cleanup_temp_audio_files @@ -160,6 +161,7 @@ def process_prompt_and_add_tasks(state, model_choice): model_filename = state["model_filename"] model_type = state["model_type"] inputs = get_model_settings(state, model_type) + if model_choice != model_type or inputs ==None: raise gr.Error("Webform can not be used as the App has been restarted since the form was displayed. Please refresh the page") @@ -1740,6 +1742,8 @@ def get_model_type(model_filename): def get_model_family(model_type): model_type = get_base_model_type(model_type) + if model_type == None: + return "unknown" if "hunyuan" in model_type : return "hunyuan" elif "ltxv" in model_type: @@ -1799,7 +1803,8 @@ def get_computed_fps(force_fps, base_model_type , video_guide, video_source ): def get_model_name(model_type, description_container = [""]): model_def = get_model_def(model_type) - if model_def == None: raise Exception(f"Unknown model {model_type}") + if model_def == None: + return f"Unknown model {model_type}" model_name = model_def["name"] description = model_def["description"] description_container[0] = description @@ -1832,7 +1837,7 @@ def get_model_filename(model_type, quantization ="int8", dtype_policy = "", is_m if choices == None: raise Exception(f"Invalid Module Id '{model_type}'") else: model_def = models_def.get(model_type, None) - if model_def == None: raise Exception(f"Unknown model type {model_type}") + if model_def == None: return None URLs = model_def["URLs"] if isinstance(URLs, str): if len(stack) > 10: raise Exception(f"Circular Reference in Model URLs dependencies: {stack}") @@ -2120,11 +2125,18 @@ model_types = models_def.keys() displayed_model_types= [] for model_type in model_types: model_def = get_model_def(model_type) - if not (model_def != None and model_def.get("visible", True) == False): + if not model_def is None and model_def.get("visible", True): displayed_model_types.append(model_type) transformer_types = server_config.get("transformer_types", []) +new_transformer_types = [] +for model_type in transformer_types: + if get_model_def(model_type) == None: + print(f"Model '{model_type}' is missing. Either install it in the finetune folder or remove this model in wgp_config.json") + else: + new_transformer_types.append(model_type) +transformer_types = new_transformer_types transformer_type = server_config.get("last_model_type", None) advanced = server_config.get("last_advanced_choice", False) if args.advanced: advanced = True @@ -2729,7 +2741,7 @@ def generate_header(model_type, compile, attention_mode): description_container = [""] get_model_name(model_type, description_container) - model_filename = get_model_filename(model_type, transformer_quantization, transformer_dtype_policy) + model_filename = get_model_filename(model_type, transformer_quantization, transformer_dtype_policy) or "" description = description_container[0] header = "
" + description + "
" @@ -3250,14 +3262,7 @@ def select_video(state, input_file_list, event_data: gr.EventData): html = get_default_video_info() visible= len(file_list) > 0 return choice, html, gr.update(visible=visible and not is_image) , gr.update(visible=visible and is_image), gr.update(visible=visible and not is_image) -def expand_slist(slist, num_inference_steps ): - new_slist= [] - inc = len(slist) / num_inference_steps - pos = 0 - for i in range(num_inference_steps): - new_slist.append(slist[ int(pos)]) - pos += inc - return new_slist + def convert_image(image): from PIL import ImageOps @@ -7871,7 +7876,7 @@ def generate_info_tab(): def get_sorted_dropdown(dropdown_types): - families_order = {"wan":0, "ltxv":1, "hunyuan":2, "flux": 3 } + families_order = {"wan":0, "ltxv":1, "hunyuan":2, "flux": 3, "unknown": 100 } dropdown_classes = [ families_order[get_model_family(type)] for type in dropdown_types] dropdown_names = [ get_model_name(type) for type in dropdown_types] From 1e2d74ae7d5bf1759ae39d037952ba4c23cb85bb Mon Sep 17 00:00:00 2001 From: deepbeepmeep Date: Wed, 16 Jul 2025 00:49:16 +0200 Subject: [PATCH 6/9] added better missing model error handling --- wgp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wgp.py b/wgp.py index ab12f50..68bfc74 100644 --- a/wgp.py +++ b/wgp.py @@ -2132,8 +2132,8 @@ for model_type in model_types: transformer_types = server_config.get("transformer_types", []) new_transformer_types = [] for model_type in transformer_types: - if get_model_def(model_type) == None: - print(f"Model '{model_type}' is missing. Either install it in the finetune folder or remove this model in wgp_config.json") + if get_model_def(model_type) == None: + print(f"Model '{model_type}' is missing. Either install it in the finetune folder or remove this model from ley 'transformer_types' in wgp_config.json") else: new_transformer_types.append(model_type) transformer_types = new_transformer_types From 49aaa12689ecbd9276c5bae549e08196604e8861 Mon Sep 17 00:00:00 2001 From: deepbeepmeep Date: Wed, 16 Jul 2025 01:15:28 +0200 Subject: [PATCH 7/9] restored neg prompt --- wgp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/wgp.py b/wgp.py index 68bfc74..bdb7327 100644 --- a/wgp.py +++ b/wgp.py @@ -6487,6 +6487,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non state = gr.State(state_dict) trigger_refresh_input_type = gr.Text(interactive= False, visible= False) t2v = base_model_type in ["t2v"] + t2v_1_3B = base_model_type in ["t2v_1.3B"] flf2v = base_model_type == "flf2v_720p" diffusion_forcing = "diffusion_forcing" in model_filename ltxv = "ltxv" in model_filename @@ -6868,7 +6869,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non with gr.Row(visible = vace): control_net_weight = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight",1), step=0.1, label="Control Net Weight #1", visible=vace) control_net_weight2 = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight2",1), step=0.1, label="Control Net Weight #2", visible=vace) - negative_prompt = gr.Textbox(label="Negative Prompt (ignored if no Guidance that is if CFG = 1)", value=ui_defaults.get("negative_prompt", ""), visible = vace or t2v or test_class_i2v(model_type) ) + negative_prompt = gr.Textbox(label="Negative Prompt (ignored if no Guidance that is if CFG = 1)", value=ui_defaults.get("negative_prompt", ""), visible = not (hunyuan_t2v or hunyuan_i2v) ) with gr.Column(visible = vace or t2v or test_class_i2v(model_type)) as NAG_col: gr.Markdown("NAG enforces Negative Prompt even if no Guidance is set (CFG = 1), set NAG Scale to > 1 to enable it") with gr.Row(): From a356c6af4b72c5ed12b5ff7be2b27d97d059037d Mon Sep 17 00:00:00 2001 From: deepbeepmeep Date: Wed, 16 Jul 2025 18:09:06 +0200 Subject: [PATCH 8/9] fixed causvid scheduler --- flux/flux_main.py | 9 +++------ flux/sampling.py | 4 ++-- preprocessing/depth_anything_v2/layers/attention.py | 8 +------- preprocessing/depth_anything_v2/layers/block.py | 9 +-------- wan/utils/basic_flowmatch.py | 2 +- wan/utils/utils.py | 4 ++-- 6 files changed, 10 insertions(+), 26 deletions(-) diff --git a/flux/flux_main.py b/flux/flux_main.py index b782cc9..202eb44 100644 --- a/flux/flux_main.py +++ b/flux/flux_main.py @@ -72,10 +72,7 @@ class model_factory: if self._interrupt: return None - rng = torch.Generator(device="cuda") - if seed is None: - seed = rng.seed() - + device="cuda" if input_ref_images != None and len(input_ref_images) > 0: image_ref = input_ref_images[0] w, h = image_ref.size @@ -91,7 +88,7 @@ class model_factory: target_height=height, bs=batch_size, seed=seed, - device="cuda", + device=device, ) inp.pop("img_cond_orig") @@ -103,7 +100,7 @@ class model_factory: if x==None: return None # decode latents to pixel space x = unpack_latent(x) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + with torch.autocast(device_type=device, dtype=torch.bfloat16): x = self.vae.decode(x) x = x.clamp(-1, 1) diff --git a/flux/sampling.py b/flux/sampling.py index 7581dea..5c137f1 100644 --- a/flux/sampling.py +++ b/flux/sampling.py @@ -30,8 +30,8 @@ def get_noise( 2 * math.ceil(height / 16), 2 * math.ceil(width / 16), dtype=dtype, - generator=torch.Generator(device="cuda").manual_seed(seed), - ).to(device) + generator=torch.Generator(device=device).manual_seed(seed), + ) def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]: diff --git a/preprocessing/depth_anything_v2/layers/attention.py b/preprocessing/depth_anything_v2/layers/attention.py index f1cacb1..5a35c06 100644 --- a/preprocessing/depth_anything_v2/layers/attention.py +++ b/preprocessing/depth_anything_v2/layers/attention.py @@ -15,13 +15,7 @@ from torch import nn logger = logging.getLogger("dinov2") -try: - from xformers.ops import memory_efficient_attention, unbind, fmha - - XFORMERS_AVAILABLE = True -except ImportError: - logger.warning("xFormers not available") - XFORMERS_AVAILABLE = False +XFORMERS_AVAILABLE = False class Attention(nn.Module): diff --git a/preprocessing/depth_anything_v2/layers/block.py b/preprocessing/depth_anything_v2/layers/block.py index a711a1f..8de1d57 100644 --- a/preprocessing/depth_anything_v2/layers/block.py +++ b/preprocessing/depth_anything_v2/layers/block.py @@ -23,14 +23,7 @@ from .mlp import Mlp logger = logging.getLogger("dinov2") -try: - from xformers.ops import fmha - from xformers.ops import scaled_index_add, index_select_cat - - XFORMERS_AVAILABLE = True -except ImportError: - # logger.warning("xFormers not available") - XFORMERS_AVAILABLE = False +XFORMERS_AVAILABLE = False class Block(nn.Module): diff --git a/wan/utils/basic_flowmatch.py b/wan/utils/basic_flowmatch.py index 591510b..ceb4657 100644 --- a/wan/utils/basic_flowmatch.py +++ b/wan/utils/basic_flowmatch.py @@ -53,7 +53,7 @@ class FlowMatchScheduler(): else: sigma_ = self.sigmas[timestep_id + 1].reshape(-1, 1, 1, 1) prev_sample = sample + model_output * (sigma_ - sigma) - return prev_sample + return [prev_sample] def add_noise(self, original_samples, noise, timestep): """ diff --git a/wan/utils/utils.py b/wan/utils/utils.py index cbd34e9..53f3b73 100644 --- a/wan/utils/utils.py +++ b/wan/utils/utils.py @@ -493,10 +493,10 @@ def extract_audio_tracks(source_video, verbose=False, query_only= False): except ffmpeg.Error as e: print(f"FFmpeg error during audio extraction: {e}") - return [] + return 0 if query_only else [] except Exception as e: print(f"Error during audio extraction: {e}") - return [] + return 0 if query_only else [] def combine_video_with_audio_tracks(target_video, audio_tracks, output_video, verbose=False): """ From 0be33acb57ac525a824a7318598e74b6b04c9e06 Mon Sep 17 00:00:00 2001 From: deepbeepmeep <84379123+deepbeepmeep@users.noreply.github.com> Date: Thu, 17 Jul 2025 13:22:36 +0200 Subject: [PATCH 9/9] Fixed multitalk crash --- wan/multitalk/multitalk_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wan/multitalk/multitalk_utils.py b/wan/multitalk/multitalk_utils.py index 4722dae..8462390 100644 --- a/wan/multitalk/multitalk_utils.py +++ b/wan/multitalk/multitalk_utils.py @@ -128,7 +128,7 @@ def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, spli _, seq_lens, heads, _ = visual_q.shape class_num, _ = ref_target_masks.shape - x_ref_attn_maps = torch.zeros(class_num, seq_lens, dtype=visual_q.device, device=visual_q.dtype) + x_ref_attn_maps = torch.zeros(class_num, seq_lens, dtype=visual_q.dtype, device=visual_q.device) split_chunk = heads // split_num @@ -350,4 +350,4 @@ def adaptive_projected_guidance( diff = diff * scale_factor diff_parallel, diff_orthogonal = project(diff, pred_cond) normalized_update = diff_orthogonal + eta * diff_parallel - return normalized_update \ No newline at end of file + return normalized_update