diff --git a/README.md b/README.md index d3da22e..85c68ac 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,15 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep ## 🔥 Latest Updates +### July 10 2025: WanGP v6.7, is NAG a game changer ? you tell me +Maybe you knew that already but most *Loras accelerators* we use today (Causvid, FusioniX) don't use *Guidance* at all (that it is *CFG* is set to 1). This helps to get much faster generations but the downside is that *Negative Prompts* are completely ignored (including the default ones set by the models). **NAG** (https://github.com/ChenDarYen/Normalized-Attention-Guidance) aims to solve that by injecting the *Negative Prompt* during the *attention* processing phase. + +So WanGP 6.7 gives you NAG, but not any NAG, a *Low VRAM* implementation, the default one ends being VRAM greedy. You will find NAG in the *General* advanced tab for most Wan models. + +Use NAG especially when Guidance is set to 1. To turn it on set the **NAG scale** to something around 10. There are other NAG parameters **NAG tau** and **NAG alpha** which I recommend to change only if you don't get good results by just playing with the NAG scale. Don't hesitate to share on this discord server the best combinations for these 3 parameters. + +The authors of NAG claim that NAG can also be used when using a Guidance (CFG > 1) and to improve the prompt adherence. + ### July 8 2025: WanGP v6.6, WanGP offers you **Vace Multitalk Dual Voices Fusionix Infinite** : **Vace** our beloved super Control Net has been combined with **Multitalk** the new king in town that can animate up to two people speaking (**Dual Voices**). It is accelerated by the **Fusionix** model and thanks to *Sliding Windows* support and *Adaptive Projected Guidance* (much slower but should reduce the reddish effect with long videos) your two people will be able to talk for very a long time (which is an **Infinite** amount of time in the field of video generation). diff --git a/hyvideo/modules/models.py b/hyvideo/modules/models.py index a29e4bd..de50efc 100644 --- a/hyvideo/modules/models.py +++ b/hyvideo/modules/models.py @@ -393,7 +393,7 @@ class MMSingleStreamBlock(nn.Module): ##### More spagheti VRAM optimizations done by DeepBeepMeep ! # I am sure you are a nice person and as you copy this code, you will give me proper credits: - # Please link to https://github.com/deepbeepmeep/HunyuanVideoGP and @deepbeepmeep on twitter + # Please link to https://github.com/deepbeepmeep/Wan2GP and @deepbeepmeep on twitter if condition_type == "token_replace": mod, tr_mod = self.modulation(vec, diff --git a/wan/any2video.py b/wan/any2video.py index e9a7026..a2901d4 100644 --- a/wan/any2video.py +++ b/wan/any2video.py @@ -392,6 +392,9 @@ class WanAny2V: keep_frames_parsed = [], model_type = None, loras_slists = None, + NAG_scale = 0, + NAG_tau = 3.5, + NAG_alpha = 0.5, offloadobj = None, apg_switch = False, **bbargs @@ -443,11 +446,21 @@ class WanAny2V: context_null = self.text_encoder([n_prompt], self.device)[0] context = context.to(self.dtype) context_null = context_null.to(self.dtype) + text_len = self.model.text_len + context = torch.cat([context, context.new_zeros(text_len -context.size(0), context.size(1)) ]).unsqueeze(0) + context_null = torch.cat([context_null, context_null.new_zeros(text_len -context_null.size(0), context_null.size(1)) ]).unsqueeze(0) + # NAG_prompt = "static, low resolution, blurry" + # context_NAG = self.text_encoder([NAG_prompt], self.device)[0] + # context_NAG = context_NAG.to(self.dtype) + # context_NAG = torch.cat([context_NAG, context_NAG.new_zeros(text_len -context_NAG.size(0), context_NAG.size(1)) ]).unsqueeze(0) + # from mmgp import offload # offloadobj.unload_all() - if self._interrupt: - return None + offload.shared_state.update({"_nag_scale" : NAG_scale, "_nag_tau" : NAG_tau, "_nag_alpha": NAG_alpha }) + if NAG_scale > 1: context = torch.cat([context, context_null], dim=0) + # if NAG_scale > 1: context = torch.cat([context, context_NAG], dim=0) + if self._interrupt: return None vace = model_type in ["vace_1.3B","vace_14B", "vace_multitalk_14B"] phantom = model_type in ["phantom_1.3B", "phantom_14B"] diff --git a/wan/diffusion_forcing.py b/wan/diffusion_forcing.py index c9b735c..ee168ec 100644 --- a/wan/diffusion_forcing.py +++ b/wan/diffusion_forcing.py @@ -45,7 +45,7 @@ class DTT2V: self.dtype = dtype self.num_train_timesteps = config.num_train_timesteps self.param_dtype = config.param_dtype - + self.text_len = config.text_len self.text_encoder = T5EncoderModel( text_len=config.text_len, dtype=config.t5_dtype, @@ -250,11 +250,15 @@ class DTT2V: if self._interrupt: return None + text_len = self.text_len prompt_embeds = self.text_encoder([input_prompt], self.device)[0] prompt_embeds = prompt_embeds.to(self.dtype).to(self.device) + prompt_embeds = torch.cat([prompt_embeds, prompt_embeds.new_zeros(text_len -prompt_embeds.size(0), prompt_embeds.size(1)) ]).unsqueeze(0) + if self.do_classifier_free_guidance: negative_prompt_embeds = self.text_encoder([n_prompt], self.device)[0] negative_prompt_embeds = negative_prompt_embeds.to(self.dtype).to(self.device) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, negative_prompt_embeds.new_zeros(text_len -negative_prompt_embeds.size(0), negative_prompt_embeds.size(1)) ]).unsqueeze(0) if self._interrupt: return None diff --git a/wan/modules/model.py b/wan/modules/model.py index ab352f6..b7a2670 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -1,4 +1,7 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +##### Enjoy this spagheti VRAM optimizations done by DeepBeepMeep ! +# I am sure you are a nice person and as you copy this code, you will give me officially proper credits: +# Please link to https://github.com/deepbeepmeep/Wan2GP and @deepbeepmeep on twitter import math from einops import rearrange import torch @@ -176,6 +179,70 @@ class WanSelfAttention(nn.Module): self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + def text_cross_attention(self, xlist, context, return_q = False): + x = xlist[0] + xlist.clear() + b, n, d = x.size(0), self.num_heads, self.head_dim + nag_scale = offload.shared_state.get("_nag_scale",0) + # compute query, key, value + q = self.q(x) + del x + self.norm_q(q) + q= q.view(b, -1, n, d) + k = self.k(context) + self.norm_k(k) + k = k.view(context.shape[0], -1, n, d) + v = self.v(context).view(context.shape[0], -1, n, d) + + if nag_scale <= 1 or len(k)==1: + qvl_list=[q, k, v] + if not return_q: del q + del k, v + x = pay_attention(qvl_list, cross_attn= True) + x = x.flatten(2, 3) + else: + nag_tau = offload.shared_state["_nag_tau"] + nag_alpha = offload.shared_state["_nag_alpha"] + qvl_list=[q, k[:1], v[:1]] + x_pos = pay_attention(qvl_list, cross_attn= True) + qvl_list=[q, k[1:], v[1:]] + if not return_q: del q + del k, v + x_neg = pay_attention(qvl_list, cross_attn= True) + + x_pos = x_pos.flatten(2, 3) + x_neg = x_neg.flatten(2, 3) + # Behold DeepBeepMeep as the NAG Butcher !: reduce highly VRAM consumption while at the same time turn the source in gibberish + x_neg.mul_(1-nag_scale) + x_neg.add_(x_pos, alpha= nag_scale) + x_guidance = x_neg + del x_neg + norm_positive = torch.norm(x_pos, p=1, dim=-1, keepdim=True) + norm_guidance = torch.norm(x_guidance, p=1, dim=-1, keepdim=True) + scale = norm_guidance / norm_positive + scale = torch.nan_to_num(scale, 10) + factor = 1 / (norm_guidance + 1e-7) * norm_positive * nag_tau + x_guidance = torch.where(scale > nag_tau, x_guidance * factor, x_guidance ) + del norm_positive, norm_guidance + x_pos.mul_(1 - nag_alpha) + x_guidance.mul_(nag_alpha) + x_guidance.add_(x_pos) + x = x_guidance + + # x_guidance = x_pos * nag_scale - x_neg * (nag_scale - 1) + # norm_positive = torch.norm(x_pos, p=1, dim=-1, keepdim=True).expand(*x_pos.shape) + # norm_guidance = torch.norm(x_guidance, p=1, dim=-1, keepdim=True).expand(*x_guidance.shape) + + # scale = norm_guidance / norm_positive + # scale = torch.nan_to_num(scale, 10) + # x_guidance[scale > nag_tau] = x_guidance[scale > nag_tau] / (norm_guidance[scale > nag_tau] + 1e-7) * norm_positive[scale > nag_tau] * nag_tau + + # x = x_guidance * nag_alpha + x_pos * (1 - nag_alpha) + if return_q: + return x, q + else: + return x, None + def forward(self, xlist, grid_sizes, freqs, block_mask = None, ref_target_masks = None, ref_images_count = 0): r""" Args: @@ -246,28 +313,7 @@ class WanT2VCrossAttention(WanSelfAttention): x(Tensor): Shape [B, L1, C] context(Tensor): Shape [B, L2, C] """ - x = xlist[0] - xlist.clear() - b, n, d = x.size(0), self.num_heads, self.head_dim - - # compute query, key, value - q = self.q(x) - del x - self.norm_q(q) - q= q.view(b, -1, n, d) - k = self.k(context) - self.norm_k(k) - k = k.view(b, -1, n, d) - v = self.v(context).view(b, -1, n, d) - - # compute attention - v = v.contiguous().clone() - qvl_list=[q, k, v] - del q, k, v - x = pay_attention(qvl_list, cross_attn= True) - - # output - x = x.flatten(2) + x, _ = self.text_cross_attention( xlist, context) x = self.o(x) return x @@ -295,30 +341,14 @@ class WanI2VCrossAttention(WanSelfAttention): context(Tensor): Shape [B, L2, C] """ - ##### Enjoy this spagheti VRAM optimizations done by DeepBeepMeep ! - # I am sure you are a nice person and as you copy this code, you will give me officially proper credits: - # Please link to https://github.com/deepbeepmeep/Wan2GP and @deepbeepmeep on twitter - - x = xlist[0] - xlist.clear() context_img = context[:, :257] context = context[:, 257:] - b, n, d = x.size(0), self.num_heads, self.head_dim - - # compute query, key, value - q = self.q(x) - del x - self.norm_q(q) - q= q.view(b, -1, n, d) - k = self.k(context) - self.norm_k(k) - k = k.view(b, -1, n, d) - v = self.v(context).view(b, -1, n, d) - - qkv_list = [q, k, v] - del k,v - x = pay_attention(qkv_list) + + x, q = self.text_cross_attention( xlist, context, return_q = True) + if len(q) != len(context_img): + context_img = context_img[:len(q)] + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim if audio_scale != None: audio_x = self.processor(q, audio_proj, grid_sizes[0], audio_context_lens) @@ -329,12 +359,9 @@ class WanI2VCrossAttention(WanSelfAttention): qkv_list = [q, k_img, v_img] del q, k_img, v_img img_x = pay_attention(qkv_list) - # compute attention - + img_x = img_x.flatten(2) # output - x = x.flatten(2) - img_x = img_x.flatten(2) x += img_x del img_x if audio_scale != None: @@ -1187,11 +1214,18 @@ class WanModel(ModelMixin, ConfigMixin): e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)) # context - context = [self.text_embedding( torch.cat( [u, u.new_zeros(self.text_len - u.size(0), u.size(1))] ).unsqueeze(0) ) for u in context ] + context = [self.text_embedding( u ) for u in context ] if clip_fea is not None: context_clip = self.img_emb(clip_fea) # bs x 257 x dim - context = [ torch.cat( [context_clip, u ], dim=1 ) for u in context ] + context_list = [] + for one_context in context: + if len(one_context) != len(context_clip): + context_list.append( torch.cat( [context_clip.repeat(len(one_context), 1, 1), one_context ], dim=1 )) + else: + context_list.append( torch.cat( [context_clip, one_context ], dim=1 )) + else: + context_list = context if multitalk_audio != None: multitalk_audio_list = [] @@ -1208,7 +1242,6 @@ class WanModel(ModelMixin, ConfigMixin): else: multitalk_masks_list = [None] * len(x_list) - context_list = context if audio_scale != None: audio_scale_list = audio_scale else: diff --git a/wgp.py b/wgp.py index fb03263..ee8b669 100644 --- a/wgp.py +++ b/wgp.py @@ -50,8 +50,8 @@ AUTOSAVE_FILENAME = "queue.zip" PROMPT_VARS_MAX = 10 target_mmgp_version = "3.5.1" -WanGP_version = "6.6" -settings_version = 2.21 +WanGP_version = "6.7" +settings_version = 2.22 max_source_video_frames = 1000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -1675,7 +1675,7 @@ model_types = [ "t2v_1.3B", "t2v", "i2v", "i2v_720p", "vace_1.3B", "phantom_1.3 "hunyuan", "hunyuan_i2v", "hunyuan_custom", "hunyuan_custom_audio", "hunyuan_custom_edit", "hunyuan_avatar"] model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B", - "i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B","recam_1.3B": "recammaster_1.3B", + "i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B", "vace_14B": "Vace_14B", "recam_1.3B": "recammaster_1.3B", "sky_df_1.3B" : "sky_reels2_diffusion_forcing_1.3B", "sky_df_14B" : "sky_reels2_diffusion_forcing_14B", "sky_df_720p_14B" : "sky_reels2_diffusion_forcing_720p_14B", "phantom_1.3B" : "phantom_1.3B", "phantom_14B" : "phantom_14B", "ltxv_13B" : "ltxv_0.9.7_13B_dev", "ltxv_13B_distilled" : "ltxv_0.9.7_13B_distilled", @@ -1736,6 +1736,8 @@ def get_model_family(model_type): return "hunyuan" elif "ltxv" in model_type: return "ltxv" + elif "flux" in model_type: + return "flux" else: return "wan" @@ -1866,6 +1868,22 @@ def get_model_name(model_type, description_container = [""]): def get_model_record(model_name): return f"WanGP v{WanGP_version} by DeepBeepMeep - " + model_name +def get_finetune_URLs(model_type, stack= []): + finetune_def = finetunes.get(model_type, None) + if finetune_def != None: + URLs = finetune_def["URLs"] + if isinstance(URLs, str): + if len(stack) > 10: raise Exception(f"Circular Reference in Model URLs dependencies: {stack}") + return get_finetune_URLs(URLs, stack = stack + [URLs] ) + else: + return URLs + else: + if model_type in model_types: + return model_type + else: + raise Exception(f"Unknown model type '{model_type}'") + + def get_model_filename(model_type, quantization ="int8", dtype_policy = "", is_module = False, stack=[]): if is_module: choices = modules_files.get(model_type, None) @@ -1969,13 +1987,13 @@ def fix_settings(model_type, ui_defaults): remove_background_images_ref = ui_defaults.get("remove_background_images_ref", 0) - if video_settings_version < 2.21: + if video_settings_version < 2.22: if "I" in video_prompt_type: if remove_background_images_ref == 2: video_prompt_type = video_prompt_type.replace("I", "KI") - if remove_background_images_ref != 0: - remove_background_images_ref = 1 - ui_defaults["remove_background_images_ref"] = remove_background_images_ref + if remove_background_images_ref != 0: + remove_background_images_ref = 1 + ui_defaults["remove_background_images_ref"] = remove_background_images_ref ui_defaults["video_prompt_type"] = video_prompt_type @@ -2331,12 +2349,12 @@ def download_models(model_filename, model_type): shared_def = { "repoId" : "DeepBeepMeep/Wan2.1", - "sourceFolderList" : [ "pose", "scribble", "flow", "depth", "mask", "wav2vec", "chinese-wav2vec2-base", "pyannote" "" ], + "sourceFolderList" : [ "pose", "scribble", "flow", "depth", "mask", "wav2vec", "chinese-wav2vec2-base", "pyannote", "" ], "fileList" : [ ["dw-ll_ucoco_384.onnx", "yolox_l.onnx"],["netG_A_latest.pth"], ["raft-things.pth"], ["depth_anything_v2_vitl.pth","depth_anything_v2_vitb.pth"], ["sam_vit_h_4b8939_fp16.safetensors"], ["config.json", "feature_extractor_config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"], ["config.json", "pytorch_model.bin", "preprocessor_config.json"], - ["pyannote_model_wespeaker-voxceleb-resnet34-LM.bin", "pytorch_model_segmentation-3.0.bin"], [ "flownet.pkl" ] ] + ["pyannote_model_wespeaker-voxceleb-resnet34-LM.bin", "pytorch_model_segmentation-3.0.bin"], [ "flownet.pkl" ] ] } process_files_def(**shared_def) @@ -2374,18 +2392,22 @@ def download_models(model_filename, model_type): finetune_def = get_model_finetune_def(model_type) if finetune_def != None and not model_type in modules_files: if not os.path.isfile(model_filename ): - use_url = model_filename - for url in finetune_def["URLs"]: - if os.path.basename(model_filename) in url: - use_url = url - break - if not url.startswith("http"): - raise Exception(f"Model '{model_filename}' was not found locally and no URL was provided to download it. Please add an URL in the finetune definition file.") - try: - download_file(use_url, model_filename) - except Exception as e: - if os.path.isfile(model_filename): os.remove(model_filename) - raise Exception(f"URL '{use_url}' is invalid for Model '{model_filename}' : {str(e)}'") + URLs = get_finetune_URLs(model_type) + if not isinstance(URLs, str): # dont download anything right now if a base type is referenced as the download will occur just after + use_url = model_filename + for url in URLs: + if os.path.basename(model_filename) in url: + use_url = url + break + if not url.startswith("http"): + raise Exception(f"Model '{model_filename}' was not found locally and no URL was provided to download it. Please add an URL in the finetune definition file.") + try: + download_file(use_url, model_filename) + except Exception as e: + if os.path.isfile(model_filename): os.remove(model_filename) + raise Exception(f"URL '{use_url}' is invalid for Model '{model_filename}' : {str(e)}'") + model_filename = None + for url in finetune_def.get("preload_URLs", []): filename = "ckpts/" + url.split("/")[-1] if not os.path.isfile(filename ): @@ -2396,7 +2418,6 @@ def download_models(model_filename, model_type): except Exception as e: if os.path.isfile(filename): os.remove(filename) raise Exception(f"Preload URL '{url}' is invalid: {str(e)}'") - model_filename = None if model_family == "wan": text_encoder_filename = get_wan_text_encoder_filename(text_encoder_quantization) model_def = { @@ -3123,10 +3144,17 @@ def select_video(state, input_file_list, event_data: gr.EventData): video_video_prompt_type = configs.get("video_prompt_type", "") video_image_prompt_type = configs.get("image_prompt_type", "") video_audio_prompt_type = configs.get("audio_prompt_type", "") - map_video_prompt = {"V" : "Control Video", "A" : "Mask Video", "I" : "Reference Images"} + def check(src, cond): + pos, neg = cond if isinstance(cond, tuple) else (cond, None) + if not all_letters(src, pos): return False + if neg is not None and any_letters(src, neg): return False + return True + map_video_prompt = {"V" : "Control Video", ("VA", "U") : "Mask Video", "I" : "Reference Images"} map_image_prompt = {"V" : "Source Video", "L" : "Last Video", "S" : "Start Image", "E" : "End Image"} map_audio_prompt = {"A" : "Audio Source", "B" : "Audio Source #2"} - video_other_prompts = [ v for s,v in map_image_prompt.items() if s in video_image_prompt_type] + [ v for s,v in map_video_prompt.items() if s in video_video_prompt_type] + [ v for s,v in map_audio_prompt.items() if s in video_audio_prompt_type] + video_other_prompts = [ v for s,v in map_image_prompt.items() if all_letters(video_image_prompt_type,s)] \ + + [ v for s,v in map_video_prompt.items() if check(video_video_prompt_type,s)] \ + + [ v for s,v in map_audio_prompt.items() if all_letters(video_audio_prompt_type,s)] video_model_type = configs.get("model_type", "t2v") video_other_prompts = ", ".join(video_other_prompts) video_resolution = configs.get("resolution", "") + f" (real: {width}x{height})" @@ -3138,24 +3166,31 @@ def select_video(state, input_file_list, event_data: gr.EventData): video_length_summary += " (" if video_length != frames_count: video_length_summary += f"real: {frames_count} frames, " video_length_summary += f"{frames_count/fps:.1f}s, {round(fps)} fps)" - video_guidance_scale = configs.get("video_guidance_scale", 1) + video_guidance_scale = configs.get("guidance_scale", 1) + video_NAG_scale = configs.get("NAG_scale", 1) video_embedded_guidance_scale = configs.get("video_embedded_guidance_scale ", 1) if get_model_family(video_model_type) == "hunyuan": video_guidance_scale = video_embedded_guidance_scale video_guidance_label = "Embedded Guidance Scale" else: - video_guidance_label = "Guidance Scale" + video_guidance_label = "Guidance" video_flow_shift = configs.get("flow_shift", 1) video_video_guide_outpainting = configs.get("video_guide_outpainting", "") video_outpainting = "" - if len(video_video_guide_outpainting) > 0 and not video_video_guide_outpainting.startswith("#"): + if len(video_video_guide_outpainting) > 0 and not video_video_guide_outpainting.startswith("#") \ + and (any_letters(video_video_prompt_type, "VFK") or any_letters(video_image_prompt_type, "VL")) : video_video_guide_outpainting = video_video_guide_outpainting.split(" ") video_outpainting = f"Top={video_video_guide_outpainting[0]}%, Bottom={video_video_guide_outpainting[1]}%, Left={video_video_guide_outpainting[2]}%, Right={video_video_guide_outpainting[3]}%" video_num_inference_steps = configs.get("num_inference_steps", 0) video_creation_date = str(get_file_creation_date(file_name)) if "." in video_creation_date: video_creation_date = video_creation_date[:video_creation_date.rfind(".")] video_generation_time = str(configs.get("generation_time", "0")) + "s" - video_activated_loras = "
".join(configs.get("activated_loras", [])) + video_activated_loras = configs.get("activated_loras", []) + video_loras_multipliers = configs.get("loras_multipliers", "") + video_loras_multipliers = preparse_loras_multipliers(video_loras_multipliers) + video_loras_multipliers += [""] * len(video_activated_loras) + video_activated_loras = [ f"{lora}x{multiplier if len(multiplier)>0 else '1'}" for lora, multiplier in zip(video_activated_loras, video_loras_multipliers) ] + video_activated_loras_str = "" + "".join(video_activated_loras) + "
" if len(video_activated_loras) > 0 else "" values += misc_values + [video_prompt] labels += misc_labels + ["Text Prompt"] if len(video_other_prompts) >0 : @@ -3166,7 +3201,14 @@ def select_video(state, input_file_list, event_data: gr.EventData): labels += ["Outpainting"] values += [video_resolution, video_length_summary, video_seed, video_guidance_scale, video_flow_shift, video_num_inference_steps] labels += [ "Resolution", "Video Length", "Seed", video_guidance_label, "Flow Shift", "Num Inference steps"] - + video_negative_prompt = configs.get("negative_prompt", "") + if len(video_negative_prompt) > 0: + values += [video_negative_prompt] + labels += ["Negative Prompt"] + video_NAG_scale = configs.get("NAG_scale", 1) + if video_NAG_scale > 1: + values += [video_NAG_scale] + labels += ["NAG Scale"] video_skip_steps_cache_type = configs.get("skip_steps_cache_type", "") video_skip_steps_multiplier = configs.get("skip_steps_multiplier", 0) video_skip_steps_cache_start_step_perc = configs.get("skip_steps_start_step_perc", 0) @@ -3180,8 +3222,8 @@ def select_video(state, input_file_list, event_data: gr.EventData): values += pp_values labels += pp_labels - if len(video_activated_loras) > 0: - values += [video_activated_loras] + if len(video_activated_loras_str) > 0: + values += [video_activated_loras_str] labels += ["Loras"] if nb_audio_tracks > 0: values +=[nb_audio_tracks] @@ -3649,6 +3691,13 @@ def get_available_filename(target_path, video_source, suffix = "", force_extensi return full_path counter += 1 +def preparse_loras_multipliers(loras_multipliers): + loras_multipliers = loras_multipliers.strip(" \r\n") + loras_mult_choices_list = loras_multipliers.replace("\r", "").split("\n") + loras_mult_choices_list = [multi for multi in loras_mult_choices_list if len(multi)>0 and not multi.startswith("#")] + loras_multipliers = " ".join(loras_mult_choices_list) + return loras_multipliers.split(" ") + def set_seed(seed): import random seed = random.randint(0, 99999999) if seed == None or seed < 0 else seed @@ -3839,6 +3888,9 @@ def generate_video( MMAudio_prompt, MMAudio_neg_prompt, RIFLEx_setting, + NAG_scale, + NAG_tau, + NAG_alpha, slg_switch, slg_layers, slg_start_perc, @@ -3927,11 +3979,8 @@ def generate_video( loras_list_mult_choices_nums = [] loras_multipliers = loras_multipliers.strip(" \r\n") if len(loras_multipliers) > 0: - loras_mult_choices_list = loras_multipliers.replace("\r", "").split("\n") - loras_mult_choices_list = [multi for multi in loras_mult_choices_list if len(multi)>0 and not multi.startswith("#")] - loras_multipliers = " ".join(loras_mult_choices_list) - list_mult_choices_str = loras_multipliers.split(" ") - for i, mult in enumerate(list_mult_choices_str): + list_mult_choices_list = preparse_loras_multipliers(loras_multipliers) + for i, mult in enumerate(list_mult_choices_list): mult = mult.strip() if "," in mult: multlist = mult.split(",") @@ -4446,6 +4495,9 @@ def generate_video( model_filename = model_filename, model_type = base_model_type, loras_slists = loras_slists, + NAG_scale = NAG_scale, + NAG_tau = NAG_tau, + NAG_alpha = NAG_alpha, offloadobj = offloadobj, ) except Exception as e: @@ -5820,6 +5872,9 @@ def save_inputs( MMAudio_prompt, MMAudio_neg_prompt, RIFLEx_setting, + NAG_scale, + NAG_tau, + NAG_alpha, slg_switch, slg_layers, slg_start_perc, @@ -5975,6 +6030,18 @@ def unload_model_if_needed(state): gc.collect() reload_needed= True +def all_letters(source_str, letters): + for letter in letters: + if not letter in source_str: + return False + return True + +def any_letters(source_str, letters): + for letter in letters: + if letter in source_str: + return True + return False + def filter_letters(source_str, letters): ret = "" for letter in letters: @@ -6301,10 +6368,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non image_prompt_type = gr.Radio( [("Start Video with Image", "S"),("Continue Video", "V"),("Text Prompt Only", "T")], value =image_prompt_type_value, label="Location", show_label= False, visible= True , scale= 3) # image_start = gr.Image(label= "Image as a starting point for a new video", type ="pil",value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value ) - image_start = gr.Gallery( + image_start = gr.Gallery(preview= True, label="Images as starting points for new videos", type ="pil", #file_types= "image", columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value) - image_end = gr.Gallery( + image_end = gr.Gallery(preview= True, label="Images as ending points for new videos", type ="pil", #file_types= "image", columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None)) video_source = gr.Video(label= "Video to Continue", visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),) @@ -6355,11 +6422,11 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non image_prompt_type_choices += [("Use both a Start and an End Image", "SE")] image_prompt_type = gr.Radio( image_prompt_type_choices, value =image_prompt_type_value, label="Location", show_label= False, visible= not hunyuan_i2v, scale= 3) - image_start = gr.Gallery( + image_start = gr.Gallery(preview= True, label="Images as starting points for new videos", type ="pil", #file_types= "image", columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value) - image_end = gr.Gallery( + image_end = gr.Gallery(preview= True, label="Images as ending points for new videos", type ="pil", #file_types= "image", columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None)) else: @@ -6495,7 +6562,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value ) - image_refs = gr.Gallery( label ="Start Image" if hunyuan_video_avatar else "Reference Images", + image_refs = gr.Gallery(preview= True, label ="Start Image" if hunyuan_video_avatar else "Reference Images", type ="pil", show_label= True, columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in video_prompt_type_value, value= ui_defaults.get("image_refs", None), @@ -6609,7 +6676,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non with gr.Column(): seed = gr.Slider(-1, 999999999, value=ui_defaults.get("seed",-1), step=1, label="Seed (-1 for random)") with gr.Row(visible = not ltxv): - guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance_scale",5), step=0.5, label="Guidance Scale", visible=not (hunyuan_t2v or hunyuan_i2v)) + guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance_scale",5), step=0.5, label="Guidance (CFG)", visible=not (hunyuan_t2v or hunyuan_i2v)) audio_guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("audio_guidance_scale", 5 if fantasy else 4), step=0.5, label="Audio Guidance", visible=fantasy or multitalk) embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=(hunyuan_t2v or hunyuan_i2v)) flow_shift = gr.Slider(1.0, 25.0, value=ui_defaults.get("flow_shift",3), step=0.1, label="Shift Scale") @@ -6626,10 +6693,15 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non with gr.Row(visible = vace): control_net_weight = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight",1), step=0.1, label="Control Net Weight #1", visible=vace) control_net_weight2 = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight2",1), step=0.1, label="Control Net Weight #2", visible=vace) + negative_prompt = gr.Textbox(label="Negative Prompt (ignored if no Guidance that is if CFG = 1)", value=ui_defaults.get("negative_prompt", "") ) + with gr.Column(visible = vace or t2v or test_class_i2v(model_type)) as NAG_col: + gr.Markdown("NAG enforces Negative Prompt even if no Guidance is set (CFG = 1), set NAG Scale to > 1 to enable it") + with gr.Row(): + NAG_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("NAG_scale",1), step=0.1, label="NAG Scale", visible = True) + NAG_tau = gr.Slider(1.0, 5.0, value=ui_defaults.get("NAG_tau",3.5), step=0.1, label="NAG Tau", visible = True) + NAG_alpha = gr.Slider(1.0, 2.0, value=ui_defaults.get("NAG_alpha",.5), step=0.1, label="NAG Alpha", visible = True) with gr.Row(): - negative_prompt = gr.Textbox(label="Negative Prompt", value=ui_defaults.get("negative_prompt", "") ) - with gr.Row(): - repeat_generation = gr.Slider(1, 25.0, value=ui_defaults.get("repeat_generation",1), step=1, label="Default Number of Generated Videos per Prompt") + repeat_generation = gr.Slider(1, 25.0, value=ui_defaults.get("repeat_generation",1), step=1, label="Num. of Generated Videos per Prompt") multi_images_gen_type = gr.Dropdown( value=ui_defaults.get("multi_images_gen_type",0), choices=[ ("Generate every combination of images and texts", 0), @@ -6647,7 +6719,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non multiselect= True, label="Activated Loras" ) - loras_multipliers = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by space characters or carriage returns, lines that start with # are ignored", value=launch_multis_str) + loras_multipliers = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by Space chars or CR, lines that start with # are ignored", value=launch_multis_str) with gr.Tab("Steps Skipping", visible = not ltxv) as speed_tab: with gr.Column(): gr.Markdown("Tea Cache and Mag Cache accelerate the Video Generation by skipping intelligently some steps, the more steps are skipped the lower the quality of the video.") @@ -6894,7 +6966,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non gen_status = gr.Text(interactive= False, label = "Status") status_trigger = gr.Text(interactive= False, visible=False) default_files = [] - output = gr.Gallery(value =default_files, label="Generated videos", show_label=False, elem_id="gallery" , columns=[3], rows=[1], object_fit="contain", height=450, selected_index=0, interactive= False) + output = gr.Gallery(value =default_files, label="Generated videos", preview= True, show_label=False, elem_id="gallery" , columns=[3], rows=[1], object_fit="contain", height=450, selected_index=0, interactive= False) output_trigger = gr.Text(interactive= False, visible=False) refresh_form_trigger = gr.Text(interactive= False, visible=False) fill_wizard_prompt_trigger = gr.Text(interactive= False, visible=False) @@ -6973,7 +7045,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non sliding_window_tab, misc_tab, prompt_enhancer_row, inference_steps_row, skip_layer_guidance_row, audio_guide_row, RIFLEx_setting_col, video_prompt_type_video_guide, video_prompt_type_video_mask, video_prompt_type_image_refs, apg_col, audio_prompt_type_sources, audio_prompt_type_remux_row, video_guide_outpainting_col,video_guide_outpainting_top, video_guide_outpainting_bottom, video_guide_outpainting_left, video_guide_outpainting_right, - video_guide_outpainting_checkbox, video_guide_outpainting_row, show_advanced, video_info_to_control_video_btn, video_info_to_video_source_btn, sample_solver_row] # presets_column, + video_guide_outpainting_checkbox, video_guide_outpainting_row, show_advanced, video_info_to_control_video_btn, video_info_to_video_source_btn, sample_solver_row, + NAG_col] # presets_column, if update_form: locals_dict = locals() gen_inputs = [state_dict if k=="state" else locals_dict[k] for k in inputs_names] + [state_dict] + extra_inputs