Added Hunyuan Custom Audio and Edit support

This commit is contained in:
DeepBeepMeep 2025-06-11 01:51:25 +02:00
parent 026c2b0cbb
commit 43aa414eaf
12 changed files with 395 additions and 113 deletions

View File

@ -20,11 +20,16 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
## 🔥 Latest Updates ## 🔥 Latest Updates
### May 28 2025: WanGP v5.41 ### June 11 2025: WanGP v5.5
👋 *Hunyuan Video Custom Audio*: it is similar to Hunyuan Video Avatar except there isn't any lower limit on the number of frames and you can use your reference images in a different context than the image itself\
*Hunyuan Video Custom Edit*: Hunyuan Video Controlnet, use it to do inpainting and replace a person in a video while still keeping his poses. Similar to Vace but less restricted than the Wan models in terms of content...
### June 6 2025: WanGP v5.41
👋 Bonus release: Support for **AccVideo** Lora to speed up x2 Video generations in Wan models. Check the Loras documentation to get the usage instructions of AccVideo.\ 👋 Bonus release: Support for **AccVideo** Lora to speed up x2 Video generations in Wan models. Check the Loras documentation to get the usage instructions of AccVideo.\
You will need to do a *pip install -r requirements.txt* You will need to do a *pip install -r requirements.txt*
### May 28 2025: WanGP v5.4 ### June 6 2025: WanGP v5.4
👋 World Exclusive : **Hunyuan Video Avatar** Support ! You won't need 80 GB of VRAM nor 32 GB oF VRAM, just 10 GB of VRAM will be sufficient to generate up to 15s of high quality speech / song driven Video at a high speed with no quality degradation. Support for TeaCache included.\ 👋 World Exclusive : **Hunyuan Video Avatar** Support ! You won't need 80 GB of VRAM nor 32 GB oF VRAM, just 10 GB of VRAM will be sufficient to generate up to 15s of high quality speech / song driven Video at a high speed with no quality degradation. Support for TeaCache included.\
Here is a link to the original repo where you will find some very interesting documentation and examples. https://github.com/Tencent-Hunyuan/HunyuanVideo-Avatar. Kudos to the Hunyuan Video Avatar team for the best model of its kind.\ Here is a link to the original repo where you will find some very interesting documentation and examples. https://github.com/Tencent-Hunyuan/HunyuanVideo-Avatar. Kudos to the Hunyuan Video Avatar team for the best model of its kind.\
Also many thanks to Reevoy24 for his repackaging / completing the documentation Also many thanks to Reevoy24 for his repackaging / completing the documentation

View File

@ -1,10 +1,14 @@
# Changelog # Changelog
## 🔥 Latest News ## 🔥 Latest News
### May 28 2025: WanGP v5.41 ### June 11 2025: WanGP v5.5
👋 *Hunyuan Video Custom Audio*: it is similar to Hunyuan Video Avatar excpet there isn't any lower limit on the number of frames and you can use your reference images in a different context than the image itself\
*Hunyuan Video Custom Edit*: Hunyuan Video Controlnet, use it to do inpainting and replace a person in a video while still keeping his poses. Similar to Vace but less restricted than the Wan models in terms of content...
### June 6 2025: WanGP v5.41
👋 Bonus release: Support for **AccVideo** Lora to speed up x2 Video generations in Wan models. Check the Loras documentation to get the usage instructions of AccVideo. 👋 Bonus release: Support for **AccVideo** Lora to speed up x2 Video generations in Wan models. Check the Loras documentation to get the usage instructions of AccVideo.
### May 28 2025: WanGP v5.4 ### June 6 2025: WanGP v5.4
👋 World Exclusive : Hunyuan Video Avatar Support ! You won't need 80 GB of VRAM nor 32 GB oF VRAM, just 10 GB of VRAM will be sufficient to generate up to 15s of high quality speech / song driven Video at a high speed with no quality degradation. Support for TeaCache included. 👋 World Exclusive : Hunyuan Video Avatar Support ! You won't need 80 GB of VRAM nor 32 GB oF VRAM, just 10 GB of VRAM will be sufficient to generate up to 15s of high quality speech / song driven Video at a high speed with no quality degradation. Support for TeaCache included.
### May 26, 2025: WanGP v5.3 ### May 26, 2025: WanGP v5.3

View File

@ -54,6 +54,10 @@ def encode_audio(wav2vec, audio_feats, fps, num_frames=129):
elif fps == 12.5: elif fps == 12.5:
start_ts = [0] start_ts = [0]
step_ts = [2] step_ts = [2]
else:
start_ts = [0]
step_ts = [1]
num_frames = min(num_frames, 400) num_frames = min(num_frames, 400)
audio_feats = wav2vec.encoder(audio_feats.unsqueeze(0)[:, :, :3000], output_hidden_states=True).hidden_states audio_feats = wav2vec.encoder(audio_feats.unsqueeze(0)[:, :, :3000], output_hidden_states=True).hidden_states
audio_feats = torch.stack(audio_feats, dim=2) audio_feats = torch.stack(audio_feats, dim=2)

View File

@ -774,7 +774,10 @@ class HunyuanVideoPipeline(DiffusionPipeline):
# uncond_ref_latents: Optional[torch.Tensor] = None, # uncond_ref_latents: Optional[torch.Tensor] = None,
pixel_value_llava: Optional[torch.Tensor] = None, pixel_value_llava: Optional[torch.Tensor] = None,
uncond_pixel_value_llava: Optional[torch.Tensor] = None, uncond_pixel_value_llava: Optional[torch.Tensor] = None,
bg_latents: Optional[torch.Tensor] = None,
audio_prompts: Optional[torch.Tensor] = None,
ip_cfg_scale: float = 0.0, ip_cfg_scale: float = 0.0,
audio_strength: float = 1.0,
use_deepcache: int = 1, use_deepcache: int = 1,
num_videos_per_prompt: Optional[int] = 1, num_videos_per_prompt: Optional[int] = 1,
eta: float = 0.0, eta: float = 0.0,
@ -922,6 +925,9 @@ class HunyuanVideoPipeline(DiffusionPipeline):
# "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", # "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
# ) # )
if self._interrupt:
return [None]
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
@ -968,7 +974,6 @@ class HunyuanVideoPipeline(DiffusionPipeline):
self._guidance_rescale = guidance_rescale self._guidance_rescale = guidance_rescale
self._clip_skip = clip_skip self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False
# 2. Define call parameters # 2. Define call parameters
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
@ -1057,6 +1062,12 @@ class HunyuanVideoPipeline(DiffusionPipeline):
prompt_mask[0] = torch.cat([torch.ones((1, prompt_mask[0].sum() - 575)).to(prompt_mask), prompt_mask[0] = torch.cat([torch.ones((1, prompt_mask[0].sum() - 575)).to(prompt_mask),
torch.zeros((1, prompt_mask.shape[1] - prompt_mask[0].sum() + 575)).to(prompt_mask)], dim=1) torch.zeros((1, prompt_mask.shape[1] - prompt_mask[0].sum() + 575)).to(prompt_mask)], dim=1)
if bg_latents is not None:
bg_latents = torch.cat([bg_latents, bg_latents], dim=0)
if audio_prompts is not None:
audio_prompts = torch.cat([torch.zeros_like(audio_prompts), audio_prompts], dim=0)
if ip_cfg_scale>0: if ip_cfg_scale>0:
prompt_embeds = torch.cat([prompt_embeds, prompt_embeds[1:]]) prompt_embeds = torch.cat([prompt_embeds, prompt_embeds[1:]])
prompt_embeds_2 = torch.cat([prompt_embeds_2, prompt_embeds_2[1:]]) prompt_embeds_2 = torch.cat([prompt_embeds_2, prompt_embeds_2[1:]])
@ -1263,6 +1274,9 @@ class HunyuanVideoPipeline(DiffusionPipeline):
pipeline=self, pipeline=self,
x_id=j, x_id=j,
step_no=i, step_no=i,
bg_latents=bg_latents[j].unsqueeze(0) if bg_latents!=None else None,
audio_prompts=audio_prompts[j].unsqueeze(0) if audio_prompts!=None else None,
audio_strength=audio_strength,
callback = callback, callback = callback,
) )
if self._interrupt: if self._interrupt:
@ -1292,6 +1306,9 @@ class HunyuanVideoPipeline(DiffusionPipeline):
guidance=guidance_expand, guidance=guidance_expand,
pipeline=self, pipeline=self,
step_no=i, step_no=i,
bg_latents=bg_latents,
audio_prompts=audio_prompts,
audio_strength=audio_strength,
callback = callback, callback = callback,
) )
if self._interrupt: if self._interrupt:

View File

@ -908,6 +908,10 @@ class HunyuanVideoAudioPipeline(DiffusionPipeline):
second element is a list of `bool`s indicating whether the corresponding generated image contains second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content. "not-safe-for-work" (nsfw) content.
""" """
if self._interrupt:
return [None]
callback = kwargs.pop("callback", None) callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None) callback_steps = kwargs.pop("callback_steps", None)
if callback_steps is not None: if callback_steps is not None:
@ -956,7 +960,6 @@ class HunyuanVideoAudioPipeline(DiffusionPipeline):
self._guidance_rescale = guidance_rescale self._guidance_rescale = guidance_rescale
self._clip_skip = clip_skip self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False
# 2. Define call parameters # 2. Define call parameters
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):

View File

@ -235,13 +235,13 @@ def patched_llava_forward(
image_hidden_states=image_features if pixel_values is not None else None, image_hidden_states=image_features if pixel_values is not None else None,
) )
def adapt_avatar_model(model): def adapt_model(model, audio_block_name):
modules_dict= { k: m for k, m in model.named_modules()} modules_dict= { k: m for k, m in model.named_modules()}
for model_layer, avatar_layer in model.double_stream_map.items(): for model_layer, avatar_layer in model.double_stream_map.items():
module = modules_dict[f"audio_adapter_blocks.{avatar_layer}"] module = modules_dict[f"{audio_block_name}.{avatar_layer}"]
target = modules_dict[f"double_blocks.{model_layer}"] target = modules_dict[f"double_blocks.{model_layer}"]
setattr(target, "audio_adapter", module ) setattr(target, "audio_adapter", module )
delattr(model, "audio_adapter_blocks") delattr(model, audio_block_name)
class DataPreprocess(object): class DataPreprocess(object):
def __init__(self): def __init__(self):
@ -329,17 +329,25 @@ class Inference(object):
precision = "bf16" precision = "bf16"
vae_precision = "fp32" if VAE_dtype == torch.float32 else "bf16" vae_precision = "fp32" if VAE_dtype == torch.float32 else "bf16"
embedded_cfg_scale = 6 embedded_cfg_scale = 6
filepath = model_filepath[0]
i2v_condition_type = None i2v_condition_type = None
i2v_mode = "i2v" in model_filepath[0] i2v_mode = "i2v" in filepath
custom = False custom = False
custom_audio = False
avatar = False avatar = False
if i2v_mode: if i2v_mode:
model_id = "HYVideo-T/2" model_id = "HYVideo-T/2"
i2v_condition_type = "token_replace" i2v_condition_type = "token_replace"
elif "custom" in model_filepath[0]: elif "custom" in filepath:
if "audio" in filepath:
model_id = "HYVideo-T/2-custom-audio"
custom_audio = True
elif "edit" in filepath:
model_id = "HYVideo-T/2-custom-edit"
else:
model_id = "HYVideo-T/2-custom" model_id = "HYVideo-T/2-custom"
custom = True custom = True
elif "avatar" in model_filepath[0]: elif "avatar" in filepath :
model_id = "HYVideo-T/2-avatar" model_id = "HYVideo-T/2-avatar"
text_len = 256 text_len = 256
avatar = True avatar = True
@ -376,11 +384,11 @@ class Inference(object):
from mmgp import offload from mmgp import offload
# model = Inference.load_state_dict(args, model, model_filepath) # model = Inference.load_state_dict(args, model, model_filepath)
# model_filepath ="c:/temp/avatar/mp_rank_00_model_states.pt" # model_filepath ="c:/temp/hc/mp_rank_00_model_states_video.pt"
offload.load_model_data(model, model_filepath, pinToMemory = pinToMemory, partialPinning = partialPinning) offload.load_model_data(model, model_filepath, pinToMemory = pinToMemory, partialPinning = partialPinning)
pass pass
# offload.save_model(model, "hunyuan_video_avatar_720_bf16.safetensors") # offload.save_model(model, "hunyuan_video_avatar_edit_720_bf16.safetensors")
# offload.save_model(model, "hunyuan_video_avatar_720_quanto_bf16_int8.safetensors", do_quantize= True) # offload.save_model(model, "hunyuan_video_avatar_edit_720_quanto_bf16_int8.safetensors", do_quantize= True)
model.mixed_precision = mixed_precision_transformer model.mixed_precision = mixed_precision_transformer
@ -472,15 +480,17 @@ class Inference(object):
wav2vec = None wav2vec = None
align_instance = None align_instance = None
if avatar: if avatar or custom_audio:
feature_extractor = AutoFeatureExtractor.from_pretrained("ckpts/whisper-tiny/") feature_extractor = AutoFeatureExtractor.from_pretrained("ckpts/whisper-tiny/")
wav2vec = WhisperModel.from_pretrained("ckpts/whisper-tiny/").to(device="cpu", dtype=torch.float32) wav2vec = WhisperModel.from_pretrained("ckpts/whisper-tiny/").to(device="cpu", dtype=torch.float32)
wav2vec._model_dtype = torch.float32 wav2vec._model_dtype = torch.float32
wav2vec.requires_grad_(False) wav2vec.requires_grad_(False)
if avatar:
align_instance = AlignImage("cuda", det_path="ckpts/det_align/detface.pt") align_instance = AlignImage("cuda", det_path="ckpts/det_align/detface.pt")
align_instance.facedet.model.to("cpu") align_instance.facedet.model.to("cpu")
adapt_model(model, "audio_adapter_blocks")
adapt_avatar_model(model) elif custom_audio:
adapt_model(model, "audio_models")
return cls( return cls(
i2v=i2v_mode, i2v=i2v_mode,
@ -600,7 +610,7 @@ class HunyuanVideoSampler(Inference):
return pipeline return pipeline
def get_rotary_pos_embed_new(self, video_length, height, width, concat_dict={}): def get_rotary_pos_embed_new(self, video_length, height, width, concat_dict={}, enable_riflex = False):
target_ndim = 3 target_ndim = 3
ndim = 5 - 2 ndim = 5 - 2
latents_size = [(video_length-1)//4+1 , height//8, width//8] latents_size = [(video_length-1)//4+1 , height//8, width//8]
@ -628,7 +638,10 @@ class HunyuanVideoSampler(Inference):
theta=256, theta=256,
use_real=True, use_real=True,
theta_rescale_factor=1, theta_rescale_factor=1,
concat_dict=concat_dict) concat_dict=concat_dict,
L_test = (video_length - 1) // 4 + 1,
enable_riflex = enable_riflex
)
return freqs_cos, freqs_sin return freqs_cos, freqs_sin
def get_rotary_pos_embed(self, video_length, height, width, enable_riflex = False): def get_rotary_pos_embed(self, video_length, height, width, enable_riflex = False):
@ -687,6 +700,9 @@ class HunyuanVideoSampler(Inference):
input_prompt, input_prompt,
input_ref_images = None, input_ref_images = None,
audio_guide = None, audio_guide = None,
input_frames = None,
input_masks = None,
input_video = None,
fps = 24, fps = 24,
height=192, height=192,
width=336, width=336,
@ -701,13 +717,14 @@ class HunyuanVideoSampler(Inference):
num_videos_per_prompt=1, num_videos_per_prompt=1,
i2v_resolution="720p", i2v_resolution="720p",
image_start=None, image_start=None,
enable_riflex = False, enable_RIFLEx = False,
i2v_condition_type: str = "token_replace", i2v_condition_type: str = "token_replace",
i2v_stability=True, i2v_stability=True,
VAE_tile_size = None, VAE_tile_size = None,
joint_pass = False, joint_pass = False,
cfg_star_switch = False, cfg_star_switch = False,
fit_into_canvas = True, fit_into_canvas = True,
conditioning_latents_size = 0,
**kwargs, **kwargs,
): ):
@ -777,6 +794,7 @@ class HunyuanVideoSampler(Inference):
target_height = align_to(height, 16) target_height = align_to(height, 16)
target_width = align_to(width, 16) target_width = align_to(width, 16)
target_frame_num = frame_num target_frame_num = frame_num
audio_strength = 1
if input_ref_images != None: if input_ref_images != None:
# ip_cfg_scale = 3.0 # ip_cfg_scale = 3.0
@ -862,7 +880,7 @@ class HunyuanVideoSampler(Inference):
# ======================================================================== # ========================================================================
if input_ref_images == None: if input_ref_images == None:
freqs_cos, freqs_sin = self.get_rotary_pos_embed(target_frame_num, target_height, target_width, enable_riflex) freqs_cos, freqs_sin = self.get_rotary_pos_embed(target_frame_num, target_height, target_width, enable_RIFLEx)
else: else:
if self.avatar: if self.avatar:
w, h = input_ref_images.size w, h = input_ref_images.size
@ -873,8 +891,13 @@ class HunyuanVideoSampler(Inference):
concat_dict = {'mode': 'timecat', 'bias': -1} concat_dict = {'mode': 'timecat', 'bias': -1}
freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(129, target_height, target_width, concat_dict) freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(129, target_height, target_width, concat_dict)
else: else:
if input_frames != None:
target_height, target_width = input_frames.shape[-3:-1]
elif input_video != None:
target_height, target_width = input_video.shape[-2:]
concat_dict = {'mode': 'timecat-w', 'bias': -1} concat_dict = {'mode': 'timecat-w', 'bias': -1}
freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(target_frame_num, target_height, target_width, concat_dict) freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(target_frame_num, target_height, target_width, concat_dict, enable_RIFLEx)
n_tokens = freqs_cos.shape[0] n_tokens = freqs_cos.shape[0]
@ -892,7 +915,35 @@ class HunyuanVideoSampler(Inference):
ref_latents, uncond_audio_prompts, audio_prompts, face_masks, motion_exp, motion_pose = None, None, None, None, None, None ref_latents, uncond_audio_prompts, audio_prompts, face_masks, motion_exp, motion_pose = None, None, None, None, None, None
if audio_guide != None:
bg_latents = None
if input_video != None:
pixel_value_bg = input_video.unsqueeze(0)
pixel_value_mask = torch.zeros_like(input_video).unsqueeze(0)
if input_frames != None:
pixel_value_video_bg = input_frames.permute(-1,0,1,2).unsqueeze(0).float()
pixel_value_video_mask = input_masks.unsqueeze(-1).repeat(1,1,1,3).permute(-1,0,1,2).unsqueeze(0).float()
pixel_value_video_bg = pixel_value_video_bg.div_(127.5).add_(-1.)
if input_video != None:
pixel_value_bg = torch.cat([pixel_value_bg, pixel_value_video_bg], dim=2)
pixel_value_mask = torch.cat([ pixel_value_mask, pixel_value_video_mask], dim=2)
else:
pixel_value_bg = pixel_value_video_bg
pixel_value_mask = pixel_value_video_mask
pixel_value_video_mask, pixel_value_video_bg = None, None
if input_video != None or input_frames != None:
if pixel_value_bg.shape[2] < frame_num:
padding_shape = list(pixel_value_bg.shape[0:2]) + [frame_num-pixel_value_bg.shape[2]] + list(pixel_value_bg.shape[3:])
pixel_value_bg = torch.cat([pixel_value_bg, torch.full(padding_shape, -1, dtype=pixel_value_bg.dtype, device= pixel_value_bg.device ) ], dim=2)
pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 255, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2)
bg_latents = self.vae.encode(pixel_value_bg).latent_dist.sample()
pixel_value_mask = pixel_value_mask.div_(127.5).add_(-1.)
mask_latents = self.vae.encode(pixel_value_mask).latent_dist.sample()
bg_latents = torch.cat([bg_latents, mask_latents], dim=1)
bg_latents.mul_(self.vae.config.scaling_factor)
if self.avatar:
if n_prompt == None or len(n_prompt) == 0: if n_prompt == None or len(n_prompt) == 0:
n_prompt = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, Lens changes" n_prompt = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, Lens changes"
@ -930,28 +981,33 @@ class HunyuanVideoSampler(Inference):
# from wan.utils.utils import cache_video # from wan.utils.utils import cache_video
# cache_video( tensor=image, save_file="decode.mp4", fps=25, nrow=1, normalize=True, value_range=(-1, 1)) # cache_video( tensor=image, save_file="decode.mp4", fps=25, nrow=1, normalize=True, value_range=(-1, 1))
motion_pose = np.array([25] * 4)
motion_exp = np.array([30] * 4)
motion_pose = torch.from_numpy(motion_pose).unsqueeze(0)
motion_exp = torch.from_numpy(motion_exp).unsqueeze(0)
face_masks = torch.nn.functional.interpolate(face_masks.float().squeeze(2), face_masks = torch.nn.functional.interpolate(face_masks.float().squeeze(2),
(ref_latents.shape[-2], (ref_latents.shape[-2],
ref_latents.shape[-1]), ref_latents.shape[-1]),
mode="bilinear").unsqueeze(2).to(dtype=ref_latents.dtype) mode="bilinear").unsqueeze(2).to(dtype=ref_latents.dtype)
if audio_guide != None:
audio_input, audio_len = get_audio_feature(self.feature_extractor, audio_guide, duration = frame_num/fps ) audio_input, audio_len = get_audio_feature(self.feature_extractor, audio_guide, duration = frame_num/fps )
audio_prompts = audio_input[0] audio_prompts = audio_input[0]
weight_dtype = audio_prompts.dtype weight_dtype = audio_prompts.dtype
if self.custom:
motion_pose = np.array([25] * 4) audio_len = min(audio_len, frame_num)
motion_exp = np.array([30] * 4) audio_input = audio_input[:, :audio_len]
motion_pose = torch.from_numpy(motion_pose).unsqueeze(0)
motion_exp = torch.from_numpy(motion_exp).unsqueeze(0)
audio_prompts = encode_audio(self.wav2vec, audio_prompts.to(dtype=self.wav2vec.dtype), fps, num_frames=audio_len) audio_prompts = encode_audio(self.wav2vec, audio_prompts.to(dtype=self.wav2vec.dtype), fps, num_frames=audio_len)
audio_prompts = audio_prompts.to(self.model.dtype) audio_prompts = audio_prompts.to(self.model.dtype)
if audio_prompts.shape[1] <= 129: segment_size = 129 if self.avatar else frame_num
audio_prompts = torch.cat([audio_prompts, torch.zeros_like(audio_prompts[:, :1]).repeat(1,129-audio_prompts.shape[1], 1, 1, 1)], dim=1) if audio_prompts.shape[1] <= segment_size:
audio_prompts = torch.cat([audio_prompts, torch.zeros_like(audio_prompts[:, :1]).repeat(1,segment_size-audio_prompts.shape[1], 1, 1, 1)], dim=1)
else: else:
audio_prompts = torch.cat([audio_prompts, torch.zeros_like(audio_prompts[:, :1]).repeat(1, 5, 1, 1, 1)], dim=1) audio_prompts = torch.cat([audio_prompts, torch.zeros_like(audio_prompts[:, :1]).repeat(1, 5, 1, 1, 1)], dim=1)
uncond_audio_prompts = torch.zeros_like(audio_prompts[:,:129]) uncond_audio_prompts = torch.zeros_like(audio_prompts[:,:129])
# target_frame_num = min(target_frame_num, audio_len)
samples = self.pipeline( samples = self.pipeline(
prompt=input_prompt, prompt=input_prompt,
height=target_height, height=target_height,
@ -976,6 +1032,9 @@ class HunyuanVideoSampler(Inference):
motion_pose=motion_pose, motion_pose=motion_pose,
fps= torch.from_numpy(np.array(fps)), fps= torch.from_numpy(np.array(fps)),
bg_latents = bg_latents,
audio_strength = audio_strength,
denoise_strength=denoise_strength, denoise_strength=denoise_strength,
ip_cfg_scale=ip_cfg_scale, ip_cfg_scale=ip_cfg_scale,
freqs_cis=(freqs_cos, freqs_sin), freqs_cis=(freqs_cos, freqs_sin),

View File

@ -591,7 +591,10 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
attention_mode: Optional[str] = "sdpa", attention_mode: Optional[str] = "sdpa",
video_condition: bool = False,
audio_condition: bool = False,
avatar = False, avatar = False,
custom = False,
): ):
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
super().__init__() super().__init__()
@ -606,6 +609,10 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
self.rope_dim_list = rope_dim_list self.rope_dim_list = rope_dim_list
self.i2v_condition_type = i2v_condition_type self.i2v_condition_type = i2v_condition_type
self.attention_mode = attention_mode self.attention_mode = attention_mode
self.video_condition = video_condition
self.audio_condition = audio_condition
self.avatar = avatar
self.custom = custom
# Text projection. Default to linear projection. # Text projection. Default to linear projection.
# Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831 # Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831
@ -710,8 +717,15 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
get_activation_layer("silu"), get_activation_layer("silu"),
**factory_kwargs, **factory_kwargs,
) )
avatar_audio = avatar
if avatar_audio: if self.video_condition:
self.bg_in = PatchEmbed(
self.patch_size, self.in_channels * 2, self.hidden_size, **factory_kwargs
)
self.bg_proj = nn.Linear(self.hidden_size, self.hidden_size)
if audio_condition:
if avatar:
self.ref_in = PatchEmbed( self.ref_in = PatchEmbed(
self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs
) )
@ -741,13 +755,18 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
# -------------------- audio_insert_model -------------------- # -------------------- audio_insert_model --------------------
self.double_stream_list = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19] self.double_stream_list = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19]
self.single_stream_list = [] audio_block_name = "audio_adapter_blocks"
self.double_stream_map = {str(i): j for j, i in enumerate(self.double_stream_list)} elif custom:
self.single_stream_map = {str(i): j+len(self.double_stream_list) for j, i in enumerate(self.single_stream_list)} self.audio_proj = AudioProjNet2(seq_len=10, blocks=5, channels=384, intermediate_dim=1024, output_dim=3072, context_tokens=4)
self.double_stream_list = [1, 3, 5, 7, 9, 11]
audio_block_name = "audio_models"
self.audio_adapter_blocks = nn.ModuleList([ self.double_stream_map = {str(i): j for j, i in enumerate(self.double_stream_list)}
self.single_stream_list = []
self.single_stream_map = {str(i): j+len(self.double_stream_list) for j, i in enumerate(self.single_stream_list)}
setattr(self, audio_block_name, nn.ModuleList([
PerceiverAttentionCA(dim=3072, dim_head=1024, heads=33) for _ in range(len(self.double_stream_list) + len(self.single_stream_list)) PerceiverAttentionCA(dim=3072, dim_head=1024, heads=33) for _ in range(len(self.double_stream_list) + len(self.single_stream_list))
]) ]))
@ -798,6 +817,8 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
motion_pose = None, motion_pose = None,
fps = None, fps = None,
face_mask = None, face_mask = None,
audio_strength = None,
bg_latents = None,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
img = x img = x
@ -854,12 +875,16 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
# Embed image and text. # Embed image and text.
img, shape_mask = self.img_in(img) img, shape_mask = self.img_in(img)
if audio_prompts != None: if self.avatar:
ref_latents_first = ref_latents[:, :, :1].clone() ref_latents_first = ref_latents[:, :, :1].clone()
ref_latents,_ = self.ref_in(ref_latents) ref_latents,_ = self.ref_in(ref_latents)
ref_latents_first,_ = self.img_in(ref_latents_first) ref_latents_first,_ = self.img_in(ref_latents_first)
elif ref_latents != None: elif self.custom:
if ref_latents != None:
ref_latents, _ = self.img_in(ref_latents) ref_latents, _ = self.img_in(ref_latents)
if bg_latents is not None and self.video_condition:
bg_latents, _ = self.bg_in(bg_latents)
img += self.bg_proj(bg_latents)
if self.text_projection == "linear": if self.text_projection == "linear":
txt = self.txt_in(txt) txt = self.txt_in(txt)
@ -870,7 +895,7 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
f"Unsupported text_projection: {self.text_projection}" f"Unsupported text_projection: {self.text_projection}"
) )
if audio_prompts != None: if self.avatar:
img += self.before_proj(ref_latents) img += self.before_proj(ref_latents)
ref_length = ref_latents_first.shape[-2] # [b s c] ref_length = ref_latents_first.shape[-2] # [b s c]
img = torch.cat([ref_latents_first, img], dim=-2) # t c img = torch.cat([ref_latents_first, img], dim=-2) # t c
@ -956,7 +981,10 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
if audio_adapter != None: if audio_adapter != None:
real_img = img[i:i+1,ref_length:].view(1, ot, -1, 3072) real_img = img[i:i+1,ref_length:].view(1, ot, -1, 3072)
real_img = audio_adapter(audio_feature_all_insert[i:i+1], real_img).view(1, -1, 3072) real_img = audio_adapter(audio_feature_all_insert[i:i+1], real_img).view(1, -1, 3072)
if face_mask != None:
real_img *= face_mask[i:i+1] real_img *= face_mask[i:i+1]
if audio_strength != None and audio_strength != 1:
real_img *= audio_strength
img[i:i+1, ref_length:] += real_img img[i:i+1, ref_length:] += real_img
real_img = None real_img = None
@ -1095,6 +1123,27 @@ HUNYUAN_VIDEO_CONFIG = {
"hidden_size": 3072, "hidden_size": 3072,
"heads_num": 24, "heads_num": 24,
"mlp_width_ratio": 4, "mlp_width_ratio": 4,
'custom' : True
},
'HYVideo-T/2-custom-audio': { # 9.0B / 12.5B
"mm_double_blocks_depth": 20,
"mm_single_blocks_depth": 40,
"rope_dim_list": [16, 56, 56],
"hidden_size": 3072,
"heads_num": 24,
"mlp_width_ratio": 4,
'custom' : True,
'audio_condition' : True,
},
'HYVideo-T/2-custom-edit': { # 9.0B / 12.5B
"mm_double_blocks_depth": 20,
"mm_single_blocks_depth": 40,
"rope_dim_list": [16, 56, 56],
"hidden_size": 3072,
"heads_num": 24,
"mlp_width_ratio": 4,
'custom' : True,
'video_condition' : True,
}, },
'HYVideo-T/2-avatar': { # 9.0B / 12.5B 'HYVideo-T/2-avatar': { # 9.0B / 12.5B
'mm_double_blocks_depth': 20, 'mm_double_blocks_depth': 20,
@ -1104,6 +1153,7 @@ HUNYUAN_VIDEO_CONFIG = {
'heads_num': 24, 'heads_num': 24,
'mlp_width_ratio': 4, 'mlp_width_ratio': 4,
'avatar': True, 'avatar': True,
'audio_condition' : True,
}, },
} }

View File

@ -295,7 +295,10 @@ def apply_rotary_emb( qklist,
def get_nd_rotary_pos_embed_new(rope_dim_list, start, *args, theta=10000., use_real=False, def get_nd_rotary_pos_embed_new(rope_dim_list, start, *args, theta=10000., use_real=False,
theta_rescale_factor: Union[float, List[float]]=1.0, theta_rescale_factor: Union[float, List[float]]=1.0,
interpolation_factor: Union[float, List[float]]=1.0, interpolation_factor: Union[float, List[float]]=1.0,
concat_dict={} concat_dict={},
k = 4,
L_test = 66,
enable_riflex = True
): ):
grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H] grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H]
@ -327,9 +330,17 @@ def get_nd_rotary_pos_embed_new(rope_dim_list, start, *args, theta=10000., use_r
# use 1/ndim of dimensions to encode grid_axis # use 1/ndim of dimensions to encode grid_axis
embs = [] embs = []
for i in range(len(rope_dim_list)): for i in range(len(rope_dim_list)):
emb = get_1d_rotary_pos_embed(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=use_real, # === RIFLEx modification start ===
theta_rescale_factor=theta_rescale_factor[i], # apply RIFLEx for time dimension
interpolation_factor=interpolation_factor[i]) # 2 x [WHD, rope_dim_list[i]] if i == 0 and enable_riflex:
emb = get_1d_rotary_pos_embed_riflex(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=True, k=k, L_test=L_test)
# === RIFLEx modification end ===
else:
emb = get_1d_rotary_pos_embed(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=True, theta_rescale_factor=theta_rescale_factor[i],interpolation_factor=interpolation_factor[i],)
# emb = get_1d_rotary_pos_embed(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=use_real,
# theta_rescale_factor=theta_rescale_factor[i],
# w interpolation_factor=interpolation_factor[i]) # 2 x [WHD, rope_dim_list[i]]
embs.append(emb) embs.append(emb)

View File

@ -517,11 +517,15 @@ def export_image(image_refs, image_output):
image_refs.append( image_output) image_refs.append( image_output)
return image_refs return image_refs
def export_to_current_video_engine(foreground_video_output, alpha_video_output): def export_to_current_video_engine(model_type, foreground_video_output, alpha_video_output):
gr.Info("Masked Video Input and Full Mask transferred to Current Video Engine For Inpainting") gr.Info("Masked Video Input and Full Mask transferred to Current Video Engine For Inpainting")
# return "MV#" + str(time.time()), foreground_video_output, alpha_video_output # return "MV#" + str(time.time()), foreground_video_output, alpha_video_output
if "custom_edit" in model_type:
return gr.update(), alpha_video_output
else:
return foreground_video_output, alpha_video_output return foreground_video_output, alpha_video_output
def teleport_to_video_tab(): def teleport_to_video_tab():
return gr.Tabs(selected="video_gen") return gr.Tabs(selected="video_gen")
@ -675,7 +679,7 @@ def display(tabs, model_choice, vace_video_input, vace_video_mask, vace_image_re
export_to_vace_video_14B_btn.click( fn=teleport_to_vace_14B, inputs=[], outputs=[tabs, model_choice]).then( export_to_vace_video_14B_btn.click( fn=teleport_to_vace_14B, inputs=[], outputs=[tabs, model_choice]).then(
fn=export_to_current_video_engine, inputs= [foreground_video_output, alpha_video_output], outputs= [video_prompt_video_guide_trigger, vace_video_input, vace_video_mask]) fn=export_to_current_video_engine, inputs= [foreground_video_output, alpha_video_output], outputs= [video_prompt_video_guide_trigger, vace_video_input, vace_video_mask])
export_to_current_video_engine_btn.click( fn=export_to_current_video_engine, inputs= [foreground_video_output, alpha_video_output], outputs= [vace_video_input, vace_video_mask]).then( #video_prompt_video_guide_trigger, export_to_current_video_engine_btn.click( fn=export_to_current_video_engine, inputs= [model_choice, foreground_video_output, alpha_video_output], outputs= [vace_video_input, vace_video_mask]).then( #video_prompt_video_guide_trigger,
fn=teleport_to_video_tab, inputs= [], outputs= [tabs]) fn=teleport_to_video_tab, inputs= [], outputs= [tabs])

View File

@ -33,4 +33,5 @@ librosa
loguru loguru
sentencepiece sentencepiece
av av
opencv-python
# rembg==2.0.65 # rembg==2.0.65

View File

@ -46,7 +46,8 @@ def resample(video_fps, video_frames_count, max_target_frames_count, target_fps,
while True: while True:
if max_target_frames_count != 0 and len(frame_ids) >= max_target_frames_count : if max_target_frames_count != 0 and len(frame_ids) >= max_target_frames_count :
break break
add_frames_count = math.ceil( (target_time -cur_time) / video_frame_duration ) diff = round( (target_time -cur_time) / video_frame_duration , 5)
add_frames_count = math.ceil( diff)
frame_no += add_frames_count frame_no += add_frames_count
if frame_no >= video_frames_count: if frame_no >= video_frames_count:
break break

185
wgp.py
View File

@ -33,6 +33,7 @@ import tempfile
import atexit import atexit
import shutil import shutil
import glob import glob
import cv2
from transformers.utils import logging from transformers.utils import logging
logging.set_verbosity_error logging.set_verbosity_error
@ -43,7 +44,7 @@ AUTOSAVE_FILENAME = "queue.zip"
PROMPT_VARS_MAX = 10 PROMPT_VARS_MAX = 10
target_mmgp_version = "3.4.8" target_mmgp_version = "3.4.8"
WanGP_version = "5.41" WanGP_version = "5.5"
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
from importlib.metadata import version from importlib.metadata import version
@ -184,12 +185,18 @@ def process_prompt_and_add_tasks(state, model_choice):
gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P") gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P")
return return
if "diffusion_forcing" in model_filename or "ltxv" in model_filename or "Vace" in model_filename: if "diffusion_forcing" in model_filename or "ltxv" in model_filename or "Vace" in model_filename or "hunyuan_video_custom_edit" in model_filename:
video_length = inputs["video_length"] video_length = inputs["video_length"]
sliding_window_size = inputs["sliding_window_size"] sliding_window_size = inputs["sliding_window_size"]
if video_length > sliding_window_size: if video_length > sliding_window_size:
gr.Info(f"The Number of Frames to generate ({video_length}) is greater than the Sliding Window Size ({sliding_window_size}) , multiple Windows will be generated") gr.Info(f"The Number of Frames to generate ({video_length}) is greater than the Sliding Window Size ({sliding_window_size}) , multiple Windows will be generated")
if "hunyuan_video_custom_edit" in model_filename:
keep_frames_video_guide= inputs["keep_frames_video_guide"]
if len(keep_frames_video_guide) > 0:
gr.Info("Filtering Frames with this model is not supported")
return
if "phantom" in model_filename or "hunyuan_video_custom" in model_filename or "hunyuan_video_avatar" in model_filename: if "phantom" in model_filename or "hunyuan_video_custom" in model_filename or "hunyuan_video_avatar" in model_filename:
image_refs = inputs["image_refs"] image_refs = inputs["image_refs"]
audio_guide = inputs["audio_guide"] audio_guide = inputs["audio_guide"]
@ -1552,6 +1559,8 @@ ltxv_choices= ["ckpts/ltxv_0.9.7_13B_dev_bf16.safetensors", "ckpts/ltxv_0.9.7_13
hunyuan_choices= ["ckpts/hunyuan_video_720_bf16.safetensors", "ckpts/hunyuan_video_720_quanto_int8.safetensors", "ckpts/hunyuan_video_i2v_720_bf16v2.safetensors", "ckpts/hunyuan_video_i2v_720_quanto_int8v2.safetensors", hunyuan_choices= ["ckpts/hunyuan_video_720_bf16.safetensors", "ckpts/hunyuan_video_720_quanto_int8.safetensors", "ckpts/hunyuan_video_i2v_720_bf16v2.safetensors", "ckpts/hunyuan_video_i2v_720_quanto_int8v2.safetensors",
"ckpts/hunyuan_video_custom_720_bf16.safetensors", "ckpts/hunyuan_video_custom_720_quanto_bf16_int8.safetensors", "ckpts/hunyuan_video_custom_720_bf16.safetensors", "ckpts/hunyuan_video_custom_720_quanto_bf16_int8.safetensors",
"ckpts/hunyuan_video_custom_audio_720_bf16.safetensors", "ckpts/hunyuan_video_custom_audio_720_quanto_bf16_int8.safetensors",
"ckpts/hunyuan_video_custom_edit_720_bf16.safetensors", "ckpts/hunyuan_video_custom_edit_720_quanto_bf16_int8.safetensors",
"ckpts/hunyuan_video_avatar_720_bf16.safetensors", "ckpts/hunyuan_video_avatar_720_quanto_bf16_int8.safetensors", "ckpts/hunyuan_video_avatar_720_bf16.safetensors", "ckpts/hunyuan_video_avatar_720_quanto_bf16_int8.safetensors",
] ]
@ -1563,13 +1572,14 @@ def get_dependent_models(model_filename, quantization, dtype_policy ):
return [get_model_filename("ltxv_13B", quantization, dtype_policy)] return [get_model_filename("ltxv_13B", quantization, dtype_policy)]
else: else:
return [] return []
model_types = [ "t2v_1.3B", "t2v", "i2v", "i2v_720p", "flf2v_720p", "vace_1.3B","vace_14B","moviigen", "phantom_1.3B", "phantom_14B", "fantasy", "fun_inp_1.3B", "fun_inp", "recam_1.3B", "sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "ltxv_13B", "ltxv_13B_distilled", "hunyuan", "hunyuan_i2v", "hunyuan_custom", "hunyuan_avatar"] model_types = [ "t2v_1.3B", "t2v", "i2v", "i2v_720p", "flf2v_720p", "vace_1.3B","vace_14B","moviigen", "phantom_1.3B", "phantom_14B", "fantasy", "fun_inp_1.3B", "fun_inp", "recam_1.3B", "sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "ltxv_13B", "ltxv_13B_distilled", "hunyuan", "hunyuan_i2v", "hunyuan_custom", "hunyuan_custom_audio", "hunyuan_custom_edit", "hunyuan_avatar"]
model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B", model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B",
"i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B", "vace_14B" : "Vace_14B","recam_1.3B": "recammaster_1.3B", "i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B", "vace_14B" : "Vace_14B","recam_1.3B": "recammaster_1.3B",
"flf2v_720p" : "FLF2V_720p", "sky_df_1.3B" : "sky_reels2_diffusion_forcing_1.3B", "sky_df_14B" : "sky_reels2_diffusion_forcing_14B", "flf2v_720p" : "FLF2V_720p", "sky_df_1.3B" : "sky_reels2_diffusion_forcing_1.3B", "sky_df_14B" : "sky_reels2_diffusion_forcing_14B",
"sky_df_720p_14B" : "sky_reels2_diffusion_forcing_720p_14B", "moviigen" :"moviigen", "sky_df_720p_14B" : "sky_reels2_diffusion_forcing_720p_14B", "moviigen" :"moviigen",
"phantom_1.3B" : "phantom_1.3B", "phantom_14B" : "phantom_14B", "fantasy" : "fantasy", "ltxv_13B" : "ltxv_0.9.7_13B_dev", "ltxv_13B_distilled" : "ltxv_0.9.7_13B_distilled", "phantom_1.3B" : "phantom_1.3B", "phantom_14B" : "phantom_14B", "fantasy" : "fantasy", "ltxv_13B" : "ltxv_0.9.7_13B_dev", "ltxv_13B_distilled" : "ltxv_0.9.7_13B_distilled",
"hunyuan" : "hunyuan_video_720", "hunyuan_i2v" : "hunyuan_video_i2v_720", "hunyuan_custom" : "hunyuan_video_custom", "hunyuan_avatar" : "hunyuan_video_avatar" } "hunyuan" : "hunyuan_video_720", "hunyuan_i2v" : "hunyuan_video_i2v_720", "hunyuan_custom" : "hunyuan_video_custom_720", "hunyuan_custom_audio" : "hunyuan_video_custom_audio", "hunyuan_custom_edit" : "hunyuan_video_custom_edit",
"hunyuan_avatar" : "hunyuan_video_avatar" }
def get_model_type(model_filename): def get_model_type(model_filename):
@ -1653,6 +1663,13 @@ def get_model_name(model_filename, description_container = [""]):
model_name = "Hunyuan Video image2video 720p 13B" model_name = "Hunyuan Video image2video 720p 13B"
description = "A good looking image 2 video model, but not so good in prompt adherence." description = "A good looking image 2 video model, but not so good in prompt adherence."
elif "hunyuan_video_custom" in model_filename: elif "hunyuan_video_custom" in model_filename:
if "audio" in model_filename:
model_name = "Hunyuan Video Custom Audio 720p 13B"
description = "The Hunyuan Video Custom Audio model can be used to generate scenes of a person speaking given a Reference Image and a Recorded Voice or Song. The reference image is not a start image and therefore one can represent the person in a different context.The video length can be anything up to 10s. It is also quite good to generate no sound Video based on a person."
elif "edit" in model_filename:
model_name = "Hunyuan Video Custom Edit 720p 13B"
description = "The Hunyuan Video Custom Edit model can be used to do Video inpainting on a person (add accessories or completely replace the person). You will need in any case to define a Video Mask which will indicate which area of the Video should be edited."
else:
model_name = "Hunyuan Video Custom 720p 13B" model_name = "Hunyuan Video Custom 720p 13B"
description = "The Hunyuan Video Custom model is probably the best model to transfer people (only people for the momment) as it is quite good to keep their identity. However it is slow as to get good results, you need to generate 720p videos with 30 steps." description = "The Hunyuan Video Custom model is probably the best model to transfer people (only people for the momment) as it is quite good to keep their identity. However it is slow as to get good results, you need to generate 720p videos with 30 steps."
elif "hunyuan_video_avatar" in model_filename: elif "hunyuan_video_avatar" in model_filename:
@ -1778,6 +1795,13 @@ def get_default_settings(filename):
"flow_shift": 13, "flow_shift": 13,
"resolution": "1280x720", "resolution": "1280x720",
}) })
elif get_model_type(filename) in ("hunyuan_custom_edit"):
ui_defaults.update({
"guidance_scale": 7.5,
"flow_shift": 13,
"video_prompt_type": "MV",
"sliding_window_size": 129,
})
elif get_model_type(filename) in ("hunyuan_avatar"): elif get_model_type(filename) in ("hunyuan_avatar"):
ui_defaults.update({ ui_defaults.update({
"guidance_scale": 7.5, "guidance_scale": 7.5,
@ -2570,19 +2594,96 @@ def convert_image(image):
image = image.convert('RGB') image = image.convert('RGB')
return cast(Image, ImageOps.exif_transpose(image)) return cast(Image, ImageOps.exif_transpose(image))
def get_resampled_video(video_in, start_frame, max_frames, target_fps): def get_resampled_video(video_in, start_frame, max_frames, target_fps, bridge='torch'):
from wan.utils.utils import resample from wan.utils.utils import resample
import decord import decord
decord.bridge.set_bridge('torch') decord.bridge.set_bridge(bridge)
reader = decord.VideoReader(video_in) reader = decord.VideoReader(video_in)
fps = reader.get_avg_fps() fps = round(reader.get_avg_fps())
frame_nos = resample(fps, len(reader), max_target_frames_count= max_frames, target_fps=target_fps, start_target_frame= start_frame) frame_nos = resample(fps, len(reader), max_target_frames_count= max_frames, target_fps=target_fps, start_target_frame= start_frame)
frames_list = reader.get_batch(frame_nos) frames_list = reader.get_batch(frame_nos)
# print(f"frame nos: {frame_nos}")
return frames_list return frames_list
def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = False, target_fps = 16, block_size= 16, expand_scale = 2, pose_enhance = True, to_bbox = False):
if not input_video_path or not input_mask_path:
return None, None
from preprocessing.dwpose.pose import PoseBodyFaceVideoAnnotator
cfg_dict = {
"DETECTION_MODEL": "ckpts/pose/yolox_l.onnx",
"POSE_MODEL": "ckpts/pose/dw-ll_ucoco_384.onnx",
"RESIZE_SIZE": 1024
}
dwpose = PoseBodyFaceVideoAnnotator(cfg_dict)
video = get_resampled_video(input_video_path, start_frame, max_frames, target_fps)
mask_video = get_resampled_video(input_mask_path, start_frame, max_frames, target_fps)
if len(video) == 0 or len(mask_video) == 0:
return None, None
frame_height, frame_width, _ = video[0].shape
if fit_canvas :
scale1 = min(height / frame_height, width / frame_width)
scale2 = min(height / frame_width, width / frame_height)
scale = max(scale1, scale2)
else:
scale = ((height * width ) / (frame_height * frame_width))**(1/2)
height = (int(frame_height * scale) // block_size) * block_size
width = (int(frame_width * scale) // block_size) * block_size
num_frames = min(len(video), len(mask_video))
masked_frames = []
masks = []
for frame_idx in range(num_frames):
frame = Image.fromarray(video[frame_idx].cpu().numpy()) #.asnumpy()
mask = Image.fromarray(mask_video[frame_idx].cpu().numpy()) #.asnumpy()
frame = frame.resize((width, height), resample=Image.Resampling.LANCZOS)
mask = mask.resize((width, height), resample=Image.Resampling.LANCZOS)
frame = np.array(frame)
mask = np.array(mask)
if len(mask.shape) == 3 and mask.shape[2] == 3:
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
if expand_scale != 0:
kernel_size = abs(expand_scale)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
op_expand = cv2.dilate if expand_scale > 0 else cv2.erode
mask = op_expand(mask, kernel, iterations=3)
_, mask = cv2.threshold(mask, 127.5, 255, cv2.THRESH_BINARY)
if to_bbox and np.sum(mask == 255) > 0:
x0, y0, x1, y1 = mask_to_xyxy_box(mask)
mask = mask * 0
mask[y0:y1, x0:x1] = 255
inverse_mask = mask == 0
if pose_enhance:
pose_img = dwpose.forward([frame])[0]
masked_frame = np.where(inverse_mask[..., None], frame, pose_img)
else:
masked_frame = frame * (inverse_mask[..., None].astype(frame.dtype))
mask = torch.from_numpy(mask) # to be commented if save one video enabled
masked_frame = torch.from_numpy(masked_frame) # to be commented if save one video debug enabled
masks.append(mask)
masked_frames.append(masked_frame)
# from preprocessing.dwpose.pose import save_one_video
# save_one_video("masked_frames.mp4", masked_frames, fps=target_fps, quality=8, macro_block_size=None)
# save_one_video("masks.mp4", masks, fps=target_fps, quality=8, macro_block_size=None)
return torch.stack(masked_frames), torch.stack(masks)
def preprocess_video(process_type, height, width, video_in, max_frames, start_frame=0, fit_canvas = False, target_fps = 16, block_size = 16): def preprocess_video(process_type, height, width, video_in, max_frames, start_frame=0, fit_canvas = False, target_fps = 16, block_size = 16):
frames_list = get_resampled_video(video_in, start_frame, max_frames, target_fps) frames_list = get_resampled_video(video_in, start_frame, max_frames, target_fps)
@ -2600,9 +2701,6 @@ def preprocess_video(process_type, height, width, video_in, max_frames, start_fr
new_height = (int(frame_height * scale) // block_size) * block_size new_height = (int(frame_height * scale) // block_size) * block_size
new_width = (int(frame_width * scale) // block_size) * block_size new_width = (int(frame_width * scale) // block_size) * block_size
# if fit_canvas :
# new_height = height
# new_width = width
processed_frames_list = [] processed_frames_list = []
for frame in frames_list: for frame in frames_list:
@ -2857,12 +2955,14 @@ def generate_video(
hunyuan_t2v = "hunyuan_video_720" in model_filename hunyuan_t2v = "hunyuan_video_720" in model_filename
hunyuan_i2v = "hunyuan_video_i2v" in model_filename hunyuan_i2v = "hunyuan_video_i2v" in model_filename
hunyuan_custom = "hunyuan_video_custom" in model_filename hunyuan_custom = "hunyuan_video_custom" in model_filename
hunyuan_custom_audio = hunyuan_custom and "audio" in model_filename
hunyuan_custom_edit = hunyuan_custom and "edit" in model_filename
hunyuan_avatar = "hunyuan_video_avatar" in model_filename hunyuan_avatar = "hunyuan_video_avatar" in model_filename
fantasy = "fantasy" in model_filename fantasy = "fantasy" in model_filename
if diffusion_forcing or hunyuan_t2v or hunyuan_i2v or hunyuan_custom: if hunyuan_avatar or hunyuan_custom_audio:
fps = 24
elif hunyuan_avatar:
fps = 25 fps = 25
elif diffusion_forcing or hunyuan_t2v or hunyuan_i2v or hunyuan_custom:
fps = 24
elif fantasy: elif fantasy:
fps = 23 fps = 23
elif ltxv: elif ltxv:
@ -2913,7 +3013,7 @@ def generate_video(
audio_proj_split = None audio_proj_split = None
audio_scale = None audio_scale = None
audio_context_lens = None audio_context_lens = None
if (fantasy or hunyuan_avatar) and audio_guide != None: if (fantasy or hunyuan_avatar or hunyuan_custom_audio) and audio_guide != None:
from fantasytalking.infer import parse_audio from fantasytalking.infer import parse_audio
import librosa import librosa
duration = librosa.get_duration(path=audio_guide) duration = librosa.get_duration(path=audio_guide)
@ -2922,6 +3022,12 @@ def generate_video(
audio_proj_split, audio_context_lens = parse_audio(audio_guide, num_frames= current_video_length, fps= fps, device= processing_device ) audio_proj_split, audio_context_lens = parse_audio(audio_guide, num_frames= current_video_length, fps= fps, device= processing_device )
audio_scale = 1.0 audio_scale = 1.0
if hunyuan_custom_edit and video_guide != None:
import cv2
cap = cv2.VideoCapture(video_guide)
length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
current_video_length = min(current_video_length, length)
import random import random
if seed == None or seed <0: if seed == None or seed <0:
seed = random.randint(0, 999999999) seed = random.randint(0, 999999999)
@ -2938,13 +3044,10 @@ def generate_video(
repeat_no = 0 repeat_no = 0
extra_generation = 0 extra_generation = 0
initial_total_windows = 0 initial_total_windows = 0
if diffusion_forcing or vace or ltxv:
reuse_frames = min(sliding_window_size - 4, sliding_window_overlap)
else:
reuse_frames = 0
if (diffusion_forcing or ltxv) and source_video != None: if (diffusion_forcing or ltxv) and source_video != None:
current_video_length += sliding_window_overlap current_video_length += sliding_window_overlap
sliding_window = (vace or diffusion_forcing or ltxv) and current_video_length > sliding_window_size sliding_window = (vace or diffusion_forcing or ltxv or hunyuan_custom_edit) and current_video_length > sliding_window_size
reuse_frames = min(sliding_window_size - 4, sliding_window_overlap) if sliding_window else 0
discard_last_frames = sliding_window_discard_last_frames discard_last_frames = sliding_window_discard_last_frames
default_max_frames_to_generate = current_video_length default_max_frames_to_generate = current_video_length
@ -3094,6 +3197,10 @@ def generate_video(
pre_src_video = [pre_video_guide], pre_src_video = [pre_video_guide],
fit_into_canvas = fit_canvas fit_into_canvas = fit_canvas
) )
elif hunyuan_custom_edit:
progress_args = [0, get_latest_status(state,"Extracting Open Pose Information and Expanding Mask")]
send_cmd("progress", progress_args)
src_video, src_mask = preprocess_video_with_mask(video_guide, video_mask, height=height, width = width, max_frames= current_video_length if window_no == 1 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = fit_canvas, target_fps = fps, pose_enhance = "P" in video_prompt_type)
if window_no == 1: if window_no == 1:
conditioning_latents_size = ( (prefix_video_frames_count-1) // latent_size) + 1 if prefix_video_frames_count > 0 else 0 conditioning_latents_size = ( (prefix_video_frames_count-1) // latent_size) + 1 if prefix_video_frames_count > 0 else 0
else: else:
@ -3124,7 +3231,7 @@ def generate_video(
input_frames = src_video, input_frames = src_video,
input_ref_images= src_ref_images, input_ref_images= src_ref_images,
input_masks = src_mask, input_masks = src_mask,
input_video= pre_video_guide if diffusion_forcing or ltxv else source_video, input_video= pre_video_guide if diffusion_forcing or ltxv or hunyuan_custom_edit else source_video,
target_camera= target_camera, target_camera= target_camera,
frame_num=(current_video_length // latent_size)* latent_size + 1, frame_num=(current_video_length // latent_size)* latent_size + 1,
height = height, height = height,
@ -3693,12 +3800,12 @@ def process_tasks(state):
def get_generation_status(prompt_no, prompts_max, repeat_no, repeat_max, window_no, total_windows): def get_generation_status(prompt_no, prompts_max, repeat_no, repeat_max, window_no, total_windows):
if prompts_max == 1: if prompts_max == 1:
if repeat_max == 1: if repeat_max <= 1:
status = "" status = ""
else: else:
status = f"Sample {repeat_no}/{repeat_max}" status = f"Sample {repeat_no}/{repeat_max}"
else: else:
if repeat_max == 1: if repeat_max <= 1:
status = f"Prompt {prompt_no}/{prompts_max}" status = f"Prompt {prompt_no}/{prompts_max}"
else: else:
status = f"Prompt {prompt_no}/{prompts_max}, Sample {repeat_no}/{repeat_max}" status = f"Prompt {prompt_no}/{prompts_max}, Sample {repeat_no}/{repeat_max}"
@ -4111,7 +4218,7 @@ def prepare_inputs_dict(target, inputs ):
inputs.pop(k) inputs.pop(k)
if not "Vace" in model_filename and not "diffusion_forcing" in model_filename and not "ltxv" in model_filename: if not "Vace" in model_filename and not "diffusion_forcing" in model_filename and not "ltxv" in model_filename and not "hunyuan_custom_edit" in model_filename:
unsaved_params = [ "sliding_window_size", "sliding_window_overlap", "sliding_window_overlap_noise", "sliding_window_discard_last_frames"] unsaved_params = [ "sliding_window_size", "sliding_window_overlap", "sliding_window_overlap_noise", "sliding_window_discard_last_frames"]
for k in unsaved_params: for k in unsaved_params:
inputs.pop(k) inputs.pop(k)
@ -4556,8 +4663,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
hunyuan_t2v = "hunyuan_video_720" in model_filename hunyuan_t2v = "hunyuan_video_720" in model_filename
hunyuan_i2v = "hunyuan_video_i2v" in model_filename hunyuan_i2v = "hunyuan_video_i2v" in model_filename
hunyuan_video_custom = "hunyuan_video_custom" in model_filename hunyuan_video_custom = "hunyuan_video_custom" in model_filename
hunyuan_video_custom_audio = hunyuan_video_custom and "audio" in model_filename
hunyuan_video_custom_edit = hunyuan_video_custom and "edit" in model_filename
hunyuan_video_avatar = "hunyuan_video_avatar" in model_filename hunyuan_video_avatar = "hunyuan_video_avatar" in model_filename
sliding_window_enabled = vace or diffusion_forcing or ltxv sliding_window_enabled = vace or diffusion_forcing or ltxv or hunyuan_video_custom_edit
new_line_text = "each new line of prompt will be used for a window" if sliding_window_enabled else "each new line of prompt will generate a new video" new_line_text = "each new line of prompt will be used for a window" if sliding_window_enabled else "each new line of prompt will generate a new video"
with gr.Column(visible= test_class_i2v(model_filename) or diffusion_forcing or ltxv or recammaster) as image_prompt_column: with gr.Column(visible= test_class_i2v(model_filename) or diffusion_forcing or ltxv or recammaster) as image_prompt_column:
@ -4630,7 +4739,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
model_mode = gr.Dropdown(value=None, visible=False) model_mode = gr.Dropdown(value=None, visible=False)
keep_frames_video_source = gr.Text(visible=False) keep_frames_video_source = gr.Text(visible=False)
with gr.Column(visible= vace or phantom or hunyuan_video_custom or hunyuan_video_avatar) as video_prompt_column: with gr.Column(visible= vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or hunyuan_video_custom_edit ) as video_prompt_column:
video_prompt_type_value= ui_defaults.get("video_prompt_type","") video_prompt_type_value= ui_defaults.get("video_prompt_type","")
video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False) video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False)
with gr.Row(): with gr.Row():
@ -4648,6 +4757,15 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
value=filter_letters(video_prompt_type_value, "ODPCMV"), value=filter_letters(video_prompt_type_value, "ODPCMV"),
label="Video to Video", scale = 3, visible= True label="Video to Video", scale = 3, visible= True
) )
elif hunyuan_video_custom_edit:
video_prompt_type_video_guide = gr.Dropdown(
choices=[
("Inpaint Control Video in area defined by Mask", "MV"),
("Inpaint and Transfer Human Motion from the Control Video in area defined by Mask", "PMV"),
],
value=filter_letters(video_prompt_type_value, "ODPCMV"),
label="Video to Video", scale = 3, visible= True
)
else: else:
video_prompt_type_video_guide = gr.Dropdown(visible= False) video_prompt_type_video_guide = gr.Dropdown(visible= False)
@ -4686,7 +4804,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
video_mask = gr.Video(label= "Video Mask (for Inpainting or Outpaing, white pixels = Mask)", visible= "M" in video_prompt_type_value, value= ui_defaults.get("video_mask", None)) video_mask = gr.Video(label= "Video Mask (for Inpainting or Outpaing, white pixels = Mask)", visible= "M" in video_prompt_type_value, value= ui_defaults.get("video_mask", None))
audio_guide = gr.Audio(value= ui_defaults.get("audio_guide", None), type="filepath", label="Voice to follow", show_download_button= True, visible= fantasy or hunyuan_video_avatar ) audio_guide = gr.Audio(value= ui_defaults.get("audio_guide", None), type="filepath", label="Voice to follow", show_download_button= True, visible= fantasy or hunyuan_video_avatar or hunyuan_video_custom_audio )
advanced_prompt = advanced_ui advanced_prompt = advanced_ui
prompt_vars=[] prompt_vars=[]
@ -4775,9 +4893,9 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
video_length = gr.Slider(17, 737, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True) video_length = gr.Slider(17, 737, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True)
elif fantasy: elif fantasy:
video_length = gr.Slider(5, 233, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (23 = 1s)", interactive= True) video_length = gr.Slider(5, 233, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (23 = 1s)", interactive= True)
elif hunyuan_video_avatar: elif hunyuan_video_avatar or hunyuan_video_custom_audio:
video_length = gr.Slider(5, 401, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (25 = 1s)", interactive= True) video_length = gr.Slider(5, 401, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (25 = 1s)", interactive= True)
elif hunyuan_t2v or hunyuan_i2v: elif hunyuan_t2v or hunyuan_i2v or hunyuan_video_custom:
video_length = gr.Slider(5, 337, value=ui_defaults.get("video_length", 97), step=4, label="Number of frames (24 = 1s)", interactive= True) video_length = gr.Slider(5, 337, value=ui_defaults.get("video_length", 97), step=4, label="Number of frames (24 = 1s)", interactive= True)
else: else:
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True) video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True)
@ -4915,17 +5033,22 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
if diffusion_forcing: if diffusion_forcing:
sliding_window_size = gr.Slider(37, 137, value=ui_defaults.get("sliding_window_size", 97), step=20, label="Sliding Window Size (recommended to keep it at 97)") sliding_window_size = gr.Slider(37, 137, value=ui_defaults.get("sliding_window_size", 97), step=20, label="Sliding Window Size (recommended to keep it at 97)")
sliding_window_overlap = gr.Slider(17, 97, value=ui_defaults.get("sliding_window_overlap",17), step=20, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)") sliding_window_overlap = gr.Slider(17, 97, value=ui_defaults.get("sliding_window_overlap",17), step=20, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
sliding_window_overlap_noise = gr.Slider(0, 100, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect") sliding_window_overlap_noise = gr.Slider(0, 100, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect", visible = True)
sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, visible = False) sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, visible = False)
elif ltxv: elif ltxv:
sliding_window_size = gr.Slider(41, 257, value=ui_defaults.get("sliding_window_size", 129), step=8, label="Sliding Window Size") sliding_window_size = gr.Slider(41, 257, value=ui_defaults.get("sliding_window_size", 129), step=8, label="Sliding Window Size")
sliding_window_overlap = gr.Slider(9, 97, value=ui_defaults.get("sliding_window_overlap",9), step=8, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)") sliding_window_overlap = gr.Slider(9, 97, value=ui_defaults.get("sliding_window_overlap",9), step=8, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
sliding_window_overlap_noise = gr.Slider(0, 100, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect") sliding_window_overlap_noise = gr.Slider(0, 100, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect", visible = False)
sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, visible = False) sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, visible = False)
elif hunyuan_video_custom_edit:
sliding_window_size = gr.Slider(5, 257, value=ui_defaults.get("sliding_window_size", 129), step=4, label="Sliding Window Size")
sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",5), step=4, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
sliding_window_overlap_noise = gr.Slider(0, 150, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect", visible = False)
sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True)
else: else:
sliding_window_size = gr.Slider(5, 137, value=ui_defaults.get("sliding_window_size", 81), step=4, label="Sliding Window Size") sliding_window_size = gr.Slider(5, 137, value=ui_defaults.get("sliding_window_size", 81), step=4, label="Sliding Window Size")
sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",5), step=4, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)") sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",5), step=4, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
sliding_window_overlap_noise = gr.Slider(0, 150, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect") sliding_window_overlap_noise = gr.Slider(0, 150, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect" , visible = True)
sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 8), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True) sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 8), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True)