added outpainting on injected frames and shapes preprocessor

This commit is contained in:
DeepBeepMeep 2025-06-20 23:57:07 +02:00
parent 8b146a8d7b
commit d0a32c67a0
5 changed files with 276 additions and 80 deletions

View File

@ -26,7 +26,8 @@ class DepthV2Annotator:
self.model.load_state_dict( self.model.load_state_dict(
torch.load( torch.load(
pretrained_model, pretrained_model,
map_location=self.device map_location=self.device,
weights_only=True
) )
) )
self.model.eval() self.model.eval()

148
preprocessing/scribble.py Normal file
View File

@ -0,0 +1,148 @@
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from PIL import Image
norm_layer = nn.InstanceNorm2d
def convert_to_torch(image):
if isinstance(image, Image.Image):
image = torch.from_numpy(np.array(image)).float()
elif isinstance(image, torch.Tensor):
image = image.clone()
elif isinstance(image, np.ndarray):
image = torch.from_numpy(image.copy()).float()
else:
raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.'
return image
class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
conv_block = [
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
norm_layer(in_features),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
norm_layer(in_features)
]
self.conv_block = nn.Sequential(*conv_block)
def forward(self, x):
return x + self.conv_block(x)
class ContourInference(nn.Module):
def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
super(ContourInference, self).__init__()
# Initial convolution block
model0 = [
nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, 64, 7),
norm_layer(64),
nn.ReLU(inplace=True)
]
self.model0 = nn.Sequential(*model0)
# Downsampling
model1 = []
in_features = 64
out_features = in_features * 2
for _ in range(2):
model1 += [
nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
norm_layer(out_features),
nn.ReLU(inplace=True)
]
in_features = out_features
out_features = in_features * 2
self.model1 = nn.Sequential(*model1)
model2 = []
# Residual blocks
for _ in range(n_residual_blocks):
model2 += [ResidualBlock(in_features)]
self.model2 = nn.Sequential(*model2)
# Upsampling
model3 = []
out_features = in_features // 2
for _ in range(2):
model3 += [
nn.ConvTranspose2d(in_features,
out_features,
3,
stride=2,
padding=1,
output_padding=1),
norm_layer(out_features),
nn.ReLU(inplace=True)
]
in_features = out_features
out_features = in_features // 2
self.model3 = nn.Sequential(*model3)
# Output layer
model4 = [nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7)]
if sigmoid:
model4 += [nn.Sigmoid()]
self.model4 = nn.Sequential(*model4)
def forward(self, x, cond=None):
out = self.model0(x)
out = self.model1(out)
out = self.model2(out)
out = self.model3(out)
out = self.model4(out)
return out
class ScribbleAnnotator:
def __init__(self, cfg, device=None):
input_nc = cfg.get('INPUT_NC', 3)
output_nc = cfg.get('OUTPUT_NC', 1)
n_residual_blocks = cfg.get('N_RESIDUAL_BLOCKS', 3)
sigmoid = cfg.get('SIGMOID', True)
pretrained_model = cfg['PRETRAINED_MODEL']
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
self.model = ContourInference(input_nc, output_nc, n_residual_blocks,
sigmoid)
self.model.load_state_dict(torch.load(pretrained_model, weights_only=True))
self.model = self.model.eval().requires_grad_(False).to(self.device)
@torch.no_grad()
@torch.inference_mode()
@torch.autocast('cuda', enabled=False)
def forward(self, image):
is_batch = False if len(image.shape) == 3 else True
image = convert_to_torch(image)
if len(image.shape) == 3:
image = rearrange(image, 'h w c -> 1 c h w')
image = image.float().div(255).to(self.device)
contour_map = self.model(image)
contour_map = (contour_map.squeeze(dim=1) * 255.0).clip(
0, 255).cpu().numpy().astype(np.uint8)
contour_map = contour_map[..., None].repeat(3, -1)
if not is_batch:
contour_map = contour_map.squeeze()
return contour_map
class ScribbleVideoAnnotator(ScribbleAnnotator):
def forward(self, frames):
ret_frames = []
for frame in frames:
anno_frame = super().forward(np.array(frame))
ret_frames.append(anno_frame)
return ret_frames

View File

@ -27,6 +27,7 @@ from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from wan.modules.posemb_layers import get_rotary_pos_embed from wan.modules.posemb_layers import get_rotary_pos_embed
from .utils.vace_preprocessor import VaceVideoProcessor from .utils.vace_preprocessor import VaceVideoProcessor
from wan.utils.basic_flowmatch import FlowMatchScheduler from wan.utils.basic_flowmatch import FlowMatchScheduler
from wan.utils.utils import get_outpainting_frame_location
def optimized_scale(positive_flat, negative_flat): def optimized_scale(positive_flat, negative_flat):
@ -188,38 +189,52 @@ class WanT2V:
def vace_latent(self, z, m): def vace_latent(self, z, m):
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device): def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device, fill_max = False, outpainting_dims = None):
from wan.utils.utils import save_image
ref_width, ref_height = ref_img.size ref_width, ref_height = ref_img.size
if (ref_height, ref_width) == image_size: if (ref_height, ref_width) == image_size and outpainting_dims == None:
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
else: else:
canvas_height, canvas_width = image_size if outpainting_dims != None:
final_height, final_width = image_size
canvas_height, canvas_width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 8)
else:
canvas_height, canvas_width = image_size
scale = min(canvas_height / ref_height, canvas_width / ref_width) scale = min(canvas_height / ref_height, canvas_width / ref_width)
new_height = int(ref_height * scale) new_height = int(ref_height * scale)
new_width = int(ref_width * scale) new_width = int(ref_width * scale)
white_canvas = torch.full((3, 1, canvas_height, canvas_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1] if fill_max and (canvas_height - new_height) < 16:
ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) new_height = canvas_height
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) if fill_max and (canvas_width - new_width) < 16:
new_width = canvas_width
top = (canvas_height - new_height) // 2 top = (canvas_height - new_height) // 2
left = (canvas_width - new_width) // 2 left = (canvas_width - new_width) // 2
white_canvas[:, :, top:top + new_height, left:left + new_width] = ref_img ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
if outpainting_dims != None:
white_canvas = torch.full((3, 1, final_height, final_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1]
white_canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = ref_img
else:
white_canvas = torch.full((3, 1, canvas_height, canvas_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1]
white_canvas[:, :, top:top + new_height, left:left + new_width] = ref_img
ref_img = white_canvas ref_img = white_canvas
return ref_img.to(device) return ref_img.to(device)
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, original_video = False, keep_frames= [], start_frame = 0, fit_into_canvas = None, pre_src_video = None, inject_frames = []): def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, original_video = False, keep_frames= [], start_frame = 0, fit_into_canvas = None, pre_src_video = None, inject_frames = [], outpainting_dims = None):
image_sizes = [] image_sizes = []
trim_video = len(keep_frames) trim_video = len(keep_frames)
canvas_height, canvas_width = image_size def conv_tensor(t, device):
return t.float().div_(127.5).add_(-1).permute(3, 0, 1, 2).to(device)
for i, (sub_src_video, sub_src_mask, sub_pre_src_video) in enumerate(zip(src_video, src_mask,pre_src_video)): for i, (sub_src_video, sub_src_mask, sub_pre_src_video) in enumerate(zip(src_video, src_mask,pre_src_video)):
prepend_count = 0 if sub_pre_src_video == None else sub_pre_src_video.shape[1] prepend_count = 0 if sub_pre_src_video == None else sub_pre_src_video.shape[1]
num_frames = total_frames - prepend_count num_frames = total_frames - prepend_count
num_frames = min(num_frames, trim_video) if trim_video > 0 else num_frames
if sub_src_mask is not None and sub_src_video is not None: if sub_src_mask is not None and sub_src_video is not None:
src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame, canvas_height = canvas_height, canvas_width = canvas_width, fit_into_canvas = fit_into_canvas) src_video[i] = conv_tensor(sub_src_video[:num_frames], device)
src_mask[i] = conv_tensor(sub_src_mask[:num_frames], device)
# src_video is [-1, 1] (at this function output), 0 = inpainting area (in fact 127 in [0, 255]) # src_video is [-1, 1] (at this function output), 0 = inpainting area (in fact 127 in [0, 255])
# src_mask is [-1, 1] (at this function output), 0 = preserve original video (in fact 127 in [0, 255]) and 1 = Inpainting (in fact 255 in [0, 255]) # src_mask is [-1, 1] (at this function output), 0 = preserve original video (in fact 127 in [0, 255]) and 1 = Inpainting (in fact 255 in [0, 255])
src_video[i] = src_video[i].to(device)
src_mask[i] = src_mask[i].to(device)
if prepend_count > 0: if prepend_count > 0:
src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1) src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
src_mask[i] = torch.cat( [torch.full_like(sub_pre_src_video, -1.0), src_mask[i]] ,1) src_mask[i] = torch.cat( [torch.full_like(sub_pre_src_video, -1.0), src_mask[i]] ,1)
@ -238,8 +253,7 @@ class WanT2V:
src_mask[i] = torch.ones_like(src_video[i], device=device) src_mask[i] = torch.ones_like(src_video[i], device=device)
image_sizes.append(image_size) image_sizes.append(image_size)
else: else:
src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame, canvas_height = canvas_height, canvas_width = canvas_width, fit_into_canvas = fit_into_canvas) src_video[i] = conv_tensor(sub_src_video[:num_frames], device)
src_video[i] = src_video[i].to(device)
src_mask[i] = torch.zeros_like(src_video[i], device=device) if original_video else torch.ones_like(src_video[i], device=device) src_mask[i] = torch.zeros_like(src_video[i], device=device) if original_video else torch.ones_like(src_video[i], device=device)
if prepend_count > 0: if prepend_count > 0:
src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1) src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
@ -256,7 +270,7 @@ class WanT2V:
for k, frame in enumerate(inject_frames): for k, frame in enumerate(inject_frames):
if frame != None: if frame != None:
src_video[i][:, k:k+1] = self.fit_image_into_canvas(frame, image_size, 0, device) src_video[i][:, k:k+1] = self.fit_image_into_canvas(frame, image_size, 0, device, True, outpainting_dims)
src_mask[i][:, k:k+1] = 0 src_mask[i][:, k:k+1] = 0

View File

@ -78,11 +78,37 @@ def remove_background(img, session=None):
img = remove(img, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') img = remove(img, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0) return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0)
def save_image(tensor_image, name): def convert_tensor_to_image(t, frame_no = -1):
import numpy as np t = t[:, frame_no] if frame_no >= 0 else t
tensor_image = tensor_image.clone() return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy())
tensor_image= tensor_image.add_(1).mul(127.5).squeeze(1).permute(1,2,0)
Image.fromarray(tensor_image.cpu().numpy().astype(np.uint8)).save(name) def save_image(tensor_image, name, frame_no = -1):
convert_tensor_to_image(tensor_image, frame_no).save(name)
def get_outpainting_full_area_dimensions(frame_height,frame_width, outpainting_dims):
outpainting_top, outpainting_bottom, outpainting_left, outpainting_right= outpainting_dims
frame_height = int(frame_height * (100 + outpainting_top + outpainting_bottom) / 100)
frame_width = int(frame_width * (100 + outpainting_left + outpainting_right) / 100)
return frame_height, frame_width
def get_outpainting_frame_location(final_height, final_width, outpainting_dims, block_size = 8):
outpainting_top, outpainting_bottom, outpainting_left, outpainting_right= outpainting_dims
raw_height = int(final_height / ((100 + outpainting_top + outpainting_bottom) / 100))
height = int(raw_height / block_size) * block_size
extra_height = raw_height - height
raw_width = int(final_width / ((100 + outpainting_left + outpainting_right) / 100))
width = int(raw_width / block_size) * block_size
extra_width = raw_width - width
margin_top = int(outpainting_top/(100 + outpainting_top + outpainting_bottom) * final_height)
if extra_height != 0 and (outpainting_top + outpainting_bottom) != 0:
margin_top += int(outpainting_top / (outpainting_top + outpainting_bottom) * extra_height)
if (margin_top + height) > final_height or outpainting_bottom == 0: margin_top = final_height - height
margin_left = int(outpainting_left/(100 + outpainting_left + outpainting_right) * final_width)
if extra_width != 0 and (outpainting_left + outpainting_right) != 0:
margin_left += int(outpainting_left / (outpainting_left + outpainting_right) * extra_height)
if (margin_left + width) > final_width or outpainting_right == 0: margin_left = final_width - width
return height, width, margin_top, margin_left
def calculate_new_dimensions(canvas_height, canvas_width, height, width, fit_into_canvas, block_size = 16): def calculate_new_dimensions(canvas_height, canvas_width, height, width, fit_into_canvas, block_size = 16):
if fit_into_canvas == None: if fit_into_canvas == None:

123
wgp.py
View File

@ -16,7 +16,7 @@ import json
import wan import wan
from wan.utils import notification_sound from wan.utils import notification_sound
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES, VACE_SIZE_CONFIGS from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES, VACE_SIZE_CONFIGS
from wan.utils.utils import cache_video from wan.utils.utils import cache_video, convert_tensor_to_image, save_image
from wan.modules.attention import get_attention_modes, get_supported_attention_modes from wan.modules.attention import get_attention_modes, get_supported_attention_modes
import torch import torch
import gc import gc
@ -45,7 +45,7 @@ AUTOSAVE_FILENAME = "queue.zip"
PROMPT_VARS_MAX = 10 PROMPT_VARS_MAX = 10
target_mmgp_version = "3.4.9" target_mmgp_version = "3.4.9"
WanGP_version = "6.2" WanGP_version = "6.21"
settings_version = 2 settings_version = 2
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
@ -1793,12 +1793,12 @@ def get_default_settings(model_type):
"slg_end_perc": 90 "slg_end_perc": 90
} }
if model_type in ("hunyuan","hunyuan_i2v"): if model_type in ["hunyuan","hunyuan_i2v"]:
ui_defaults.update({ ui_defaults.update({
"guidance_scale": 7.0, "guidance_scale": 7.0,
}) })
if model_type in ("sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"): if model_type in ["sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"]:
ui_defaults.update({ ui_defaults.update({
"guidance_scale": 6.0, "guidance_scale": 6.0,
"flow_shift": 8, "flow_shift": 8,
@ -1811,7 +1811,7 @@ def get_default_settings(model_type):
}) })
if model_type in ("phantom_1.3B", "phantom_14B"): if model_type in ["phantom_1.3B", "phantom_14B"]:
ui_defaults.update({ ui_defaults.update({
"guidance_scale": 7.5, "guidance_scale": 7.5,
"flow_shift": 5, "flow_shift": 5,
@ -1820,27 +1820,27 @@ def get_default_settings(model_type):
# "resolution": "1280x720" # "resolution": "1280x720"
}) })
elif model_type in ("hunyuan_custom"): elif model_type in ["hunyuan_custom"]:
ui_defaults.update({ ui_defaults.update({
"guidance_scale": 7.5, "guidance_scale": 7.5,
"flow_shift": 13, "flow_shift": 13,
"resolution": "1280x720", "resolution": "1280x720",
"video_prompt_type": "I", "video_prompt_type": "I",
}) })
elif model_type in ("hunyuan_custom_audio"): elif model_type in ["hunyuan_custom_audio"]:
ui_defaults.update({ ui_defaults.update({
"guidance_scale": 7.5, "guidance_scale": 7.5,
"flow_shift": 13, "flow_shift": 13,
"video_prompt_type": "I", "video_prompt_type": "I",
}) })
elif model_type in ("hunyuan_custom_edit"): elif model_type in ["hunyuan_custom_edit"]:
ui_defaults.update({ ui_defaults.update({
"guidance_scale": 7.5, "guidance_scale": 7.5,
"flow_shift": 13, "flow_shift": 13,
"video_prompt_type": "MVAI", "video_prompt_type": "MVAI",
"sliding_window_size": 129, "sliding_window_size": 129,
}) })
elif model_type in ("hunyuan_avatar"): elif model_type in ["hunyuan_avatar"]:
ui_defaults.update({ ui_defaults.update({
"guidance_scale": 7.5, "guidance_scale": 7.5,
"flow_shift": 5, "flow_shift": 5,
@ -1848,7 +1848,7 @@ def get_default_settings(model_type):
"video_length": 129, "video_length": 129,
"video_prompt_type": "I", "video_prompt_type": "I",
}) })
elif model_type in ("vace_14B"): elif model_type in ["vace_14B"]:
ui_defaults.update({ ui_defaults.update({
"sliding_window_discard_last_frames": 0, "sliding_window_discard_last_frames": 0,
}) })
@ -2063,8 +2063,8 @@ def download_models(model_filename, model_type):
shared_def = { shared_def = {
"repoId" : "DeepBeepMeep/Wan2.1", "repoId" : "DeepBeepMeep/Wan2.1",
"sourceFolderList" : [ "pose", "depth", "mask", "wav2vec", "" ], "sourceFolderList" : [ "pose", "scribble", "depth", "mask", "wav2vec", "" ],
"fileList" : [ [],["depth_anything_v2_vitl.pth"], ["sam_vit_h_4b8939_fp16.safetensors"], ["config.json", "feature_extractor_config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"], "fileList" : [ ["dw-ll_ucoco_384.onnx", "yolox_l.onnx"],["netG_A_latest.pth"], ["depth_anything_v2_vitl.pth"], ["sam_vit_h_4b8939_fp16.safetensors"], ["config.json", "feature_extractor_config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"],
[ "flownet.pkl" ] ] [ "flownet.pkl" ] ]
} }
process_files_def(**shared_def) process_files_def(**shared_def)
@ -2813,6 +2813,12 @@ def get_preprocessor(process_type, inpaint_color):
from preprocessing.gray import GrayVideoAnnotator from preprocessing.gray import GrayVideoAnnotator
cfg_dict = {} cfg_dict = {}
anno_ins = lambda img: GrayVideoAnnotator(cfg_dict).forward(img)[0] anno_ins = lambda img: GrayVideoAnnotator(cfg_dict).forward(img)[0]
elif process_type=="scribble":
from preprocessing.scribble import ScribbleVideoAnnotator
cfg_dict = {
"PRETRAINED_MODEL": "ckpts/scribble/netG_A_latest.pth"
}
anno_ins = lambda img: ScribbleVideoAnnotator(cfg_dict).forward(img)[0]
elif process_type=="inpaint": elif process_type=="inpaint":
anno_ins = lambda img : inpaint_color anno_ins = lambda img : inpaint_color
# anno_ins = lambda img : np.full_like(img, inpaint_color) # anno_ins = lambda img : np.full_like(img, inpaint_color)
@ -2821,7 +2827,7 @@ def get_preprocessor(process_type, inpaint_color):
return anno_ins return anno_ins
def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None): def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None):
from wan.utils.utils import calculate_new_dimensions from wan.utils.utils import calculate_new_dimensions, get_outpainting_frame_location, get_outpainting_full_area_dimensions
def mask_to_xyxy_box(mask): def mask_to_xyxy_box(mask):
rows, cols = np.where(mask == 255) rows, cols = np.where(mask == 255)
@ -2837,7 +2843,7 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
box = [int(x) for x in box] box = [int(x) for x in box]
return box return box
if not input_video_path: if not input_video_path or max_frames <= 0:
return None, None return None, None
any_mask = input_mask_path != None any_mask = input_mask_path != None
pose_special = "pose" in process_type pose_special = "pose" in process_type
@ -2859,24 +2865,17 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
frame_height, frame_width, _ = video[0].shape frame_height, frame_width, _ = video[0].shape
if outpainting_dims != None: if outpainting_dims != None:
outpainting_top, outpainting_bottom, outpainting_left, outpainting_right= outpainting_dims
if fit_canvas != None: if fit_canvas != None:
frame_height = int(frame_height * (100 + outpainting_top + outpainting_bottom) / 100) frame_height, frame_width = get_outpainting_full_area_dimensions(frame_height,frame_width, outpainting_dims)
frame_width = int(frame_width * (100 + outpainting_left + outpainting_right) / 100)
else: else:
frame_height,frame_width = height, width frame_height, frame_width = height, width
if fit_canvas != None: if fit_canvas != None:
height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas = fit_canvas, block_size = block_size) height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas = fit_canvas, block_size = block_size)
if outpainting_dims != None: if outpainting_dims != None:
final_height, final_width = height, width final_height, final_width = height, width
height = int(height / ((100 + outpainting_top + outpainting_bottom) / 100)) height, width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 8)
width = int(width / ((100 + outpainting_left + outpainting_right) / 100))
margin_top = int(outpainting_top/(100 + outpainting_top + outpainting_bottom) * final_height)
if (margin_top + height) > final_height or outpainting_bottom == 0: margin_top = final_height - height
margin_left = int(outpainting_left/(100 + outpainting_left + outpainting_right) * final_width)
if (margin_left + width) > final_width or outpainting_right == 0: margin_left = final_width - width
if any_mask: if any_mask:
num_frames = min(len(video), len(mask_video)) num_frames = min(len(video), len(mask_video))
@ -3250,6 +3249,8 @@ def generate_video(
original_image_refs = image_refs original_image_refs = image_refs
frames_to_inject = [] frames_to_inject = []
outpainting_dims = None if video_guide_outpainting== None or len(video_guide_outpainting) == 0 or video_guide_outpainting == "0 0 0 0" or video_guide_outpainting.startswith("#") else [int(v) for v in video_guide_outpainting.split(" ")]
if image_refs != None and len(image_refs) > 0 and (hunyuan_custom or phantom or hunyuan_avatar or vace): if image_refs != None and len(image_refs) > 0 and (hunyuan_custom or phantom or hunyuan_avatar or vace):
frames_positions_list = [ int(pos)-1 for pos in frames_positions.split(" ")] if frames_positions !=None and len(frames_positions)> 0 else [] frames_positions_list = [ int(pos)-1 for pos in frames_positions.split(" ")] if frames_positions !=None and len(frames_positions)> 0 else []
frames_positions_list = frames_positions_list[:len(image_refs)] frames_positions_list = frames_positions_list[:len(image_refs)]
@ -3259,8 +3260,10 @@ def generate_video(
for i, pos in enumerate(frames_positions_list): for i, pos in enumerate(frames_positions_list):
frames_to_inject[pos] = image_refs[i] frames_to_inject[pos] = image_refs[i]
if video_guide == None and video_source == None and not "L" in image_prompt_type: if video_guide == None and video_source == None and not "L" in image_prompt_type:
from wan.utils.utils import resize_lanczos, calculate_new_dimensions from wan.utils.utils import resize_lanczos, calculate_new_dimensions, get_outpainting_full_area_dimensions
w, h = image_refs[0].size w, h = image_refs[0].size
if outpainting_dims != None:
h, w = get_outpainting_full_area_dimensions(h,w, outpainting_dims)
default_image_size = calculate_new_dimensions(height, width, h, w, fit_canvas) default_image_size = calculate_new_dimensions(height, width, h, w, fit_canvas)
fit_canvas = None fit_canvas = None
@ -3465,6 +3468,11 @@ def generate_video(
image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications
video_guide_copy = video_guide video_guide_copy = video_guide
video_mask_copy = video_mask video_mask_copy = video_mask
keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, max_frames_to_generate)
if len(error) > 0:
raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
keep_frames_parsed = keep_frames_parsed[guide_start_frame: guide_start_frame + current_video_length]
if "V" in video_prompt_type: if "V" in video_prompt_type:
extra_label = "" extra_label = ""
if "X" in video_prompt_type: if "X" in video_prompt_type:
@ -3472,18 +3480,21 @@ def generate_video(
elif "Y" in video_prompt_type: elif "Y" in video_prompt_type:
process_outside_mask = "depth" process_outside_mask = "depth"
extra_label = " and Depth" extra_label = " and Depth"
elif "W" in video_prompt_type:
process_outside_mask = "scribble"
extra_label = " and Shapes"
else: else:
process_outside_mask = None process_outside_mask = None
preprocess_type = None preprocess_type = None
# if "P" in video_prompt_type and "D" in video_prompt_type :
# progress_args = [0, get_latest_status(state,"Extracting Open Pose and Depth Information")]
# preprocess_type = "pose_depth"
if "P" in video_prompt_type : if "P" in video_prompt_type :
progress_args = [0, get_latest_status(state,f"Extracting Open Pose{extra_label} Information")] progress_args = [0, get_latest_status(state,f"Extracting Open Pose{extra_label} Information")]
preprocess_type = "pose" preprocess_type = "pose"
elif "D" in video_prompt_type : elif "D" in video_prompt_type :
progress_args = [0, get_latest_status(state,"Extracting Depth Information")] progress_args = [0, get_latest_status(state,"Extracting Depth Information")]
preprocess_type = "depth" preprocess_type = "depth"
elif "S" in video_prompt_type :
progress_args = [0, get_latest_status(state,"Extracting Shapes Information")]
preprocess_type = "scribble"
elif "C" in video_prompt_type : elif "C" in video_prompt_type :
progress_args = [0, get_latest_status(state,f"Extracting Gray Level{extra_label} Information")] progress_args = [0, get_latest_status(state,f"Extracting Gray Level{extra_label} Information")]
preprocess_type = "gray" preprocess_type = "gray"
@ -3497,8 +3508,7 @@ def generate_video(
progress_args = [0, get_latest_status(state,f"Creating Vace Generic{extra_label} Mask")] progress_args = [0, get_latest_status(state,f"Creating Vace Generic{extra_label} Mask")]
preprocess_type = "vace" preprocess_type = "vace"
send_cmd("progress", progress_args) send_cmd("progress", progress_args)
outpainting_dims = None if video_guide_outpainting== None or len(video_guide_outpainting) == 0 or video_guide_outpainting == "0 0 0 0" or video_guide_outpainting.startswith("#") else [int(v) for v in video_guide_outpainting.split(" ")] video_guide_copy, video_mask_copy = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) if guide_start_frame == 0 else len(keep_frames_parsed) - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims )
video_guide_copy, video_mask_copy = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= current_video_length if guide_start_frame == 0 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims )
if video_guide_copy != None: if video_guide_copy != None:
if sample_fit_canvas != None: if sample_fit_canvas != None:
image_size = video_guide_copy.shape[-3: -1] image_size = video_guide_copy.shape[-3: -1]
@ -3506,14 +3516,10 @@ def generate_video(
refresh_preview["video_guide"] = Image.fromarray(video_guide_copy[0].cpu().numpy()) refresh_preview["video_guide"] = Image.fromarray(video_guide_copy[0].cpu().numpy())
if video_mask_copy != None: if video_mask_copy != None:
refresh_preview["video_mask"] = Image.fromarray(video_mask_copy[0].cpu().numpy()) refresh_preview["video_mask"] = Image.fromarray(video_mask_copy[0].cpu().numpy())
keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, max_frames_to_generate)
if len(error) > 0:
raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
keep_frames_parsed = keep_frames_parsed[guide_start_frame: guide_start_frame + current_video_length]
frames_to_inject_parsed = frames_to_inject[guide_start_frame: guide_start_frame + current_video_length] frames_to_inject_parsed = frames_to_inject[guide_start_frame: guide_start_frame + current_video_length]
src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_copy], src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_copy],
[video_mask_copy ], [video_mask_copy],
[image_refs_copy], [image_refs_copy],
current_video_length, image_size = image_size, device ="cpu", current_video_length, image_size = image_size, device ="cpu",
original_video= "O" in video_prompt_type, original_video= "O" in video_prompt_type,
@ -3521,8 +3527,11 @@ def generate_video(
start_frame = guide_start_frame, start_frame = guide_start_frame,
pre_src_video = [pre_video_guide], pre_src_video = [pre_video_guide],
fit_into_canvas = sample_fit_canvas, fit_into_canvas = sample_fit_canvas,
inject_frames= frames_to_inject_parsed, inject_frames= frames_to_inject_parsed,
outpainting_dims = outpainting_dims,
) )
if len(frames_to_inject_parsed):
refresh_preview["image_refs"] = [convert_tensor_to_image(src_video[0], frame_no) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject] + image_refs[nb_frames_positions:]
if sample_fit_canvas != None: if sample_fit_canvas != None:
image_size = src_video[0].shape[-2:] image_size = src_video[0].shape[-2:]
sample_fit_canvas = None sample_fit_canvas = None
@ -4912,35 +4921,30 @@ def del_in_sequence(source_str, letters):
return ret return ret
def refresh_video_prompt_type_image_refs(video_prompt_type, video_prompt_type_image_refs): def refresh_video_prompt_type_image_refs(state, video_prompt_type, video_prompt_type_image_refs):
video_prompt_type = del_in_sequence(video_prompt_type, "FI") video_prompt_type = del_in_sequence(video_prompt_type, "FI")
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_image_refs) video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_image_refs)
visible = "I" in video_prompt_type visible = "I" in video_prompt_type
return video_prompt_type, gr.update(visible = visible),gr.update(visible = visible), gr.update(visible = visible and "F" in video_prompt_type_image_refs) vace = get_base_model_type(state["model_type"]) in ("vace_1.3B","vace_14B")
return video_prompt_type, gr.update(visible = visible),gr.update(visible = visible), gr.update(visible = visible and "F" in video_prompt_type_image_refs), gr.update(visible= ("F" in video_prompt_type_image_refs or "V" in video_prompt_type) and vace )
def refresh_video_prompt_type_video_mask(video_prompt_type, video_prompt_type_video_mask): def refresh_video_prompt_type_video_mask(video_prompt_type, video_prompt_type_video_mask):
video_prompt_type = del_in_sequence(video_prompt_type, "XYNA") video_prompt_type = del_in_sequence(video_prompt_type, "XWYNA")
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_mask) video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_mask)
visible= "A" in video_prompt_type visible= "A" in video_prompt_type
return video_prompt_type, gr.update(visible= visible), gr.update(visible= visible ) return video_prompt_type, gr.update(visible= visible), gr.update(visible= visible )
def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide): def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide):
video_prompt_type = del_in_sequence(video_prompt_type, "DPCMUV") video_prompt_type = del_in_sequence(video_prompt_type, "DSPCMUV")
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide) video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide)
visible = "V" in video_prompt_type visible = "V" in video_prompt_type
mask_visible = visible and "A" in video_prompt_type and not "U" in video_prompt_type mask_visible = visible and "A" in video_prompt_type and not "U" in video_prompt_type
vace = get_base_model_type(state["model_type"]) in ("vace_1.3B","vace_14B") vace = get_base_model_type(state["model_type"]) in ("vace_1.3B","vace_14B")
return video_prompt_type, gr.update(visible = visible), gr.update(visible = visible),gr.update(visible= visible and vace), gr.update(visible= visible and not "U" in video_prompt_type), gr.update(visible= mask_visible), gr.update(visible= mask_visible) return video_prompt_type, gr.update(visible = visible), gr.update(visible = visible),gr.update(visible= (visible or "F" in video_prompt_type) and vace), gr.update(visible= visible and not "U" in video_prompt_type), gr.update(visible= mask_visible), gr.update(visible= mask_visible)
def refresh_video_prompt_video_guide_trigger(state, video_prompt_type, video_prompt_type_video_guide): def refresh_video_prompt_video_guide_trigger(state, video_prompt_type, video_prompt_type_video_guide):
video_prompt_type_video_guide = video_prompt_type_video_guide.split("#")[0] video_prompt_type_video_guide = video_prompt_type_video_guide.split("#")[0]
video_prompt_type = del_in_sequence(video_prompt_type, "DPCMUV") return refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide)
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide)
visible = "V" in video_prompt_type
mask_visible = visible and "A" in video_prompt_type and not "U" in video_prompt_type
vace = get_base_model_type(state["model_type"]) in ("vace_1.3B","vace_14B")
return video_prompt_type, video_prompt_type_video_guide, gr.update(visible= visible ),gr.update(visible= visible and vace), gr.update(visible= visible and not "U" in video_prompt_type), gr.update(visible= mask_visible), gr.update(visible= mask_visible)
def refresh_preview(state): def refresh_preview(state):
gen = get_gen_info(state) gen = get_gen_info(state)
@ -5211,13 +5215,13 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
("No Control Video", ""), ("No Control Video", ""),
("Transfer Human Motion", "PV"), ("Transfer Human Motion", "PV"),
("Transfer Depth", "DV"), ("Transfer Depth", "DV"),
# ("Transfer Human Motion & Depth", "DPV"), ("Transfer Shapes", "SV"),
("Recolorize Control Video", "CV"), ("Recolorize", "CV"),
("Inpainting", "MV"), ("Inpainting", "MV"),
("Vace raw format", "V"), ("Vace raw format", "V"),
("Keep Unchanged", "UV"), ("Keep Unchanged", "UV"),
], ],
value=filter_letters(video_prompt_type_value, "DPCMUV"), value=filter_letters(video_prompt_type_value, "DSPCMUV"),
label="Control Video Process", scale = 2, visible= True label="Control Video Process", scale = 2, visible= True
) )
elif hunyuan_video_custom_edit: elif hunyuan_video_custom_edit:
@ -5226,7 +5230,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
("Inpaint Control Video", "MV"), ("Inpaint Control Video", "MV"),
("Transfer Human Motion", "PMV"), ("Transfer Human Motion", "PMV"),
], ],
value=filter_letters(video_prompt_type_value, "DPCMUV"), value=filter_letters(video_prompt_type_value, "DSPCMUV"),
label="Video to Video", scale = 3, visible= True label="Video to Video", scale = 3, visible= True
) )
else: else:
@ -5254,8 +5258,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
("Non Masked Area, rest Inpainted", "XNA"), ("Non Masked Area, rest Inpainted", "XNA"),
("Masked Area, rest Depth", "YA"), ("Masked Area, rest Depth", "YA"),
("Non Masked Area, rest Depth", "YNA"), ("Non Masked Area, rest Depth", "YNA"),
("Masked Area, rest Shapes", "WA"),
("Non Masked Area, rest Shapes", "WNA"),
], ],
value= filter_letters(video_prompt_type_value, "XYNA"), value= filter_letters(video_prompt_type_value, "XYWNA"),
visible= "V" in video_prompt_type_value and not "U" in video_prompt_type_value and not hunyuan_video_custom, visible= "V" in video_prompt_type_value and not "U" in video_prompt_type_value and not hunyuan_video_custom,
label="Area Processed", scale = 2 label="Area Processed", scale = 2
) )
@ -5280,11 +5286,11 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
video_guide = gr.Video(label= "Control Video", visible= "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None),) video_guide = gr.Video(label= "Control Video", visible= "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None),)
keep_frames_video_guide = gr.Text(value=ui_defaults.get("keep_frames_video_guide","") , visible= "V" in video_prompt_type_value, scale = 2, label= "Frames to keep in Control Video (empty=All, 1=first, a:b for a range, space to separate values)" ) #, -1=last keep_frames_video_guide = gr.Text(value=ui_defaults.get("keep_frames_video_guide","") , visible= "V" in video_prompt_type_value, scale = 2, label= "Frames to keep in Control Video (empty=All, 1=first, a:b for a range, space to separate values)" ) #, -1=last
with gr.Column(visible= "V" in video_prompt_type_value and vace) as video_guide_outpainting_col: with gr.Column(visible= ("V" in video_prompt_type_value or "F" in video_prompt_type_value) and vace) as video_guide_outpainting_col:
video_guide_outpainting_value = ui_defaults.get("video_guide_outpainting","#") video_guide_outpainting_value = ui_defaults.get("video_guide_outpainting","#")
video_guide_outpainting = gr.Text(value=video_guide_outpainting_value , visible= False) video_guide_outpainting = gr.Text(value=video_guide_outpainting_value , visible= False)
with gr.Group(): with gr.Group():
video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Outpainting on Control Video", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") ) video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Outpainting on Control Video or Injected Reference Frames", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") )
with gr.Row(visible = not video_guide_outpainting_value.startswith("#")) as video_guide_outpainting_row: with gr.Row(visible = not video_guide_outpainting_value.startswith("#")) as video_guide_outpainting_row:
video_guide_outpainting_value = video_guide_outpainting_value[1:] if video_guide_outpainting_value.startswith("#") else video_guide_outpainting_value video_guide_outpainting_value = video_guide_outpainting_value[1:] if video_guide_outpainting_value.startswith("#") else video_guide_outpainting_value
video_guide_outpainting_list = [0] * 4 if len(video_guide_outpainting_value) == 0 else [int(v) for v in video_guide_outpainting_value.split(" ")] video_guide_outpainting_list = [0] * 4 if len(video_guide_outpainting_value) == 0 else [int(v) for v in video_guide_outpainting_value.split(" ")]
@ -5298,6 +5304,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "A" in video_prompt_type_value and not "U" in video_prompt_type_value ) mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "A" in video_prompt_type_value and not "U" in video_prompt_type_value )
if (phantom or hunyuan_video_custom) and not "I" in video_prompt_type_value: video_prompt_type_value += "I" if (phantom or hunyuan_video_custom) and not "I" in video_prompt_type_value: video_prompt_type_value += "I"
if hunyuan_t2v and not "I" in video_prompt_type_value: video_prompt_type_value = del_in_sequence(video_prompt_type_value, "I")
image_refs = gr.Gallery( label ="Start Image" if hunyuan_video_avatar else "Reference Images", image_refs = gr.Gallery( label ="Start Image" if hunyuan_video_avatar else "Reference Images",
type ="pil", show_label= True, type ="pil", show_label= True,
@ -5544,7 +5551,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
gr.Markdown("<B>A Sliding Window allows you to generate video with a duration not limited by the Model</B>") gr.Markdown("<B>A Sliding Window allows you to generate video with a duration not limited by the Model</B>")
gr.Markdown("<B>It is automatically turned on if the number of frames to generate is higher than the Window Size</B>") gr.Markdown("<B>It is automatically turned on if the number of frames to generate is higher than the Window Size</B>")
if diffusion_forcing: if diffusion_forcing:
sliding_window_size = gr.Slider(37, 137, value=ui_defaults.get("sliding_window_size", 97), step=20, label="Sliding Window Size (recommended to keep it at 97)") sliding_window_size = gr.Slider(37, 257, value=ui_defaults.get("sliding_window_size", 97), step=20, label=" (recommended to keep it at 97)")
sliding_window_overlap = gr.Slider(17, 97, value=ui_defaults.get("sliding_window_overlap",17), step=20, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)") sliding_window_overlap = gr.Slider(17, 97, value=ui_defaults.get("sliding_window_overlap",17), step=20, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
sliding_window_overlap_noise = gr.Slider(0, 100, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect", visible = True) sliding_window_overlap_noise = gr.Slider(0, 100, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect", visible = True)
sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, visible = False) sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, visible = False)
@ -5559,7 +5566,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
sliding_window_overlap_noise = gr.Slider(0, 150, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect", visible = False) sliding_window_overlap_noise = gr.Slider(0, 150, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect", visible = False)
sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True) sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True)
else: else:
sliding_window_size = gr.Slider(5, 137, value=ui_defaults.get("sliding_window_size", 81), step=4, label="Sliding Window Size") sliding_window_size = gr.Slider(5, 257, value=ui_defaults.get("sliding_window_size", 81), step=4, label="Sliding Window Size")
sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",5), step=4, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)") sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",5), step=4, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
sliding_window_overlap_noise = gr.Slider(0, 150, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect" , visible = True) sliding_window_overlap_noise = gr.Slider(0, 150, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect" , visible = True)
sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 8), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True) sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 8), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True)
@ -5661,7 +5668,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start, image_end, video_source, keep_frames_video_source] ) image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start, image_end, video_source, keep_frames_video_source] )
video_prompt_video_guide_trigger.change(fn=refresh_video_prompt_video_guide_trigger, inputs=[state, video_prompt_type, video_prompt_video_guide_trigger], outputs=[video_prompt_type, video_prompt_type_video_guide, video_guide, keep_frames_video_guide, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, mask_expand]) video_prompt_video_guide_trigger.change(fn=refresh_video_prompt_video_guide_trigger, inputs=[state, video_prompt_type, video_prompt_video_guide_trigger], outputs=[video_prompt_type, video_prompt_type_video_guide, video_guide, keep_frames_video_guide, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, mask_expand])
video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_images_ref, frames_positions ]) video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_images_ref, frames_positions, video_guide_outpainting_col])
video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, video_prompt_type, video_prompt_type_video_guide], outputs = [video_prompt_type, video_guide, keep_frames_video_guide, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, mask_expand]) video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, video_prompt_type, video_prompt_type_video_guide], outputs = [video_prompt_type, video_guide, keep_frames_video_guide, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, mask_expand])
video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [video_prompt_type, video_prompt_type_video_mask], outputs = [video_prompt_type, video_mask, mask_expand]) video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [video_prompt_type, video_prompt_type_video_mask], outputs = [video_prompt_type, video_mask, mask_expand])
multi_prompts_gen_type.select(fn=refresh_prompt_labels, inputs=multi_prompts_gen_type, outputs=[prompt, wizard_prompt]) multi_prompts_gen_type.select(fn=refresh_prompt_labels, inputs=multi_prompts_gen_type, outputs=[prompt, wizard_prompt])