diff --git a/README.md b/README.md index df859e6..52d7c89 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,13 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep ## 🔥 Latest Updates : +### August 21 2025: WanGP v8.0 - the killer of seven + +- Qwen Image Edit : Flux Kontext challenger (prompt driven image edition). You should use it at high res (1080p) if you want to preserve identity of the original people / objects. It works with Qwen Lora Lightning 4 steps. I have also unlocked all the resolutions for Qwen models. Bonus Zone: support for multiple image compositions +- On demand Prompt Enhancer (need to be enabled in Configuration Tab) that you can use to Enhance a Text Prompt before starting a Generation. You can refine the Enhanced Prompt or change the original Prompt. +- Choice of a non censored Prompt Enhancer. Beware this is one is VRAM hungry and will require 12 GB of VRAM to work +- Memory Profile customizable per model : useful to set for instance Profile 3 (preload the model entirely in VRAM) with only Image Generation models, if you have 24 GB of VRAM. In that case Generation will be much faster because with Image generators (contrary to Video generators) as a lot of time is wasted in offloading +- Expert Guidance Mode: change the Guidance during the generation up to 2 times. Very useful with Wan 2.2 Ligthning to reduce the slow motion effect. The idea is to insert a CFG phase before the 2 accelerated phases that follow and have no Guidance. I have added the finetune *Wan2.2 Vace Lightning 3 Phases 14B* with a prebuilt configuration. Please note that it is a 8 steps process although the lora lightning is 4 steps. This expert guidance mode is also available with Wan 2.1. ### August 12 2025: WanGP v7.7777 - Lucky Day(s) diff --git a/defaults/hunyuan.json b/defaults/hunyuan.json index b02a7ea..a6ba832 100644 --- a/defaults/hunyuan.json +++ b/defaults/hunyuan.json @@ -5,7 +5,7 @@ "architecture" : "hunyuan", "description": "Probably the best text 2 video model available.", "URLs": [ - "https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_720_bf16.safetensors.safetensors", + "https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_720_bf16.safetensors", "https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/hunyuan_video_720_quanto_int8.safetensors" ] } diff --git a/defaults/qwen_image_20B.json b/defaults/qwen_image_20B.json index 691afee..8e3fba5 100644 --- a/defaults/qwen_image_20B.json +++ b/defaults/qwen_image_20B.json @@ -2,12 +2,12 @@ "model": { "name": "Qwen Image 20B", "architecture": "qwen_image_20B", - "description": "Qwen Image is generative model that will very high quality images. It is one of the few models capable to generate in the image very long texts.", + "description": "Qwen Image is generative model that will generate very high quality images. It is one of the few models capable to generate in the image very long texts.", "URLs": [ "https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_20B_bf16.safetensors", "https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_20B_quanto_bf16_int8.safetensors" ], - "resolutions": [ ["1328x1328 (1:1)", "1328x1328"], + "xresolutions": [ ["1328x1328 (1:1)", "1328x1328"], ["1664x928 (16:9)", "1664x928"], ["928x1664 (9:16)", "928x1664"], ["1472x1140 (4:3)", "1472x1140"], @@ -16,6 +16,6 @@ "image_outputs": true }, "prompt": "draw a hat", - "resolution": "1280x720", + "resolution": "1920x1088", "batch_size": 1 } \ No newline at end of file diff --git a/defaults/qwen_image_edit_20B.json b/defaults/qwen_image_edit_20B.json new file mode 100644 index 0000000..183d418 --- /dev/null +++ b/defaults/qwen_image_edit_20B.json @@ -0,0 +1,19 @@ +{ + "model": { + "name": "Qwen Image Edit 20B", + "architecture": "qwen_image_edit_20B", + "description": "Qwen Image Edit is generative model that will generate very high quality images. It can be used to edit a Subject or combine multiple Subjects. It is one of the few models capable to generate in the image very long texts.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_edit_20B_bf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_edit_20B_quanto_bf16_int8.safetensors" + ], + "attention": { + "<89": "sdpa" + }, + "reference_image": true, + "image_outputs": true + }, + "prompt": "draw a hat", + "resolution": "1280x720", + "batch_size": 1 +} \ No newline at end of file diff --git a/defaults/vace_14B_lightning_3p_2_2.json b/defaults/vace_14B_lightning_3p_2_2.json new file mode 100644 index 0000000..cfa6ce5 --- /dev/null +++ b/defaults/vace_14B_lightning_3p_2_2.json @@ -0,0 +1,29 @@ +{ + "model": { + "name": "Wan2.2 Vace Lightning 3 Phases 14B", + "architecture": "vace_14B", + "modules": [ + "vace_14B" + ], + "description": "This finetune uses the Lightning 4 steps Loras Accelerator for Wan 2.2 but extend them to 8 steps in order to insert a CFG phase before the 2 accelerated phases with no Guidance. The ultimate goal is reduce the slow motion effect of these Loras Accelerators.", + "URLs": "t2v_2_2", + "URLs2": "t2v_2_2", + "loras": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Lightning_T2V-v1.1-A14B-4steps-lora_HIGH_fp16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Lightning_T2V-v1.1-A14B-4steps-lora_LOW_fp16.safetensors" + ], + "loras_multipliers": ["0;1;0", "0;0;1"], + "lock_guidance_phases": true, + "group": "wan2_2" + }, + "num_inference_steps": 8, + "guidance_phases": 3.1, + "guidance_scale": 3.5, + "guidance2_scale": 1, + "guidance3_scale": 1, + "switch_threshold": 965, + "switch_threshold2": 800, + "model_switch_phase": 2, + "flow_shift": 3, + "sample_solver": "euler" +} \ No newline at end of file diff --git a/docs/FINETUNES.md b/docs/FINETUNES.md index 32bc7c6..7f9dc3f 100644 --- a/docs/FINETUNES.md +++ b/docs/FINETUNES.md @@ -69,9 +69,25 @@ For instance if one adds a module *vace_14B* on top of a model with architecture -*visible* : by default assumed to be true. If set to false the model will no longer be visible. This can be useful if you create a finetune to override a default model and hide it. -*image_outputs* : turn any model that generates a video into a model that generates images. In fact it will adapt the user interface for image generation and ask the model to generate a video with a single frame. -In order to favor reusability the properties of *URLs*, *modules*, *loras* and *preload_URLs* can contain instead of a list of URLs a single text which corresponds to the id of a finetune or default model to reuse. +In order to favor reusability the properties of *URLs*, *modules*, *loras* and *preload_URLs* can contain instead of a list of URLs a single text which corresponds to the id of a finetune or default model to reuse. Instead of: +``` + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_high_mbf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_high_quanto_mbf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_high_quanto_mfp16_int8.safetensors" + ], + "URLs2": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_low_mbf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_low_quanto_mbf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_low_quanto_mfp16_int8.safetensors" + ], +``` + You can write: +``` + "URLs": "t2v_2_2", + "URLs2": "t2v_2_2", +``` -For example let’s say you have defined a *t2v_fusionix.json* file which contains the URLs to download the finetune. In the *vace_fusionix.json* you can write « URLs » : « fusionix » to reuse automatically the URLS already defined in the correspond file. Example of **model** subtree ``` diff --git a/docs/LORAS.md b/docs/LORAS.md index 73a88ac..89b2e59 100644 --- a/docs/LORAS.md +++ b/docs/LORAS.md @@ -7,18 +7,21 @@ Loras (Low-Rank Adaptations) allow you to customize video generation models by a Loras are organized in different folders based on the model they're designed for: ### Wan Text-to-Video Models -- `loras/` - General t2v loras +- `loras/` - General t2v loras for Wan 2.1 (t2v only) and for all Wan 2.2 models +Optional sub folders: - `loras/1.3B/` - Loras specifically for 1.3B models +- `loras/5B/` - Loras specifically for 1.3B models - `loras/14B/` - Loras specifically for 14B models ### Wan Image-to-Video Models -- `loras_i2v/` - Image-to-video loras +- `loras_i2v/` - Image-to-video loras for Wan 2.1 ### Other Models - `loras_hunyuan/` - Hunyuan Video t2v loras - `loras_hunyuan_i2v/` - Hunyuan Video i2v loras - `loras_ltxv/` - LTX Video loras - `loras_flux/` - Flux loras +- `loras_qwen/` - Qwen loras ## Custom Lora Directory @@ -40,7 +43,9 @@ python wgp.py --lora-dir-hunyuan /path/to/hunyuan/loras --lora-dir-ltxv /path/to 2. Launch WanGP 3. In the Advanced Tab, select the "Loras" section 4. Check the loras you want to activate -5. Set multipliers for each lora (default is 1.0) +5. Set multipliers for each lora (default is 1.0 if multiplier is not mentioned) + +If you store loras in the loras folder once WanGP has been launched, click the *Refresh* button at the top so that it can become selectable. ### Lora Multipliers @@ -53,7 +58,7 @@ Multipliers control the strength of each lora's effect: - First lora: 1.2 strength - Second lora: 0.8 strength -#### Time-based Multipliers +#### Time-based and Phase-based Multipliers For dynamic effects over generation steps, use comma-separated values: ``` 0.9,0.8,0.7 @@ -75,7 +80,7 @@ Also with Wan 2.2, if you have two loras and you want the first one to be applie 1;0 0;1 ``` -As usual, you can use any float for of multiplier and have a multiplier varries throughout one phase for one Lora: +As usual, you can use any float for a multiplier and have a multiplier varries throughout one phase for one Lora: ``` 0.9,0.8;1.2,1.1,1 ``` @@ -87,7 +92,31 @@ Here is another example for two loras: 0.5;0,0.7 ``` -Note that the syntax for multipliers can also be used in a Finetune model definition file (except that each multiplier definition is a string in a json list) +If one of several of your Lora multipliers are phased based (that is with a ";") and there are also Loras Multipliers that are only time based (don't have a ";" but have a ",") the time only multiplier will ignore the phases. For instance, let's assume we have a 6 steps denoising process in the following example: + +``` +1;0 +0;1 +0.8,0.7,0.5 +``` +Here the first lora will be as expected only used with the High Noise model and the second lora only used with the Low noise model. However for the third Lora: for steps 1-2 the multiplier will be (regadless of the phase) 0.8 then for steps 3-4 the multiplier will be 0.7 and finally for steps 5-6 the multiplier will be 0.5 + +You can use phased Lora multipliers even if have a single model (that is without any High / Low models) as Lora multiplier phases are aligned with Guidance phases. Let's assume you have defined 3 guidance phases (for instance guidance=3, then guidance=1.5 and at last guidance=1 ): +``` +0;1;0 +0;0;1 +``` +In that case no lora will be applied during the first phase when guidance is 3. Then the fist lora will be only used when guidance is 1.5 and the second lora only when guidance is 1. + +Best of all you can combine 3 guidance phases with High / Low models. Let's take this practical example with *Lightning 4/8 steps loras accelerators for Wan 2.2* where we want to increase the motion by adding some guidance at the very beginning (in that case a first phase that lasts only 1 step should be sufficient): +``` +Guidances: 3.5, 1 and 1 +Model transition: Phase 2-3 +Loras Multipliers: 0;1;0 0;0;1 +``` +Here during the first phase with guidance 3.5, the High model will be used but there won't be any lora at all. Then during phase 2 only the High lora will be used (which requires to set the guidance to 1). At last in phase 3 WanGP will switch to the Low model and then only the Low lora will be used. + +*Note that the syntax for multipliers can also be used in a Finetune model definition file (except that each multiplier definition is a string in a json list)* ## Lora Presets Lora Presets are combinations of loras with predefined multipliers and prompts. @@ -125,15 +154,22 @@ WanGP supports multiple lora formats: ## Loras Accelerators Most Loras are used to apply a specific style or to alter the content of the output of the generated video. However some Loras have been designed to tranform a model into a distilled model which requires fewer steps to generate a video. +Loras accelerators usually require to the set the Guidance to 1. Don't forget to do it as not only the quality of the generate video will be bad but it will two times slower. -You will find most *Loras Accelerators* here: +You will find most *Loras Accelerators* below: +- Wan 2.1 https://huggingface.co/DeepBeepMeep/Wan2.1/tree/main/loras_accelerators +- Wan 2.2 +https://huggingface.co/DeepBeepMeep/Wan2.2/tree/main/loras_accelerators +- Qwen: +https://huggingface.co/DeepBeepMeep/Qwen_image/tree/main/loras_accelerators + ### Setup Instructions 1. Download the Lora 2. Place it in your `loras/` directory if it is a t2v lora or in the `loras_i2v/` directory if it isa i2v lora -## FusioniX (or FusionX) Lora +## FusioniX (or FusionX) Lora for Wan 2.1 / Wan 2.2 If you need just one Lora accelerator use this one. It is a combination of multiple Loras acelerators (including Causvid below) and style loras. It will not only accelerate the video generation but it will also improve the quality. There are two versions of this lora whether you use it for t2v or i2v ### Usage @@ -148,8 +184,8 @@ If you need just one Lora accelerator use this one. It is a combination of multi 5. Set generation steps from 8-10 6. Generate! -## Safe-Forcing lightx2v Lora (Video Generation Accelerator) -Safeforcing Lora has been created by Kijai from the Safe-Forcing lightx2v distilled Wan model and can generate videos with only 2 steps and offers also a 2x speed improvement since it doesnt require classifier free guidance. It works on both t2v and i2v models +## Self-Forcing lightx2v Lora (Video Generation Accelerator) for Wan 2.1 / Wan 2.2 +Selg forcing Lora has been created by Kijai from the Self-Forcing lightx2v distilled Wan model and can generate videos with only 2 steps and offers also a 2x speed improvement since it doesnt require classifier free guidance. It works on both t2v and i2v models You will find it under the name of *Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32.safetensors* ### Usage @@ -165,7 +201,7 @@ You will find it under the name of *Wan21_T2V_14B_lightx2v_cfg_step_distill_lora 6. Generate! -## CausVid Lora (Video Generation Accelerator) +## CausVid Lora (Video Generation Accelerator) for Wan 2.1 / Wan 2.2 CausVid is a distilled Wan model that generates videos in 4-12 steps with 2x speed improvement. ### Usage @@ -188,11 +224,10 @@ CausVid is a distilled Wan model that generates videos in 4-12 steps with 2x spe *Note: Lower steps = lower quality (especially motion)* -## AccVid Lora (Video Generation Accelerator) +## AccVid Lora (Video Generation Accelerator) for Wan 2.1 / Wan 2.2 AccVid is a distilled Wan model that generates videos with a 2x speed improvement since classifier free guidance is no longer needed (that is cfg = 1). - ### Usage 1. Select a Wan t2v model (e.g., Wan 2.1 text2video 13B or Vace 13B) or Wan i2v model 2. Enable Advanced Mode @@ -201,6 +236,21 @@ AccVid is a distilled Wan model that generates videos with a 2x speed improvemen - Set Shift Scale = 5 4. The number steps remain unchanged compared to what you would use with the original model but it will be two times faster since classifier free guidance is not needed +## Lightx2v 4 steps Lora (Video Generation Accelerator) for Wan 2.2 +This lora is in fact composed of two loras, one for the High model and one for the Low Wan 2.2 model. + +You need to select these two loras and set the following Loras multipliers: + +``` +1;0 0;1 (the High lora should be only enabled when only the High model is loaded, same for the Low lora) +``` + +Don't forget to set guidance to 1 ! +## Qwen Image Lightning 4 steps / Lightning 8 steps +Very powerful lora that you can use to reduce the number of steps from 30 to only 4 ! +Just install the lora in *lora_qwen* folder, select the lora and set Guidance to 1 and the number of steps to 4 or 8 + + https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32.safetensors @@ -215,6 +265,7 @@ https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_T2V_14B_lightx2v_cfg - Loras are loaded on-demand to save VRAM - Multiple loras can be used simultaneously - Time-based multipliers don't use extra memory +- The order of Loras doesn't matter (as long as the loras multipliers are in the right order of course !) ## Finding Loras @@ -266,6 +317,7 @@ In the video, a man is presented. The man is in a city and looks at his watch. ## Troubleshooting ### Lora Not Working +0. If it is a lora accelerator, Guidance should be set to 1 1. Check if lora is compatible with your model size (1.3B vs 14B) 2. Verify lora format is supported 3. Try different multiplier values @@ -287,12 +339,13 @@ In the video, a man is presented. The man is in a city and looks at his watch. ```bash # Lora-related command line options ---lora-dir path # Path to t2v loras directory +--lora-dir path # Path to t2v loras directory --lora-dir-i2v path # Path to i2v loras directory --lora-dir-hunyuan path # Path to Hunyuan t2v loras --lora-dir-hunyuan-i2v path # Path to Hunyuan i2v loras --lora-dir-ltxv path # Path to LTX Video loras --lora-dir-flux path # Path to Flux loras +--lora-dir-qwen path # Path to Qwen loras --lora-preset preset # Load preset on startup --check-loras # Filter incompatible loras ``` \ No newline at end of file diff --git a/models/flux/flux_handler.py b/models/flux/flux_handler.py index 9bdc0cf..8b59027 100644 --- a/models/flux/flux_handler.py +++ b/models/flux/flux_handler.py @@ -15,9 +15,7 @@ class family_handler(): "image_outputs" : True, "no_negative_prompt" : True, } - if flux_schnell: - model_def_output["no_guidance"] = True - else: + if not flux_schnell: model_def_output["embedded_guidance"] = True diff --git a/models/hyvideo/hunyuan_handler.py b/models/hyvideo/hunyuan_handler.py index dc9ba96..79e3083 100644 --- a/models/hyvideo/hunyuan_handler.py +++ b/models/hyvideo/hunyuan_handler.py @@ -42,7 +42,11 @@ class family_handler(): extra_model_def["frames_minimum"] = 5 extra_model_def["frames_steps"] = 4 extra_model_def["sliding_window"] = False - extra_model_def["embedded_guidance"] = base_model_type in ["hunyuan", "hunyuan_i2v"] + if base_model_type in ["hunyuan", "hunyuan_i2v"]: + extra_model_def["embedded_guidance"] = True + else: + extra_model_def["guidance_max_phases"] = 1 + extra_model_def["cfg_star"] = base_model_type in [ "hunyuan_avatar", "hunyuan_custom_audio", "hunyuan_custom_edit", "hunyuan_custom"] extra_model_def["tea_cache"] = True extra_model_def["mag_cache"] = True diff --git a/models/ltx_video/ltxv_handler.py b/models/ltx_video/ltxv_handler.py index cfdd069..d35bcd4 100644 --- a/models/ltx_video/ltxv_handler.py +++ b/models/ltx_video/ltxv_handler.py @@ -12,9 +12,7 @@ class family_handler(): def query_model_def(base_model_type, model_def): LTXV_config = model_def.get("LTXV_config", "") distilled= "distilled" in LTXV_config - extra_model_def = { - "no_guidance": True, - } + extra_model_def = {} if distilled: extra_model_def.update({ "lock_inference_steps": True, diff --git a/models/qwen/pipeline_qwenimage.py b/models/qwen/pipeline_qwenimage.py index ad39ae5..f9bc871 100644 --- a/models/qwen/pipeline_qwenimage.py +++ b/models/qwen/pipeline_qwenimage.py @@ -27,9 +27,30 @@ from diffusers.utils.torch_utils import randn_tensor from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, AutoTokenizer from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage from diffusers import FlowMatchEulerDiscreteScheduler +from PIL import Image XLA_AVAILABLE = False +PREFERRED_QWENIMAGE_RESOLUTIONS = [ + (672, 1568), + (688, 1504), + (720, 1456), + (752, 1392), + (800, 1328), + (832, 1248), + (880, 1184), + (944, 1104), + (1024, 1024), + (1104, 944), + (1184, 880), + (1248, 832), + (1328, 800), + (1392, 752), + (1456, 720), + (1504, 688), + (1568, 672), +] + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -122,6 +143,18 @@ def retrieve_timesteps( timesteps = scheduler.timesteps return timesteps, num_inference_steps +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") class QwenImagePipeline(): #DiffusionPipeline r""" @@ -151,20 +184,27 @@ class QwenImagePipeline(): #DiffusionPipeline text_encoder, tokenizer, transformer, + processor, ): self.vae=vae self.text_encoder=text_encoder self.tokenizer=tokenizer self.transformer=transformer + self.processor = processor + self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16 self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible # by the patch size. So the vae scale factor is multiplied by the patch size to account for this self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) self.tokenizer_max_length = 1024 - self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" - self.prompt_template_encode_start_idx = 34 + if processor is not None: + self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode_start_idx = 64 + else: + self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode_start_idx = 34 self.default_sample_size = 128 def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): @@ -178,6 +218,7 @@ class QwenImagePipeline(): #DiffusionPipeline def _get_qwen_prompt_embeds( self, prompt: Union[str, List[str]] = None, + image: Optional[torch.Tensor] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): @@ -189,16 +230,36 @@ class QwenImagePipeline(): #DiffusionPipeline template = self.prompt_template_encode drop_idx = self.prompt_template_encode_start_idx txt = [template.format(e) for e in prompt] - txt_tokens = self.tokenizer( - txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" - ).to(device) - encoder_hidden_states = self.text_encoder( - input_ids=txt_tokens.input_ids, - attention_mask=txt_tokens.attention_mask, - output_hidden_states=True, - ) - hidden_states = encoder_hidden_states.hidden_states[-1] - split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + + if self.processor is not None and image is not None: + model_inputs = self.processor( + text=txt, + images=image, + padding=True, + return_tensors="pt", + ).to(device) + + outputs = self.text_encoder( + input_ids=model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + pixel_values=model_inputs.pixel_values, + image_grid_thw=model_inputs.image_grid_thw, + output_hidden_states=True, + ) + hidden_states = outputs.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) + else: + txt_tokens = self.tokenizer( + txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" + ).to(device) + encoder_hidden_states = self.text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_hidden_states.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] max_seq_len = max([e.size(0) for e in split_hidden_states]) @@ -216,6 +277,7 @@ class QwenImagePipeline(): #DiffusionPipeline def encode_prompt( self, prompt: Union[str, List[str]], + image: Optional[torch.Tensor] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.Tensor] = None, @@ -227,6 +289,8 @@ class QwenImagePipeline(): #DiffusionPipeline Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded + image (`torch.Tensor`, *optional*): + image to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): @@ -241,7 +305,7 @@ class QwenImagePipeline(): #DiffusionPipeline batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -251,6 +315,7 @@ class QwenImagePipeline(): #DiffusionPipeline return prompt_embeds, prompt_embeds_mask + def check_inputs( self, prompt, @@ -344,6 +409,29 @@ class QwenImagePipeline(): #DiffusionPipeline return latents + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + image_latents = (image_latents - latents_mean) / latents_std + + return image_latents def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to @@ -375,6 +463,7 @@ class QwenImagePipeline(): #DiffusionPipeline def prepare_latents( self, + image, batch_size, num_channels_latents, height, @@ -391,22 +480,41 @@ class QwenImagePipeline(): #DiffusionPipeline shape = (batch_size, 1, num_channels_latents, height, width) - if latents is not None: - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) - return latents.to(device=device, dtype=dtype), latent_image_ids + image_latents = None + if image is not None: + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latent_height, image_latent_width = image_latents.shape[3:] + image_latents = self._pack_latents( + image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width + ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + else: + latents = latents.to(device=device, dtype=dtype) - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) - - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) - - return latents, latent_image_ids + return latents, image_latents @property def guidance_scale(self): @@ -453,6 +561,7 @@ class QwenImagePipeline(): #DiffusionPipeline callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, + image = None, callback=None, pipeline=None, loras_slists=None, @@ -540,6 +649,10 @@ class QwenImagePipeline(): #DiffusionPipeline height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor + multiple_of = self.vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, @@ -567,13 +680,30 @@ class QwenImagePipeline(): #DiffusionPipeline else: batch_size = prompt_embeds.shape[0] device = "cuda" - # device = self._execution_device + + prompt_image = None + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + image = image[0] if isinstance(image, list) else image + image_height, image_width = self.image_processor.get_default_height_width(image) + aspect_ratio = image_width / image_height + if True : + _, image_width, image_height = min( + (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS + ) + image_width = image_width // multiple_of * multiple_of + image_height = image_height // multiple_of * multiple_of + # image = self.image_processor.resize(image, image_height, image_width) + image = image.resize((image_width,image_height), resample=Image.Resampling.LANCZOS) + prompt_image = image + image = self.image_processor.preprocess(image, image_height, image_width) + image = image.unsqueeze(2) has_neg_prompt = negative_prompt is not None or ( negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None ) do_true_cfg = true_cfg_scale > 1 and has_neg_prompt prompt_embeds, prompt_embeds_mask = self.encode_prompt( + image=prompt_image, prompt=prompt, prompt_embeds=prompt_embeds, prompt_embeds_mask=prompt_embeds_mask, @@ -583,6 +713,7 @@ class QwenImagePipeline(): #DiffusionPipeline ) if do_true_cfg: negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + image=prompt_image, prompt=negative_prompt, prompt_embeds=negative_prompt_embeds, prompt_embeds_mask=negative_prompt_embeds_mask, @@ -597,7 +728,8 @@ class QwenImagePipeline(): #DiffusionPipeline # 4. Prepare latent variables num_channels_latents = self.transformer.in_channels // 4 - latents, latent_image_ids = self.prepare_latents( + latents, image_latents = self.prepare_latents( + image, batch_size * num_images_per_prompt, num_channels_latents, height, @@ -607,7 +739,15 @@ class QwenImagePipeline(): #DiffusionPipeline generator, latents, ) - img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size + if image is not None: + img_shapes = [ + [ + (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2), + (1, image_height // self.vae_scale_factor // 2, image_width // self.vae_scale_factor // 2), + ] + ] * batch_size + else: + img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas @@ -639,6 +779,11 @@ class QwenImagePipeline(): #DiffusionPipeline if self.attention_kwargs is None: self._attention_kwargs = {} + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + negative_txt_seq_lens = ( + negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None + ) + # 6. Denoising loop self.scheduler.set_begin_index(0) updated_num_steps= len(timesteps) @@ -655,46 +800,54 @@ class QwenImagePipeline(): #DiffusionPipeline # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) + latent_model_input = latents + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + if do_true_cfg and joint_pass: noise_pred, neg_noise_pred = self.transformer( - hidden_states=latents, + hidden_states=latent_model_input, timestep=timestep / 1000, guidance=guidance, encoder_hidden_states_mask_list=[prompt_embeds_mask,negative_prompt_embeds_mask], encoder_hidden_states_list=[prompt_embeds, negative_prompt_embeds], img_shapes=img_shapes, - txt_seq_lens_list=[prompt_embeds_mask.sum(dim=1).tolist(),negative_prompt_embeds_mask.sum(dim=1).tolist()], + txt_seq_lens_list=[txt_seq_lens, negative_txt_seq_lens], attention_kwargs=self.attention_kwargs, **kwargs ) if noise_pred == None: return None + noise_pred = noise_pred[:, : latents.size(1)] + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] else: noise_pred = self.transformer( - hidden_states=latents, + hidden_states=latent_model_input, timestep=timestep / 1000, guidance=guidance, encoder_hidden_states_mask_list=[prompt_embeds_mask], encoder_hidden_states_list=[prompt_embeds], img_shapes=img_shapes, - txt_seq_lens_list=[prompt_embeds_mask.sum(dim=1).tolist()], + txt_seq_lens_list=[txt_seq_lens], attention_kwargs=self.attention_kwargs, **kwargs )[0] if noise_pred == None: return None + noise_pred = noise_pred[:, : latents.size(1)] if do_true_cfg: neg_noise_pred = self.transformer( - hidden_states=latents, + hidden_states=latent_model_input, timestep=timestep / 1000, guidance=guidance, encoder_hidden_states_mask_list=[negative_prompt_embeds_mask], encoder_hidden_states_list=[negative_prompt_embeds], img_shapes=img_shapes, - txt_seq_lens_list=[negative_prompt_embeds_mask.sum(dim=1).tolist()], + txt_seq_lens_list=[negative_txt_seq_lens], attention_kwargs=self.attention_kwargs, **kwargs )[0] if neg_noise_pred == None: return None + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] if do_true_cfg: comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) diff --git a/models/qwen/qwen_handler.py b/models/qwen/qwen_handler.py index 6db6a76..c6004e1 100644 --- a/models/qwen/qwen_handler.py +++ b/models/qwen/qwen_handler.py @@ -13,7 +13,8 @@ class family_handler(): "image_outputs" : True, "sample_solvers":[ ("Default", "default"), - ("Lightning", "lightning")] + ("Lightning", "lightning")], + "guidance_max_phases" : 1, } @@ -21,7 +22,7 @@ class family_handler(): @staticmethod def query_supported_types(): - return ["qwen_image_20B"] + return ["qwen_image_20B", "qwen_image_edit_20B"] @staticmethod def query_family_maps(): @@ -41,7 +42,7 @@ class family_handler(): return { "repoId" : "DeepBeepMeep/Qwen_image", "sourceFolderList" : ["", "Qwen2.5-VL-7B-Instruct"], - "fileList" : [ ["qwen_vae.safetensors", "qwen_vae_config.json"], ["merges.txt", "tokenizer_config.json", "config.json", "vocab.json"] + computeList(text_encoder_filename) ] + "fileList" : [ ["qwen_vae.safetensors", "qwen_vae_config.json"], ["merges.txt", "tokenizer_config.json", "config.json", "vocab.json", "video_preprocessor_config.json", "preprocessor_config.json"] + computeList(text_encoder_filename) ] } @staticmethod diff --git a/models/qwen/qwen_main.py b/models/qwen/qwen_main.py index ccaa758..8f38427 100644 --- a/models/qwen/qwen_main.py +++ b/models/qwen/qwen_main.py @@ -12,10 +12,24 @@ from .transformer_qwenimage import QwenImageTransformer2DModel from diffusers.utils import logging, replace_example_docstring from diffusers.utils.torch_utils import randn_tensor -from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, AutoTokenizer +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, AutoTokenizer, Qwen2VLProcessor from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage from diffusers import FlowMatchEulerDiscreteScheduler from .pipeline_qwenimage import QwenImagePipeline +from PIL import Image +from shared.utils.utils import calculate_new_dimensions + +def stitch_images(img1, img2): + # Resize img2 to match img1's height + width1, height1 = img1.size + width2, height2 = img2.size + new_width2 = int(width2 * height1 / height2) + img2_resized = img2.resize((new_width2, height1), Image.Resampling.LANCZOS) + + stitched = Image.new('RGB', (width1 + new_width2, height1)) + stitched.paste(img1, (0, 0)) + stitched.paste(img2_resized, (width1, 0)) + return stitched class model_factory(): def __init__( @@ -35,9 +49,16 @@ class model_factory(): transformer_filename = model_filename[0] - tokenizer = AutoTokenizer.from_pretrained(os.path.join(checkpoint_dir,"Qwen2.5-VL-7B-Instruct")) + processor = None + tokenizer = None + if base_model_type == "qwen_image_edit_20B": + processor = Qwen2VLProcessor.from_pretrained(os.path.join(checkpoint_dir,"Qwen2.5-VL-7B-Instruct")) + else: + tokenizer = AutoTokenizer.from_pretrained(os.path.join(checkpoint_dir,"Qwen2.5-VL-7B-Instruct")) - with open("configs/qwen_image_20B.json", 'r', encoding='utf-8') as f: + + base_config_file = "configs/qwen_image_20B.json" + with open(base_config_file, 'r', encoding='utf-8') as f: transformer_config = json.load(f) transformer_config.pop("_diffusers_version") transformer_config.pop("_class_name") @@ -46,9 +67,22 @@ class model_factory(): from accelerate import init_empty_weights with init_empty_weights(): transformer = QwenImageTransformer2DModel(**transformer_config) - offload.load_model_data(transformer, transformer_filename) + source = model_def.get("source", None) + + if source is not None: + offload.load_model_data(transformer, source) + else: + offload.load_model_data(transformer, transformer_filename) # transformer = offload.fast_load_transformers_model("transformer_quanto.safetensors", writable_tensors= True , modelClass=QwenImageTransformer2DModel, defaultConfigPath="transformer_config.json") + if not source is None: + from wgp import save_model + save_model(transformer, model_type, dtype, None) + + if save_quantized: + from wgp import save_quantized_model + save_quantized_model(transformer, model_type, model_filename[0], dtype, base_config_file) + text_encoder = offload.fast_load_transformers_model(text_encoder_filename, writable_tensors= True , modelClass=Qwen2_5_VLForConditionalGeneration, defaultConfigPath= os.path.join(checkpoint_dir, "Qwen2.5-VL-7B-Instruct", "config.json")) # text_encoder = offload.fast_load_transformers_model(text_encoder_filename, do_quantize=True, writable_tensors= True , modelClass=Qwen2_5_VLForConditionalGeneration, defaultConfigPath="text_encoder_config.json", verboseLevel=2) # text_encoder.to(torch.float16) @@ -56,11 +90,12 @@ class model_factory(): vae = offload.fast_load_transformers_model( os.path.join(checkpoint_dir,"qwen_vae.safetensors"), writable_tensors= True , modelClass=AutoencoderKLQwenImage, defaultConfigPath=os.path.join(checkpoint_dir,"qwen_vae_config.json")) - self.pipeline = QwenImagePipeline(vae, text_encoder, tokenizer, transformer) + self.pipeline = QwenImagePipeline(vae, text_encoder, tokenizer, transformer, processor) self.vae=vae self.text_encoder=text_encoder self.tokenizer=tokenizer self.transformer=transformer + self.processor = processor def generate( self, @@ -141,19 +176,31 @@ class model_factory(): if n_prompt is None or len(n_prompt) == 0: n_prompt= "text, watermark, copyright, blurry, low resolution" + if input_ref_images is not None: + # image stiching method + stiched = input_ref_images[0] + if "K" in video_prompt_type : + w, h = input_ref_images[0].size + height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas) + + for new_img in input_ref_images[1:]: + stiched = stitch_images(stiched, new_img) + input_ref_images = [stiched] + image = self.pipeline( - prompt=input_prompt, - negative_prompt=n_prompt, - width=width, - height=height, - num_inference_steps=sampling_steps, - num_images_per_prompt = batch_size, - true_cfg_scale=guide_scale, - callback = callback, - pipeline=self, - loras_slists=loras_slists, - joint_pass = joint_pass, - generator=torch.Generator(device="cuda").manual_seed(seed) + prompt=input_prompt, + negative_prompt=n_prompt, + image = input_ref_images, + width=width, + height=height, + num_inference_steps=sampling_steps, + num_images_per_prompt = batch_size, + true_cfg_scale=guide_scale, + callback = callback, + pipeline=self, + loras_slists=loras_slists, + joint_pass = joint_pass, + generator=torch.Generator(device="cuda").manual_seed(seed) ) if image is None: return None return image.transpose(0, 1) diff --git a/models/qwen/transformer_qwenimage.py b/models/qwen/transformer_qwenimage.py index 042032d..8751648 100644 --- a/models/qwen/transformer_qwenimage.py +++ b/models/qwen/transformer_qwenimage.py @@ -26,6 +26,7 @@ from diffusers.models.embeddings import TimestepEmbedding, Timesteps from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm from shared.attention import pay_attention +import functools def get_timestep_embedding( timesteps: torch.Tensor, @@ -150,8 +151,8 @@ class QwenEmbedRope(nn.Module): super().__init__() self.theta = theta self.axes_dim = axes_dim - pos_index = torch.arange(1024) - neg_index = torch.arange(1024).flip(0) * -1 - 1 + pos_index = torch.arange(4096) + neg_index = torch.arange(4096).flip(0) * -1 - 1 self.pos_freqs = torch.cat( [ self.rope_params(pos_index, self.axes_dim[0], self.theta), @@ -170,7 +171,7 @@ class QwenEmbedRope(nn.Module): ) self.rope_cache = {} - # 是否使用 scale rope + # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART self.scale_rope = scale_rope def rope_params(self, index, dim, theta=10000): @@ -194,38 +195,54 @@ class QwenEmbedRope(nn.Module): if isinstance(video_fhw, list): video_fhw = video_fhw[0] - frame, height, width = video_fhw - rope_key = f"{frame}_{height}_{width}" + if not isinstance(video_fhw, list): + video_fhw = [video_fhw] - if rope_key not in self.rope_cache: - seq_lens = frame * height * width - freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) - freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) - freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) - if self.scale_rope: - freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) - freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) - freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) - freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + vid_freqs = [] + max_vid_index = 0 + for idx, fhw in enumerate(video_fhw): + frame, height, width = fhw + rope_key = f"{idx}_{height}_{width}" + if not torch.compiler.is_compiling(): + if rope_key not in self.rope_cache: + self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx) + video_freq = self.rope_cache[rope_key] else: - freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) - freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + video_freq = self._compute_video_freqs(frame, height, width, idx) + video_freq = video_freq.to(device) + vid_freqs.append(video_freq) - freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) - self.rope_cache[rope_key] = freqs.clone().contiguous() - vid_freqs = self.rope_cache[rope_key] - - if self.scale_rope: - max_vid_index = max(height // 2, width // 2) - else: - max_vid_index = max(height, width) + if self.scale_rope: + max_vid_index = max(height // 2, width // 2, max_vid_index) + else: + max_vid_index = max(height, width, max_vid_index) max_len = max(txt_seq_lens) txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + vid_freqs = torch.cat(vid_freqs, dim=0) return vid_freqs, txt_freqs + @functools.lru_cache(maxsize=None) + def _compute_video_freqs(self, frame, height, width, idx=0): + seq_lens = frame * height * width + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + + freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + if self.scale_rope: + freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) + freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) + freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + else: + freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + return freqs.clone().contiguous() + class QwenDoubleStreamAttnProcessor2_0: """ diff --git a/models/wan/any2video.py b/models/wan/any2video.py index c879701..c2493ab 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -393,7 +393,11 @@ class WanAny2V: sampling_steps=50, guide_scale=5.0, guide2_scale = 5.0, + guide3_scale = 5.0, switch_threshold = 0, + switch2_threshold = 0, + guide_phases= 1 , + model_switch_phase = 1, n_prompt="", seed=-1, callback = None, @@ -427,6 +431,7 @@ class WanAny2V: prefix_frames_count = 0, image_mode = 0, window_no = 0, + set_header_text = None, **bbargs ): @@ -745,16 +750,25 @@ class WanAny2V: # init denoising updated_num_steps= len(timesteps) - if callback != None: - from shared.utils.loras_mutipliers import update_loras_slists - model_switch_step = updated_num_steps - for i, t in enumerate(timesteps): - if t <= switch_threshold: - model_switch_step = i - break - update_loras_slists(self.model, loras_slists, updated_num_steps, model_switch_step= model_switch_step) - if self.model2 is not None: update_loras_slists(self.model2, loras_slists, updated_num_steps, model_switch_step= model_switch_step) - callback(-1, None, True, override_num_inference_steps = updated_num_steps) + + denoising_extra = "" + from shared.utils.loras_mutipliers import update_loras_slists, get_model_switch_steps + + phase_switch_step, phase_switch_step2, phases_description = get_model_switch_steps(timesteps, updated_num_steps, guide_phases, 0 if self.model2 is None else model_switch_phase, switch_threshold, switch2_threshold ) + if len(phases_description) > 0: set_header_text(phases_description) + guidance_switch_done = guidance_switch2_done = False + if guide_phases > 1: denoising_extra = f"Phase 1/{guide_phases} High Noise" if self.model2 is not None else f"Phase 1/{guide_phases}" + def update_guidance(step_no, t, guide_scale, new_guide_scale, guidance_switch_done, switch_threshold, trans, phase_no, denoising_extra): + if guide_phases >= phase_no and not guidance_switch_done and t <= switch_threshold: + if model_switch_phase == phase_no-1 and self.model2 is not None: trans = self.model2 + guide_scale, guidance_switch_done = new_guide_scale, True + denoising_extra = f"Phase {phase_no}/{guide_phases} {'Low Noise' if trans == self.model2 else 'High Noise'}" if self.model2 is not None else f"Phase {phase_no}/{guide_phases}" + callback(step_no-1, denoising_extra = denoising_extra) + return guide_scale, guidance_switch_done, trans, denoising_extra + update_loras_slists(self.model, loras_slists, updated_num_steps, phase_switch_step= phase_switch_step, phase_switch_step2= phase_switch_step2) + if self.model2 is not None: update_loras_slists(self.model2, loras_slists, updated_num_steps, phase_switch_step= phase_switch_step, phase_switch_step2= phase_switch_step2) + callback(-1, None, True, override_num_inference_steps = updated_num_steps, denoising_extra = denoising_extra) + if sample_scheduler != None: scheduler_kwargs = {} if isinstance(sample_scheduler, FlowMatchScheduler) else {"generator": seed_g} @@ -766,16 +780,12 @@ class WanAny2V: text_momentumbuffer = MomentumBuffer(apg_momentum) audio_momentumbuffer = MomentumBuffer(apg_momentum) - guidance_switch_done = False # denoising trans = self.model for i, t in enumerate(tqdm(timesteps)): - if not guidance_switch_done and t <= switch_threshold: - guide_scale = guide2_scale - if self.model2 is not None: trans = self.model2 - guidance_switch_done = True - + guide_scale, guidance_switch_done, trans, denoising_extra = update_guidance(i, t, guide_scale, guide2_scale, guidance_switch_done, switch_threshold, trans, 2, denoising_extra) + guide_scale, guidance_switch2_done, trans, denoising_extra = update_guidance(i, t, guide_scale, guide3_scale, guidance_switch2_done, switch2_threshold, trans, 3, denoising_extra) offload.set_step_no_for_lora(trans, i) timestep = torch.stack([t]) @@ -920,7 +930,7 @@ class WanAny2V: if trim_frames > 0: latents_preview= latents_preview[:, :,:-trim_frames] if image_outputs: latents_preview= latents_preview[:, :,:1] if len(latents_preview) > 1: latents_preview = latents_preview.transpose(0,2) - callback(i, latents_preview[0], False) + callback(i, latents_preview[0], False, denoising_extra =denoising_extra ) latents_preview = None if timestep_injection: diff --git a/models/wan/df_handler.py b/models/wan/df_handler.py index bf7c266..bc79e2e 100644 --- a/models/wan/df_handler.py +++ b/models/wan/df_handler.py @@ -24,6 +24,8 @@ class family_handler(): extra_model_def["sliding_window"] = True extra_model_def["skip_layer_guidance"] = True extra_model_def["tea_cache"] = True + extra_model_def["guidance_max_phases"] = 1 + return extra_model_def @staticmethod diff --git a/models/wan/modules/model.py b/models/wan/modules/model.py index d36efc2..8f9695b 100644 --- a/models/wan/modules/model.py +++ b/models/wan/modules/model.py @@ -767,6 +767,15 @@ class WanModel(ModelMixin, ConfigMixin): if first == None: return sd + new_sd = {} + + # for k,v in sd.items(): + # if k.endswith("modulation.diff"): + # pass + # else: + # new_sd[ k] = v + # sd = new_sd + # if first.startswith("blocks."): # new_sd = {} # for k,v in sd.items(): diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index 6d222a2..dc55a15 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -7,6 +7,9 @@ def test_class_i2v(base_model_type): def test_class_1_3B(base_model_type): return base_model_type in [ "vace_1.3B", "t2v_1.3B", "recam_1.3B","phantom_1.3B","fun_inp_1.3B"] +def test_multitalk(base_model_type): + return base_model_type in ["multitalk", "vace_multitalk_14B", "i2v_2_2_multitalk"] + class family_handler(): @staticmethod @@ -79,11 +82,11 @@ class family_handler(): extra_model_def["no_steps_skipping"] = True i2v = test_class_i2v(base_model_type) extra_model_def["i2v_class"] = i2v - extra_model_def["multitalk_class"] = base_model_type in ["multitalk", "vace_multitalk_14B", "i2v_2_2_multitalk"] + extra_model_def["multitalk_class"] = test_multitalk(base_model_type) vace_class = base_model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B"] extra_model_def["vace_class"] = vace_class - if base_model_type in ["multitalk", "vace_multitalk_14B", "i2v_2_2_multitalk"]: + if test_multitalk(base_model_type): fps = 25 elif base_model_type in ["fantasy"]: fps = 23 @@ -92,7 +95,7 @@ class family_handler(): else: fps = 16 extra_model_def["fps"] =fps - + multiple_submodels = "URLs2" in model_def if vace_class: frames_minimum, frames_steps = 17, 4 else: @@ -101,12 +104,13 @@ class family_handler(): "frames_minimum" : frames_minimum, "frames_steps" : frames_steps, "sliding_window" : base_model_type in ["multitalk", "t2v", "fantasy"] or test_class_i2v(base_model_type) or vace_class, #"ti2v_2_2", - "guidance_max_phases" : 2, + "multiple_submodels" : multiple_submodels, + "guidance_max_phases" : 3, "skip_layer_guidance" : True, "cfg_zero" : True, "cfg_star" : True, "adaptive_projected_guidance" : True, - "tea_cache" : not (base_model_type in ["i2v_2_2", "ti2v_2_2" ] or "URLs2" in model_def), + "tea_cache" : not (base_model_type in ["i2v_2_2", "ti2v_2_2" ] or multiple_submodels), "mag_cache" : True, "sample_solvers":[ ("unipc", "unipc"), @@ -157,8 +161,7 @@ class family_handler(): @staticmethod def get_rgb_factors(base_model_type ): from shared.RGB_factors import get_rgb_factors - if base_model_type == "ti2v_2_2": return None, None - latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("wan") + latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("wan", base_model_type) return latent_rgb_factors, latent_rgb_factors_bias @staticmethod @@ -218,6 +221,10 @@ class family_handler(): if ui_defaults.get("sample_solver", "") == "": ui_defaults["sample_solver"] = "unipc" + if settings_version < 2.24: + if model_def.get("multiple_submodels", False) or ui_defaults.get("switch_threshold", 0) > 0: + ui_defaults["guidance_phases"] = 2 + @staticmethod def update_default_settings(base_model_type, model_def, ui_defaults): ui_defaults.update({ @@ -233,7 +240,6 @@ class family_handler(): ui_defaults.update({ "guidance_scale": 5.0, "flow_shift": 7, # 11 for 720p - "audio_guidance_scale": 4, "sliding_window_discard_last_frames" : 4, "sample_solver" : "euler", "adaptive_switch" : 1, @@ -258,4 +264,10 @@ class family_handler(): "image_prompt_type": "T", }) - \ No newline at end of file + if test_multitalk(base_model_type): + ui_defaults["audio_guidance_scale"] = 4 + + if model_def.get("multiple_submodels", False): + ui_defaults["guidance_phases"] = 2 + + \ No newline at end of file diff --git a/shared/RGB_factors.py b/shared/RGB_factors.py index 5ec1e59..6e865fa 100644 --- a/shared/RGB_factors.py +++ b/shared/RGB_factors.py @@ -1,26 +1,80 @@ -# thanks Comfyui for the rgb factors -def get_rgb_factors(model_family): +# thanks Comfyui for the rgb factors (https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py) +def get_rgb_factors(model_family, model_type = None): if model_family == "wan": - latent_channels = 16 - latent_dimensions = 3 - latent_rgb_factors = [ - [-0.1299, -0.1692, 0.2932], - [ 0.0671, 0.0406, 0.0442], - [ 0.3568, 0.2548, 0.1747], - [ 0.0372, 0.2344, 0.1420], - [ 0.0313, 0.0189, -0.0328], - [ 0.0296, -0.0956, -0.0665], - [-0.3477, -0.4059, -0.2925], - [ 0.0166, 0.1902, 0.1975], - [-0.0412, 0.0267, -0.1364], - [-0.1293, 0.0740, 0.1636], - [ 0.0680, 0.3019, 0.1128], - [ 0.0032, 0.0581, 0.0639], - [-0.1251, 0.0927, 0.1699], - [ 0.0060, -0.0633, 0.0005], - [ 0.3477, 0.2275, 0.2950], - [ 0.1984, 0.0913, 0.1861] - ] + if model_type =="ti2v_2_2": + latent_channels = 48 + latent_dimensions = 3 + latent_rgb_factors = [ + [ 0.0119, 0.0103, 0.0046], + [-0.1062, -0.0504, 0.0165], + [ 0.0140, 0.0409, 0.0491], + [-0.0813, -0.0677, 0.0607], + [ 0.0656, 0.0851, 0.0808], + [ 0.0264, 0.0463, 0.0912], + [ 0.0295, 0.0326, 0.0590], + [-0.0244, -0.0270, 0.0025], + [ 0.0443, -0.0102, 0.0288], + [-0.0465, -0.0090, -0.0205], + [ 0.0359, 0.0236, 0.0082], + [-0.0776, 0.0854, 0.1048], + [ 0.0564, 0.0264, 0.0561], + [ 0.0006, 0.0594, 0.0418], + [-0.0319, -0.0542, -0.0637], + [-0.0268, 0.0024, 0.0260], + [ 0.0539, 0.0265, 0.0358], + [-0.0359, -0.0312, -0.0287], + [-0.0285, -0.1032, -0.1237], + [ 0.1041, 0.0537, 0.0622], + [-0.0086, -0.0374, -0.0051], + [ 0.0390, 0.0670, 0.2863], + [ 0.0069, 0.0144, 0.0082], + [ 0.0006, -0.0167, 0.0079], + [ 0.0313, -0.0574, -0.0232], + [-0.1454, -0.0902, -0.0481], + [ 0.0714, 0.0827, 0.0447], + [-0.0304, -0.0574, -0.0196], + [ 0.0401, 0.0384, 0.0204], + [-0.0758, -0.0297, -0.0014], + [ 0.0568, 0.1307, 0.1372], + [-0.0055, -0.0310, -0.0380], + [ 0.0239, -0.0305, 0.0325], + [-0.0663, -0.0673, -0.0140], + [-0.0416, -0.0047, -0.0023], + [ 0.0166, 0.0112, -0.0093], + [-0.0211, 0.0011, 0.0331], + [ 0.1833, 0.1466, 0.2250], + [-0.0368, 0.0370, 0.0295], + [-0.3441, -0.3543, -0.2008], + [-0.0479, -0.0489, -0.0420], + [-0.0660, -0.0153, 0.0800], + [-0.0101, 0.0068, 0.0156], + [-0.0690, -0.0452, -0.0927], + [-0.0145, 0.0041, 0.0015], + [ 0.0421, 0.0451, 0.0373], + [ 0.0504, -0.0483, -0.0356], + [-0.0837, 0.0168, 0.0055] + ] + else: + latent_channels = 16 + latent_dimensions = 3 + latent_rgb_factors = [ + [-0.1299, -0.1692, 0.2932], + [ 0.0671, 0.0406, 0.0442], + [ 0.3568, 0.2548, 0.1747], + [ 0.0372, 0.2344, 0.1420], + [ 0.0313, 0.0189, -0.0328], + [ 0.0296, -0.0956, -0.0665], + [-0.3477, -0.4059, -0.2925], + [ 0.0166, 0.1902, 0.1975], + [-0.0412, 0.0267, -0.1364], + [-0.1293, 0.0740, 0.1636], + [ 0.0680, 0.3019, 0.1128], + [ 0.0032, 0.0581, 0.0639], + [-0.1251, 0.0927, 0.1699], + [ 0.0060, -0.0633, 0.0005], + [ 0.3477, 0.2275, 0.2950], + [ 0.1984, 0.0913, 0.1861] + ] latent_rgb_factors_bias = [-0.1835, -0.0868, -0.3360] diff --git a/shared/utils/loras_mutipliers.py b/shared/utils/loras_mutipliers.py index 6d2acca..58cc9a9 100644 --- a/shared/utils/loras_mutipliers.py +++ b/shared/utils/loras_mutipliers.py @@ -8,7 +8,7 @@ def preparse_loras_multipliers(loras_multipliers): loras_multipliers = " ".join(loras_mult_choices_list) return loras_multipliers.split(" ") -def expand_slist(slists_dict, mult_no, num_inference_steps, model_switch_step ): +def expand_slist(slists_dict, mult_no, num_inference_steps, model_switch_step, model_switch_step2 ): def expand_one(slist, num_inference_steps): if not isinstance(slist, list): slist = [slist] new_slist= [] @@ -23,13 +23,20 @@ def expand_slist(slists_dict, mult_no, num_inference_steps, model_switch_step ): phase1 = slists_dict["phase1"][mult_no] phase2 = slists_dict["phase2"][mult_no] - if isinstance(phase1, float) and isinstance(phase2, float) and phase1 == phase2: - return phase1 - return expand_one(phase1, model_switch_step) + expand_one(phase2, num_inference_steps - model_switch_step) + phase3 = slists_dict["phase3"][mult_no] + shared = slists_dict["shared"][mult_no] + if shared: + if isinstance(phase1, float): return phase1 + return expand_one(phase1, num_inference_steps) + else: + if isinstance(phase1, float) and isinstance(phase2, float) and isinstance(phase3, float) and phase1 == phase2 and phase2 == phase3: return phase1 + return expand_one(phase1, model_switch_step) + expand_one(phase2, model_switch_step2 - model_switch_step) + expand_one(phase3, num_inference_steps - model_switch_step2) -def parse_loras_multipliers(loras_multipliers, nb_loras, num_inference_steps, merge_slist = None, max_phases = 2, model_switch_step = None): +def parse_loras_multipliers(loras_multipliers, nb_loras, num_inference_steps, merge_slist = None, nb_phases = 2, model_switch_step = None, model_switch_step2 = None): if model_switch_step is None: model_switch_step = num_inference_steps + if model_switch_step2 is None: + model_switch_step2 = num_inference_steps def is_float(element: any) -> bool: if element is None: return False @@ -40,8 +47,11 @@ def parse_loras_multipliers(loras_multipliers, nb_loras, num_inference_steps, me return False loras_list_mult_choices_nums = [] slists_dict = { "model_switch_step": model_switch_step} + slists_dict = { "model_switch_step2": model_switch_step2} slists_dict["phase1"] = phase1 = [1.] * nb_loras slists_dict["phase2"] = phase2 = [1.] * nb_loras + slists_dict["phase3"] = phase3 = [1.] * nb_loras + slists_dict["shared"] = shared = [False] * nb_loras if isinstance(loras_multipliers, list) or len(loras_multipliers) > 0: list_mult_choices_list = preparse_loras_multipliers(loras_multipliers)[:nb_loras] @@ -51,41 +61,66 @@ def parse_loras_multipliers(loras_multipliers, nb_loras, num_inference_steps, me mult = mult.strip() phase_mult = mult.split(";") shared_phases = len(phase_mult) <=1 - if len(phase_mult) > max_phases: - return "", "", f"Loras can not be defined for more than {max_phases} Denoising phase{'s' if max_phases>1 else ''} for this model" + if not shared_phases and len(phase_mult) != nb_phases : + return "", "", f"if the ';' syntax is used for one Lora multiplier, the multipliers for its {nb_phases} denoising phases should be specified for this multiplier" for phase_no, mult in enumerate(phase_mult): - if phase_no > 0: current_phase = phase2 + if phase_no == 1: + current_phase = phase2 + elif phase_no == 2: + current_phase = phase3 if "," in mult: multlist = mult.split(",") slist = [] for smult in multlist: if not is_float(smult): - return "", "", f"Lora sub value no {i+1} ({smult}) in Multiplier definition '{multlist}' is invalid" + return "", "", f"Lora sub value no {i+1} ({smult}) in Multiplier definition '{multlist}' is invalid in Phase {phase_no+1}" slist.append(float(smult)) else: if not is_float(mult): return "", "", f"Lora Multiplier no {i+1} ({mult}) is invalid" slist = float(mult) if shared_phases: - phase1[i] = phase2[i] = slist + phase1[i] = phase2[i] = phase3[i] = slist + shared[i] = True else: current_phase[i] = slist else: - phase1[i] = phase2[i] = float(mult) + phase1[i] = phase2[i] = phase3[i] = float(mult) + shared[i] = True if merge_slist is not None: slists_dict["phase1"] = phase1 = merge_slist["phase1"] + phase1 slists_dict["phase2"] = phase2 = merge_slist["phase2"] + phase2 + slists_dict["phase3"] = phase3 = merge_slist["phase3"] + phase3 + slists_dict["shared"] = shared = merge_slist["shared"] + shared - loras_list_mult_choices_nums = [ expand_slist(slists_dict, i, num_inference_steps, model_switch_step ) for i in range(len(phase1)) ] + loras_list_mult_choices_nums = [ expand_slist(slists_dict, i, num_inference_steps, model_switch_step, model_switch_step2 ) for i in range(len(phase1)) ] loras_list_mult_choices_nums = [ slist[0] if isinstance(slist, list) else slist for slist in loras_list_mult_choices_nums ] return loras_list_mult_choices_nums, slists_dict, "" -def update_loras_slists(trans, slists_dict, num_inference_steps, model_switch_step = None ): +def update_loras_slists(trans, slists_dict, num_inference_steps, phase_switch_step = None, phase_switch_step2 = None ): from mmgp import offload sz = len(slists_dict["phase1"]) - slists = [ expand_slist(slists_dict, i, num_inference_steps, model_switch_step ) for i in range(sz) ] + slists = [ expand_slist(slists_dict, i, num_inference_steps, phase_switch_step, phase_switch_step2 ) for i in range(sz) ] nos = [str(l) for l in range(sz)] offload.activate_loras(trans, nos, slists ) + + +def get_model_switch_steps(timesteps, total_num_steps, guide_phases, model_switch_phase, switch_threshold, switch2_threshold ): + model_switch_step = model_switch_step2 = None + for i, t in enumerate(timesteps): + if guide_phases >=2 and model_switch_step is None and t <= switch_threshold: model_switch_step = i + if guide_phases >=3 and model_switch_step2 is None and t <= switch2_threshold: model_switch_step2 = i + if model_switch_step is None: model_switch_step = total_num_steps + if model_switch_step2 is None: model_switch_step2 = total_num_steps + phases_description = "" + if guide_phases > 1: + phases_description = "Denoising Steps: " + phases_description += f" Phase 1 = None" if model_switch_step == 0 else f" Phase 1 = 1:{ min(model_switch_step,total_num_steps) }" + if model_switch_step < total_num_steps: + phases_description += f", Phase 2 = None" if model_switch_step == model_switch_step2 else f", Phase 2 = {model_switch_step +1}:{ min(model_switch_step2,total_num_steps) }" + if guide_phases > 2 and model_switch_step2 < total_num_steps: + phases_description += f", Phase 3 = {model_switch_step2 +1}:{ total_num_steps}" + return model_switch_step, model_switch_step2, phases_description diff --git a/shared/utils/notification_sound.py b/shared/utils/notification_sound.py index 64ffd8f..26d1966 100644 --- a/shared/utils/notification_sound.py +++ b/shared/utils/notification_sound.py @@ -8,17 +8,22 @@ import sys import threading import time import numpy as np -import os + os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "1" -def generate_notification_beep(volume=50, sample_rate=44100): +_cached_waveforms = {} +_sample_rate = 44100 +_mixer_initialized = False +_mixer_lock = threading.Lock() + +def _generate_notification_beep(volume=50, sample_rate=_sample_rate): """Generate pleasant C major chord notification sound""" if volume == 0: return np.array([]) - + volume = max(0, min(100, volume)) - - # Volume curve mapping: 25%->50%, 50%->75%, 75%->100%, 100%->105% + + # Volume curve mapping if volume <= 25: volume_mapped = (volume / 25.0) * 0.5 elif volume <= 50: @@ -26,232 +31,191 @@ def generate_notification_beep(volume=50, sample_rate=44100): elif volume <= 75: volume_mapped = 0.75 + ((volume - 50) / 25.0) * 0.25 else: - volume_mapped = 1.0 + ((volume - 75) / 25.0) * 0.05 # Only 5% boost instead of 15% - + volume_mapped = 1.0 + ((volume - 75) / 25.0) * 0.05 + volume = volume_mapped - + # C major chord frequencies - freq_c = 261.63 # C4 - freq_e = 329.63 # E4 - freq_g = 392.00 # G4 - + freq_c, freq_e, freq_g = 261.63, 329.63, 392.00 duration = 0.8 t = np.linspace(0, duration, int(sample_rate * duration), False) - + # Generate chord components - wave_c = np.sin(freq_c * 2 * np.pi * t) * 0.4 - wave_e = np.sin(freq_e * 2 * np.pi * t) * 0.3 - wave_g = np.sin(freq_g * 2 * np.pi * t) * 0.2 - - wave = wave_c + wave_e + wave_g - - # Prevent clipping + wave = ( + np.sin(freq_c * 2 * np.pi * t) * 0.4 + + np.sin(freq_e * 2 * np.pi * t) * 0.3 + + np.sin(freq_g * 2 * np.pi * t) * 0.2 + ) + + # Normalize max_amplitude = np.max(np.abs(wave)) if max_amplitude > 0: wave = wave / max_amplitude * 0.8 - + # ADSR envelope def apply_adsr_envelope(wave_data): length = len(wave_data) attack_time = int(0.2 * length) decay_time = int(0.1 * length) release_time = int(0.5 * length) - + envelope = np.ones(length) - + if attack_time > 0: envelope[:attack_time] = np.power(np.linspace(0, 1, attack_time), 3) - + if decay_time > 0: - start_idx = attack_time - end_idx = attack_time + decay_time + start_idx, end_idx = attack_time, attack_time + decay_time envelope[start_idx:end_idx] = np.linspace(1, 0.85, decay_time) - + if release_time > 0: start_idx = length - release_time envelope[start_idx:] = 0.85 * np.exp(-4 * np.linspace(0, 1, release_time)) - + return wave_data * envelope - + wave = apply_adsr_envelope(wave) - + # Simple low-pass filter def simple_lowpass_filter(signal, cutoff_ratio=0.8): window_size = max(3, int(len(signal) * 0.001)) if window_size % 2 == 0: window_size += 1 - + kernel = np.ones(window_size) / window_size - padded = np.pad(signal, window_size//2, mode='edge') - filtered = np.convolve(padded, kernel, mode='same') - return filtered[window_size//2:-window_size//2] - + padded = np.pad(signal, window_size // 2, mode="edge") + filtered = np.convolve(padded, kernel, mode="same") + return filtered[window_size // 2 : -window_size // 2] + wave = simple_lowpass_filter(wave) - - # Add reverb effect + + # Add reverb if len(wave) > sample_rate // 4: delay_samples = int(0.12 * sample_rate) reverb = np.zeros_like(wave) reverb[delay_samples:] = wave[:-delay_samples] * 0.08 wave = wave + reverb - - # Apply volume first, then normalize to prevent clipping - wave = wave * volume * 0.5 - - # Final normalization with safety margin - max_amplitude = np.max(np.abs(wave)) - if max_amplitude > 0.85: # If approaching clipping threshold - wave = wave / max_amplitude * 0.85 # More conservative normalization - - return wave -_mixer_lock = threading.Lock() -def play_audio_with_pygame(audio_data, sample_rate=44100): - """ - Play audio with clean stereo output - sounds like single notification from both speakers - """ + # Apply volume & final normalize + wave = wave * volume * 0.5 + max_amplitude = np.max(np.abs(wave)) + if max_amplitude > 0.85: + wave = wave / max_amplitude * 0.85 + + return wave + +def _get_cached_waveform(volume): + """Return cached waveform for volume""" + if volume not in _cached_waveforms: + _cached_waveforms[volume] = _generate_notification_beep(volume) + return _cached_waveforms[volume] + + +def play_audio_with_pygame(audio_data, sample_rate=_sample_rate): + """Play audio with pygame backend""" + global _mixer_initialized try: import pygame - + with _mixer_lock: - if len(audio_data) == 0: - return False - - # Clean mixer initialization - quit any existing mixer first - if pygame.mixer.get_init() is not None: - pygame.mixer.quit() - time.sleep(0.2) # Longer pause to ensure clean shutdown - - # Initialize fresh mixer - pygame.mixer.pre_init( - frequency=sample_rate, - size=-16, - channels=2, - buffer=512 # Smaller buffer to reduce latency/doubling - ) - pygame.mixer.init() - - # Verify clean initialization + if not _mixer_initialized: + pygame.mixer.pre_init(frequency=sample_rate, size=-16, channels=2, buffer=512) + pygame.mixer.init() + _mixer_initialized = True + mixer_info = pygame.mixer.get_init() if mixer_info is None or mixer_info[2] != 2: return False - - # Prepare audio - ensure clean conversion + audio_int16 = (audio_data * 32767).astype(np.int16) if len(audio_int16.shape) > 1: audio_int16 = audio_int16.flatten() - - # Create clean stereo with identical channels + stereo_data = np.zeros((len(audio_int16), 2), dtype=np.int16) - stereo_data[:, 0] = audio_int16 # Left channel - stereo_data[:, 1] = audio_int16 # Right channel - - # Create sound and play once + stereo_data[:, 0] = audio_int16 + stereo_data[:, 1] = audio_int16 + sound = pygame.sndarray.make_sound(stereo_data) - - # Ensure only one playback - pygame.mixer.stop() # Stop any previous sounds + pygame.mixer.stop() sound.play() - - # Wait for completion + duration_ms = int(len(audio_data) / sample_rate * 1000) + 50 pygame.time.wait(duration_ms) - + return True - + except ImportError: return False except Exception as e: - print(f"Pygame clean error: {e}") + print(f"Pygame error: {e}") return False - -def play_audio_with_sounddevice(audio_data, sample_rate=44100): + +def play_audio_with_sounddevice(audio_data, sample_rate=_sample_rate): """Play audio using sounddevice backend""" try: import sounddevice as sd sd.play(audio_data, sample_rate) sd.wait() return True - except ImportError: return False except Exception as e: print(f"Sounddevice error: {e}") return False - -def play_audio_with_winsound(audio_data, sample_rate=44100): +def play_audio_with_winsound(audio_data, sample_rate=_sample_rate): """Play audio using winsound backend (Windows only)""" if sys.platform != "win32": return False - try: - import winsound - import wave - import tempfile - import uuid - + import winsound, wave, tempfile, uuid + temp_dir = tempfile.gettempdir() temp_filename = os.path.join(temp_dir, f"notification_{uuid.uuid4().hex}.wav") - + try: - with wave.open(temp_filename, 'w') as wav_file: + with wave.open(temp_filename, "w") as wav_file: wav_file.setnchannels(1) wav_file.setsampwidth(2) wav_file.setframerate(sample_rate) - audio_int16 = (audio_data * 32767).astype(np.int16) wav_file.writeframes(audio_int16.tobytes()) - + winsound.PlaySound(temp_filename, winsound.SND_FILENAME) - + finally: - # Clean up temp file - for _ in range(3): - try: - if os.path.exists(temp_filename): - os.unlink(temp_filename) - break - except: - time.sleep(0.1) - + try: + if os.path.exists(temp_filename): + os.unlink(temp_filename) + except: + pass + return True - except ImportError: return False except Exception as e: print(f"Winsound error: {e}") return False - def play_notification_sound(volume=50): """Play notification sound with specified volume""" if volume == 0: return - - audio_data = generate_notification_beep(volume=volume) - + + audio_data = _get_cached_waveform(volume) if len(audio_data) == 0: return - - # Try audio backends in order - audio_backends = [ - play_audio_with_pygame, - play_audio_with_sounddevice, - play_audio_with_winsound, - ] - + + audio_backends = [play_audio_with_pygame, play_audio_with_sounddevice, play_audio_with_winsound] for backend in audio_backends: try: if backend(audio_data): return - except Exception as e: + except Exception: continue - - # Fallback: terminal beep - print(f"All audio backends failed, using terminal beep") - print('\a') + print("All audio backends failed, using terminal beep") + print("\a") def play_notification_async(volume=50): """Play notification sound asynchronously (non-blocking)""" @@ -260,24 +224,12 @@ def play_notification_async(volume=50): play_notification_sound(volume) except Exception as e: print(f"Error playing notification sound: {e}") - - sound_thread = threading.Thread(target=play_sound, daemon=True) - sound_thread.start() + threading.Thread(target=play_sound, daemon=True).start() def notify_video_completion(video_path=None, volume=50): """Notify about completed video generation""" play_notification_async(volume) - -if __name__ == "__main__": - print("Testing notification sounds with different volumes...") - print("Auto-detecting available audio backends...") - - volumes = [25, 50, 75, 100] - for vol in volumes: - print(f"Testing volume {vol}%:") - play_notification_sound(vol) - time.sleep(2) - - print("Test completed!") \ No newline at end of file +for vol in (25, 50, 75, 100): + _get_cached_waveform(vol) \ No newline at end of file diff --git a/shared/utils/prompt_parser.py b/shared/utils/prompt_parser.py index faaa1ca..46edec4 100644 --- a/shared/utils/prompt_parser.py +++ b/shared/utils/prompt_parser.py @@ -1,6 +1,6 @@ import re -def process_template(input_text): +def process_template(input_text, keep_comments = False): """ Process a text template with macro instructions and variable substitution. Supports multiple values for variables to generate multiple output versions. @@ -28,9 +28,12 @@ def process_template(input_text): line_number += 1 # Skip empty lines or comments - if not line or line.startswith('#'): + if not line: continue - + + if line.startswith('#') and not keep_comments: + continue + # Handle macro instructions if line.startswith('!'): # Process any accumulated template lines before starting a new macro @@ -106,13 +109,14 @@ def process_template(input_text): # Handle template lines else: - # Check for unknown variables in template line - var_references = re.findall(r'\{([^}]+)\}', line) - for var_ref in var_references: - if var_ref not in current_variables: - error_message = f"Unknown variable '{{{var_ref}}}' in template\nLine: '{orig_line}'" - return "", error_message - + if not line.startswith('#'): + # Check for unknown variables in template line + var_references = re.findall(r'\{([^}]+)\}', line) + for var_ref in var_references: + if var_ref not in current_variables: + error_message = f"Unknown variable '{{{var_ref}}}' in template\nLine: '{orig_line}'" + return "", error_message + # Add to current template lines current_template_lines.append(line) diff --git a/wgp.py b/wgp.py index c306383..c7e8bc2 100644 --- a/wgp.py +++ b/wgp.py @@ -55,8 +55,8 @@ AUTOSAVE_FILENAME = "queue.zip" PROMPT_VARS_MAX = 10 target_mmgp_version = "3.5.10" -WanGP_version = "7.7777" -settings_version = 2.23 +WanGP_version = "8.0" +settings_version = 2.24 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -71,8 +71,8 @@ task_id = 0 vmc_event_handler = matanyone_app.get_vmc_event_handler() unique_id = 0 unique_id_lock = threading.Lock() -offloadobj = None -wan_model = None +gen_lock = threading.Lock() +offloadobj = enhancer_offloadobj = wan_model = None def get_unique_id(): global unique_id @@ -164,7 +164,6 @@ def process_prompt_and_add_tasks(state, model_choice): return state["validate_success"] = 0 - model_filename = state["model_filename"] model_type = state["model_type"] inputs = get_model_settings(state, model_type) @@ -299,13 +298,24 @@ def process_prompt_and_add_tasks(state, model_choice): switch_threshold = inputs["switch_threshold"] loras_multipliers = inputs["loras_multipliers"] activated_loras = inputs["activated_loras"] + guidance_phases= inputs["guidance_phases"] + model_switch_phase = inputs["model_switch_phase"] + switch_threshold = inputs["switch_threshold"] + switch_threshold2 = inputs["switch_threshold2"] + if len(loras_multipliers) > 0: - _, _, errors = parse_loras_multipliers(loras_multipliers, len(activated_loras), num_inference_steps, max_phases= 2 if get_model_family(model_type)=="wan" and model_type not in ["sky_df_1.3B", "sky_df_14B"] else 1) + _, _, errors = parse_loras_multipliers(loras_multipliers, len(activated_loras), num_inference_steps, nb_phases= guidance_phases) if len(errors) > 0: gr.Info(f"Error parsing Loras Multipliers: {errors}") return - + if guidance_phases == 3: + if switch_threshold < switch_threshold2: + gr.Info(f"Phase 1-2 Switch Noise Level ({switch_threshold}) should be Greater than Phase 2-3 Switch Noise Level ({switch_threshold2}). As a reminder, noise will gradually go down from 1000 to 0.") + return + else: + model_switch_phase = 1 + if not any_steps_skipping: skip_steps_cache_type = "" if not model_def.get("lock_inference_steps", False) and model_type in ["ltxv_13B"] and num_inference_steps < 20: gr.Info("The minimum number of steps should be 20") @@ -535,7 +545,8 @@ def process_prompt_and_add_tasks(state, model_choice): "image_prompt_type": image_prompt_type, "video_prompt_type": video_prompt_type, "audio_prompt_type": audio_prompt_type, - "skip_steps_cache_type": skip_steps_cache_type + "skip_steps_cache_type": skip_steps_cache_type, + "model_switch_phase": model_switch_phase, } if inputs["multi_prompts_gen_type"] == 0: @@ -2185,7 +2196,8 @@ if len(args.attention)> 0: else: raise Exception(f"Unknown attention mode '{args.attention}'") -profile = force_profile_no if force_profile_no >=0 else server_config["profile"] +default_profile = force_profile_no if force_profile_no >=0 else server_config["profile"] +loaded_profile = -1 compile = server_config.get("compile", "") boost = server_config.get("boost", 1) vae_config = server_config.get("vae_config", 0) @@ -2333,7 +2345,7 @@ def download_mmaudio(): } process_files_def(**enhancer_def) -def download_models(model_filename, model_type, module_type = None, submodel_no = 1): +def download_models(model_filename = None, model_type= None, module_type = None, submodel_no = 1): def computeList(filename): if filename == None: return [] @@ -2366,7 +2378,16 @@ def download_models(model_filename, model_type, module_type = None, submodel_no } process_files_def(**enhancer_def) + elif server_config.get("enhancer_enabled", 0) == 2: + enhancer_def = { + "repoId" : "DeepBeepMeep/LTX_Video", + "sourceFolderList" : [ "Florence2", "llama-joycaption-beta-one-hf-llava" ], + "fileList" : [ ["config.json", "configuration_florence2.py", "model.safetensors", "modeling_florence2.py", "preprocessor_config.json", "processing_florence2.py", "tokenizer.json", "tokenizer_config.json"],["config.json", "llama_joycaption_quanto_bf16_int8.safetensors", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"] ] + } + process_files_def(**enhancer_def) + download_mmaudio() + if model_filename is None: return def download_file(url,filename): if url.startswith("https://huggingface.co/") and "/resolve/main/" in url: @@ -2554,12 +2575,76 @@ def get_transformer_model(model, submodel_no = 1): else: raise Exception("no transformer found") +def init_pipe(pipe, kwargs, override_profile): + preload =int(args.preload) + if preload == 0: + preload = server_config.get("preload_in_VRAM", 0) -def load_models(model_type): - global transformer_type + kwargs["extraModelsToQuantize"]= None + profile = override_profile if override_profile != -1 else default_profile + if profile in (2, 4, 5): + budgets = { "transformer" : 100 if preload == 0 else preload, "text_encoder" : 100 if preload == 0 else preload, "*" : max(1000 if profile==5 else 3000 , preload) } + if "transformer2" in pipe: + budgets["transformer2"] = 100 if preload == 0 else preload + kwargs["budgets"] = budgets + elif profile == 3: + kwargs["budgets"] = { "*" : "70%" } + + if "transformer2" in pipe: + if profile in [3,4]: + kwargs["pinnedMemory"] = ["transformer", "transformer2"] + + return profile + +def reset_prompt_enhancer(): + global prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer, enhancer_offloadobj + prompt_enhancer_image_caption_model = None + prompt_enhancer_image_caption_processor = None + prompt_enhancer_llm_model = None + prompt_enhancer_llm_tokenizer = None + if enhancer_offloadobj is not None: + enhancer_offloadobj.release() + enhancer_offloadobj = None + +def setup_prompt_enhancer(pipe, kwargs): + global prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer + model_no = server_config.get("enhancer_enabled", 0) + if model_no != 0: + from transformers import ( AutoModelForCausalLM, AutoProcessor, AutoTokenizer, LlamaForCausalLM ) + prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained( "ckpts/Florence2", trust_remote_code=True) + prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained( "ckpts/Florence2", trust_remote_code=True) + pipe["prompt_enhancer_image_caption_model"] = prompt_enhancer_image_caption_model + prompt_enhancer_image_caption_model._model_dtype = torch.float + # def preprocess_sd(sd, map): + # new_sd ={} + # for k, v in sd.items(): + # k = "model." + k.replace(".model.", ".") + # if "lm_head.weight" in k: k = "lm_head.weight" + # new_sd[k] = v + # return new_sd, map + # prompt_enhancer_llm_model = offload.fast_load_transformers_model("c:/temp/joy/model-00001-of-00004.safetensors", modelClass= LlavaForConditionalGeneration, defaultConfigPath="ckpts/llama-joycaption-beta-one-hf-llava/config.json", preprocess_sd=preprocess_sd) + # offload.save_model(prompt_enhancer_llm_model, "joy_llava_quanto_int8.safetensors", do_quantize= True) + + if model_no == 1: + budget = 5000 + prompt_enhancer_llm_model = offload.fast_load_transformers_model("ckpts/Llama3_2/Llama3_2_quanto_bf16_int8.safetensors") + prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained("ckpts/Llama3_2") + else: + budget = 10000 + prompt_enhancer_llm_model = offload.fast_load_transformers_model("ckpts/llama-joycaption-beta-one-hf-llava/llama_joycaption_quanto_bf16_int8.safetensors") + prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained("ckpts/llama-joycaption-beta-one-hf-llava") + pipe["prompt_enhancer_llm_model"] = prompt_enhancer_llm_model + if not "budgets" in kwargs: kwargs["budgets"] = {} + kwargs["budgets"]["prompt_enhancer_llm_model"] = budget + else: + reset_prompt_enhancer() + + + +def load_models(model_type, override_profile = -1): + global transformer_type, loaded_profile base_model_type = get_base_model_type(model_type) model_def = get_model_def(model_type) - preload =int(args.preload) save_quantized = args.save_quantized and model_def != None model_filename = get_model_filename(model_type=model_type, quantization= "" if save_quantized else transformer_quantization, dtype_policy = transformer_dtype_policy) if "URLs2" in model_def: @@ -2583,8 +2668,6 @@ def load_models(model_type): transformer_dtype = torch.bfloat16 if "bf16" in model_filename or "BF16" in model_filename else transformer_dtype transformer_dtype = torch.float16 if "fp16" in model_filename or"FP16" in model_filename else transformer_dtype perc_reserved_mem_max = args.perc_reserved_mem_max - if preload == 0: - preload = server_config.get("preload_in_VRAM", 0) model_file_list = [model_filename] model_type_list = [model_type] module_type_list = [None] @@ -2614,44 +2697,18 @@ def load_models(model_type): model_file_list, model_type, base_model_type, model_def, quantizeTransformer = quantizeTransformer, text_encoder_quantization = text_encoder_quantization, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized) - kwargs = { "extraModelsToQuantize": None } + kwargs = {} + profile = init_pipe(pipe, kwargs, override_profile) + if server_config.get("enhancer_mode", 0) == 0: + setup_prompt_enhancer(pipe, kwargs) loras_transformer = ["transformer"] - if profile in (2, 4, 5): - budgets = { "transformer" : 100 if preload == 0 else preload, "text_encoder" : 100 if preload == 0 else preload, "*" : max(1000 if profile==5 else 3000 , preload) } - if "transformer2" in pipe: - budgets["transformer2"] = 100 if preload == 0 else preload - kwargs["budgets"] = budgets - elif profile == 3: - kwargs["budgets"] = { "*" : "70%" } - if "transformer2" in pipe: loras_transformer += ["transformer2"] - if profile in [3,4]: - kwargs["pinnedMemory"] = ["transformer", "transformer2"] - - global prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer - if server_config.get("enhancer_enabled", 0) == 1: - from transformers import ( AutoModelForCausalLM, AutoProcessor, AutoTokenizer, LlamaForCausalLM ) - prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained( "ckpts/Florence2", trust_remote_code=True) - prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained( "ckpts/Florence2", trust_remote_code=True) - prompt_enhancer_llm_model = offload.fast_load_transformers_model("ckpts/Llama3_2/Llama3_2_quanto_bf16_int8.safetensors") #, configKwargs= {"_attn_implementation" :"XXXsdpa"} - prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained("ckpts/Llama3_2") - pipe["prompt_enhancer_image_caption_model"] = prompt_enhancer_image_caption_model - pipe["prompt_enhancer_llm_model"] = prompt_enhancer_llm_model - prompt_enhancer_image_caption_model._model_dtype = torch.float - if "budgets" in kwargs: - kwargs["budgets"]["prompt_enhancer_llm_model"] = 5000 - else: - prompt_enhancer_image_caption_model = None - prompt_enhancer_image_caption_processor = None - prompt_enhancer_llm_model = None - prompt_enhancer_llm_tokenizer = None - - offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = False, loras = loras_transformer, coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = transformer_dtype, **kwargs) if len(args.gpu) > 0: torch.set_default_device(args.gpu) transformer_type = model_type + loaded_profile = profile return wan_model, offloadobj if not "P" in preload_model_policy: @@ -2722,11 +2779,12 @@ def apply_changes( state, preload_model_policy_choice = 1, UI_theme_choice = "default", enhancer_enabled_choice = 0, + enhancer_mode_choice = 0, mmaudio_enabled_choice = 0, fit_canvas_choice = 0, preload_in_VRAM_choice = 0, depth_anything_v2_variant_choice = "vitl", - notification_sound_enabled_choice = 1, + notification_sound_enabled_choice = 0, notification_sound_volume_choice = 50, max_frames_multiplier_choice = 1, display_stats_choice = 0, @@ -2736,9 +2794,9 @@ def apply_changes( state, last_resolution_choice = None, ): if args.lock_config: - return + return "
Config Locked
",*[gr.update()]*4 if gen_in_progress: - return "
Unable to change config when a generation is in progress
",*[gr.update()]*6 + return "
Unable to change config when a generation is in progress
",*[gr.update()]*4 global offloadobj, wan_model, server_config, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset, loras_presets server_config = { "attention_mode" : attention_choice, @@ -2760,6 +2818,7 @@ def apply_changes( state, "UI_theme" : UI_theme_choice, "fit_canvas": fit_canvas_choice, "enhancer_enabled" : enhancer_enabled_choice, + "enhancer_mode" : enhancer_mode_choice, "mmaudio_enabled" : mmaudio_enabled_choice, "preload_in_VRAM" : preload_in_VRAM_choice, "depth_anything_v2_variant": depth_anything_v2_variant_choice, @@ -2795,9 +2854,9 @@ def apply_changes( state, if v != v_old: changes.append(k) - global attention_mode, profile, compile, vae_config, boost, lora_dir, reload_needed, preload_model_policy, transformer_quantization, transformer_dtype_policy, transformer_types, text_encoder_quantization, save_path + global attention_mode, default_profile, compile, vae_config, boost, lora_dir, reload_needed, preload_model_policy, transformer_quantization, transformer_dtype_policy, transformer_types, text_encoder_quantization, save_path attention_mode = server_config["attention_mode"] - profile = server_config["profile"] + default_profile = server_config["profile"] compile = server_config["compile"] text_encoder_quantization = server_config["text_encoder_quantization"] vae_config = server_config["vae_config"] @@ -2811,6 +2870,8 @@ def apply_changes( state, transformer_types = server_config["transformer_types"] model_filename = get_model_filename(transformer_type, transformer_quantization, transformer_dtype_policy) state["model_filename"] = model_filename + if "enhancer_enabled" in changes or "enhancer_mode" in changes: + reset_prompt_enhancer() if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list", "fit_canvas", "depth_anything_v2_variant", "notification_sound_enabled", "notification_sound_volume", "mmaudio_enabled", "max_frames_multiplier", "display_stats", "video_output_codec", "image_output_codec", "audio_output_codec"] for change in changes ): @@ -2822,7 +2883,7 @@ def apply_changes( state, header = generate_header(state["model_type"], compile=compile, attention_mode= attention_mode) mmaudio_enabled = server_config["mmaudio_enabled"] > 0 - return "
The new configuration has been succesfully applied
", header, model_family, model_choice, gr.Row(visible= server_config["enhancer_enabled"] == 1), gr.Row(visible= mmaudio_enabled), gr.Column(visible= mmaudio_enabled) + return "
The new configuration has been succesfully applied
", header, model_family, model_choice, get_unique_id() def get_gen_info(state): cache = state.get("gen", None) @@ -2835,7 +2896,25 @@ def build_callback(state, pipe, send_cmd, status, num_inference_steps): gen = get_gen_info(state) gen["num_inference_steps"] = num_inference_steps start_time = time.time() - def callback(step_idx, latent, force_refresh, read_state = False, override_num_inference_steps = -1, pass_no = -1): + def callback(step_idx = -1, latent = None, force_refresh = True, read_state = False, override_num_inference_steps = -1, pass_no = -1, denoising_extra =""): + in_pause = False + with gen_lock: + process_status = gen.get("process_status", None) + pause_msg = None + if process_status.startswith("request:"): + gen["process_status"] = "process:" + process_status[len("request:"):] + offloadobj.unload_all() + pause_msg = gen.get("pause_msg", "Unknown Pause") + in_pause = True + + if in_pause: + send_cmd("progress", [0, pause_msg]) + while True: + time.sleep(1) + with gen_lock: + process_status = gen.get("process_status", None) + if process_status == "process:main": break + refresh_id = gen.get("refresh", -1) if force_refresh or step_idx >= 0: pass @@ -2872,7 +2951,9 @@ def build_callback(state, pipe, send_cmd, status, num_inference_steps): phase = "Denoising Third Pass" else: phase = f"Denoising {pass_no}th Pass" - + + if len(denoising_extra) > 0: phase += " | " + denoising_extra + gen["progress_phase"] = (phase, step_idx) status_msg = merge_status_context(status, phase) @@ -2913,6 +2994,7 @@ def refresh_gallery(state): #, msg # gen["last_msg"] = msg file_list = gen.get("file_list", None) choice = gen.get("selected",0) + header_text = gen.get("header_text", "") in_progress = "in_progress" in gen if in_progress: if gen.get("last_selected", True): @@ -2924,17 +3006,14 @@ def refresh_gallery(state): #, msg return gr.Gallery(selected_index=choice, value = file_list), gr.HTML("", visible= False), gr.Button(visible=True), gr.Button(visible=False), gr.Row(visible=False), gr.Row(visible=False), update_queue_data(queue), gr.Button(interactive= abort_interactive), gr.Button(visible= False) else: task = queue[0] - start_img_md = "" - end_img_md = "" prompt = task["prompt"] params = task["params"] model_type = params["model_type"] base_model_type = get_base_model_type(model_type) model_def = get_model_def(model_type) - is_image = model_def.get("image_outputs", False) onemorewindow_visible = test_any_sliding_window(base_model_type) and params.get("image_mode",0) == 0 and not params.get("mode","").startswith("edit_") enhanced = False - if prompt.startswith("!enhanced!\n"): + if prompt.startswith("!enhanced!\n"): enhanced = True prompt = prompt[len("!enhanced!\n"):] if "\n" in prompt : @@ -2947,6 +3026,9 @@ def refresh_gallery(state): #, msg prompt = "
".join(prompts) if enhanced: prompt = "Enhanced:
" + prompt + + if len(header_text) > 0: + prompt = "" + header_text + "

" + prompt list_uri = [] list_labels = [] start_img_uri = task.get('start_image_data_base64') @@ -3126,6 +3208,7 @@ def select_video(state, input_file_list, event_data: gr.EventData): video_model_type = configs.get("model_type", "t2v") model_family = get_model_family(video_model_type) model_def = get_model_def(video_model_type) + multiple_submodels = model_def.get("multiple_submodels", False) video_other_prompts = ", ".join(video_other_prompts) video_resolution = configs.get("resolution", "") + f" (real: {width}x{height})" video_length = configs.get("video_length", 0) @@ -3143,15 +3226,28 @@ def select_video(state, input_file_list, event_data: gr.EventData): video_length_summary += f"{frames_count/fps:.1f}s, {round(fps)} fps)" video_guidance_scale = configs.get("guidance_scale", None) video_guidance2_scale = configs.get("guidance2_scale", None) + video_guidance3_scale = configs.get("guidance3_scale", None) video_switch_threshold = configs.get("switch_threshold", 0) + video_switch_threshold2 = configs.get("switch_threshold2", 0) + video_model_switch_phase = configs.get("model_switch_phase", 1) + video_guidance_phases = configs.get("guidance_phases", 0) video_embedded_guidance_scale = configs.get("embedded_guidance_scale", None) + video_guidance_label = "Guidance" if model_def.get("embedded_guidance", False): video_guidance_scale = video_embedded_guidance_scale video_guidance_label = "Embedded Guidance Scale" - else: - if video_switch_threshold > 0: - video_guidance_scale = f"{video_guidance_scale} (High Noise), {video_guidance2_scale} (Low Noise) with Switch at Noise Level {video_switch_threshold}" - video_guidance_label = "Guidance" + elif video_guidance_phases > 0: + if video_guidance_phases == 1: + video_guidance_scale = f"{video_guidance_scale}" + elif video_guidance_phases == 2: + if multiple_submodels: + video_guidance_scale = f"{video_guidance_scale} (High Noise), {video_guidance2_scale} (Low Noise) with Switch at Noise Level {video_switch_threshold}" + else: + video_guidance_scale = f"{video_guidance_scale}, {video_guidance2_scale} with Guidance Switch at Noise Level {video_switch_threshold}" + else: + video_guidance_scale = f"{video_guidance_scale}, {video_guidance2_scale} & {video_guidance3_scale} with Switch at Noise Levels {video_switch_threshold} & {video_switch_threshold2}" + if multiple_submodels: + video_guidance_scale += f" + Model Switch at {video_switch_threshold if video_model_switch_phase ==1 else video_switch_threshold2}" video_flow_shift = configs.get("flow_shift", None) video_video_guide_outpainting = configs.get("video_guide_outpainting", "") video_outpainting = "" @@ -3882,6 +3978,150 @@ class DynamicClass: """Alias for assign() - more dict-like""" return self.assign(**dict) + +def process_prompt_enhancer(prompt_enhancer, original_prompts, image_start, original_image_refs, is_image, seed ): + + text_encoder_max_tokens = 256 + from models.ltx_video.utils.prompt_enhance_utils import generate_cinematic_prompt + prompt_images = [] + if "I" in prompt_enhancer: + if image_start != None: + prompt_images.append(image_start) + if original_image_refs != None: + prompt_images += original_image_refs[:1] + if len(original_prompts) == 0 and not "T" in prompt_enhancer: + return None + else: + from shared.utils.utils import seed_everything + seed = seed_everything(seed) + # for i, original_prompt in enumerate(original_prompts): + prompts = generate_cinematic_prompt( + prompt_enhancer_image_caption_model, + prompt_enhancer_image_caption_processor, + prompt_enhancer_llm_model, + prompt_enhancer_llm_tokenizer, + original_prompts if "T" in prompt_enhancer else ["an image"], + prompt_images if len(prompt_images) > 0 else None, + video_prompt = not is_image, + max_new_tokens=text_encoder_max_tokens, + ) + return prompts + +def enhance_prompt(state, prompt, prompt_enhancer, multi_images_gen_type, override_profile, progress=gr.Progress()): + global enhancer_offloadobj + prefix = "#!PROMPT!:" + model_type = state["model_type"] + inputs = get_model_settings(state, model_type) + original_prompts = inputs["prompt"] + + original_prompts, errors = prompt_parser.process_template(original_prompts, keep_comments= True) + if len(errors) > 0: + gr.Info("Error processing prompt template: " + errors) + return gr.update(), gr.update() + original_prompts = original_prompts.replace("\r", "").split("\n") + + prompts_to_process = [] + skip_next_non_comment = False + for prompt in original_prompts: + if prompt.startswith(prefix): + new_prompt = prompt[len(prefix):].strip() + prompts_to_process.append(new_prompt) + skip_next_non_comment = True + else: + if not prompt.startswith("#") and not skip_next_non_comment and len(prompt) > 0: + prompts_to_process.append(prompt) + skip_next_non_comment = False + + original_prompts = prompts_to_process + num_prompts = len(original_prompts) + image_start = inputs["image_start"] + if image_start is None or not "I" in prompt_enhancer: + image_start = [None] * num_prompts + else: + image_start = [img[0] for img in image_start] + if len(image_start) == 1: + image_start = image_start * num_prompts + else: + if multi_images_gen_type !=1: + gr.Info("On Demand Prompt Enhancer with multiple Start Images requires that option 'Match images and text prompts' is set") + return gr.update(), gr.update() + + if len(image_start) != num_prompts: + gr.Info("On Demand Prompt Enhancer supports only mutiple Start Images if their number matches the number of Text Prompts") + return gr.update(), gr.update() + + if enhancer_offloadobj is None: + status = "Please Wait While Loading Prompt Enhancer" + progress(0, status) + kwargs = {} + pipe = {} + download_models() + + gen = get_gen_info(state) + while True: + with gen_lock: + process_status = gen.get("process_status", None) + if process_status is None: + original_process_status = process_status + gen["process_status"] = "process:prompt_enhancer" + break + elif process_status == "process:main": + original_process_status = process_status + gen["process_status"] = "request:prompt_enhancer" + gen["pause_msg"] = "Generation Suspended while using Prompt Enhancer" + break + elif process_status == "process:prompt_enhancer": + break + time.sleep(1) + + if original_process_status is not None: + while True: + with gen_lock: + process_status = gen.get("process_status", None) + if process_status == "process:prompt_enhancer": break + if process_status is None: + # handle case when main process has finished at some point in between the last check and now + gen["process_status"] = "process:prompt_enhancer" + break + time.sleep(0.1) + + if enhancer_offloadobj is None: + profile = init_pipe(pipe, kwargs, override_profile) + setup_prompt_enhancer(pipe, kwargs) + enhancer_offloadobj = offload.profile(pipe, profile_no= profile, **kwargs) + + original_image_refs = inputs["image_refs"] + is_image = inputs["image_mode"] == 1 + seed = inputs["seed"] + seed = set_seed(seed) + enhanced_prompts = [] + for i, (one_prompt, one_image) in enumerate(zip(original_prompts, image_start)): + status = f'Please Wait While Enhancing Prompt' if num_prompts==1 else f'Please Wait While Enhancing Prompt #{i+1}' + progress((i , num_prompts), desc=status, total= num_prompts) + + try: + enhanced_prompt = process_prompt_enhancer(prompt_enhancer, [one_prompt], [one_image], original_image_refs, is_image, seed ) + except Exception as e: + enhancer_offloadobj.unload_all() + with gen_lock: + gen["process_status"] = original_process_status + raise gr.Error(e) + if enhanced_prompt is not None: + enhanced_prompt = enhanced_prompt[0].replace("\n", "").replace("\r", "") + enhanced_prompts.append(prefix + " " + one_prompt) + enhanced_prompts.append(enhanced_prompt) + + enhancer_offloadobj.unload_all() + with gen_lock: + gen["process_status"] = original_process_status + + prompt = '\n'.join(enhanced_prompts) + if num_prompts > 1: + gr.Info(f'{num_prompts} Prompts have been Enhanced') + else: + gr.Info(f'Prompt "{original_prompts[0][:100]}" has been enhanced') + return prompt, prompt + def generate_video( task, send_cmd, @@ -3896,7 +4136,11 @@ def generate_video( num_inference_steps, guidance_scale, guidance2_scale, + guidance3_scale, switch_threshold, + switch_threshold2, + guidance_phases, + model_switch_phase, audio_guidance_scale, flow_shift, sample_solver, @@ -3959,6 +4203,7 @@ def generate_video( cfg_zero_step, prompt_enhancer, min_frames_if_references, + override_profile, state, model_type, model_filename, @@ -3994,7 +4239,10 @@ def generate_video( model_def = get_model_def(model_type) is_image = image_mode == 1 if is_image: - video_length = min_frames_if_references if "I" in video_prompt_type or "V" in video_prompt_type else 1 + if min_frames_if_references >= 1000: + video_length = min_frames_if_references - 1000 + else: + video_length = min_frames_if_references if "I" in video_prompt_type or "V" in video_prompt_type else 1 else: batch_size = 1 temp_filenames_list = [] @@ -4010,7 +4258,7 @@ def generate_video( image_mask = None base_model_type = get_base_model_type(model_type) - + model_family = get_model_family(base_model_type) fit_canvas = server_config.get("fit_canvas", 0) model_handler = get_model_handler(base_model_type) block_size = model_handler.get_vae_block_size(base_model_type) if hasattr(model_handler, "get_vae_block_size") else 16 @@ -4019,14 +4267,14 @@ def generate_video( while wan_model == None: time.sleep(1) - if model_type != transformer_type or reload_needed: + if model_type != transformer_type or reload_needed or override_profile>0 and override_profile != loaded_profile or override_profile<0 and default_profile != loaded_profile: wan_model = None if offloadobj is not None: offloadobj.release() offloadobj = None gc.collect() send_cmd("status", f"Loading model {get_model_name(model_type)}...") - wan_model, offloadobj = load_models(model_type) + wan_model, offloadobj = load_models(model_type, override_profile) send_cmd("status", "Model loaded") reload_needed= False overridden_attention = get_overridden_attention(model_type) @@ -4060,28 +4308,28 @@ def generate_video( prompts = prompt.split("\n") prompts = [part for part in prompts if len(prompt)>0] parsed_keep_frames_video_source= max_source_video_frames if len(keep_frames_video_source) ==0 else int(keep_frames_video_source) - transformer_loras_filenames, transformer_loras_multipliers = get_transformer_loras(model_type) + if guidance_phases < 1: guidance_phases = 1 if transformer_loras_filenames != None: - loras_list_mult_choices_nums, loras_slists, errors = parse_loras_multipliers(transformer_loras_multipliers, len(transformer_loras_filenames), num_inference_steps) + loras_list_mult_choices_nums, loras_slists, errors = parse_loras_multipliers(transformer_loras_multipliers, len(transformer_loras_filenames), num_inference_steps, nb_phases = guidance_phases ) if len(errors) > 0: raise Exception(f"Error parsing Transformer Loras: {errors}") loras_selected = transformer_loras_filenames if hasattr(wan_model, "get_loras_transformer"): extra_loras_transformers, extra_loras_multipliers = wan_model.get_loras_transformer(get_model_recursive_prop, **locals()) - loras_list_mult_choices_nums, loras_slists, errors = parse_loras_multipliers(extra_loras_multipliers, len(extra_loras_transformers), num_inference_steps, merge_slist= loras_slists ) + loras_list_mult_choices_nums, loras_slists, errors = parse_loras_multipliers(extra_loras_multipliers, len(extra_loras_transformers), num_inference_steps, nb_phases = guidance_phases, merge_slist= loras_slists ) if len(errors) > 0: raise Exception(f"Error parsing Extra Transformer Loras: {errors}") loras_selected += extra_loras_transformers loras = state["loras"] if len(loras) > 0: - loras_list_mult_choices_nums, loras_slists, errors = parse_loras_multipliers(loras_multipliers, len(activated_loras), num_inference_steps, merge_slist= loras_slists ) + loras_list_mult_choices_nums, loras_slists, errors = parse_loras_multipliers(loras_multipliers, len(activated_loras), num_inference_steps, nb_phases = guidance_phases, merge_slist= loras_slists ) if len(errors) > 0: raise Exception(f"Error parsing Loras: {errors}") lora_dir = get_lora_dir(model_type) loras_selected += [ os.path.join(lora_dir, lora) for lora in activated_loras] if len(loras_selected) > 0: - pinnedLora = profile !=5 # and transformer_loras_filenames == None False # # # + pinnedLora = loaded_profile !=5 # and transformer_loras_filenames == None False # # # split_linear_modules_map = getattr(trans,"split_linear_modules_map", None) offload.load_loras_into_model(trans , loras_selected, loras_list_mult_choices_nums, activate_all_loras=True, preprocess_sd=get_loras_preprocessor(trans, base_model_type), pinnedLora=pinnedLora, split_linear_modules_map = split_linear_modules_map) errors = trans._loras_errors @@ -4115,7 +4363,8 @@ def generate_video( hunyuan_avatar = "hunyuan_video_avatar" in model_filename fantasy = base_model_type in ["fantasy"] multitalk = model_def.get("multitalk_class", False) - flux = base_model_type in ["flux"] + flux = model_family in ["flux"] + qwen = model_family in ["qwen"] if "B" in audio_prompt_type or "X" in audio_prompt_type: from models.wan.multitalk.multitalk import parse_speakers_locations @@ -4174,7 +4423,7 @@ def generate_video( send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")]) os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg") from shared.utils.utils import resize_and_remove_background - image_refs[nb_frames_positions:] = resize_and_remove_background(image_refs[nb_frames_positions:] , width, height, remove_background_images_ref > 0, any_background_ref, fit_into_canvas= not (vace or hunyuan_avatar or flux) ) # no fit for vace ref images as it is done later + image_refs[nb_frames_positions:] = resize_and_remove_background(image_refs[nb_frames_positions:] , width, height, remove_background_images_ref > 0, any_background_ref, fit_into_canvas= not (vace or hunyuan_avatar or flux or qwen) ) # no fit for vace ref images as it is done later update_task_thumbnails(task, locals()) send_cmd("output") joint_pass = boost ==1 #and profile != 1 and profile != 3 @@ -4276,7 +4525,8 @@ def generate_video( extra_generation += gen.get("extra_orders",0) gen["extra_orders"] = 0 total_generation = repeat_generation + extra_generation - gen["total_generation"] = total_generation + gen["total_generation"] = total_generation + gen["header_text"] = "" if repeat_no >= total_generation: break repeat_no +=1 gen["repeat_no"] = repeat_no @@ -4301,36 +4551,14 @@ def generate_video( num_frames_generated = 0 # num of new frames created (lower than the number of frames really processed due to overlaps and discards) requested_frames_to_generate = default_requested_frames_to_generate # num of num frames to create (if any source window this num includes also the overlapped source window frames) start_time = time.time() - if prompt_enhancer_image_caption_model != None and prompt_enhancer !=None and len(prompt_enhancer)>0: - text_encoder_max_tokens = 256 + if prompt_enhancer_image_caption_model != None and prompt_enhancer !=None and len(prompt_enhancer)>0 and server_config.get("enhancer_mode", 0) == 0: send_cmd("progress", [0, get_latest_status(state, "Enhancing Prompt")]) - from models.ltx_video.utils.prompt_enhance_utils import generate_cinematic_prompt - prompt_images = [] - if "I" in prompt_enhancer: - if image_start != None: - prompt_images.append(image_start) - if original_image_refs != None: - prompt_images += original_image_refs[:1] - if len(original_prompts) == 0 and not "T" in prompt_enhancer: - pass - else: - from shared.utils.utils import seed_everything - seed_everything(seed) - # for i, original_prompt in enumerate(original_prompts): - prompts = generate_cinematic_prompt( - prompt_enhancer_image_caption_model, - prompt_enhancer_image_caption_processor, - prompt_enhancer_llm_model, - prompt_enhancer_llm_tokenizer, - original_prompts if "T" in prompt_enhancer else ["an image"], - prompt_images if len(prompt_images) > 0 else None, - video_prompt = not is_image, - max_new_tokens=text_encoder_max_tokens, - ) - print(f"Enhanced prompts: {prompts}" ) - task["prompt"] = "\n".join(["!enhanced!"] + prompts) + enhanced_prompts = process_prompt_enhancer(prompt_enhancer, original_prompts, image_start, original_image_refs, is_image, seed ) + if enhanced_prompts is not None: + print(f"Enhanced prompts: {enhanced_prompts}" ) + task["prompt"] = "\n".join(["!enhanced!"] + enhanced_prompts) send_cmd("output") - prompt = prompts[0] + prompt = enhanced_prompts[0] abort = gen.get("abort", False) while not abort: @@ -4355,7 +4583,6 @@ def generate_video( window_no += 1 gen["window_no"] = window_no return_latent_slice = None - if reuse_frames > 0: return_latent_slice = slice(-(reuse_frames - 1 + discard_last_frames ) // latent_size - 1, None if discard_last_frames == 0 else -(discard_last_frames // latent_size) ) refresh_preview = {"image_guide" : None, "image_mask" : None} @@ -4529,7 +4756,10 @@ def generate_video( }) # samples = torch.empty( (1,2)) #for testing # if False: - + def set_header_text(txt): + gen["header_text"] = txt + send_cmd("output") + try: samples = wan_model.generate( input_prompt = prompt, @@ -4551,7 +4781,11 @@ def generate_video( sampling_steps=num_inference_steps, guide_scale=guidance_scale, guide2_scale = guidance2_scale, + guide3_scale = guidance3_scale, switch_threshold = switch_threshold, + switch2_threshold = switch_threshold2, + guide_phases= guidance_phases, + model_switch_phase = model_switch_phase, embedded_guidance_scale=embedded_guidance_scale, n_prompt=negative_prompt, seed=seed, @@ -4593,6 +4827,7 @@ def generate_video( video_prompt_type= video_prompt_type, window_no = window_no, offloadobj = offloadobj, + set_header_text= set_header_text, ) except Exception as e: if len(control_audio_tracks) > 0 or len(source_audio_tracks) > 0: @@ -4811,7 +5046,7 @@ def generate_video( # Play notification sound for single video try: - if server_config.get("notification_sound_enabled", 1): + if server_config.get("notification_sound_enabled", 0): volume = server_config.get("notification_sound_volume", 50) notification_sound.notify_video_completion( video_path=video_path, @@ -4912,6 +5147,22 @@ def process_tasks(state): gen["file_list"] = file_list gen["file_settings_list"] = file_settings_list + while True: + with gen_lock: + process_status = gen.get("process_status", None) + if process_status is None: + gen["process_status"] = "process:main" + break + time.sleep(1) + + def release_gen(): + with gen_lock: + process_status = gen.get("process_status", None) + if process_status.startswith("request:"): + gen["process_status"] = "process:" + process_status[len("request:"):] + else: + gen["process_status"] = None + start_time = time.time() global gen_in_progress @@ -4919,6 +5170,8 @@ def process_tasks(state): gen["in_progress"] = True gen["preview"] = None gen["status"] = "Generating Video" + gen["header_text"] = "" + yield time.time(), time.time() prompt_no = 0 while len(queue) > 0: @@ -4954,7 +5207,7 @@ def process_tasks(state): gen["prompts_max"] = 0 gen["prompt"] = "" gen["status_display"] = False - + release_gen() raise gr.Error(data, print_exception= False, duration = 0) elif cmd == "status": gen["status"] = data @@ -4970,6 +5223,7 @@ def process_tasks(state): gen["preview"] = preview yield time.time() , gr.Text() else: + release_gen() raise Exception(f"unknown command {cmd}") abort = gen.get("abort", False) @@ -5001,6 +5255,7 @@ def process_tasks(state): print(f"Error playing notification sound: {e}") gen["status"] = status gen["status_display"] = False + release_gen() @@ -5498,16 +5753,17 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None inputs["settings_version"] = settings_version model_def = get_model_def(model_type) base_model_type = get_base_model_type(model_type) + model_family = get_model_family(base_model_type) if model_type != base_model_type: inputs["base_model_type"] = base_model_type diffusion_forcing = base_model_type in ["sky_df_1.3B", "sky_df_14B"] vace = test_vace_module(base_model_type) + t2v= base_model_type in ["t2v"] ltxv = base_model_type in ["ltxv_13B"] recammaster = base_model_type in ["recam_1.3B"] phantom = base_model_type in ["phantom_1.3B", "phantom_14B"] - flux = base_model_type in ["flux"] + flux = model_family in ["flux"] hunyuan_video_custom = base_model_type in ["hunyuan_custom", "hunyuan_custom_audio", "hunyuan_custom_edit"] - model_family = get_model_family(base_model_type) if target == "settings": return inputs @@ -5528,7 +5784,7 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None if not base_model_type in ["t2v"]: pop += ["denoising_strength"] - if not server_config.get("enhancer_enabled", 0) == 1: + if not (server_config.get("enhancer_enabled", 0) > 0 and server_config.get("enhancer_mode", 0) == 0): pop += ["prompt_enhancer"] if not recammaster and not diffusion_forcing and not flux: @@ -5539,7 +5795,10 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None if base_model_type in ["t2v"]: unsaved_params = unsaved_params[1:] pop += unsaved_params if not vace: - pop += ["frames_positions", "video_guide_outpainting", "control_net_weight", "control_net_weight2", "min_frames_if_references"] + pop += ["frames_positions", "video_guide_outpainting", "control_net_weight", "control_net_weight2"] + + if not (vace or t2v): + pop += ["min_frames_if_references"] if not (diffusion_forcing or ltxv or vace): pop += ["keep_frames_video_source"] @@ -5550,20 +5809,24 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None if not (base_model_type in ["fantasy"] or model_def.get("multitalk_class", False)): pop += ["audio_guidance_scale", "speakers_locations"] - if not model_def.get("embedded_guidance", False) or model_def.get("no_guidance", False): + if not model_def.get("embedded_guidance", False): pop += ["embedded_guidance_scale"] if not (model_def.get("tea_cache", False) or model_def.get("mag_cache", False)) : pop += ["skip_steps_cache_type", "skip_steps_multiplier", "skip_steps_start_step_perc"] - if model_def.get("no_guidance", False) : - pop += ["guidance_scale", "guidance2_scale", "switch_threshold", "audio_guidance_scale"] + guidance_max_phases = model_def.get("guidance_max_phases", 0) + guidance_phases = inputs.get("guidance_phases", 1) + if guidance_max_phases < 1: + pop += ["guidance_scale", "guidance_phases"] - - if not model_def.get("guidance_max_phases",1) >1: + if guidance_max_phases < 2 or guidance_phases < 2: pop += ["guidance2_scale", "switch_threshold"] - if model_def.get("image_outputs", False) or ltxv: + if guidance_max_phases < 3 or guidance_phases < 3: + pop += ["guidance3_scale", "switch_threshold2", "model_switch_phase"] + + if ltxv: pop += ["flow_shift"] if model_def.get("no_negative_prompt", False) : @@ -5920,7 +6183,11 @@ def save_inputs( num_inference_steps, guidance_scale, guidance2_scale, - switch_threshold, + guidance3_scale, + switch_threshold, + switch_threshold2, + guidance_phases, + model_switch_phase, audio_guidance_scale, flow_shift, sample_solver, @@ -5982,7 +6249,8 @@ def save_inputs( cfg_star_switch, cfg_zero_step, prompt_enhancer, - min_frames_if_references, + min_frames_if_references, + override_profile, mode, state, ): @@ -6415,7 +6683,30 @@ def record_last_resolution(state, resolution): def get_max_frames(nb): return (nb - 1) * server_config.get("max_frames_multiplier",1) + 1 -def generate_video_tab(update_form = False, state_dict = None, ui_defaults = None, model_family = None, model_choice = None, header = None, main = None): + +def change_guidance_phases(state, guidance_phases): + model_type = state["model_type"] + model_def = get_model_def(model_type) + multiple_submodels = model_def.get("multiple_submodels", False) + label ="Phase 1-2" if guidance_phases ==3 else ( "Model / Guidance Switch Threshold" if multiple_submodels else "Guidance Switch Threshold" ) + return gr.update(visible= guidance_phases >=3 and multiple_submodels) , gr.update(visible= guidance_phases >=2), gr.update(visible= guidance_phases >=2, label = label), gr.update(visible= guidance_phases >=3), gr.update(visible= guidance_phases >=2), gr.update(visible= guidance_phases >=3) + + +memory_profile_choices= [ ("HighRAM_HighVRAM, profile 1: at least 48 GB of RAM and 24 GB of VRAM, the fastest for short videos with a RTX 3090 / RTX 4090", 1), + ("HighRAM_LowVRAM, profile 2 (Recommended): at least 48 GB of RAM and 12 GB of VRAM, the most versatile profile with high RAM, better suited for RTX 3070/3080/4070/4080 or for RTX 3090 / RTX 4090 with large pictures batches or long videos", 2), + ("LowRAM_HighVRAM, profile 3: at least 32 GB of RAM and 24 GB of VRAM, adapted for RTX 3090 / RTX 4090 with limited RAM for good speed short video",3), + ("LowRAM_LowVRAM, profile 4 (Default): at least 32 GB of RAM and 12 GB of VRAM, if you have little VRAM or want to generate longer videos",4), + ("VerylowRAM_LowVRAM, profile 5: (Fail safe): at least 16 GB of RAM and 10 GB of VRAM, if you don't have much it won't be fast but maybe it will work",5)] + +def detect_auto_save_form(state, evt:gr.SelectData): + last_tab_id = state.get("last_tab_id", 0) + state["last_tab_id"] = new_tab_id = evt.index + if new_tab_id > 0 and last_tab_id == 0: + return get_unique_id() + else: + return gr.update() + +def generate_video_tab(update_form = False, state_dict = None, ui_defaults = None, model_family = None, model_choice = None, header = None, main = None, main_tabs= None): global inputs_names #, advanced if update_form: @@ -6522,6 +6813,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non t2v = base_model_type in ["t2v"] t2v_1_3B = base_model_type in ["t2v_1.3B"] flf2v = base_model_type == "flf2v_720p" + base_model_family = get_model_family(base_model_type) diffusion_forcing = "diffusion_forcing" in model_filename ltxv = "ltxv" in model_filename lock_inference_steps = model_def.get("lock_inference_steps", False) @@ -6540,7 +6832,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non hunyuan_video_custom_audio = base_model_type in ["hunyuan_custom_audio"] hunyuan_video_custom_edit = base_model_type in ["hunyuan_custom_edit"] hunyuan_video_avatar = "hunyuan_video_avatar" in model_filename - flux = base_model_type in ["flux"] + flux = base_model_family in ["flux"] + qwen = base_model_family in ["qwen"] image_outputs = model_def.get("image_outputs", False) sliding_window_enabled = test_any_sliding_window(model_type) multi_prompts_gen_type_value = ui_defaults.get("multi_prompts_gen_type_value",0) @@ -6673,7 +6966,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non model_mode = gr.Dropdown(value=None, visible=False) keep_frames_video_source = gr.Text(visible=False) - with gr.Column(visible= vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or hunyuan_video_custom_edit or t2v or ltxv or flux and model_reference_image) as video_prompt_column: + with gr.Column(visible= vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or hunyuan_video_custom_edit or t2v or ltxv or flux and model_reference_image or qwen and model_reference_image) as video_prompt_column: video_prompt_type_value= ui_defaults.get("video_prompt_type","") video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False) any_control_video = True @@ -6799,7 +7092,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ) - elif flux and model_reference_image: + elif (flux or qwen) and model_reference_image: video_prompt_type_image_refs = gr.Dropdown( choices=[ ("None", ""), @@ -6841,7 +7134,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non video_mask = gr.Video(label= "Video Mask Area (for Inpainting, white = Control Area, black = Unchanged)", visible= (not image_outputs) and "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value , value= ui_defaults.get("video_mask", None)) mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value ) - any_reference_image = vace or phantom or hunyuan_video_custom or hunyuan_video_avatar + any_reference_image = vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or (flux or qwen) and model_reference_image image_refs = gr.Gallery(preview= True, label ="Start Image" if hunyuan_video_avatar else "Reference Images", type ="pil", show_label= True, columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in video_prompt_type_value, @@ -6852,7 +7145,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non remove_background_images_ref = gr.Dropdown( choices=[ ("Keep Backgrounds behind all Reference Images", 0), - ("Remove Backgrounds only behind People / Objects except main Subject" if flux else "Remove Backgrounds only behind People / Objects" , 1), + ("Remove Backgrounds only behind People / Objects except main Subject / Landscape" if (flux or qwen) else "Remove Backgrounds only behind People / Objects" , 1), ], value=ui_defaults.get("remove_background_images_ref",1), label="Automatic Removal of Background of People or Objects (Only)", scale = 3, visible= "I" in video_prompt_type_value and not hunyuan_video_avatar @@ -6911,17 +7204,21 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non wizard_prompt = gr.Textbox(visible = not advanced_prompt, label=wizard_prompt_label, value=default_wizard_prompt, lines=3) wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False) wizard_variables_var = gr.Text(wizard_variables, visible = False) - with gr.Row(visible= server_config.get("enhancer_enabled", 0) == 1 ) as prompt_enhancer_row: + with gr.Row(visible= server_config.get("enhancer_enabled", 0) > 0 ) as prompt_enhancer_row: + on_demand_prompt_enhancer = server_config.get("enhancer_mode", 0) == 1 + prompt_enhancer_value = ui_defaults.get("prompt_enhancer", "") + if len(prompt_enhancer_value) == 0 and on_demand_prompt_enhancer: prompt_enhancer_value = "T" + prompt_enhancer_btn = gr.Button( value ="Enhance Prompt", visible= on_demand_prompt_enhancer, size="lg", elem_classes="btn_centered") prompt_enhancer = gr.Dropdown( - choices=[ - ("Disabled", ""), - ("Based on Text Prompts", "T"), - ("Based on Image Prompts (such as Start Image and Reference Images)", "I"), - ("Based on both Text Prompts and Image Prompts", "TI"), + choices= + ([] if on_demand_prompt_enhancer else [("Disabled", "")]) + + [("Based on Text Prompt Content", "T"), + ("Based on Images Prompts Content (such as Start Image and Reference Images)", "I"), + ("Based on both Text Prompt and Images Prompts Content", "TI"), ], - value=ui_defaults.get("prompt_enhancer", ""), - label="Enhance Prompt using a LLM", scale = 3, - visible= True + value=prompt_enhancer_value, + label="Enhance Prompt using a LLM", scale = 5, + visible= True, show_label= not on_demand_prompt_enhancer, ) with gr.Row(): if server_config.get("fit_canvas", 0) == 1: @@ -6963,31 +7260,59 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non show_advanced = gr.Checkbox(label="Advanced Mode", value=advanced_ui) with gr.Tabs(visible=advanced_ui) as advanced_row: - # with gr.Row(visible=advanced_ui) as advanced_row: - no_guidance = model_def.get("no_guidance", False) + guidance_max_phases = model_def.get("guidance_max_phases", 0) no_negative_prompt = model_def.get("no_negative_prompt", False) + any_audio_guidance = fantasy or multitalk with gr.Tab("General"): with gr.Column(): - seed = gr.Slider(-1, 999999999, value=ui_defaults.get("seed",-1), step=1, label="Seed (-1 for random)") - any_embedded_guidance = model_def.get("embedded_guidance", False) - with gr.Row(visible = not ltxv and not (no_guidance and image_outputs)) as guidance_row: - guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance_scale",5), step=0.5, label="Guidance (CFG)", visible=not (hunyuan_t2v or hunyuan_i2v or any_embedded_guidance) and not no_guidance) - audio_guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("audio_guidance_scale", 5 if fantasy else 4), step=0.5, label="Audio Guidance", visible=(fantasy or multitalk) and not no_guidance) - embedded_guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("embedded_guidance", 2.5 if flux else 6.0), step=0.5, label="Embedded Guidance Scale", visible=(hunyuan_t2v or hunyuan_i2v or any_embedded_guidance) and not no_guidance) - flow_shift = gr.Slider(1.0, 25.0, value=ui_defaults.get("flow_shift",3), step=0.1, label="Shift Scale", visible = not image_outputs) - with gr.Row(visible = model_def.get("guidance_max_phases",1) >1 and not (no_guidance and image_outputs)) as guidance_row2: - guidance2_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance2_scale",5), step=0.5, label="Guidance2 (CFG)", visible=not (hunyuan_t2v or hunyuan_i2v or any_embedded_guidance) and not no_guidance) - switch_threshold = gr.Slider(0, 1000, value=ui_defaults.get("switch_threshold",0), step=1, label="Guidance / Model Switch Threshold", visible=not (hunyuan_t2v or hunyuan_i2v or any_embedded_guidance) and not no_guidance) + with gr.Row(): + seed = gr.Slider(-1, 999999999, value=ui_defaults.get("seed",-1), step=1, label="Seed (-1 for random)", scale=2) + guidance_phases_value = ui_defaults.get("guidance_phases", 1) + guidance_phases = gr.Dropdown( + choices=[ + ("One Phase", 1), + ("Two Phases", 2), + ("Three Phases", 3)], + value= guidance_phases_value, + label="Guidance Phases", + visible= guidance_max_phases >=2, + interactive = not model_def.get("lock_guidance_phases", False) + ) + with gr.Row(visible = guidance_phases_value >=2 ) as guidance_phases_row: + multiple_submodels = model_def.get("multiple_submodels", False) + model_switch_phase = gr.Dropdown( + choices=[ + ("Phase 1-2 transition", 1), + ("Phase 2-3 transition", 2)], + value=ui_defaults.get("model_switch_phase", 1), + label="Model Switch", + visible= model_def.get("multiple_submodels", False) and guidance_phases_value >= 3 and multiple_submodels + ) + label ="Phase 1-2" if guidance_phases_value ==3 else ( "Model / Guidance Switch Threshold" if multiple_submodels else "Guidance Switch Threshold" ) + switch_threshold = gr.Slider(0, 1000, value=ui_defaults.get("switch_threshold",0), step=1, label = label, visible= guidance_max_phases >= 2 and guidance_phases_value >= 2) + switch_threshold2 = gr.Slider(0, 1000, value=ui_defaults.get("switch_threshold2",0), step=1, label="Phase 2-3", visible= guidance_max_phases >= 3 and guidance_phases_value >= 3) + with gr.Row(visible = guidance_max_phases >=1 ) as guidance_row: + guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance_scale",5), step=0.5, label="Guidance (CFG)", visible=guidance_max_phases >=1 ) + guidance2_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance2_scale",5), step=0.5, label="Guidance2 (CFG)", visible= guidance_max_phases >=2 and guidance_phases_value >= 2) + guidance3_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance3_scale",5), step=0.5, label="Guidance3 (CFG)", visible= guidance_max_phases >=3 and guidance_phases_value >= 3) + + sample_solver_choices = model_def.get("sample_solvers", None) - with gr.Row(visible = sample_solver_choices is not None ) as sample_solver_row: + with gr.Row(visible = sample_solver_choices is not None or not image_outputs) as sample_solver_row: if sample_solver_choices is None: sample_solver = gr.Dropdown( value="", choices=[ ("", ""), ], visible= False, label= "Sampler Solver / Scheduler" ) else: sample_solver = gr.Dropdown( value=ui_defaults.get("sample_solver", sample_solver_choices[0][1]), choices= sample_solver_choices, visible= True, label= "Sampler Solver / Scheduler" ) + flow_shift = gr.Slider(1.0, 25.0, value=ui_defaults.get("flow_shift",3), step=0.1, label="Shift Scale", visible = not image_outputs) + any_embedded_guidance = model_def.get("embedded_guidance", False) + with gr.Row(visible =any_embedded_guidance or any_audio_guidance) as embedded_guidance_row: + audio_guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("audio_guidance_scale", 4), step=0.5, label="Audio Guidance", visible= any_audio_guidance ) + embedded_guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("embedded_guidance", 6.0), step=0.5, label="Embedded Guidance Scale", visible=any_embedded_guidance ) + with gr.Row(visible = vace) as control_net_weights_row: control_net_weight = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight",1), step=0.1, label="Control Net Weight #1", visible=vace) control_net_weight2 = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight2",1), step=0.1, label="Control Net Weight #2", visible=vace) @@ -7177,20 +7502,24 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non with gr.Row(): cfg_zero_step = gr.Slider(-1, 39, value=ui_defaults.get("cfg_zero_step",-1), step=1, label="CFG Zero below this Layer (Extra Process)", visible = any_cfg_zero) - with gr.Column(visible = vace and image_outputs) as min_frames_if_references_col: - gr.Markdown("If using Reference Images, generating a single Frame alone may not be sufficient to preserve Identity") + with gr.Column(visible = (vace or t2v) and image_outputs) as min_frames_if_references_col: + gr.Markdown("Generating a single Frame alone may not be sufficient to preserve Reference Image Identity / Control Image Information or simply to get a good Image Quality. A workaround is to generate a short Video and keep the First Frame.") min_frames_if_references = gr.Dropdown( choices=[ ("Disabled, generate only one Frame", 1), - ("Generate a 5 Frames long Video but keep only the First Frame (x1.5 slower)",5), - ("Generate a 9 Frames long Video but keep only the First Frame (x2.0 slower)",9), - ("Generate a 13 Frames long Video but keep only the First Frame (x2.5 slower)",13), - ("Generate a 17 Frames long Video but keep only the First Frame (x3.0 slower)",17), + ("Generate a 5 Frames long Video only if any Reference Image / Control Image (x1.5 slower)",5), + ("Generate a 9 Frames long Video only if any Reference Image / Control Image (x2.0 slower)",9), + ("Generate a 13 Frames long Video only if any Reference Image / Control Image (x2.5 slower)",13), + ("Generate a 17 Frames long Video only if any Reference Image / Control Image (x3.0 slower)",17), + ("Generate always a 5 Frames long Video (x1.5 slower)",1005), + ("Generate always a 9 Frames long Video (x2.0 slower)",1009), + ("Generate always a 13 Frames long Video (x2.5 slower)",1013), + ("Generate always a 17 Frames long Video (x3.0 slower)",1017), ], - value=ui_defaults.get("min_frames_if_references",5), + value=ui_defaults.get("min_frames_if_references",5 if vace else 1), visible=True, scale = 1, - label="Generate more frames to preserve Reference Image Identity or Control Image Information" + label="Generate more frames to preserve Reference Image Identity / Control Image Information or improve" ) with gr.Tab("Sliding Window", visible= sliding_window_enabled and not image_outputs) as sliding_window_tab: @@ -7244,7 +7573,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non label="Text Prompts separated by a Carriage Return" ) - with gr.Tab("Misc.", visible = not image_outputs) as misc_tab: + with gr.Tab("Misc.", visible = True) as misc_tab: with gr.Column(visible = not (recammaster or ltxv or diffusion_forcing)) as RIFLEx_setting_col: gr.Markdown("With Riflex you can generate videos longer than 5s which is the default duration of videos used to train the model") RIFLEx_setting = gr.Dropdown( @@ -7271,6 +7600,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non if any_video_source or recammaster: force_fps_choices += [("Source Video fps", "source")] force_fps_choices += [ + ("15", "15"), ("16", "16"), ("23", "23"), ("24", "24"), @@ -7285,6 +7615,12 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ) + gr.Markdown("You can set a more agressive Memory Profile if you generate only Short Videos or Images") + override_profile = gr.Dropdown( + choices=[("Default Memory Profile", -1)] + memory_profile_choices, + value=ui_defaults.get("override_profile", -1), + label=f"Override Memory Profile" + ) with gr.Row(): save_settings_btn = gr.Button("Set Settings as Default", visible = not args.lock_config) @@ -7305,6 +7641,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non output_trigger = gr.Text(interactive= False, visible=False) refresh_form_trigger = gr.Text(interactive= False, visible=False) fill_wizard_prompt_trigger = gr.Text(interactive= False, visible=False) + saveform_trigger = gr.Text(interactive= False, visible=False) with gr.Accordion("Video Info and Late Post Processing & Audio Remuxing", open=False) as video_info_accordion: with gr.Tabs() as video_info_tabs: @@ -7407,8 +7744,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non video_guide_outpainting_checkbox, video_guide_outpainting_row, show_advanced, video_info_to_control_video_btn, video_info_to_video_source_btn, sample_solver_row, video_buttons_row, image_buttons_row, video_postprocessing_tab, audio_remuxing_tab, PP_MMAudio_row, PP_custom_audio_row, video_info_to_start_image_btn, video_info_to_end_image_btn, video_info_to_reference_image_btn, video_info_to_image_guide_btn, video_info_to_image_mask_btn, - NAG_col, speakers_locations_row, guidance_row, guidance_row2, resolution_group, cfg_free_guidance_col, control_net_weights_row, image_mode_tabs, - min_frames_if_references_col, video_prompt_type_alignment] # presets_column, + NAG_col, speakers_locations_row, embedded_guidance_row, guidance_phases_row, guidance_row, resolution_group, cfg_free_guidance_col, control_net_weights_row, image_mode_tabs, + min_frames_if_references_col, video_prompt_type_alignment, prompt_enhancer_btn] # presets_column, if update_form: locals_dict = locals() gen_inputs = [state_dict if k=="state" else locals_dict[k] for k in inputs_names] + [state_dict] + extra_inputs @@ -7421,7 +7758,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non resolution_group.input(fn=change_resolution_group, inputs=[state, resolution_group], outputs=[resolution]) resolution.change(fn=record_last_resolution, inputs=[state, resolution]) - + + guidance_phases.change(fn=change_guidance_phases, inputs= [state, guidance_phases], outputs =[model_switch_phase, guidance_phases_row, switch_threshold, switch_threshold2, guidance2_scale, guidance3_scale ]) audio_prompt_type_remux.change(fn=refresh_audio_prompt_type_remux, inputs=[state, audio_prompt_type, audio_prompt_type_remux], outputs=[audio_prompt_type]) audio_prompt_type_sources.change(fn=refresh_audio_prompt_type_sources, inputs=[state, audio_prompt_type, audio_prompt_type_sources], outputs=[audio_prompt_type, audio_guide, audio_guide2, speakers_locations_row]) image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start, image_end, video_source, keep_frames_video_source] ) @@ -7494,6 +7832,26 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non outputs= None ).then( fn=use_video_settings, inputs =[state, output, last_choice] , outputs= [model_family, model_choice, refresh_form_trigger]) + + prompt_enhancer_btn.click(fn=validate_wizard_prompt, + inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , + outputs= [prompt] + ).then(fn=save_inputs, + inputs =[target_state] + gen_inputs, + outputs= None + ).then( fn=enhance_prompt, inputs =[state, prompt, prompt_enhancer, multi_images_gen_type, override_profile ] , outputs= [prompt, wizard_prompt]) + + saveform_trigger.change(fn=validate_wizard_prompt, + inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , + outputs= [prompt] + ).then(fn=save_inputs, + inputs =[target_state] + gen_inputs, + outputs= None + ) + + main_tabs.select(fn=detect_auto_save_form, inputs= [state], outputs= saveform_trigger, trigger_mode="multiple") + + video_info_add_videos_btn.click(fn=add_videos_to_gallery, inputs =[state, output, last_choice, files_to_load], outputs = [output, files_to_load, video_info_tabs] ) gr.on(triggers=[video_info_eject_video_btn.click, video_info_eject_video2_btn.click, video_info_eject_video3_btn.click, video_info_eject_image_btn.click], fn=eject_video_from_gallery, inputs =[state, output, last_choice], outputs = [output, video_info, video_buttons_row] ) video_info_to_control_video_btn.click(fn=video_to_control_video, inputs =[state, output, last_choice], outputs = [video_guide] ) @@ -7738,7 +8096,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ) return ( state, loras_choices, lset_name, resolution, - video_guide, image_guide, video_mask, image_mask, image_refs, prompt_enhancer_row, audio_tab, PP_MMAudio_col + video_guide, image_guide, video_mask, image_mask, image_refs, refresh_form_trigger ) @@ -7756,7 +8114,7 @@ def generate_download_tab(lset_name,loras_choices, state): download_loras_btn.click(fn=download_loras, inputs=[], outputs=[download_status_row, download_status]).then(fn=refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices]) -def generate_configuration_tab(state, blocks, header, model_family, model_choice, resolution, prompt_enhancer_row, mmaudio_tab, PP_MMAudio_col): +def generate_configuration_tab(state, blocks, header, model_family, model_choice, resolution, refresh_form_trigger): gr.Markdown("Please click Apply Changes at the bottom so that the changes are effective. Some choices below may be locked if the app has been launched by specifying a config preset.") with gr.Column(): with gr.Tabs(): @@ -7956,14 +8314,8 @@ def generate_configuration_tab(state, blocks, header, model_family, model_choice ) profile_choice = gr.Dropdown( - choices=[ - ("HighRAM_HighVRAM, profile 1: at least 48 GB of RAM and 24 GB of VRAM, the fastest for short videos a RTX 3090 / RTX 4090", 1), - ("HighRAM_LowVRAM, profile 2 (Recommended): at least 48 GB of RAM and 12 GB of VRAM, the most versatile profile with high RAM, better suited for RTX 3070/3080/4070/4080 or for RTX 3090 / RTX 4090 with large pictures batches or long videos", 2), - ("LowRAM_HighVRAM, profile 3: at least 32 GB of RAM and 24 GB of VRAM, adapted for RTX 3090 / RTX 4090 with limited RAM for good speed short video",3), - ("LowRAM_LowVRAM, profile 4 (Default): at least 32 GB of RAM and 12 GB of VRAM, if you have little VRAM or want to generate longer videos",4), - ("VerylowRAM_LowVRAM, profile 5: (Fail safe): at least 16 GB of RAM and 10 GB of VRAM, if you don't have much it won't be fast but maybe it will work",5) - ], - value= profile, + choices = memory_profile_choices, + value= default_profile, label="Profile (for power users only, not needed to change it)" ) preload_in_VRAM_choice = gr.Slider(0, 40000, value=server_config.get("preload_in_VRAM", 0), step=100, label="Number of MB of Models that are Preloaded in VRAM (0 will use Profile default)") @@ -7971,10 +8323,20 @@ def generate_configuration_tab(state, blocks, header, model_family, model_choice enhancer_enabled_choice = gr.Dropdown( choices=[ ("Off", 0), - ("On", 1), + ("Florence 2 1.6B + LLava 3.2 3.5B", 1), + ("Florence 2 1.6B + LLama Joy Caption (uncensored but needs more VRAM) 9,3B", 2), ], value=server_config.get("enhancer_enabled", 0), - label="Prompt Enhancer (if enabled, 8 GB of extra models will be downloaded)" + label="Prompt Enhancer (if enabled, from 8 GB to 14 GB of extra models will be downloaded)" + ) + + enhancer_mode_choice = gr.Dropdown( + choices=[ + ("Automatically triggered when Generating a Video", 0), + ("On Demand Only", 1), + ], + value=server_config.get("enhancer_mode", 0), + label="Prompt Enhancer Usage" ) mmaudio_enabled_choice = gr.Dropdown( @@ -8041,7 +8403,7 @@ def generate_configuration_tab(state, blocks, header, model_family, model_choice ("On", 1), ("Off", 0), ], - value=server_config.get("notification_sound_enabled", 1), + value=server_config.get("notification_sound_enabled", 0), label="Notification Sound Enabled" ) @@ -8079,6 +8441,7 @@ def generate_configuration_tab(state, blocks, header, model_family, model_choice preload_model_policy_choice, UI_theme_choice, enhancer_enabled_choice, + enhancer_mode_choice, mmaudio_enabled_choice, fit_canvas_choice, preload_in_VRAM_choice, @@ -8092,19 +8455,20 @@ def generate_configuration_tab(state, blocks, header, model_family, model_choice audio_output_codec_choice, resolution, ], - outputs= [msg , header, model_family, model_choice, prompt_enhancer_row, mmaudio_tab, PP_MMAudio_col] + outputs= [msg , header, model_family, model_choice, refresh_form_trigger] ) def generate_about_tab(): - gr.Markdown("

WanGP - Wan 2.1 model for the GPU Poor by DeepBeepMeep (GitHub)

") - gr.Markdown("Original Wan 2.1 Model by Alibaba (GitHub)") + gr.Markdown("

WanGP - AI Generative Models for the GPU Poor by DeepBeepMeep (GitHub)

") gr.Markdown("Many thanks to:") - gr.Markdown("- Alibaba Wan team for the best open source video generator") - gr.Markdown("- Alibaba Vace, Multitalk and Fun Teams for their incredible control net models") - gr.Markdown("- Tencent for the impressive Hunyuan Video models") - gr.Markdown("- Blackforest Labs for the innovative Flux image generators") - gr.Markdown("- Lightricks for their super fast LTX Video models") - gr.Markdown("
Huge acknowlegments to these great open source projects used in WanGP:") + gr.Markdown("- Alibaba Wan Team for the best open source video generators (https://github.com/Wan-Video/Wan2.1)") + gr.Markdown("- Alibaba Vace, Multitalk and Fun Teams for their incredible control net models (https://github.com/ali-vilab/VACE), (https://github.com/MeiGen-AI/MultiTalk) and (https://huggingface.co/alibaba-pai/Wan2.2-Fun-A14B-InP) ") + gr.Markdown("- Tencent for the impressive Hunyuan Video models (https://github.com/Tencent-Hunyuan/HunyuanVideo)") + gr.Markdown("- Blackforest Labs for the innovative Flux image generators (https://github.com/black-forest-labs/flux)") + gr.Markdown("- Alibaba Qwen Team for their state of the art Qwen Image generators (https://github.com/QwenLM/Qwen-Image)") + gr.Markdown("- Lightricks for their super fast LTX Video models (https://github.com/Lightricks/LTX-Video)") + gr.Markdown("- Hugging Face for the providing hosting for the models and developing must have open source libraries such as Tranformers, Diffusers, Accelerate and Gradio (https://huggingface.co/)") + gr.Markdown("
Huge acknowledgments to these great open source projects used in WanGP:") gr.Markdown("- Rife: temporal upsampler (https://github.com/hzwer/ECCV2022-RIFE)") gr.Markdown("- DwPose: Open Pose extractor (https://github.com/IDEA-Research/DWPose)") gr.Markdown("- DepthAnything & Midas: Depth extractors (https://github.com/DepthAnything/Depth-Anything-V2) and (https://github.com/isl-org/MiDaS") @@ -8117,7 +8481,7 @@ def generate_about_tab(): gr.Markdown("- AmericanPresidentJimmyCarter : added original support for Skip Layer Guidance") gr.Markdown("- Remade_AI : for their awesome Loras collection") gr.Markdown("- Reevoy24 : for his repackaging / completing the documentation") - gr.Markdown("- Redtash1 : for designing the protype of the RAM /VRAM stats viewer") + gr.Markdown("- Redtash1 : for designing the protype of the RAM / VRAM stats viewer") def generate_info_tab(): @@ -8131,6 +8495,9 @@ def generate_info_tab(): with open("docs/LORAS.md", "r", encoding="utf-8") as reader: loras = reader.read() + with open("docs/FINETUNES.md", "r", encoding="utf-8") as reader: + finetunes = reader.read() + with gr.Tabs() : with gr.Tab("Models", id="models"): gr.Markdown(models) @@ -8138,6 +8505,8 @@ def generate_info_tab(): gr.Markdown(loras) with gr.Tab("Vace", id="vace"): gr.Markdown(vace) + with gr.Tab("Finetunes", id="finetunes"): + gr.Markdown(finetunes) def compact_name(family_name, model_name): if model_name.startswith(family_name): @@ -8628,6 +8997,7 @@ def create_ui(): opacity: 1; transition: visibility 0s linear 1s, opacity 0.3s linear 1s; /* 1s delay before showing */ } + .btn_centered {margin-top:10px; text-wrap-mode: nowrap;} """ UI_theme = server_config.get("UI_theme", "default") UI_theme = args.theme if len(args.theme) > 0 else UI_theme @@ -8683,8 +9053,8 @@ def create_ui(): with gr.Row(): ( state, loras_choices, lset_name, resolution, - video_guide, image_guide, video_mask, image_mask, image_refs, prompt_enhancer_row, mmaudio_tab, PP_MMAudio_col - ) = generate_video_tab(model_family=model_family, model_choice=model_choice, header=header, main = main) + video_guide, image_guide, video_mask, image_mask, image_refs, refresh_form_trigger + ) = generate_video_tab(model_family=model_family, model_choice=model_choice, header=header, main = main, main_tabs =main_tabs) with gr.Tab("Guides", id="info") as info_tab: generate_info_tab() with gr.Tab("Video Mask Creator", id="video_mask_creator") as video_mask_creator: @@ -8693,7 +9063,7 @@ def create_ui(): with gr.Tab("Downloads", id="downloads") as downloads_tab: generate_download_tab(lset_name, loras_choices, state) with gr.Tab("Configuration", id="configuration") as configuration_tab: - generate_configuration_tab(state, main, header, model_family, model_choice, resolution, prompt_enhancer_row, mmaudio_tab, PP_MMAudio_col) + generate_configuration_tab(state, main, header, model_family, model_choice, resolution, refresh_form_trigger) with gr.Tab("About"): generate_about_tab() if stats_app is not None: