diff --git a/preprocessing/depth_anything_v2/depth.py b/preprocessing/depth_anything_v2/depth.py index 1b3685b..fd4fcdd 100644 --- a/preprocessing/depth_anything_v2/depth.py +++ b/preprocessing/depth_anything_v2/depth.py @@ -26,7 +26,8 @@ class DepthV2Annotator: self.model.load_state_dict( torch.load( pretrained_model, - map_location=self.device + map_location=self.device, + weights_only=True ) ) self.model.eval() diff --git a/preprocessing/scribble.py b/preprocessing/scribble.py new file mode 100644 index 0000000..d408288 --- /dev/null +++ b/preprocessing/scribble.py @@ -0,0 +1,148 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from PIL import Image + + +norm_layer = nn.InstanceNorm2d + +def convert_to_torch(image): + if isinstance(image, Image.Image): + image = torch.from_numpy(np.array(image)).float() + elif isinstance(image, torch.Tensor): + image = image.clone() + elif isinstance(image, np.ndarray): + image = torch.from_numpy(image.copy()).float() + else: + raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.' + return image + +class ResidualBlock(nn.Module): + def __init__(self, in_features): + super(ResidualBlock, self).__init__() + + conv_block = [ + nn.ReflectionPad2d(1), + nn.Conv2d(in_features, in_features, 3), + norm_layer(in_features), + nn.ReLU(inplace=True), + nn.ReflectionPad2d(1), + nn.Conv2d(in_features, in_features, 3), + norm_layer(in_features) + ] + + self.conv_block = nn.Sequential(*conv_block) + + def forward(self, x): + return x + self.conv_block(x) + + +class ContourInference(nn.Module): + def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True): + super(ContourInference, self).__init__() + + # Initial convolution block + model0 = [ + nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, 64, 7), + norm_layer(64), + nn.ReLU(inplace=True) + ] + self.model0 = nn.Sequential(*model0) + + # Downsampling + model1 = [] + in_features = 64 + out_features = in_features * 2 + for _ in range(2): + model1 += [ + nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), + norm_layer(out_features), + nn.ReLU(inplace=True) + ] + in_features = out_features + out_features = in_features * 2 + self.model1 = nn.Sequential(*model1) + + model2 = [] + # Residual blocks + for _ in range(n_residual_blocks): + model2 += [ResidualBlock(in_features)] + self.model2 = nn.Sequential(*model2) + + # Upsampling + model3 = [] + out_features = in_features // 2 + for _ in range(2): + model3 += [ + nn.ConvTranspose2d(in_features, + out_features, + 3, + stride=2, + padding=1, + output_padding=1), + norm_layer(out_features), + nn.ReLU(inplace=True) + ] + in_features = out_features + out_features = in_features // 2 + self.model3 = nn.Sequential(*model3) + + # Output layer + model4 = [nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7)] + if sigmoid: + model4 += [nn.Sigmoid()] + + self.model4 = nn.Sequential(*model4) + + def forward(self, x, cond=None): + out = self.model0(x) + out = self.model1(out) + out = self.model2(out) + out = self.model3(out) + out = self.model4(out) + + return out + + +class ScribbleAnnotator: + def __init__(self, cfg, device=None): + input_nc = cfg.get('INPUT_NC', 3) + output_nc = cfg.get('OUTPUT_NC', 1) + n_residual_blocks = cfg.get('N_RESIDUAL_BLOCKS', 3) + sigmoid = cfg.get('SIGMOID', True) + pretrained_model = cfg['PRETRAINED_MODEL'] + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device + self.model = ContourInference(input_nc, output_nc, n_residual_blocks, + sigmoid) + self.model.load_state_dict(torch.load(pretrained_model, weights_only=True)) + self.model = self.model.eval().requires_grad_(False).to(self.device) + + @torch.no_grad() + @torch.inference_mode() + @torch.autocast('cuda', enabled=False) + def forward(self, image): + is_batch = False if len(image.shape) == 3 else True + image = convert_to_torch(image) + if len(image.shape) == 3: + image = rearrange(image, 'h w c -> 1 c h w') + image = image.float().div(255).to(self.device) + contour_map = self.model(image) + contour_map = (contour_map.squeeze(dim=1) * 255.0).clip( + 0, 255).cpu().numpy().astype(np.uint8) + contour_map = contour_map[..., None].repeat(3, -1) + if not is_batch: + contour_map = contour_map.squeeze() + return contour_map + + +class ScribbleVideoAnnotator(ScribbleAnnotator): + def forward(self, frames): + ret_frames = [] + for frame in frames: + anno_frame = super().forward(np.array(frame)) + ret_frames.append(anno_frame) + return ret_frames \ No newline at end of file diff --git a/wan/text2video.py b/wan/text2video.py index 210d068..3e55ac9 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -27,6 +27,7 @@ from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from wan.modules.posemb_layers import get_rotary_pos_embed from .utils.vace_preprocessor import VaceVideoProcessor from wan.utils.basic_flowmatch import FlowMatchScheduler +from wan.utils.utils import get_outpainting_frame_location def optimized_scale(positive_flat, negative_flat): @@ -188,38 +189,52 @@ class WanT2V: def vace_latent(self, z, m): return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] - def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device): + def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device, fill_max = False, outpainting_dims = None): + from wan.utils.utils import save_image ref_width, ref_height = ref_img.size - if (ref_height, ref_width) == image_size: + if (ref_height, ref_width) == image_size and outpainting_dims == None: ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) else: - canvas_height, canvas_width = image_size + if outpainting_dims != None: + final_height, final_width = image_size + canvas_height, canvas_width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 8) + else: + canvas_height, canvas_width = image_size scale = min(canvas_height / ref_height, canvas_width / ref_width) new_height = int(ref_height * scale) new_width = int(ref_width * scale) - white_canvas = torch.full((3, 1, canvas_height, canvas_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1] - ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) - ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) + if fill_max and (canvas_height - new_height) < 16: + new_height = canvas_height + if fill_max and (canvas_width - new_width) < 16: + new_width = canvas_width top = (canvas_height - new_height) // 2 left = (canvas_width - new_width) // 2 - white_canvas[:, :, top:top + new_height, left:left + new_width] = ref_img + ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) + ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) + if outpainting_dims != None: + white_canvas = torch.full((3, 1, final_height, final_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1] + white_canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = ref_img + else: + white_canvas = torch.full((3, 1, canvas_height, canvas_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1] + white_canvas[:, :, top:top + new_height, left:left + new_width] = ref_img ref_img = white_canvas return ref_img.to(device) - def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, original_video = False, keep_frames= [], start_frame = 0, fit_into_canvas = None, pre_src_video = None, inject_frames = []): + def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, original_video = False, keep_frames= [], start_frame = 0, fit_into_canvas = None, pre_src_video = None, inject_frames = [], outpainting_dims = None): image_sizes = [] trim_video = len(keep_frames) - canvas_height, canvas_width = image_size + def conv_tensor(t, device): + return t.float().div_(127.5).add_(-1).permute(3, 0, 1, 2).to(device) for i, (sub_src_video, sub_src_mask, sub_pre_src_video) in enumerate(zip(src_video, src_mask,pre_src_video)): prepend_count = 0 if sub_pre_src_video == None else sub_pre_src_video.shape[1] - num_frames = total_frames - prepend_count + num_frames = total_frames - prepend_count + num_frames = min(num_frames, trim_video) if trim_video > 0 else num_frames if sub_src_mask is not None and sub_src_video is not None: - src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame, canvas_height = canvas_height, canvas_width = canvas_width, fit_into_canvas = fit_into_canvas) + src_video[i] = conv_tensor(sub_src_video[:num_frames], device) + src_mask[i] = conv_tensor(sub_src_mask[:num_frames], device) # src_video is [-1, 1] (at this function output), 0 = inpainting area (in fact 127 in [0, 255]) # src_mask is [-1, 1] (at this function output), 0 = preserve original video (in fact 127 in [0, 255]) and 1 = Inpainting (in fact 255 in [0, 255]) - src_video[i] = src_video[i].to(device) - src_mask[i] = src_mask[i].to(device) if prepend_count > 0: src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1) src_mask[i] = torch.cat( [torch.full_like(sub_pre_src_video, -1.0), src_mask[i]] ,1) @@ -238,8 +253,7 @@ class WanT2V: src_mask[i] = torch.ones_like(src_video[i], device=device) image_sizes.append(image_size) else: - src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame, canvas_height = canvas_height, canvas_width = canvas_width, fit_into_canvas = fit_into_canvas) - src_video[i] = src_video[i].to(device) + src_video[i] = conv_tensor(sub_src_video[:num_frames], device) src_mask[i] = torch.zeros_like(src_video[i], device=device) if original_video else torch.ones_like(src_video[i], device=device) if prepend_count > 0: src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1) @@ -256,7 +270,7 @@ class WanT2V: for k, frame in enumerate(inject_frames): if frame != None: - src_video[i][:, k:k+1] = self.fit_image_into_canvas(frame, image_size, 0, device) + src_video[i][:, k:k+1] = self.fit_image_into_canvas(frame, image_size, 0, device, True, outpainting_dims) src_mask[i][:, k:k+1] = 0 diff --git a/wan/utils/utils.py b/wan/utils/utils.py index 052b6b4..b2ee774 100644 --- a/wan/utils/utils.py +++ b/wan/utils/utils.py @@ -78,11 +78,37 @@ def remove_background(img, session=None): img = remove(img, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0) -def save_image(tensor_image, name): - import numpy as np - tensor_image = tensor_image.clone() - tensor_image= tensor_image.add_(1).mul(127.5).squeeze(1).permute(1,2,0) - Image.fromarray(tensor_image.cpu().numpy().astype(np.uint8)).save(name) +def convert_tensor_to_image(t, frame_no = -1): + t = t[:, frame_no] if frame_no >= 0 else t + return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy()) + +def save_image(tensor_image, name, frame_no = -1): + convert_tensor_to_image(tensor_image, frame_no).save(name) + +def get_outpainting_full_area_dimensions(frame_height,frame_width, outpainting_dims): + outpainting_top, outpainting_bottom, outpainting_left, outpainting_right= outpainting_dims + frame_height = int(frame_height * (100 + outpainting_top + outpainting_bottom) / 100) + frame_width = int(frame_width * (100 + outpainting_left + outpainting_right) / 100) + return frame_height, frame_width + +def get_outpainting_frame_location(final_height, final_width, outpainting_dims, block_size = 8): + outpainting_top, outpainting_bottom, outpainting_left, outpainting_right= outpainting_dims + raw_height = int(final_height / ((100 + outpainting_top + outpainting_bottom) / 100)) + height = int(raw_height / block_size) * block_size + extra_height = raw_height - height + + raw_width = int(final_width / ((100 + outpainting_left + outpainting_right) / 100)) + width = int(raw_width / block_size) * block_size + extra_width = raw_width - width + margin_top = int(outpainting_top/(100 + outpainting_top + outpainting_bottom) * final_height) + if extra_height != 0 and (outpainting_top + outpainting_bottom) != 0: + margin_top += int(outpainting_top / (outpainting_top + outpainting_bottom) * extra_height) + if (margin_top + height) > final_height or outpainting_bottom == 0: margin_top = final_height - height + margin_left = int(outpainting_left/(100 + outpainting_left + outpainting_right) * final_width) + if extra_width != 0 and (outpainting_left + outpainting_right) != 0: + margin_left += int(outpainting_left / (outpainting_left + outpainting_right) * extra_height) + if (margin_left + width) > final_width or outpainting_right == 0: margin_left = final_width - width + return height, width, margin_top, margin_left def calculate_new_dimensions(canvas_height, canvas_width, height, width, fit_into_canvas, block_size = 16): if fit_into_canvas == None: diff --git a/wgp.py b/wgp.py index f06c3ef..7d4bbf4 100644 --- a/wgp.py +++ b/wgp.py @@ -16,7 +16,7 @@ import json import wan from wan.utils import notification_sound from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES, VACE_SIZE_CONFIGS -from wan.utils.utils import cache_video +from wan.utils.utils import cache_video, convert_tensor_to_image, save_image from wan.modules.attention import get_attention_modes, get_supported_attention_modes import torch import gc @@ -45,7 +45,7 @@ AUTOSAVE_FILENAME = "queue.zip" PROMPT_VARS_MAX = 10 target_mmgp_version = "3.4.9" -WanGP_version = "6.2" +WanGP_version = "6.21" settings_version = 2 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -1793,12 +1793,12 @@ def get_default_settings(model_type): "slg_end_perc": 90 } - if model_type in ("hunyuan","hunyuan_i2v"): + if model_type in ["hunyuan","hunyuan_i2v"]: ui_defaults.update({ "guidance_scale": 7.0, }) - if model_type in ("sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"): + if model_type in ["sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"]: ui_defaults.update({ "guidance_scale": 6.0, "flow_shift": 8, @@ -1811,7 +1811,7 @@ def get_default_settings(model_type): }) - if model_type in ("phantom_1.3B", "phantom_14B"): + if model_type in ["phantom_1.3B", "phantom_14B"]: ui_defaults.update({ "guidance_scale": 7.5, "flow_shift": 5, @@ -1820,27 +1820,27 @@ def get_default_settings(model_type): # "resolution": "1280x720" }) - elif model_type in ("hunyuan_custom"): + elif model_type in ["hunyuan_custom"]: ui_defaults.update({ "guidance_scale": 7.5, "flow_shift": 13, "resolution": "1280x720", "video_prompt_type": "I", }) - elif model_type in ("hunyuan_custom_audio"): + elif model_type in ["hunyuan_custom_audio"]: ui_defaults.update({ "guidance_scale": 7.5, "flow_shift": 13, "video_prompt_type": "I", }) - elif model_type in ("hunyuan_custom_edit"): + elif model_type in ["hunyuan_custom_edit"]: ui_defaults.update({ "guidance_scale": 7.5, "flow_shift": 13, "video_prompt_type": "MVAI", "sliding_window_size": 129, }) - elif model_type in ("hunyuan_avatar"): + elif model_type in ["hunyuan_avatar"]: ui_defaults.update({ "guidance_scale": 7.5, "flow_shift": 5, @@ -1848,7 +1848,7 @@ def get_default_settings(model_type): "video_length": 129, "video_prompt_type": "I", }) - elif model_type in ("vace_14B"): + elif model_type in ["vace_14B"]: ui_defaults.update({ "sliding_window_discard_last_frames": 0, }) @@ -2063,8 +2063,8 @@ def download_models(model_filename, model_type): shared_def = { "repoId" : "DeepBeepMeep/Wan2.1", - "sourceFolderList" : [ "pose", "depth", "mask", "wav2vec", "" ], - "fileList" : [ [],["depth_anything_v2_vitl.pth"], ["sam_vit_h_4b8939_fp16.safetensors"], ["config.json", "feature_extractor_config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"], + "sourceFolderList" : [ "pose", "scribble", "depth", "mask", "wav2vec", "" ], + "fileList" : [ ["dw-ll_ucoco_384.onnx", "yolox_l.onnx"],["netG_A_latest.pth"], ["depth_anything_v2_vitl.pth"], ["sam_vit_h_4b8939_fp16.safetensors"], ["config.json", "feature_extractor_config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"], [ "flownet.pkl" ] ] } process_files_def(**shared_def) @@ -2813,6 +2813,12 @@ def get_preprocessor(process_type, inpaint_color): from preprocessing.gray import GrayVideoAnnotator cfg_dict = {} anno_ins = lambda img: GrayVideoAnnotator(cfg_dict).forward(img)[0] + elif process_type=="scribble": + from preprocessing.scribble import ScribbleVideoAnnotator + cfg_dict = { + "PRETRAINED_MODEL": "ckpts/scribble/netG_A_latest.pth" + } + anno_ins = lambda img: ScribbleVideoAnnotator(cfg_dict).forward(img)[0] elif process_type=="inpaint": anno_ins = lambda img : inpaint_color # anno_ins = lambda img : np.full_like(img, inpaint_color) @@ -2821,7 +2827,7 @@ def get_preprocessor(process_type, inpaint_color): return anno_ins def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None): - from wan.utils.utils import calculate_new_dimensions + from wan.utils.utils import calculate_new_dimensions, get_outpainting_frame_location, get_outpainting_full_area_dimensions def mask_to_xyxy_box(mask): rows, cols = np.where(mask == 255) @@ -2837,7 +2843,7 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, box = [int(x) for x in box] return box - if not input_video_path: + if not input_video_path or max_frames <= 0: return None, None any_mask = input_mask_path != None pose_special = "pose" in process_type @@ -2859,24 +2865,17 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, frame_height, frame_width, _ = video[0].shape if outpainting_dims != None: - outpainting_top, outpainting_bottom, outpainting_left, outpainting_right= outpainting_dims if fit_canvas != None: - frame_height = int(frame_height * (100 + outpainting_top + outpainting_bottom) / 100) - frame_width = int(frame_width * (100 + outpainting_left + outpainting_right) / 100) + frame_height, frame_width = get_outpainting_full_area_dimensions(frame_height,frame_width, outpainting_dims) else: - frame_height,frame_width = height, width + frame_height, frame_width = height, width if fit_canvas != None: height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas = fit_canvas, block_size = block_size) if outpainting_dims != None: final_height, final_width = height, width - height = int(height / ((100 + outpainting_top + outpainting_bottom) / 100)) - width = int(width / ((100 + outpainting_left + outpainting_right) / 100)) - margin_top = int(outpainting_top/(100 + outpainting_top + outpainting_bottom) * final_height) - if (margin_top + height) > final_height or outpainting_bottom == 0: margin_top = final_height - height - margin_left = int(outpainting_left/(100 + outpainting_left + outpainting_right) * final_width) - if (margin_left + width) > final_width or outpainting_right == 0: margin_left = final_width - width + height, width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 8) if any_mask: num_frames = min(len(video), len(mask_video)) @@ -3250,6 +3249,8 @@ def generate_video( original_image_refs = image_refs frames_to_inject = [] + outpainting_dims = None if video_guide_outpainting== None or len(video_guide_outpainting) == 0 or video_guide_outpainting == "0 0 0 0" or video_guide_outpainting.startswith("#") else [int(v) for v in video_guide_outpainting.split(" ")] + if image_refs != None and len(image_refs) > 0 and (hunyuan_custom or phantom or hunyuan_avatar or vace): frames_positions_list = [ int(pos)-1 for pos in frames_positions.split(" ")] if frames_positions !=None and len(frames_positions)> 0 else [] frames_positions_list = frames_positions_list[:len(image_refs)] @@ -3259,8 +3260,10 @@ def generate_video( for i, pos in enumerate(frames_positions_list): frames_to_inject[pos] = image_refs[i] if video_guide == None and video_source == None and not "L" in image_prompt_type: - from wan.utils.utils import resize_lanczos, calculate_new_dimensions + from wan.utils.utils import resize_lanczos, calculate_new_dimensions, get_outpainting_full_area_dimensions w, h = image_refs[0].size + if outpainting_dims != None: + h, w = get_outpainting_full_area_dimensions(h,w, outpainting_dims) default_image_size = calculate_new_dimensions(height, width, h, w, fit_canvas) fit_canvas = None @@ -3465,6 +3468,11 @@ def generate_video( image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications video_guide_copy = video_guide video_mask_copy = video_mask + keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, max_frames_to_generate) + if len(error) > 0: + raise gr.Error(f"invalid keep frames {keep_frames_video_guide}") + keep_frames_parsed = keep_frames_parsed[guide_start_frame: guide_start_frame + current_video_length] + if "V" in video_prompt_type: extra_label = "" if "X" in video_prompt_type: @@ -3472,18 +3480,21 @@ def generate_video( elif "Y" in video_prompt_type: process_outside_mask = "depth" extra_label = " and Depth" + elif "W" in video_prompt_type: + process_outside_mask = "scribble" + extra_label = " and Shapes" else: process_outside_mask = None preprocess_type = None - # if "P" in video_prompt_type and "D" in video_prompt_type : - # progress_args = [0, get_latest_status(state,"Extracting Open Pose and Depth Information")] - # preprocess_type = "pose_depth" if "P" in video_prompt_type : progress_args = [0, get_latest_status(state,f"Extracting Open Pose{extra_label} Information")] preprocess_type = "pose" elif "D" in video_prompt_type : progress_args = [0, get_latest_status(state,"Extracting Depth Information")] preprocess_type = "depth" + elif "S" in video_prompt_type : + progress_args = [0, get_latest_status(state,"Extracting Shapes Information")] + preprocess_type = "scribble" elif "C" in video_prompt_type : progress_args = [0, get_latest_status(state,f"Extracting Gray Level{extra_label} Information")] preprocess_type = "gray" @@ -3497,8 +3508,7 @@ def generate_video( progress_args = [0, get_latest_status(state,f"Creating Vace Generic{extra_label} Mask")] preprocess_type = "vace" send_cmd("progress", progress_args) - outpainting_dims = None if video_guide_outpainting== None or len(video_guide_outpainting) == 0 or video_guide_outpainting == "0 0 0 0" or video_guide_outpainting.startswith("#") else [int(v) for v in video_guide_outpainting.split(" ")] - video_guide_copy, video_mask_copy = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= current_video_length if guide_start_frame == 0 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims ) + video_guide_copy, video_mask_copy = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) if guide_start_frame == 0 else len(keep_frames_parsed) - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims ) if video_guide_copy != None: if sample_fit_canvas != None: image_size = video_guide_copy.shape[-3: -1] @@ -3506,14 +3516,10 @@ def generate_video( refresh_preview["video_guide"] = Image.fromarray(video_guide_copy[0].cpu().numpy()) if video_mask_copy != None: refresh_preview["video_mask"] = Image.fromarray(video_mask_copy[0].cpu().numpy()) - keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, max_frames_to_generate) - if len(error) > 0: - raise gr.Error(f"invalid keep frames {keep_frames_video_guide}") - keep_frames_parsed = keep_frames_parsed[guide_start_frame: guide_start_frame + current_video_length] frames_to_inject_parsed = frames_to_inject[guide_start_frame: guide_start_frame + current_video_length] src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_copy], - [video_mask_copy ], + [video_mask_copy], [image_refs_copy], current_video_length, image_size = image_size, device ="cpu", original_video= "O" in video_prompt_type, @@ -3521,8 +3527,11 @@ def generate_video( start_frame = guide_start_frame, pre_src_video = [pre_video_guide], fit_into_canvas = sample_fit_canvas, - inject_frames= frames_to_inject_parsed, + inject_frames= frames_to_inject_parsed, + outpainting_dims = outpainting_dims, ) + if len(frames_to_inject_parsed): + refresh_preview["image_refs"] = [convert_tensor_to_image(src_video[0], frame_no) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject] + image_refs[nb_frames_positions:] if sample_fit_canvas != None: image_size = src_video[0].shape[-2:] sample_fit_canvas = None @@ -4912,35 +4921,30 @@ def del_in_sequence(source_str, letters): return ret -def refresh_video_prompt_type_image_refs(video_prompt_type, video_prompt_type_image_refs): +def refresh_video_prompt_type_image_refs(state, video_prompt_type, video_prompt_type_image_refs): video_prompt_type = del_in_sequence(video_prompt_type, "FI") video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_image_refs) visible = "I" in video_prompt_type - return video_prompt_type, gr.update(visible = visible),gr.update(visible = visible), gr.update(visible = visible and "F" in video_prompt_type_image_refs) + vace = get_base_model_type(state["model_type"]) in ("vace_1.3B","vace_14B") + return video_prompt_type, gr.update(visible = visible),gr.update(visible = visible), gr.update(visible = visible and "F" in video_prompt_type_image_refs), gr.update(visible= ("F" in video_prompt_type_image_refs or "V" in video_prompt_type) and vace ) def refresh_video_prompt_type_video_mask(video_prompt_type, video_prompt_type_video_mask): - video_prompt_type = del_in_sequence(video_prompt_type, "XYNA") + video_prompt_type = del_in_sequence(video_prompt_type, "XWYNA") video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_mask) visible= "A" in video_prompt_type return video_prompt_type, gr.update(visible= visible), gr.update(visible= visible ) def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide): - video_prompt_type = del_in_sequence(video_prompt_type, "DPCMUV") + video_prompt_type = del_in_sequence(video_prompt_type, "DSPCMUV") video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide) visible = "V" in video_prompt_type mask_visible = visible and "A" in video_prompt_type and not "U" in video_prompt_type vace = get_base_model_type(state["model_type"]) in ("vace_1.3B","vace_14B") - return video_prompt_type, gr.update(visible = visible), gr.update(visible = visible),gr.update(visible= visible and vace), gr.update(visible= visible and not "U" in video_prompt_type), gr.update(visible= mask_visible), gr.update(visible= mask_visible) + return video_prompt_type, gr.update(visible = visible), gr.update(visible = visible),gr.update(visible= (visible or "F" in video_prompt_type) and vace), gr.update(visible= visible and not "U" in video_prompt_type), gr.update(visible= mask_visible), gr.update(visible= mask_visible) def refresh_video_prompt_video_guide_trigger(state, video_prompt_type, video_prompt_type_video_guide): video_prompt_type_video_guide = video_prompt_type_video_guide.split("#")[0] - video_prompt_type = del_in_sequence(video_prompt_type, "DPCMUV") - video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide) - visible = "V" in video_prompt_type - mask_visible = visible and "A" in video_prompt_type and not "U" in video_prompt_type - vace = get_base_model_type(state["model_type"]) in ("vace_1.3B","vace_14B") - return video_prompt_type, video_prompt_type_video_guide, gr.update(visible= visible ),gr.update(visible= visible and vace), gr.update(visible= visible and not "U" in video_prompt_type), gr.update(visible= mask_visible), gr.update(visible= mask_visible) - + return refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide) def refresh_preview(state): gen = get_gen_info(state) @@ -5211,13 +5215,13 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ("No Control Video", ""), ("Transfer Human Motion", "PV"), ("Transfer Depth", "DV"), - # ("Transfer Human Motion & Depth", "DPV"), - ("Recolorize Control Video", "CV"), + ("Transfer Shapes", "SV"), + ("Recolorize", "CV"), ("Inpainting", "MV"), ("Vace raw format", "V"), ("Keep Unchanged", "UV"), ], - value=filter_letters(video_prompt_type_value, "DPCMUV"), + value=filter_letters(video_prompt_type_value, "DSPCMUV"), label="Control Video Process", scale = 2, visible= True ) elif hunyuan_video_custom_edit: @@ -5226,7 +5230,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ("Inpaint Control Video", "MV"), ("Transfer Human Motion", "PMV"), ], - value=filter_letters(video_prompt_type_value, "DPCMUV"), + value=filter_letters(video_prompt_type_value, "DSPCMUV"), label="Video to Video", scale = 3, visible= True ) else: @@ -5254,8 +5258,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ("Non Masked Area, rest Inpainted", "XNA"), ("Masked Area, rest Depth", "YA"), ("Non Masked Area, rest Depth", "YNA"), + ("Masked Area, rest Shapes", "WA"), + ("Non Masked Area, rest Shapes", "WNA"), ], - value= filter_letters(video_prompt_type_value, "XYNA"), + value= filter_letters(video_prompt_type_value, "XYWNA"), visible= "V" in video_prompt_type_value and not "U" in video_prompt_type_value and not hunyuan_video_custom, label="Area Processed", scale = 2 ) @@ -5280,11 +5286,11 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non video_guide = gr.Video(label= "Control Video", visible= "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None),) keep_frames_video_guide = gr.Text(value=ui_defaults.get("keep_frames_video_guide","") , visible= "V" in video_prompt_type_value, scale = 2, label= "Frames to keep in Control Video (empty=All, 1=first, a:b for a range, space to separate values)" ) #, -1=last - with gr.Column(visible= "V" in video_prompt_type_value and vace) as video_guide_outpainting_col: + with gr.Column(visible= ("V" in video_prompt_type_value or "F" in video_prompt_type_value) and vace) as video_guide_outpainting_col: video_guide_outpainting_value = ui_defaults.get("video_guide_outpainting","#") video_guide_outpainting = gr.Text(value=video_guide_outpainting_value , visible= False) with gr.Group(): - video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Outpainting on Control Video", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") ) + video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Outpainting on Control Video or Injected Reference Frames", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") ) with gr.Row(visible = not video_guide_outpainting_value.startswith("#")) as video_guide_outpainting_row: video_guide_outpainting_value = video_guide_outpainting_value[1:] if video_guide_outpainting_value.startswith("#") else video_guide_outpainting_value video_guide_outpainting_list = [0] * 4 if len(video_guide_outpainting_value) == 0 else [int(v) for v in video_guide_outpainting_value.split(" ")] @@ -5298,6 +5304,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "A" in video_prompt_type_value and not "U" in video_prompt_type_value ) if (phantom or hunyuan_video_custom) and not "I" in video_prompt_type_value: video_prompt_type_value += "I" + if hunyuan_t2v and not "I" in video_prompt_type_value: video_prompt_type_value = del_in_sequence(video_prompt_type_value, "I") image_refs = gr.Gallery( label ="Start Image" if hunyuan_video_avatar else "Reference Images", type ="pil", show_label= True, @@ -5544,7 +5551,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non gr.Markdown("A Sliding Window allows you to generate video with a duration not limited by the Model") gr.Markdown("It is automatically turned on if the number of frames to generate is higher than the Window Size") if diffusion_forcing: - sliding_window_size = gr.Slider(37, 137, value=ui_defaults.get("sliding_window_size", 97), step=20, label="Sliding Window Size (recommended to keep it at 97)") + sliding_window_size = gr.Slider(37, 257, value=ui_defaults.get("sliding_window_size", 97), step=20, label=" (recommended to keep it at 97)") sliding_window_overlap = gr.Slider(17, 97, value=ui_defaults.get("sliding_window_overlap",17), step=20, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)") sliding_window_overlap_noise = gr.Slider(0, 100, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect", visible = True) sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, visible = False) @@ -5559,7 +5566,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non sliding_window_overlap_noise = gr.Slider(0, 150, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect", visible = False) sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True) else: - sliding_window_size = gr.Slider(5, 137, value=ui_defaults.get("sliding_window_size", 81), step=4, label="Sliding Window Size") + sliding_window_size = gr.Slider(5, 257, value=ui_defaults.get("sliding_window_size", 81), step=4, label="Sliding Window Size") sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",5), step=4, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)") sliding_window_overlap_noise = gr.Slider(0, 150, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect" , visible = True) sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 8), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True) @@ -5661,7 +5668,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start, image_end, video_source, keep_frames_video_source] ) video_prompt_video_guide_trigger.change(fn=refresh_video_prompt_video_guide_trigger, inputs=[state, video_prompt_type, video_prompt_video_guide_trigger], outputs=[video_prompt_type, video_prompt_type_video_guide, video_guide, keep_frames_video_guide, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, mask_expand]) - video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_images_ref, frames_positions ]) + video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_images_ref, frames_positions, video_guide_outpainting_col]) video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, video_prompt_type, video_prompt_type_video_guide], outputs = [video_prompt_type, video_guide, keep_frames_video_guide, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, mask_expand]) video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [video_prompt_type, video_prompt_type_video_mask], outputs = [video_prompt_type, video_mask, mask_expand]) multi_prompts_gen_type.select(fn=refresh_prompt_labels, inputs=multi_prompts_gen_type, outputs=[prompt, wizard_prompt])