added outpainting on injected frames and shapes preprocessor

This commit is contained in:
DeepBeepMeep 2025-06-20 23:57:07 +02:00
parent 8b146a8d7b
commit d0a32c67a0
5 changed files with 276 additions and 80 deletions

View File

@ -26,7 +26,8 @@ class DepthV2Annotator:
self.model.load_state_dict(
torch.load(
pretrained_model,
map_location=self.device
map_location=self.device,
weights_only=True
)
)
self.model.eval()

148
preprocessing/scribble.py Normal file
View File

@ -0,0 +1,148 @@
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from PIL import Image
norm_layer = nn.InstanceNorm2d
def convert_to_torch(image):
if isinstance(image, Image.Image):
image = torch.from_numpy(np.array(image)).float()
elif isinstance(image, torch.Tensor):
image = image.clone()
elif isinstance(image, np.ndarray):
image = torch.from_numpy(image.copy()).float()
else:
raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.'
return image
class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
conv_block = [
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
norm_layer(in_features),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
norm_layer(in_features)
]
self.conv_block = nn.Sequential(*conv_block)
def forward(self, x):
return x + self.conv_block(x)
class ContourInference(nn.Module):
def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
super(ContourInference, self).__init__()
# Initial convolution block
model0 = [
nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, 64, 7),
norm_layer(64),
nn.ReLU(inplace=True)
]
self.model0 = nn.Sequential(*model0)
# Downsampling
model1 = []
in_features = 64
out_features = in_features * 2
for _ in range(2):
model1 += [
nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
norm_layer(out_features),
nn.ReLU(inplace=True)
]
in_features = out_features
out_features = in_features * 2
self.model1 = nn.Sequential(*model1)
model2 = []
# Residual blocks
for _ in range(n_residual_blocks):
model2 += [ResidualBlock(in_features)]
self.model2 = nn.Sequential(*model2)
# Upsampling
model3 = []
out_features = in_features // 2
for _ in range(2):
model3 += [
nn.ConvTranspose2d(in_features,
out_features,
3,
stride=2,
padding=1,
output_padding=1),
norm_layer(out_features),
nn.ReLU(inplace=True)
]
in_features = out_features
out_features = in_features // 2
self.model3 = nn.Sequential(*model3)
# Output layer
model4 = [nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7)]
if sigmoid:
model4 += [nn.Sigmoid()]
self.model4 = nn.Sequential(*model4)
def forward(self, x, cond=None):
out = self.model0(x)
out = self.model1(out)
out = self.model2(out)
out = self.model3(out)
out = self.model4(out)
return out
class ScribbleAnnotator:
def __init__(self, cfg, device=None):
input_nc = cfg.get('INPUT_NC', 3)
output_nc = cfg.get('OUTPUT_NC', 1)
n_residual_blocks = cfg.get('N_RESIDUAL_BLOCKS', 3)
sigmoid = cfg.get('SIGMOID', True)
pretrained_model = cfg['PRETRAINED_MODEL']
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
self.model = ContourInference(input_nc, output_nc, n_residual_blocks,
sigmoid)
self.model.load_state_dict(torch.load(pretrained_model, weights_only=True))
self.model = self.model.eval().requires_grad_(False).to(self.device)
@torch.no_grad()
@torch.inference_mode()
@torch.autocast('cuda', enabled=False)
def forward(self, image):
is_batch = False if len(image.shape) == 3 else True
image = convert_to_torch(image)
if len(image.shape) == 3:
image = rearrange(image, 'h w c -> 1 c h w')
image = image.float().div(255).to(self.device)
contour_map = self.model(image)
contour_map = (contour_map.squeeze(dim=1) * 255.0).clip(
0, 255).cpu().numpy().astype(np.uint8)
contour_map = contour_map[..., None].repeat(3, -1)
if not is_batch:
contour_map = contour_map.squeeze()
return contour_map
class ScribbleVideoAnnotator(ScribbleAnnotator):
def forward(self, frames):
ret_frames = []
for frame in frames:
anno_frame = super().forward(np.array(frame))
ret_frames.append(anno_frame)
return ret_frames

View File

@ -27,6 +27,7 @@ from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from wan.modules.posemb_layers import get_rotary_pos_embed
from .utils.vace_preprocessor import VaceVideoProcessor
from wan.utils.basic_flowmatch import FlowMatchScheduler
from wan.utils.utils import get_outpainting_frame_location
def optimized_scale(positive_flat, negative_flat):
@ -188,38 +189,52 @@ class WanT2V:
def vace_latent(self, z, m):
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device):
def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device, fill_max = False, outpainting_dims = None):
from wan.utils.utils import save_image
ref_width, ref_height = ref_img.size
if (ref_height, ref_width) == image_size:
if (ref_height, ref_width) == image_size and outpainting_dims == None:
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
else:
canvas_height, canvas_width = image_size
if outpainting_dims != None:
final_height, final_width = image_size
canvas_height, canvas_width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 8)
else:
canvas_height, canvas_width = image_size
scale = min(canvas_height / ref_height, canvas_width / ref_width)
new_height = int(ref_height * scale)
new_width = int(ref_width * scale)
white_canvas = torch.full((3, 1, canvas_height, canvas_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1]
ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
if fill_max and (canvas_height - new_height) < 16:
new_height = canvas_height
if fill_max and (canvas_width - new_width) < 16:
new_width = canvas_width
top = (canvas_height - new_height) // 2
left = (canvas_width - new_width) // 2
white_canvas[:, :, top:top + new_height, left:left + new_width] = ref_img
ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
if outpainting_dims != None:
white_canvas = torch.full((3, 1, final_height, final_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1]
white_canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = ref_img
else:
white_canvas = torch.full((3, 1, canvas_height, canvas_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1]
white_canvas[:, :, top:top + new_height, left:left + new_width] = ref_img
ref_img = white_canvas
return ref_img.to(device)
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, original_video = False, keep_frames= [], start_frame = 0, fit_into_canvas = None, pre_src_video = None, inject_frames = []):
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, original_video = False, keep_frames= [], start_frame = 0, fit_into_canvas = None, pre_src_video = None, inject_frames = [], outpainting_dims = None):
image_sizes = []
trim_video = len(keep_frames)
canvas_height, canvas_width = image_size
def conv_tensor(t, device):
return t.float().div_(127.5).add_(-1).permute(3, 0, 1, 2).to(device)
for i, (sub_src_video, sub_src_mask, sub_pre_src_video) in enumerate(zip(src_video, src_mask,pre_src_video)):
prepend_count = 0 if sub_pre_src_video == None else sub_pre_src_video.shape[1]
num_frames = total_frames - prepend_count
num_frames = min(num_frames, trim_video) if trim_video > 0 else num_frames
if sub_src_mask is not None and sub_src_video is not None:
src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame, canvas_height = canvas_height, canvas_width = canvas_width, fit_into_canvas = fit_into_canvas)
src_video[i] = conv_tensor(sub_src_video[:num_frames], device)
src_mask[i] = conv_tensor(sub_src_mask[:num_frames], device)
# src_video is [-1, 1] (at this function output), 0 = inpainting area (in fact 127 in [0, 255])
# src_mask is [-1, 1] (at this function output), 0 = preserve original video (in fact 127 in [0, 255]) and 1 = Inpainting (in fact 255 in [0, 255])
src_video[i] = src_video[i].to(device)
src_mask[i] = src_mask[i].to(device)
if prepend_count > 0:
src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
src_mask[i] = torch.cat( [torch.full_like(sub_pre_src_video, -1.0), src_mask[i]] ,1)
@ -238,8 +253,7 @@ class WanT2V:
src_mask[i] = torch.ones_like(src_video[i], device=device)
image_sizes.append(image_size)
else:
src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame, canvas_height = canvas_height, canvas_width = canvas_width, fit_into_canvas = fit_into_canvas)
src_video[i] = src_video[i].to(device)
src_video[i] = conv_tensor(sub_src_video[:num_frames], device)
src_mask[i] = torch.zeros_like(src_video[i], device=device) if original_video else torch.ones_like(src_video[i], device=device)
if prepend_count > 0:
src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
@ -256,7 +270,7 @@ class WanT2V:
for k, frame in enumerate(inject_frames):
if frame != None:
src_video[i][:, k:k+1] = self.fit_image_into_canvas(frame, image_size, 0, device)
src_video[i][:, k:k+1] = self.fit_image_into_canvas(frame, image_size, 0, device, True, outpainting_dims)
src_mask[i][:, k:k+1] = 0

View File

@ -78,11 +78,37 @@ def remove_background(img, session=None):
img = remove(img, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0)
def save_image(tensor_image, name):
import numpy as np
tensor_image = tensor_image.clone()
tensor_image= tensor_image.add_(1).mul(127.5).squeeze(1).permute(1,2,0)
Image.fromarray(tensor_image.cpu().numpy().astype(np.uint8)).save(name)
def convert_tensor_to_image(t, frame_no = -1):
t = t[:, frame_no] if frame_no >= 0 else t
return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy())
def save_image(tensor_image, name, frame_no = -1):
convert_tensor_to_image(tensor_image, frame_no).save(name)
def get_outpainting_full_area_dimensions(frame_height,frame_width, outpainting_dims):
outpainting_top, outpainting_bottom, outpainting_left, outpainting_right= outpainting_dims
frame_height = int(frame_height * (100 + outpainting_top + outpainting_bottom) / 100)
frame_width = int(frame_width * (100 + outpainting_left + outpainting_right) / 100)
return frame_height, frame_width
def get_outpainting_frame_location(final_height, final_width, outpainting_dims, block_size = 8):
outpainting_top, outpainting_bottom, outpainting_left, outpainting_right= outpainting_dims
raw_height = int(final_height / ((100 + outpainting_top + outpainting_bottom) / 100))
height = int(raw_height / block_size) * block_size
extra_height = raw_height - height
raw_width = int(final_width / ((100 + outpainting_left + outpainting_right) / 100))
width = int(raw_width / block_size) * block_size
extra_width = raw_width - width
margin_top = int(outpainting_top/(100 + outpainting_top + outpainting_bottom) * final_height)
if extra_height != 0 and (outpainting_top + outpainting_bottom) != 0:
margin_top += int(outpainting_top / (outpainting_top + outpainting_bottom) * extra_height)
if (margin_top + height) > final_height or outpainting_bottom == 0: margin_top = final_height - height
margin_left = int(outpainting_left/(100 + outpainting_left + outpainting_right) * final_width)
if extra_width != 0 and (outpainting_left + outpainting_right) != 0:
margin_left += int(outpainting_left / (outpainting_left + outpainting_right) * extra_height)
if (margin_left + width) > final_width or outpainting_right == 0: margin_left = final_width - width
return height, width, margin_top, margin_left
def calculate_new_dimensions(canvas_height, canvas_width, height, width, fit_into_canvas, block_size = 16):
if fit_into_canvas == None:

121
wgp.py
View File

@ -16,7 +16,7 @@ import json
import wan
from wan.utils import notification_sound
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES, VACE_SIZE_CONFIGS
from wan.utils.utils import cache_video
from wan.utils.utils import cache_video, convert_tensor_to_image, save_image
from wan.modules.attention import get_attention_modes, get_supported_attention_modes
import torch
import gc
@ -45,7 +45,7 @@ AUTOSAVE_FILENAME = "queue.zip"
PROMPT_VARS_MAX = 10
target_mmgp_version = "3.4.9"
WanGP_version = "6.2"
WanGP_version = "6.21"
settings_version = 2
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
@ -1793,12 +1793,12 @@ def get_default_settings(model_type):
"slg_end_perc": 90
}
if model_type in ("hunyuan","hunyuan_i2v"):
if model_type in ["hunyuan","hunyuan_i2v"]:
ui_defaults.update({
"guidance_scale": 7.0,
})
if model_type in ("sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"):
if model_type in ["sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"]:
ui_defaults.update({
"guidance_scale": 6.0,
"flow_shift": 8,
@ -1811,7 +1811,7 @@ def get_default_settings(model_type):
})
if model_type in ("phantom_1.3B", "phantom_14B"):
if model_type in ["phantom_1.3B", "phantom_14B"]:
ui_defaults.update({
"guidance_scale": 7.5,
"flow_shift": 5,
@ -1820,27 +1820,27 @@ def get_default_settings(model_type):
# "resolution": "1280x720"
})
elif model_type in ("hunyuan_custom"):
elif model_type in ["hunyuan_custom"]:
ui_defaults.update({
"guidance_scale": 7.5,
"flow_shift": 13,
"resolution": "1280x720",
"video_prompt_type": "I",
})
elif model_type in ("hunyuan_custom_audio"):
elif model_type in ["hunyuan_custom_audio"]:
ui_defaults.update({
"guidance_scale": 7.5,
"flow_shift": 13,
"video_prompt_type": "I",
})
elif model_type in ("hunyuan_custom_edit"):
elif model_type in ["hunyuan_custom_edit"]:
ui_defaults.update({
"guidance_scale": 7.5,
"flow_shift": 13,
"video_prompt_type": "MVAI",
"sliding_window_size": 129,
})
elif model_type in ("hunyuan_avatar"):
elif model_type in ["hunyuan_avatar"]:
ui_defaults.update({
"guidance_scale": 7.5,
"flow_shift": 5,
@ -1848,7 +1848,7 @@ def get_default_settings(model_type):
"video_length": 129,
"video_prompt_type": "I",
})
elif model_type in ("vace_14B"):
elif model_type in ["vace_14B"]:
ui_defaults.update({
"sliding_window_discard_last_frames": 0,
})
@ -2063,8 +2063,8 @@ def download_models(model_filename, model_type):
shared_def = {
"repoId" : "DeepBeepMeep/Wan2.1",
"sourceFolderList" : [ "pose", "depth", "mask", "wav2vec", "" ],
"fileList" : [ [],["depth_anything_v2_vitl.pth"], ["sam_vit_h_4b8939_fp16.safetensors"], ["config.json", "feature_extractor_config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"],
"sourceFolderList" : [ "pose", "scribble", "depth", "mask", "wav2vec", "" ],
"fileList" : [ ["dw-ll_ucoco_384.onnx", "yolox_l.onnx"],["netG_A_latest.pth"], ["depth_anything_v2_vitl.pth"], ["sam_vit_h_4b8939_fp16.safetensors"], ["config.json", "feature_extractor_config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"],
[ "flownet.pkl" ] ]
}
process_files_def(**shared_def)
@ -2813,6 +2813,12 @@ def get_preprocessor(process_type, inpaint_color):
from preprocessing.gray import GrayVideoAnnotator
cfg_dict = {}
anno_ins = lambda img: GrayVideoAnnotator(cfg_dict).forward(img)[0]
elif process_type=="scribble":
from preprocessing.scribble import ScribbleVideoAnnotator
cfg_dict = {
"PRETRAINED_MODEL": "ckpts/scribble/netG_A_latest.pth"
}
anno_ins = lambda img: ScribbleVideoAnnotator(cfg_dict).forward(img)[0]
elif process_type=="inpaint":
anno_ins = lambda img : inpaint_color
# anno_ins = lambda img : np.full_like(img, inpaint_color)
@ -2821,7 +2827,7 @@ def get_preprocessor(process_type, inpaint_color):
return anno_ins
def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None):
from wan.utils.utils import calculate_new_dimensions
from wan.utils.utils import calculate_new_dimensions, get_outpainting_frame_location, get_outpainting_full_area_dimensions
def mask_to_xyxy_box(mask):
rows, cols = np.where(mask == 255)
@ -2837,7 +2843,7 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
box = [int(x) for x in box]
return box
if not input_video_path:
if not input_video_path or max_frames <= 0:
return None, None
any_mask = input_mask_path != None
pose_special = "pose" in process_type
@ -2859,24 +2865,17 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
frame_height, frame_width, _ = video[0].shape
if outpainting_dims != None:
outpainting_top, outpainting_bottom, outpainting_left, outpainting_right= outpainting_dims
if fit_canvas != None:
frame_height = int(frame_height * (100 + outpainting_top + outpainting_bottom) / 100)
frame_width = int(frame_width * (100 + outpainting_left + outpainting_right) / 100)
frame_height, frame_width = get_outpainting_full_area_dimensions(frame_height,frame_width, outpainting_dims)
else:
frame_height,frame_width = height, width
frame_height, frame_width = height, width
if fit_canvas != None:
height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas = fit_canvas, block_size = block_size)
if outpainting_dims != None:
final_height, final_width = height, width
height = int(height / ((100 + outpainting_top + outpainting_bottom) / 100))
width = int(width / ((100 + outpainting_left + outpainting_right) / 100))
margin_top = int(outpainting_top/(100 + outpainting_top + outpainting_bottom) * final_height)
if (margin_top + height) > final_height or outpainting_bottom == 0: margin_top = final_height - height
margin_left = int(outpainting_left/(100 + outpainting_left + outpainting_right) * final_width)
if (margin_left + width) > final_width or outpainting_right == 0: margin_left = final_width - width
height, width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 8)
if any_mask:
num_frames = min(len(video), len(mask_video))
@ -3250,6 +3249,8 @@ def generate_video(
original_image_refs = image_refs
frames_to_inject = []
outpainting_dims = None if video_guide_outpainting== None or len(video_guide_outpainting) == 0 or video_guide_outpainting == "0 0 0 0" or video_guide_outpainting.startswith("#") else [int(v) for v in video_guide_outpainting.split(" ")]
if image_refs != None and len(image_refs) > 0 and (hunyuan_custom or phantom or hunyuan_avatar or vace):
frames_positions_list = [ int(pos)-1 for pos in frames_positions.split(" ")] if frames_positions !=None and len(frames_positions)> 0 else []
frames_positions_list = frames_positions_list[:len(image_refs)]
@ -3259,8 +3260,10 @@ def generate_video(
for i, pos in enumerate(frames_positions_list):
frames_to_inject[pos] = image_refs[i]
if video_guide == None and video_source == None and not "L" in image_prompt_type:
from wan.utils.utils import resize_lanczos, calculate_new_dimensions
from wan.utils.utils import resize_lanczos, calculate_new_dimensions, get_outpainting_full_area_dimensions
w, h = image_refs[0].size
if outpainting_dims != None:
h, w = get_outpainting_full_area_dimensions(h,w, outpainting_dims)
default_image_size = calculate_new_dimensions(height, width, h, w, fit_canvas)
fit_canvas = None
@ -3465,6 +3468,11 @@ def generate_video(
image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications
video_guide_copy = video_guide
video_mask_copy = video_mask
keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, max_frames_to_generate)
if len(error) > 0:
raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
keep_frames_parsed = keep_frames_parsed[guide_start_frame: guide_start_frame + current_video_length]
if "V" in video_prompt_type:
extra_label = ""
if "X" in video_prompt_type:
@ -3472,18 +3480,21 @@ def generate_video(
elif "Y" in video_prompt_type:
process_outside_mask = "depth"
extra_label = " and Depth"
elif "W" in video_prompt_type:
process_outside_mask = "scribble"
extra_label = " and Shapes"
else:
process_outside_mask = None
preprocess_type = None
# if "P" in video_prompt_type and "D" in video_prompt_type :
# progress_args = [0, get_latest_status(state,"Extracting Open Pose and Depth Information")]
# preprocess_type = "pose_depth"
if "P" in video_prompt_type :
progress_args = [0, get_latest_status(state,f"Extracting Open Pose{extra_label} Information")]
preprocess_type = "pose"
elif "D" in video_prompt_type :
progress_args = [0, get_latest_status(state,"Extracting Depth Information")]
preprocess_type = "depth"
elif "S" in video_prompt_type :
progress_args = [0, get_latest_status(state,"Extracting Shapes Information")]
preprocess_type = "scribble"
elif "C" in video_prompt_type :
progress_args = [0, get_latest_status(state,f"Extracting Gray Level{extra_label} Information")]
preprocess_type = "gray"
@ -3497,8 +3508,7 @@ def generate_video(
progress_args = [0, get_latest_status(state,f"Creating Vace Generic{extra_label} Mask")]
preprocess_type = "vace"
send_cmd("progress", progress_args)
outpainting_dims = None if video_guide_outpainting== None or len(video_guide_outpainting) == 0 or video_guide_outpainting == "0 0 0 0" or video_guide_outpainting.startswith("#") else [int(v) for v in video_guide_outpainting.split(" ")]
video_guide_copy, video_mask_copy = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= current_video_length if guide_start_frame == 0 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims )
video_guide_copy, video_mask_copy = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) if guide_start_frame == 0 else len(keep_frames_parsed) - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims )
if video_guide_copy != None:
if sample_fit_canvas != None:
image_size = video_guide_copy.shape[-3: -1]
@ -3506,14 +3516,10 @@ def generate_video(
refresh_preview["video_guide"] = Image.fromarray(video_guide_copy[0].cpu().numpy())
if video_mask_copy != None:
refresh_preview["video_mask"] = Image.fromarray(video_mask_copy[0].cpu().numpy())
keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, max_frames_to_generate)
if len(error) > 0:
raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
keep_frames_parsed = keep_frames_parsed[guide_start_frame: guide_start_frame + current_video_length]
frames_to_inject_parsed = frames_to_inject[guide_start_frame: guide_start_frame + current_video_length]
src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_copy],
[video_mask_copy ],
[video_mask_copy],
[image_refs_copy],
current_video_length, image_size = image_size, device ="cpu",
original_video= "O" in video_prompt_type,
@ -3522,7 +3528,10 @@ def generate_video(
pre_src_video = [pre_video_guide],
fit_into_canvas = sample_fit_canvas,
inject_frames= frames_to_inject_parsed,
outpainting_dims = outpainting_dims,
)
if len(frames_to_inject_parsed):
refresh_preview["image_refs"] = [convert_tensor_to_image(src_video[0], frame_no) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject] + image_refs[nb_frames_positions:]
if sample_fit_canvas != None:
image_size = src_video[0].shape[-2:]
sample_fit_canvas = None
@ -4912,35 +4921,30 @@ def del_in_sequence(source_str, letters):
return ret
def refresh_video_prompt_type_image_refs(video_prompt_type, video_prompt_type_image_refs):
def refresh_video_prompt_type_image_refs(state, video_prompt_type, video_prompt_type_image_refs):
video_prompt_type = del_in_sequence(video_prompt_type, "FI")
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_image_refs)
visible = "I" in video_prompt_type
return video_prompt_type, gr.update(visible = visible),gr.update(visible = visible), gr.update(visible = visible and "F" in video_prompt_type_image_refs)
vace = get_base_model_type(state["model_type"]) in ("vace_1.3B","vace_14B")
return video_prompt_type, gr.update(visible = visible),gr.update(visible = visible), gr.update(visible = visible and "F" in video_prompt_type_image_refs), gr.update(visible= ("F" in video_prompt_type_image_refs or "V" in video_prompt_type) and vace )
def refresh_video_prompt_type_video_mask(video_prompt_type, video_prompt_type_video_mask):
video_prompt_type = del_in_sequence(video_prompt_type, "XYNA")
video_prompt_type = del_in_sequence(video_prompt_type, "XWYNA")
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_mask)
visible= "A" in video_prompt_type
return video_prompt_type, gr.update(visible= visible), gr.update(visible= visible )
def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide):
video_prompt_type = del_in_sequence(video_prompt_type, "DPCMUV")
video_prompt_type = del_in_sequence(video_prompt_type, "DSPCMUV")
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide)
visible = "V" in video_prompt_type
mask_visible = visible and "A" in video_prompt_type and not "U" in video_prompt_type
vace = get_base_model_type(state["model_type"]) in ("vace_1.3B","vace_14B")
return video_prompt_type, gr.update(visible = visible), gr.update(visible = visible),gr.update(visible= visible and vace), gr.update(visible= visible and not "U" in video_prompt_type), gr.update(visible= mask_visible), gr.update(visible= mask_visible)
return video_prompt_type, gr.update(visible = visible), gr.update(visible = visible),gr.update(visible= (visible or "F" in video_prompt_type) and vace), gr.update(visible= visible and not "U" in video_prompt_type), gr.update(visible= mask_visible), gr.update(visible= mask_visible)
def refresh_video_prompt_video_guide_trigger(state, video_prompt_type, video_prompt_type_video_guide):
video_prompt_type_video_guide = video_prompt_type_video_guide.split("#")[0]
video_prompt_type = del_in_sequence(video_prompt_type, "DPCMUV")
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide)
visible = "V" in video_prompt_type
mask_visible = visible and "A" in video_prompt_type and not "U" in video_prompt_type
vace = get_base_model_type(state["model_type"]) in ("vace_1.3B","vace_14B")
return video_prompt_type, video_prompt_type_video_guide, gr.update(visible= visible ),gr.update(visible= visible and vace), gr.update(visible= visible and not "U" in video_prompt_type), gr.update(visible= mask_visible), gr.update(visible= mask_visible)
return refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide)
def refresh_preview(state):
gen = get_gen_info(state)
@ -5211,13 +5215,13 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
("No Control Video", ""),
("Transfer Human Motion", "PV"),
("Transfer Depth", "DV"),
# ("Transfer Human Motion & Depth", "DPV"),
("Recolorize Control Video", "CV"),
("Transfer Shapes", "SV"),
("Recolorize", "CV"),
("Inpainting", "MV"),
("Vace raw format", "V"),
("Keep Unchanged", "UV"),
],
value=filter_letters(video_prompt_type_value, "DPCMUV"),
value=filter_letters(video_prompt_type_value, "DSPCMUV"),
label="Control Video Process", scale = 2, visible= True
)
elif hunyuan_video_custom_edit:
@ -5226,7 +5230,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
("Inpaint Control Video", "MV"),
("Transfer Human Motion", "PMV"),
],
value=filter_letters(video_prompt_type_value, "DPCMUV"),
value=filter_letters(video_prompt_type_value, "DSPCMUV"),
label="Video to Video", scale = 3, visible= True
)
else:
@ -5254,8 +5258,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
("Non Masked Area, rest Inpainted", "XNA"),
("Masked Area, rest Depth", "YA"),
("Non Masked Area, rest Depth", "YNA"),
("Masked Area, rest Shapes", "WA"),
("Non Masked Area, rest Shapes", "WNA"),
],
value= filter_letters(video_prompt_type_value, "XYNA"),
value= filter_letters(video_prompt_type_value, "XYWNA"),
visible= "V" in video_prompt_type_value and not "U" in video_prompt_type_value and not hunyuan_video_custom,
label="Area Processed", scale = 2
)
@ -5280,11 +5286,11 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
video_guide = gr.Video(label= "Control Video", visible= "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None),)
keep_frames_video_guide = gr.Text(value=ui_defaults.get("keep_frames_video_guide","") , visible= "V" in video_prompt_type_value, scale = 2, label= "Frames to keep in Control Video (empty=All, 1=first, a:b for a range, space to separate values)" ) #, -1=last
with gr.Column(visible= "V" in video_prompt_type_value and vace) as video_guide_outpainting_col:
with gr.Column(visible= ("V" in video_prompt_type_value or "F" in video_prompt_type_value) and vace) as video_guide_outpainting_col:
video_guide_outpainting_value = ui_defaults.get("video_guide_outpainting","#")
video_guide_outpainting = gr.Text(value=video_guide_outpainting_value , visible= False)
with gr.Group():
video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Outpainting on Control Video", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") )
video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Outpainting on Control Video or Injected Reference Frames", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") )
with gr.Row(visible = not video_guide_outpainting_value.startswith("#")) as video_guide_outpainting_row:
video_guide_outpainting_value = video_guide_outpainting_value[1:] if video_guide_outpainting_value.startswith("#") else video_guide_outpainting_value
video_guide_outpainting_list = [0] * 4 if len(video_guide_outpainting_value) == 0 else [int(v) for v in video_guide_outpainting_value.split(" ")]
@ -5298,6 +5304,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "A" in video_prompt_type_value and not "U" in video_prompt_type_value )
if (phantom or hunyuan_video_custom) and not "I" in video_prompt_type_value: video_prompt_type_value += "I"
if hunyuan_t2v and not "I" in video_prompt_type_value: video_prompt_type_value = del_in_sequence(video_prompt_type_value, "I")
image_refs = gr.Gallery( label ="Start Image" if hunyuan_video_avatar else "Reference Images",
type ="pil", show_label= True,
@ -5544,7 +5551,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
gr.Markdown("<B>A Sliding Window allows you to generate video with a duration not limited by the Model</B>")
gr.Markdown("<B>It is automatically turned on if the number of frames to generate is higher than the Window Size</B>")
if diffusion_forcing:
sliding_window_size = gr.Slider(37, 137, value=ui_defaults.get("sliding_window_size", 97), step=20, label="Sliding Window Size (recommended to keep it at 97)")
sliding_window_size = gr.Slider(37, 257, value=ui_defaults.get("sliding_window_size", 97), step=20, label=" (recommended to keep it at 97)")
sliding_window_overlap = gr.Slider(17, 97, value=ui_defaults.get("sliding_window_overlap",17), step=20, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
sliding_window_overlap_noise = gr.Slider(0, 100, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect", visible = True)
sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, visible = False)
@ -5559,7 +5566,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
sliding_window_overlap_noise = gr.Slider(0, 150, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect", visible = False)
sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True)
else:
sliding_window_size = gr.Slider(5, 137, value=ui_defaults.get("sliding_window_size", 81), step=4, label="Sliding Window Size")
sliding_window_size = gr.Slider(5, 257, value=ui_defaults.get("sliding_window_size", 81), step=4, label="Sliding Window Size")
sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",5), step=4, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
sliding_window_overlap_noise = gr.Slider(0, 150, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect" , visible = True)
sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 8), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True)
@ -5661,7 +5668,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start, image_end, video_source, keep_frames_video_source] )
video_prompt_video_guide_trigger.change(fn=refresh_video_prompt_video_guide_trigger, inputs=[state, video_prompt_type, video_prompt_video_guide_trigger], outputs=[video_prompt_type, video_prompt_type_video_guide, video_guide, keep_frames_video_guide, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, mask_expand])
video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_images_ref, frames_positions ])
video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_images_ref, frames_positions, video_guide_outpainting_col])
video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, video_prompt_type, video_prompt_type_video_guide], outputs = [video_prompt_type, video_guide, keep_frames_video_guide, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, mask_expand])
video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [video_prompt_type, video_prompt_type_video_mask], outputs = [video_prompt_type, video_mask, mask_expand])
multi_prompts_gen_type.select(fn=refresh_prompt_labels, inputs=multi_prompts_gen_type, outputs=[prompt, wizard_prompt])