mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	added outpainting on injected frames and shapes preprocessor
This commit is contained in:
		
							parent
							
								
									8b146a8d7b
								
							
						
					
					
						commit
						d0a32c67a0
					
				@ -26,7 +26,8 @@ class DepthV2Annotator:
 | 
			
		||||
        self.model.load_state_dict(
 | 
			
		||||
            torch.load(
 | 
			
		||||
                pretrained_model,
 | 
			
		||||
                map_location=self.device
 | 
			
		||||
                map_location=self.device,
 | 
			
		||||
                weights_only=True
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        self.model.eval()
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										148
									
								
								preprocessing/scribble.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										148
									
								
								preprocessing/scribble.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,148 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
# Copyright (c) Alibaba, Inc. and its affiliates.
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
from einops import rearrange
 | 
			
		||||
from PIL import Image
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
norm_layer = nn.InstanceNorm2d
 | 
			
		||||
 | 
			
		||||
def convert_to_torch(image):
 | 
			
		||||
    if isinstance(image, Image.Image):
 | 
			
		||||
        image = torch.from_numpy(np.array(image)).float()
 | 
			
		||||
    elif isinstance(image, torch.Tensor):
 | 
			
		||||
        image = image.clone()
 | 
			
		||||
    elif isinstance(image, np.ndarray):
 | 
			
		||||
        image = torch.from_numpy(image.copy()).float()
 | 
			
		||||
    else:
 | 
			
		||||
        raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.'
 | 
			
		||||
    return image
 | 
			
		||||
 | 
			
		||||
class ResidualBlock(nn.Module):
 | 
			
		||||
    def __init__(self, in_features):
 | 
			
		||||
        super(ResidualBlock, self).__init__()
 | 
			
		||||
 | 
			
		||||
        conv_block = [
 | 
			
		||||
            nn.ReflectionPad2d(1),
 | 
			
		||||
            nn.Conv2d(in_features, in_features, 3),
 | 
			
		||||
            norm_layer(in_features),
 | 
			
		||||
            nn.ReLU(inplace=True),
 | 
			
		||||
            nn.ReflectionPad2d(1),
 | 
			
		||||
            nn.Conv2d(in_features, in_features, 3),
 | 
			
		||||
            norm_layer(in_features)
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        self.conv_block = nn.Sequential(*conv_block)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        return x + self.conv_block(x)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ContourInference(nn.Module):
 | 
			
		||||
    def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
 | 
			
		||||
        super(ContourInference, self).__init__()
 | 
			
		||||
 | 
			
		||||
        # Initial convolution block
 | 
			
		||||
        model0 = [
 | 
			
		||||
            nn.ReflectionPad2d(3),
 | 
			
		||||
            nn.Conv2d(input_nc, 64, 7),
 | 
			
		||||
            norm_layer(64),
 | 
			
		||||
            nn.ReLU(inplace=True)
 | 
			
		||||
        ]
 | 
			
		||||
        self.model0 = nn.Sequential(*model0)
 | 
			
		||||
 | 
			
		||||
        # Downsampling
 | 
			
		||||
        model1 = []
 | 
			
		||||
        in_features = 64
 | 
			
		||||
        out_features = in_features * 2
 | 
			
		||||
        for _ in range(2):
 | 
			
		||||
            model1 += [
 | 
			
		||||
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
 | 
			
		||||
                norm_layer(out_features),
 | 
			
		||||
                nn.ReLU(inplace=True)
 | 
			
		||||
            ]
 | 
			
		||||
            in_features = out_features
 | 
			
		||||
            out_features = in_features * 2
 | 
			
		||||
        self.model1 = nn.Sequential(*model1)
 | 
			
		||||
 | 
			
		||||
        model2 = []
 | 
			
		||||
        # Residual blocks
 | 
			
		||||
        for _ in range(n_residual_blocks):
 | 
			
		||||
            model2 += [ResidualBlock(in_features)]
 | 
			
		||||
        self.model2 = nn.Sequential(*model2)
 | 
			
		||||
 | 
			
		||||
        # Upsampling
 | 
			
		||||
        model3 = []
 | 
			
		||||
        out_features = in_features // 2
 | 
			
		||||
        for _ in range(2):
 | 
			
		||||
            model3 += [
 | 
			
		||||
                nn.ConvTranspose2d(in_features,
 | 
			
		||||
                                   out_features,
 | 
			
		||||
                                   3,
 | 
			
		||||
                                   stride=2,
 | 
			
		||||
                                   padding=1,
 | 
			
		||||
                                   output_padding=1),
 | 
			
		||||
                norm_layer(out_features),
 | 
			
		||||
                nn.ReLU(inplace=True)
 | 
			
		||||
            ]
 | 
			
		||||
            in_features = out_features
 | 
			
		||||
            out_features = in_features // 2
 | 
			
		||||
        self.model3 = nn.Sequential(*model3)
 | 
			
		||||
 | 
			
		||||
        # Output layer
 | 
			
		||||
        model4 = [nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7)]
 | 
			
		||||
        if sigmoid:
 | 
			
		||||
            model4 += [nn.Sigmoid()]
 | 
			
		||||
 | 
			
		||||
        self.model4 = nn.Sequential(*model4)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, cond=None):
 | 
			
		||||
        out = self.model0(x)
 | 
			
		||||
        out = self.model1(out)
 | 
			
		||||
        out = self.model2(out)
 | 
			
		||||
        out = self.model3(out)
 | 
			
		||||
        out = self.model4(out)
 | 
			
		||||
 | 
			
		||||
        return out
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ScribbleAnnotator:
 | 
			
		||||
    def __init__(self, cfg, device=None):
 | 
			
		||||
        input_nc = cfg.get('INPUT_NC', 3)
 | 
			
		||||
        output_nc = cfg.get('OUTPUT_NC', 1)
 | 
			
		||||
        n_residual_blocks = cfg.get('N_RESIDUAL_BLOCKS', 3)
 | 
			
		||||
        sigmoid = cfg.get('SIGMOID', True)
 | 
			
		||||
        pretrained_model = cfg['PRETRAINED_MODEL']
 | 
			
		||||
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
 | 
			
		||||
        self.model = ContourInference(input_nc, output_nc, n_residual_blocks,
 | 
			
		||||
                                      sigmoid)
 | 
			
		||||
        self.model.load_state_dict(torch.load(pretrained_model, weights_only=True))
 | 
			
		||||
        self.model = self.model.eval().requires_grad_(False).to(self.device)
 | 
			
		||||
 | 
			
		||||
    @torch.no_grad()
 | 
			
		||||
    @torch.inference_mode()
 | 
			
		||||
    @torch.autocast('cuda', enabled=False)
 | 
			
		||||
    def forward(self, image):
 | 
			
		||||
        is_batch = False if len(image.shape) == 3 else True
 | 
			
		||||
        image = convert_to_torch(image)
 | 
			
		||||
        if len(image.shape) == 3:
 | 
			
		||||
            image = rearrange(image, 'h w c -> 1 c h w')
 | 
			
		||||
        image = image.float().div(255).to(self.device)
 | 
			
		||||
        contour_map = self.model(image)
 | 
			
		||||
        contour_map = (contour_map.squeeze(dim=1) * 255.0).clip(
 | 
			
		||||
            0, 255).cpu().numpy().astype(np.uint8)
 | 
			
		||||
        contour_map = contour_map[..., None].repeat(3, -1)
 | 
			
		||||
        if not is_batch:
 | 
			
		||||
            contour_map = contour_map.squeeze()
 | 
			
		||||
        return contour_map
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ScribbleVideoAnnotator(ScribbleAnnotator):
 | 
			
		||||
    def forward(self, frames):
 | 
			
		||||
        ret_frames = []
 | 
			
		||||
        for frame in frames:
 | 
			
		||||
            anno_frame = super().forward(np.array(frame))
 | 
			
		||||
            ret_frames.append(anno_frame)
 | 
			
		||||
        return ret_frames
 | 
			
		||||
@ -27,6 +27,7 @@ from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
 | 
			
		||||
from wan.modules.posemb_layers import get_rotary_pos_embed
 | 
			
		||||
from .utils.vace_preprocessor import VaceVideoProcessor
 | 
			
		||||
from wan.utils.basic_flowmatch import FlowMatchScheduler
 | 
			
		||||
from wan.utils.utils import get_outpainting_frame_location
 | 
			
		||||
 | 
			
		||||
def optimized_scale(positive_flat, negative_flat):
 | 
			
		||||
 | 
			
		||||
@ -188,38 +189,52 @@ class WanT2V:
 | 
			
		||||
    def vace_latent(self, z, m):
 | 
			
		||||
        return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
 | 
			
		||||
 | 
			
		||||
    def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device):
 | 
			
		||||
    def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device, fill_max = False, outpainting_dims = None):
 | 
			
		||||
        from wan.utils.utils import save_image
 | 
			
		||||
        ref_width, ref_height = ref_img.size
 | 
			
		||||
        if (ref_height, ref_width) == image_size:
 | 
			
		||||
        if (ref_height, ref_width) == image_size and outpainting_dims  == None:
 | 
			
		||||
            ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
 | 
			
		||||
        else:
 | 
			
		||||
            canvas_height, canvas_width = image_size
 | 
			
		||||
            if outpainting_dims != None:
 | 
			
		||||
                final_height, final_width = image_size
 | 
			
		||||
                canvas_height, canvas_width, margin_top, margin_left =   get_outpainting_frame_location(final_height, final_width,  outpainting_dims, 8)        
 | 
			
		||||
            else:
 | 
			
		||||
                canvas_height, canvas_width = image_size
 | 
			
		||||
            scale = min(canvas_height / ref_height, canvas_width / ref_width)
 | 
			
		||||
            new_height = int(ref_height * scale)
 | 
			
		||||
            new_width = int(ref_width * scale)
 | 
			
		||||
            white_canvas = torch.full((3, 1, canvas_height, canvas_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1]
 | 
			
		||||
            ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) 
 | 
			
		||||
            ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
 | 
			
		||||
            if fill_max  and (canvas_height - new_height) < 16:
 | 
			
		||||
                new_height = canvas_height
 | 
			
		||||
            if fill_max  and (canvas_width - new_width) < 16:
 | 
			
		||||
                new_width = canvas_width
 | 
			
		||||
            top = (canvas_height - new_height) // 2
 | 
			
		||||
            left = (canvas_width - new_width) // 2
 | 
			
		||||
            white_canvas[:, :, top:top + new_height, left:left + new_width] = ref_img 
 | 
			
		||||
            ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) 
 | 
			
		||||
            ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
 | 
			
		||||
            if outpainting_dims != None:
 | 
			
		||||
                white_canvas = torch.full((3, 1, final_height, final_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1]
 | 
			
		||||
                white_canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = ref_img 
 | 
			
		||||
            else:
 | 
			
		||||
                white_canvas = torch.full((3, 1, canvas_height, canvas_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1]
 | 
			
		||||
                white_canvas[:, :, top:top + new_height, left:left + new_width] = ref_img 
 | 
			
		||||
            ref_img = white_canvas
 | 
			
		||||
        return ref_img.to(device)
 | 
			
		||||
 | 
			
		||||
    def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size,  device, original_video = False, keep_frames= [], start_frame = 0,  fit_into_canvas = None, pre_src_video = None, inject_frames = []):
 | 
			
		||||
    def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size,  device, original_video = False, keep_frames= [], start_frame = 0,  fit_into_canvas = None, pre_src_video = None, inject_frames = [], outpainting_dims = None):
 | 
			
		||||
        image_sizes = []
 | 
			
		||||
        trim_video = len(keep_frames)
 | 
			
		||||
        canvas_height, canvas_width = image_size
 | 
			
		||||
        def conv_tensor(t, device):
 | 
			
		||||
            return t.float().div_(127.5).add_(-1).permute(3, 0, 1, 2).to(device)
 | 
			
		||||
 | 
			
		||||
        for i, (sub_src_video, sub_src_mask, sub_pre_src_video) in enumerate(zip(src_video, src_mask,pre_src_video)):
 | 
			
		||||
            prepend_count = 0 if sub_pre_src_video == None else sub_pre_src_video.shape[1]
 | 
			
		||||
            num_frames = total_frames - prepend_count 
 | 
			
		||||
            num_frames = total_frames - prepend_count            
 | 
			
		||||
            num_frames = min(num_frames, trim_video) if trim_video > 0 else num_frames
 | 
			
		||||
            if sub_src_mask is not None and sub_src_video is not None:
 | 
			
		||||
                src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame, canvas_height = canvas_height, canvas_width = canvas_width, fit_into_canvas = fit_into_canvas)
 | 
			
		||||
                src_video[i] = conv_tensor(sub_src_video[:num_frames], device)
 | 
			
		||||
                src_mask[i] = conv_tensor(sub_src_mask[:num_frames], device)
 | 
			
		||||
                # src_video is [-1, 1] (at this function output), 0 = inpainting area (in fact 127  in [0, 255])
 | 
			
		||||
                # src_mask is [-1, 1] (at this function output), 0 = preserve original video (in fact 127  in [0, 255]) and 1 = Inpainting (in fact 255  in [0, 255])
 | 
			
		||||
                src_video[i] = src_video[i].to(device)
 | 
			
		||||
                src_mask[i] = src_mask[i].to(device)
 | 
			
		||||
                if prepend_count > 0:
 | 
			
		||||
                    src_video[i] =  torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
 | 
			
		||||
                    src_mask[i] =  torch.cat( [torch.full_like(sub_pre_src_video, -1.0), src_mask[i]] ,1)
 | 
			
		||||
@ -238,8 +253,7 @@ class WanT2V:
 | 
			
		||||
                    src_mask[i] = torch.ones_like(src_video[i], device=device)
 | 
			
		||||
                image_sizes.append(image_size)
 | 
			
		||||
            else:
 | 
			
		||||
                src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame, canvas_height = canvas_height, canvas_width = canvas_width, fit_into_canvas = fit_into_canvas)
 | 
			
		||||
                src_video[i] = src_video[i].to(device)
 | 
			
		||||
                src_video[i] = conv_tensor(sub_src_video[:num_frames], device)
 | 
			
		||||
                src_mask[i] = torch.zeros_like(src_video[i], device=device) if original_video else torch.ones_like(src_video[i], device=device)
 | 
			
		||||
                if prepend_count > 0:
 | 
			
		||||
                    src_video[i] =  torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
 | 
			
		||||
@ -256,7 +270,7 @@ class WanT2V:
 | 
			
		||||
 | 
			
		||||
            for k, frame in enumerate(inject_frames):
 | 
			
		||||
                if frame != None:
 | 
			
		||||
                    src_video[i][:, k:k+1] = self.fit_image_into_canvas(frame, image_size, 0, device)
 | 
			
		||||
                    src_video[i][:, k:k+1] = self.fit_image_into_canvas(frame, image_size, 0, device, True, outpainting_dims)
 | 
			
		||||
                    src_mask[i][:, k:k+1] = 0
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -78,11 +78,37 @@ def remove_background(img, session=None):
 | 
			
		||||
    img = remove(img, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
 | 
			
		||||
    return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0)
 | 
			
		||||
 | 
			
		||||
def save_image(tensor_image, name):
 | 
			
		||||
    import numpy as np
 | 
			
		||||
    tensor_image = tensor_image.clone()
 | 
			
		||||
    tensor_image= tensor_image.add_(1).mul(127.5).squeeze(1).permute(1,2,0)
 | 
			
		||||
    Image.fromarray(tensor_image.cpu().numpy().astype(np.uint8)).save(name)
 | 
			
		||||
def convert_tensor_to_image(t, frame_no = -1):    
 | 
			
		||||
    t = t[:, frame_no] if frame_no >= 0 else t
 | 
			
		||||
    return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy())
 | 
			
		||||
 | 
			
		||||
def save_image(tensor_image, name, frame_no = -1):
 | 
			
		||||
    convert_tensor_to_image(tensor_image, frame_no).save(name)
 | 
			
		||||
 | 
			
		||||
def get_outpainting_full_area_dimensions(frame_height,frame_width, outpainting_dims):
 | 
			
		||||
    outpainting_top, outpainting_bottom, outpainting_left, outpainting_right= outpainting_dims
 | 
			
		||||
    frame_height = int(frame_height * (100 + outpainting_top + outpainting_bottom) / 100)
 | 
			
		||||
    frame_width =  int(frame_width * (100 + outpainting_left + outpainting_right) / 100)
 | 
			
		||||
    return frame_height, frame_width  
 | 
			
		||||
 | 
			
		||||
def  get_outpainting_frame_location(final_height, final_width,  outpainting_dims, block_size = 8):
 | 
			
		||||
    outpainting_top, outpainting_bottom, outpainting_left, outpainting_right= outpainting_dims
 | 
			
		||||
    raw_height = int(final_height / ((100 + outpainting_top + outpainting_bottom) / 100))
 | 
			
		||||
    height = int(raw_height / block_size) * block_size
 | 
			
		||||
    extra_height = raw_height - height
 | 
			
		||||
          
 | 
			
		||||
    raw_width = int(final_width / ((100 + outpainting_left + outpainting_right) / 100)) 
 | 
			
		||||
    width = int(raw_width / block_size) * block_size
 | 
			
		||||
    extra_width = raw_width - width  
 | 
			
		||||
    margin_top = int(outpainting_top/(100 + outpainting_top + outpainting_bottom) * final_height)
 | 
			
		||||
    if extra_height != 0 and (outpainting_top + outpainting_bottom) != 0:
 | 
			
		||||
        margin_top += int(outpainting_top / (outpainting_top + outpainting_bottom) * extra_height)
 | 
			
		||||
    if (margin_top + height) > final_height or outpainting_bottom == 0: margin_top = final_height - height
 | 
			
		||||
    margin_left = int(outpainting_left/(100 + outpainting_left + outpainting_right) * final_width)
 | 
			
		||||
    if extra_width != 0 and (outpainting_left + outpainting_right) != 0:
 | 
			
		||||
        margin_left += int(outpainting_left / (outpainting_left + outpainting_right) * extra_height)
 | 
			
		||||
    if (margin_left + width) > final_width or outpainting_right == 0: margin_left = final_width - width
 | 
			
		||||
    return height, width, margin_top, margin_left
 | 
			
		||||
 | 
			
		||||
def calculate_new_dimensions(canvas_height, canvas_width, height, width, fit_into_canvas, block_size = 16):
 | 
			
		||||
    if fit_into_canvas == None:
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										123
									
								
								wgp.py
									
									
									
									
									
								
							
							
						
						
									
										123
									
								
								wgp.py
									
									
									
									
									
								
							@ -16,7 +16,7 @@ import json
 | 
			
		||||
import wan
 | 
			
		||||
from wan.utils import notification_sound
 | 
			
		||||
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES, VACE_SIZE_CONFIGS
 | 
			
		||||
from wan.utils.utils import cache_video
 | 
			
		||||
from wan.utils.utils import cache_video, convert_tensor_to_image, save_image
 | 
			
		||||
from wan.modules.attention import get_attention_modes, get_supported_attention_modes
 | 
			
		||||
import torch
 | 
			
		||||
import gc
 | 
			
		||||
@ -45,7 +45,7 @@ AUTOSAVE_FILENAME = "queue.zip"
 | 
			
		||||
PROMPT_VARS_MAX = 10
 | 
			
		||||
 | 
			
		||||
target_mmgp_version = "3.4.9"
 | 
			
		||||
WanGP_version = "6.2"
 | 
			
		||||
WanGP_version = "6.21"
 | 
			
		||||
settings_version = 2
 | 
			
		||||
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
 | 
			
		||||
 | 
			
		||||
@ -1793,12 +1793,12 @@ def get_default_settings(model_type):
 | 
			
		||||
                "slg_end_perc": 90
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            if model_type in ("hunyuan","hunyuan_i2v"):
 | 
			
		||||
            if model_type in ["hunyuan","hunyuan_i2v"]:
 | 
			
		||||
                ui_defaults.update({
 | 
			
		||||
                    "guidance_scale": 7.0,
 | 
			
		||||
                })
 | 
			
		||||
 | 
			
		||||
            if model_type in ("sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"):
 | 
			
		||||
            if model_type in ["sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"]:
 | 
			
		||||
                ui_defaults.update({
 | 
			
		||||
                    "guidance_scale": 6.0,
 | 
			
		||||
                    "flow_shift": 8,
 | 
			
		||||
@ -1811,7 +1811,7 @@ def get_default_settings(model_type):
 | 
			
		||||
                })
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
            if model_type in ("phantom_1.3B", "phantom_14B"):
 | 
			
		||||
            if model_type in ["phantom_1.3B", "phantom_14B"]:
 | 
			
		||||
                ui_defaults.update({
 | 
			
		||||
                    "guidance_scale": 7.5,
 | 
			
		||||
                    "flow_shift": 5,
 | 
			
		||||
@ -1820,27 +1820,27 @@ def get_default_settings(model_type):
 | 
			
		||||
                    # "resolution": "1280x720" 
 | 
			
		||||
                })
 | 
			
		||||
 | 
			
		||||
            elif model_type in ("hunyuan_custom"):
 | 
			
		||||
            elif model_type in ["hunyuan_custom"]:
 | 
			
		||||
                ui_defaults.update({
 | 
			
		||||
                    "guidance_scale": 7.5,
 | 
			
		||||
                    "flow_shift": 13,
 | 
			
		||||
                    "resolution": "1280x720",
 | 
			
		||||
                    "video_prompt_type": "I",
 | 
			
		||||
                })
 | 
			
		||||
            elif model_type in ("hunyuan_custom_audio"):
 | 
			
		||||
            elif model_type in ["hunyuan_custom_audio"]:
 | 
			
		||||
                ui_defaults.update({
 | 
			
		||||
                    "guidance_scale": 7.5,
 | 
			
		||||
                    "flow_shift": 13,
 | 
			
		||||
                    "video_prompt_type": "I",
 | 
			
		||||
                })
 | 
			
		||||
            elif model_type in ("hunyuan_custom_edit"):
 | 
			
		||||
            elif model_type in ["hunyuan_custom_edit"]:
 | 
			
		||||
                ui_defaults.update({
 | 
			
		||||
                    "guidance_scale": 7.5,
 | 
			
		||||
                    "flow_shift": 13,
 | 
			
		||||
                    "video_prompt_type": "MVAI",
 | 
			
		||||
                    "sliding_window_size": 129,
 | 
			
		||||
                })
 | 
			
		||||
            elif model_type in ("hunyuan_avatar"):
 | 
			
		||||
            elif model_type in ["hunyuan_avatar"]:
 | 
			
		||||
                ui_defaults.update({
 | 
			
		||||
                    "guidance_scale": 7.5,
 | 
			
		||||
                    "flow_shift": 5,
 | 
			
		||||
@ -1848,7 +1848,7 @@ def get_default_settings(model_type):
 | 
			
		||||
                    "video_length": 129,
 | 
			
		||||
                    "video_prompt_type": "I",
 | 
			
		||||
                })
 | 
			
		||||
            elif model_type in ("vace_14B"):
 | 
			
		||||
            elif model_type in ["vace_14B"]:
 | 
			
		||||
                ui_defaults.update({
 | 
			
		||||
                    "sliding_window_discard_last_frames": 0,
 | 
			
		||||
                })
 | 
			
		||||
@ -2063,8 +2063,8 @@ def download_models(model_filename, model_type):
 | 
			
		||||
 | 
			
		||||
    shared_def = {
 | 
			
		||||
        "repoId" : "DeepBeepMeep/Wan2.1",
 | 
			
		||||
        "sourceFolderList" : [ "pose", "depth", "mask", "wav2vec", ""  ],
 | 
			
		||||
        "fileList" : [ [],["depth_anything_v2_vitl.pth"], ["sam_vit_h_4b8939_fp16.safetensors"], ["config.json", "feature_extractor_config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"],
 | 
			
		||||
        "sourceFolderList" : [ "pose", "scribble", "depth", "mask", "wav2vec", ""  ],
 | 
			
		||||
        "fileList" : [ ["dw-ll_ucoco_384.onnx", "yolox_l.onnx"],["netG_A_latest.pth"], ["depth_anything_v2_vitl.pth"], ["sam_vit_h_4b8939_fp16.safetensors"], ["config.json", "feature_extractor_config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"],
 | 
			
		||||
                [ "flownet.pkl"  ] ]
 | 
			
		||||
    }
 | 
			
		||||
    process_files_def(**shared_def)
 | 
			
		||||
@ -2813,6 +2813,12 @@ def get_preprocessor(process_type, inpaint_color):
 | 
			
		||||
        from preprocessing.gray import GrayVideoAnnotator
 | 
			
		||||
        cfg_dict = {}
 | 
			
		||||
        anno_ins = lambda img: GrayVideoAnnotator(cfg_dict).forward(img)[0]
 | 
			
		||||
    elif process_type=="scribble":
 | 
			
		||||
        from preprocessing.scribble import ScribbleVideoAnnotator
 | 
			
		||||
        cfg_dict = {
 | 
			
		||||
                "PRETRAINED_MODEL": "ckpts/scribble/netG_A_latest.pth"
 | 
			
		||||
            }
 | 
			
		||||
        anno_ins = lambda img: ScribbleVideoAnnotator(cfg_dict).forward(img)[0]
 | 
			
		||||
    elif process_type=="inpaint":
 | 
			
		||||
        anno_ins = lambda img : inpaint_color
 | 
			
		||||
        # anno_ins = lambda img : np.full_like(img, inpaint_color)
 | 
			
		||||
@ -2821,7 +2827,7 @@ def get_preprocessor(process_type, inpaint_color):
 | 
			
		||||
    return anno_ins
 | 
			
		||||
 | 
			
		||||
def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,  max_frames, start_frame=0, fit_canvas = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None):
 | 
			
		||||
    from wan.utils.utils import calculate_new_dimensions
 | 
			
		||||
    from wan.utils.utils import calculate_new_dimensions, get_outpainting_frame_location, get_outpainting_full_area_dimensions
 | 
			
		||||
 | 
			
		||||
    def mask_to_xyxy_box(mask):
 | 
			
		||||
        rows, cols = np.where(mask == 255)
 | 
			
		||||
@ -2837,7 +2843,7 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
 | 
			
		||||
        box = [int(x) for x in box]
 | 
			
		||||
        return box
 | 
			
		||||
    
 | 
			
		||||
    if not input_video_path:
 | 
			
		||||
    if not input_video_path or max_frames <= 0:
 | 
			
		||||
        return None, None
 | 
			
		||||
    any_mask = input_mask_path != None
 | 
			
		||||
    pose_special = "pose" in process_type
 | 
			
		||||
@ -2859,24 +2865,17 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
 | 
			
		||||
    frame_height, frame_width, _ = video[0].shape
 | 
			
		||||
 | 
			
		||||
    if outpainting_dims != None:
 | 
			
		||||
        outpainting_top, outpainting_bottom, outpainting_left, outpainting_right= outpainting_dims
 | 
			
		||||
        if fit_canvas != None:
 | 
			
		||||
            frame_height = int(frame_height * (100 + outpainting_top + outpainting_bottom) / 100)
 | 
			
		||||
            frame_width =  int(frame_width * (100 + outpainting_left + outpainting_right) / 100)
 | 
			
		||||
            frame_height, frame_width = get_outpainting_full_area_dimensions(frame_height,frame_width, outpainting_dims)
 | 
			
		||||
        else:
 | 
			
		||||
            frame_height,frame_width = height, width
 | 
			
		||||
            frame_height, frame_width = height, width
 | 
			
		||||
 | 
			
		||||
    if fit_canvas != None:
 | 
			
		||||
        height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas = fit_canvas, block_size = block_size)
 | 
			
		||||
 | 
			
		||||
    if outpainting_dims != None:
 | 
			
		||||
        final_height, final_width = height, width
 | 
			
		||||
        height = int(height / ((100 + outpainting_top + outpainting_bottom) / 100))
 | 
			
		||||
        width =  int(width / ((100 + outpainting_left + outpainting_right) / 100)) 
 | 
			
		||||
        margin_top = int(outpainting_top/(100 + outpainting_top + outpainting_bottom) * final_height)
 | 
			
		||||
        if (margin_top + height) > final_height or outpainting_bottom == 0: margin_top = final_height - height
 | 
			
		||||
        margin_left = int(outpainting_left/(100 + outpainting_left + outpainting_right) * final_width)
 | 
			
		||||
        if (margin_left + width) > final_width or outpainting_right == 0: margin_left = final_width - width
 | 
			
		||||
        height, width, margin_top, margin_left =  get_outpainting_frame_location(final_height, final_width,  outpainting_dims, 8)        
 | 
			
		||||
 | 
			
		||||
    if any_mask:
 | 
			
		||||
        num_frames = min(len(video), len(mask_video))
 | 
			
		||||
@ -3250,6 +3249,8 @@ def generate_video(
 | 
			
		||||
 | 
			
		||||
    original_image_refs = image_refs 
 | 
			
		||||
    frames_to_inject = []
 | 
			
		||||
    outpainting_dims = None if video_guide_outpainting== None or len(video_guide_outpainting) == 0 or video_guide_outpainting == "0 0 0 0" or video_guide_outpainting.startswith("#") else [int(v) for v in video_guide_outpainting.split(" ")] 
 | 
			
		||||
 | 
			
		||||
    if image_refs != None and len(image_refs) > 0 and (hunyuan_custom or phantom or hunyuan_avatar or vace):
 | 
			
		||||
        frames_positions_list = [ int(pos)-1 for pos in frames_positions.split(" ")] if frames_positions !=None and len(frames_positions)> 0 else []
 | 
			
		||||
        frames_positions_list = frames_positions_list[:len(image_refs)]
 | 
			
		||||
@ -3259,8 +3260,10 @@ def generate_video(
 | 
			
		||||
            for i, pos in enumerate(frames_positions_list):
 | 
			
		||||
                frames_to_inject[pos] = image_refs[i] 
 | 
			
		||||
            if video_guide == None and video_source == None and not "L" in image_prompt_type:
 | 
			
		||||
                from wan.utils.utils import resize_lanczos, calculate_new_dimensions
 | 
			
		||||
                from wan.utils.utils import resize_lanczos, calculate_new_dimensions, get_outpainting_full_area_dimensions
 | 
			
		||||
                w, h = image_refs[0].size
 | 
			
		||||
                if outpainting_dims != None:
 | 
			
		||||
                    h, w = get_outpainting_full_area_dimensions(h,w, outpainting_dims)
 | 
			
		||||
                default_image_size = calculate_new_dimensions(height, width, h, w, fit_canvas)
 | 
			
		||||
                fit_canvas = None
 | 
			
		||||
 | 
			
		||||
@ -3465,6 +3468,11 @@ def generate_video(
 | 
			
		||||
                image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications
 | 
			
		||||
                video_guide_copy = video_guide
 | 
			
		||||
                video_mask_copy = video_mask
 | 
			
		||||
                keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, max_frames_to_generate)
 | 
			
		||||
                if len(error) > 0:
 | 
			
		||||
                    raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
 | 
			
		||||
                keep_frames_parsed = keep_frames_parsed[guide_start_frame: guide_start_frame + current_video_length]
 | 
			
		||||
 | 
			
		||||
                if "V" in video_prompt_type:
 | 
			
		||||
                    extra_label = ""
 | 
			
		||||
                    if "X" in video_prompt_type:
 | 
			
		||||
@ -3472,18 +3480,21 @@ def generate_video(
 | 
			
		||||
                    elif "Y" in video_prompt_type:
 | 
			
		||||
                        process_outside_mask = "depth"
 | 
			
		||||
                        extra_label = " and Depth"
 | 
			
		||||
                    elif "W" in video_prompt_type:
 | 
			
		||||
                        process_outside_mask = "scribble"
 | 
			
		||||
                        extra_label = " and Shapes"
 | 
			
		||||
                    else:
 | 
			
		||||
                        process_outside_mask = None
 | 
			
		||||
                    preprocess_type = None
 | 
			
		||||
                    # if "P" in video_prompt_type and "D" in video_prompt_type :
 | 
			
		||||
                    #     progress_args = [0, get_latest_status(state,"Extracting Open Pose and Depth Information")]
 | 
			
		||||
                    #     preprocess_type = "pose_depth"
 | 
			
		||||
                    if "P" in video_prompt_type :
 | 
			
		||||
                        progress_args = [0, get_latest_status(state,f"Extracting Open Pose{extra_label} Information")]
 | 
			
		||||
                        preprocess_type = "pose"
 | 
			
		||||
                    elif "D" in video_prompt_type :
 | 
			
		||||
                        progress_args = [0, get_latest_status(state,"Extracting Depth Information")]
 | 
			
		||||
                        preprocess_type = "depth"
 | 
			
		||||
                    elif "S" in video_prompt_type :
 | 
			
		||||
                        progress_args = [0, get_latest_status(state,"Extracting Shapes Information")]
 | 
			
		||||
                        preprocess_type = "scribble"
 | 
			
		||||
                    elif "C" in video_prompt_type :
 | 
			
		||||
                        progress_args = [0, get_latest_status(state,f"Extracting Gray Level{extra_label} Information")]
 | 
			
		||||
                        preprocess_type = "gray"
 | 
			
		||||
@ -3497,8 +3508,7 @@ def generate_video(
 | 
			
		||||
                        progress_args = [0, get_latest_status(state,f"Creating Vace Generic{extra_label} Mask")]
 | 
			
		||||
                        preprocess_type = "vace"
 | 
			
		||||
                    send_cmd("progress", progress_args)
 | 
			
		||||
                    outpainting_dims = None if video_guide_outpainting== None or len(video_guide_outpainting) == 0 or video_guide_outpainting == "0 0 0 0" or video_guide_outpainting.startswith("#") else [int(v) for v in video_guide_outpainting.split(" ")] 
 | 
			
		||||
                    video_guide_copy, video_mask_copy = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= current_video_length if guide_start_frame == 0 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps,  process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims )
 | 
			
		||||
                    video_guide_copy, video_mask_copy = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) if guide_start_frame == 0 else len(keep_frames_parsed) - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps,  process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims )
 | 
			
		||||
                    if video_guide_copy != None:
 | 
			
		||||
                        if sample_fit_canvas != None:
 | 
			
		||||
                            image_size = video_guide_copy.shape[-3: -1]
 | 
			
		||||
@ -3506,14 +3516,10 @@ def generate_video(
 | 
			
		||||
                        refresh_preview["video_guide"] = Image.fromarray(video_guide_copy[0].cpu().numpy())
 | 
			
		||||
                        if video_mask_copy != None:                        
 | 
			
		||||
                            refresh_preview["video_mask"] = Image.fromarray(video_mask_copy[0].cpu().numpy())
 | 
			
		||||
                keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, max_frames_to_generate)
 | 
			
		||||
                if len(error) > 0:
 | 
			
		||||
                    raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
 | 
			
		||||
                keep_frames_parsed = keep_frames_parsed[guide_start_frame: guide_start_frame + current_video_length]
 | 
			
		||||
                frames_to_inject_parsed = frames_to_inject[guide_start_frame: guide_start_frame + current_video_length]
 | 
			
		||||
 | 
			
		||||
                src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_copy],
 | 
			
		||||
                                                                        [video_mask_copy ],
 | 
			
		||||
                                                                        [video_mask_copy],
 | 
			
		||||
                                                                        [image_refs_copy], 
 | 
			
		||||
                                                                        current_video_length, image_size = image_size, device ="cpu",
 | 
			
		||||
                                                                        original_video= "O" in video_prompt_type,
 | 
			
		||||
@ -3521,8 +3527,11 @@ def generate_video(
 | 
			
		||||
                                                                        start_frame = guide_start_frame,
 | 
			
		||||
                                                                        pre_src_video = [pre_video_guide],
 | 
			
		||||
                                                                        fit_into_canvas = sample_fit_canvas,
 | 
			
		||||
                                                                        inject_frames= frames_to_inject_parsed, 
 | 
			
		||||
                                                                        inject_frames= frames_to_inject_parsed,
 | 
			
		||||
                                                                        outpainting_dims = outpainting_dims,
 | 
			
		||||
                                                                        )
 | 
			
		||||
                if len(frames_to_inject_parsed):
 | 
			
		||||
                    refresh_preview["image_refs"] = [convert_tensor_to_image(src_video[0], frame_no) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject]  + image_refs[nb_frames_positions:]
 | 
			
		||||
                if sample_fit_canvas != None:
 | 
			
		||||
                    image_size = src_video[0].shape[-2:]
 | 
			
		||||
                    sample_fit_canvas = None
 | 
			
		||||
@ -4912,35 +4921,30 @@ def del_in_sequence(source_str, letters):
 | 
			
		||||
    return ret    
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def refresh_video_prompt_type_image_refs(video_prompt_type, video_prompt_type_image_refs):
 | 
			
		||||
def refresh_video_prompt_type_image_refs(state, video_prompt_type, video_prompt_type_image_refs):
 | 
			
		||||
    video_prompt_type = del_in_sequence(video_prompt_type, "FI")
 | 
			
		||||
    video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_image_refs)
 | 
			
		||||
    visible = "I" in video_prompt_type
 | 
			
		||||
    return video_prompt_type, gr.update(visible = visible),gr.update(visible = visible), gr.update(visible = visible and "F" in video_prompt_type_image_refs)
 | 
			
		||||
    vace = get_base_model_type(state["model_type"]) in ("vace_1.3B","vace_14B") 
 | 
			
		||||
    return video_prompt_type, gr.update(visible = visible),gr.update(visible = visible), gr.update(visible = visible and "F" in video_prompt_type_image_refs), gr.update(visible= ("F" in video_prompt_type_image_refs or "V" in video_prompt_type) and vace )
 | 
			
		||||
 | 
			
		||||
def refresh_video_prompt_type_video_mask(video_prompt_type, video_prompt_type_video_mask):
 | 
			
		||||
    video_prompt_type = del_in_sequence(video_prompt_type, "XYNA")
 | 
			
		||||
    video_prompt_type = del_in_sequence(video_prompt_type, "XWYNA")
 | 
			
		||||
    video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_mask)
 | 
			
		||||
    visible= "A" in video_prompt_type     
 | 
			
		||||
    return video_prompt_type, gr.update(visible= visible), gr.update(visible= visible )
 | 
			
		||||
 | 
			
		||||
def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide):
 | 
			
		||||
    video_prompt_type = del_in_sequence(video_prompt_type, "DPCMUV")
 | 
			
		||||
    video_prompt_type = del_in_sequence(video_prompt_type, "DSPCMUV")
 | 
			
		||||
    video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide)
 | 
			
		||||
    visible = "V" in video_prompt_type
 | 
			
		||||
    mask_visible = visible and "A" in video_prompt_type and not "U" in video_prompt_type
 | 
			
		||||
    vace = get_base_model_type(state["model_type"]) in ("vace_1.3B","vace_14B") 
 | 
			
		||||
    return video_prompt_type, gr.update(visible = visible), gr.update(visible = visible),gr.update(visible= visible and vace), gr.update(visible= visible and not "U" in video_prompt_type), gr.update(visible= mask_visible), gr.update(visible= mask_visible)
 | 
			
		||||
    return video_prompt_type, gr.update(visible = visible), gr.update(visible = visible),gr.update(visible= (visible or "F" in video_prompt_type) and vace), gr.update(visible= visible and not "U" in video_prompt_type), gr.update(visible= mask_visible), gr.update(visible= mask_visible)
 | 
			
		||||
 | 
			
		||||
def refresh_video_prompt_video_guide_trigger(state, video_prompt_type, video_prompt_type_video_guide):
 | 
			
		||||
    video_prompt_type_video_guide = video_prompt_type_video_guide.split("#")[0]
 | 
			
		||||
    video_prompt_type = del_in_sequence(video_prompt_type, "DPCMUV")
 | 
			
		||||
    video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide)
 | 
			
		||||
    visible = "V" in video_prompt_type
 | 
			
		||||
    mask_visible = visible and "A" in video_prompt_type and not "U" in video_prompt_type
 | 
			
		||||
    vace = get_base_model_type(state["model_type"]) in ("vace_1.3B","vace_14B") 
 | 
			
		||||
    return video_prompt_type, video_prompt_type_video_guide, gr.update(visible= visible ),gr.update(visible= visible and vace), gr.update(visible= visible and not "U" in video_prompt_type), gr.update(visible= mask_visible), gr.update(visible= mask_visible)
 | 
			
		||||
 | 
			
		||||
    return refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide)
 | 
			
		||||
 | 
			
		||||
def refresh_preview(state):
 | 
			
		||||
    gen = get_gen_info(state)
 | 
			
		||||
@ -5211,13 +5215,13 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
 | 
			
		||||
                                ("No Control Video", ""),
 | 
			
		||||
                                ("Transfer Human Motion", "PV"),
 | 
			
		||||
                                ("Transfer Depth", "DV"),
 | 
			
		||||
                                # ("Transfer Human Motion & Depth", "DPV"),
 | 
			
		||||
                                ("Recolorize Control Video", "CV"),
 | 
			
		||||
                                ("Transfer Shapes", "SV"),
 | 
			
		||||
                                ("Recolorize", "CV"),
 | 
			
		||||
                                ("Inpainting", "MV"),
 | 
			
		||||
                                ("Vace raw format", "V"),
 | 
			
		||||
                                ("Keep Unchanged", "UV"),
 | 
			
		||||
                            ],
 | 
			
		||||
                            value=filter_letters(video_prompt_type_value, "DPCMUV"),
 | 
			
		||||
                            value=filter_letters(video_prompt_type_value, "DSPCMUV"),
 | 
			
		||||
                            label="Control Video Process", scale = 2, visible= True
 | 
			
		||||
                        )
 | 
			
		||||
                    elif hunyuan_video_custom_edit:
 | 
			
		||||
@ -5226,7 +5230,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
 | 
			
		||||
                                ("Inpaint Control Video", "MV"),
 | 
			
		||||
                                ("Transfer Human Motion", "PMV"),
 | 
			
		||||
                            ],
 | 
			
		||||
                            value=filter_letters(video_prompt_type_value, "DPCMUV"),
 | 
			
		||||
                            value=filter_letters(video_prompt_type_value, "DSPCMUV"),
 | 
			
		||||
                            label="Video to Video", scale = 3, visible= True
 | 
			
		||||
                        )
 | 
			
		||||
                    else:
 | 
			
		||||
@ -5254,8 +5258,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
 | 
			
		||||
                                ("Non Masked Area, rest Inpainted", "XNA"),
 | 
			
		||||
                                ("Masked Area, rest Depth", "YA"),
 | 
			
		||||
                                ("Non Masked Area, rest Depth", "YNA"),
 | 
			
		||||
                                ("Masked Area, rest Shapes", "WA"),
 | 
			
		||||
                                ("Non Masked Area, rest Shapes", "WNA"),
 | 
			
		||||
                            ],
 | 
			
		||||
                            value= filter_letters(video_prompt_type_value, "XYNA"),
 | 
			
		||||
                            value= filter_letters(video_prompt_type_value, "XYWNA"),
 | 
			
		||||
                            visible=  "V" in video_prompt_type_value and not "U" in video_prompt_type_value and not hunyuan_video_custom,
 | 
			
		||||
                            label="Area Processed", scale = 2
 | 
			
		||||
                        )
 | 
			
		||||
@ -5280,11 +5286,11 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
 | 
			
		||||
 
 | 
			
		||||
                video_guide = gr.Video(label= "Control Video", visible= "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None),)
 | 
			
		||||
                keep_frames_video_guide = gr.Text(value=ui_defaults.get("keep_frames_video_guide","") , visible= "V" in video_prompt_type_value, scale = 2, label= "Frames to keep in Control Video (empty=All, 1=first, a:b for a range, space to separate values)" ) #, -1=last
 | 
			
		||||
                with gr.Column(visible= "V" in video_prompt_type_value and vace) as video_guide_outpainting_col:
 | 
			
		||||
                with gr.Column(visible= ("V" in video_prompt_type_value or "F" in video_prompt_type_value) and vace) as video_guide_outpainting_col:
 | 
			
		||||
                    video_guide_outpainting_value = ui_defaults.get("video_guide_outpainting","#")
 | 
			
		||||
                    video_guide_outpainting = gr.Text(value=video_guide_outpainting_value , visible= False)
 | 
			
		||||
                    with gr.Group():
 | 
			
		||||
                        video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Outpainting on Control Video", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") )
 | 
			
		||||
                        video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Outpainting on Control Video or Injected Reference Frames", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") )
 | 
			
		||||
                        with gr.Row(visible = not video_guide_outpainting_value.startswith("#")) as video_guide_outpainting_row:
 | 
			
		||||
                            video_guide_outpainting_value = video_guide_outpainting_value[1:] if video_guide_outpainting_value.startswith("#") else video_guide_outpainting_value
 | 
			
		||||
                            video_guide_outpainting_list = [0] * 4 if len(video_guide_outpainting_value) == 0 else [int(v) for v in video_guide_outpainting_value.split(" ")]
 | 
			
		||||
@ -5298,6 +5304,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
 | 
			
		||||
                mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "A" in video_prompt_type_value and not "U" in video_prompt_type_value )
 | 
			
		||||
 | 
			
		||||
                if (phantom or hunyuan_video_custom) and not "I" in video_prompt_type_value: video_prompt_type_value += "I"
 | 
			
		||||
                if hunyuan_t2v and not "I" in video_prompt_type_value: video_prompt_type_value = del_in_sequence(video_prompt_type_value, "I")
 | 
			
		||||
 | 
			
		||||
                image_refs = gr.Gallery( label ="Start Image" if hunyuan_video_avatar else "Reference Images",
 | 
			
		||||
                        type ="pil",   show_label= True,
 | 
			
		||||
@ -5544,7 +5551,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
 | 
			
		||||
                        gr.Markdown("<B>A Sliding Window allows you to generate video with a duration not limited by the Model</B>")
 | 
			
		||||
                        gr.Markdown("<B>It is automatically turned on if the number of frames to generate is higher than the Window Size</B>")
 | 
			
		||||
                        if diffusion_forcing:
 | 
			
		||||
                            sliding_window_size = gr.Slider(37, 137, value=ui_defaults.get("sliding_window_size", 97), step=20, label="Sliding Window Size (recommended to keep it at 97)")
 | 
			
		||||
                            sliding_window_size = gr.Slider(37, 257, value=ui_defaults.get("sliding_window_size", 97), step=20, label="  (recommended to keep it at 97)")
 | 
			
		||||
                            sliding_window_overlap = gr.Slider(17, 97, value=ui_defaults.get("sliding_window_overlap",17), step=20, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
 | 
			
		||||
                            sliding_window_overlap_noise = gr.Slider(0, 100, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect", visible = True)
 | 
			
		||||
                            sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, visible = False)
 | 
			
		||||
@ -5559,7 +5566,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
 | 
			
		||||
                            sliding_window_overlap_noise = gr.Slider(0, 150, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect", visible = False)
 | 
			
		||||
                            sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True)
 | 
			
		||||
                        else:
 | 
			
		||||
                            sliding_window_size = gr.Slider(5, 137, value=ui_defaults.get("sliding_window_size", 81), step=4, label="Sliding Window Size")
 | 
			
		||||
                            sliding_window_size = gr.Slider(5, 257, value=ui_defaults.get("sliding_window_size", 81), step=4, label="Sliding Window Size")
 | 
			
		||||
                            sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",5), step=4, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
 | 
			
		||||
                            sliding_window_overlap_noise = gr.Slider(0, 150, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect" , visible = True)
 | 
			
		||||
                            sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 8), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True)
 | 
			
		||||
@ -5661,7 +5668,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
 | 
			
		||||
 | 
			
		||||
            image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start, image_end, video_source, keep_frames_video_source] ) 
 | 
			
		||||
            video_prompt_video_guide_trigger.change(fn=refresh_video_prompt_video_guide_trigger, inputs=[state, video_prompt_type, video_prompt_video_guide_trigger], outputs=[video_prompt_type, video_prompt_type_video_guide, video_guide, keep_frames_video_guide, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, mask_expand])
 | 
			
		||||
            video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_images_ref, frames_positions ])
 | 
			
		||||
            video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_images_ref, frames_positions, video_guide_outpainting_col])
 | 
			
		||||
            video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, video_prompt_type, video_prompt_type_video_guide], outputs = [video_prompt_type, video_guide, keep_frames_video_guide, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, mask_expand])
 | 
			
		||||
            video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [video_prompt_type, video_prompt_type_video_mask], outputs = [video_prompt_type, video_mask, mask_expand])
 | 
			
		||||
            multi_prompts_gen_type.select(fn=refresh_prompt_labels, inputs=multi_prompts_gen_type, outputs=[prompt, wizard_prompt])
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user