Multi Vace controlnets and multithreaded preprocessing

This commit is contained in:
DeepBeepMeep 2025-06-23 19:48:19 +02:00
parent a790dd8432
commit 28fe3528e8
18 changed files with 1951 additions and 142 deletions

View File

@ -20,6 +20,16 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
## 🔥 Latest Updates
### June 23 2025: WanGP v6.3, Vace Unleashed. Thought we couldnt squeeze Vace even more ?
- Multithreaded preprocessing when possible for faster generations
- Multithreaded frames Lanczos Upsampling as a bonus
- A new Vace preprocessor : *Flow* to extract fluid motion
- Multi Vace Controlnets: you can now transfer several properties at the same time. This opens new possibilities to explore, for instance if you transfer *Human Movement* and *Shapes* at the same time for some reasons the lighting of your character will take into account much more the environment of your character.
- Injected Frames Outpainting, in case you missed it in WanGP 6.21
Don't know how to use all of the Vace features ? Check the Vace Guide embedded in WanGP as it has also been updated.
### June 19 2025: WanGP v6.2, Vace even more Powercharged
👋 Have I told you that I am a big fan of Vace ? Here are more goodies to unleash its power:
- If you ever wanted to watch Star Wars in 4:3, just use the new *Outpainting* feature and it will add the missing bits of image at the top and the bottom of the screen. The best thing is *Outpainting* can be combined with all the other Vace modifications, for instance you can change the main character of your favorite movie at the same time

View File

@ -20,9 +20,33 @@ def convert_to_numpy(image):
class DepthV2Annotator:
def __init__(self, cfg, device=None):
from .dpt import DepthAnythingV2
# Model configurations for different variants
self.model_configs = {
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
}
# Get model variant from config, default to 'vitl' if not specified
model_variant = cfg.get('MODEL_VARIANT', 'vitl')
if model_variant not in self.model_configs:
raise ValueError(f"Invalid model variant '{model_variant}'. Must be one of: {list(self.model_configs.keys())}")
pretrained_model = cfg['PRETRAINED_MODEL']
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
self.model = DepthAnythingV2(encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024]).to(self.device)
# Get configuration for the selected model variant
config = self.model_configs[model_variant]
# Initialize model with the appropriate configuration
self.model = DepthAnythingV2(
encoder=config['encoder'],
features=config['features'],
out_channels=config['out_channels']
).to(self.device)
self.model.load_state_dict(
torch.load(
pretrained_model,

View File

@ -2,13 +2,15 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import cv2
import torch
import numpy as np
from . import util
from .wholebody import Wholebody, HWC3, resize_image
from PIL import Image
import onnxruntime as ort
from concurrent.futures import ThreadPoolExecutor
import threading
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
@ -23,8 +25,6 @@ def convert_to_numpy(image):
raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.'
return image
def draw_pose(pose, H, W, use_hand=False, use_body=False, use_face=False):
bodies = pose['bodies']
faces = pose['faces']
@ -43,6 +43,56 @@ def draw_pose(pose, H, W, use_hand=False, use_body=False, use_face=False):
return canvas
class OptimizedWholebody:
"""Optimized version of Wholebody for faster serial processing"""
def __init__(self, onnx_det, onnx_pose, device='cuda:0'):
providers = ['CPUExecutionProvider'] if device == 'cpu' else ['CUDAExecutionProvider']
self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
self.device = device
# Pre-allocate session options for better performance
self.session_det.set_providers(providers)
self.session_pose.set_providers(providers)
# Get input names once to avoid repeated lookups
self.det_input_name = self.session_det.get_inputs()[0].name
self.pose_input_name = self.session_pose.get_inputs()[0].name
self.pose_output_names = [out.name for out in self.session_pose.get_outputs()]
def __call__(self, ori_img):
from .onnxdet import inference_detector
from .onnxpose import inference_pose
det_result = inference_detector(self.session_det, ori_img)
keypoints, scores = inference_pose(self.session_pose, det_result, ori_img)
keypoints_info = np.concatenate(
(keypoints, scores[..., None]), axis=-1)
# compute neck joint
neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
# neck score when visualizing pred
neck[:, 2:4] = np.logical_and(
keypoints_info[:, 5, 2:4] > 0.3,
keypoints_info[:, 6, 2:4] > 0.3).astype(int)
new_keypoints_info = np.insert(
keypoints_info, 17, neck, axis=1)
mmpose_idx = [
17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3
]
openpose_idx = [
1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17
]
new_keypoints_info[:, openpose_idx] = \
new_keypoints_info[:, mmpose_idx]
keypoints_info = new_keypoints_info
keypoints, scores = keypoints_info[
..., :2], keypoints_info[..., 2]
return keypoints, scores, det_result
class PoseAnnotator:
def __init__(self, cfg, device=None):
onnx_det = cfg['DETECTION_MODEL']
@ -84,9 +134,7 @@ class PoseAnnotator:
candidate[un_visible] = -1
foot = candidate[:, 18:24]
faces = candidate[:, 24:92]
hands = candidate[:, 92:113]
hands = np.vstack([hands, candidate[:, 113:]])
@ -127,10 +175,24 @@ class PoseAnnotator:
return ret_data, det_result
class OptimizedPoseAnnotator(PoseAnnotator):
"""Optimized version using improved Wholebody class"""
def __init__(self, cfg, device=None):
onnx_det = cfg['DETECTION_MODEL']
onnx_pose = cfg['POSE_MODEL']
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
self.pose_estimation = OptimizedWholebody(onnx_det, onnx_pose, device=self.device)
self.resize_size = cfg.get("RESIZE_SIZE", 1024)
self.use_body = cfg.get('USE_BODY', True)
self.use_face = cfg.get('USE_FACE', True)
self.use_hand = cfg.get('USE_HAND', True)
class PoseBodyFaceAnnotator(PoseAnnotator):
def __init__(self, cfg):
super().__init__(cfg)
self.use_body, self.use_face, self.use_hand = True, True, False
@torch.no_grad()
@torch.inference_mode
def forward(self, image):
@ -138,14 +200,108 @@ class PoseBodyFaceAnnotator(PoseAnnotator):
return ret_data['detected_map_bodyface']
class PoseBodyFaceVideoAnnotator(PoseBodyFaceAnnotator):
class OptimizedPoseBodyFaceVideoAnnotator:
"""Optimized video annotator with multiple optimization strategies"""
def __init__(self, cfg, num_workers=5, chunk_size=8):
self.cfg = cfg
self.num_workers = num_workers
self.chunk_size = chunk_size
self.use_body, self.use_face, self.use_hand = True, True, False
# Initialize one annotator per worker to avoid ONNX session conflicts
self.annotators = []
for _ in range(num_workers):
annotator = OptimizedPoseAnnotator(cfg)
annotator.use_body, annotator.use_face, annotator.use_hand = True, True, False
self.annotators.append(annotator)
self._current_worker = 0
self._worker_lock = threading.Lock()
def _get_annotator(self):
"""Get next available annotator in round-robin fashion"""
with self._worker_lock:
annotator = self.annotators[self._current_worker]
self._current_worker = (self._current_worker + 1) % len(self.annotators)
return annotator
def _process_single_frame(self, frame_data):
"""Process a single frame with error handling"""
frame, frame_idx = frame_data
try:
annotator = self._get_annotator()
# Convert frame
frame = convert_to_numpy(frame)
input_image = HWC3(frame[..., ::-1])
resized_image = resize_image(input_image, annotator.resize_size)
# Process
ret_data, _ = annotator.process(resized_image, frame.shape[:2])
if 'detected_map_bodyface' in ret_data:
return frame_idx, ret_data['detected_map_bodyface']
else:
# Create empty frame if no detection
h, w = frame.shape[:2]
return frame_idx, np.zeros((h, w, 3), dtype=np.uint8)
except Exception as e:
print(f"Error processing frame {frame_idx}: {e}")
# Return empty frame on error
h, w = frame.shape[:2] if hasattr(frame, 'shape') else (480, 640)
return frame_idx, np.zeros((h, w, 3), dtype=np.uint8)
def forward(self, frames):
"""Process video frames with optimizations"""
if len(frames) == 0:
return []
# For small number of frames, use serial processing to avoid threading overhead
if len(frames) <= 4:
annotator = self.annotators[0]
ret_frames = []
for frame in frames:
anno_frame = super().forward(np.array(frame))
ret_frames.append(anno_frame)
frame = convert_to_numpy(frame)
input_image = HWC3(frame[..., ::-1])
resized_image = resize_image(input_image, annotator.resize_size)
ret_data, _ = annotator.process(resized_image, frame.shape[:2])
if 'detected_map_bodyface' in ret_data:
ret_frames.append(ret_data['detected_map_bodyface'])
else:
h, w = frame.shape[:2]
ret_frames.append(np.zeros((h, w, 3), dtype=np.uint8))
return ret_frames
# For larger videos, use parallel processing
frame_data = [(frame, idx) for idx, frame in enumerate(frames)]
results = [None] * len(frames)
# Process in chunks to manage memory
for chunk_start in range(0, len(frame_data), self.chunk_size * self.num_workers):
chunk_end = min(chunk_start + self.chunk_size * self.num_workers, len(frame_data))
chunk_data = frame_data[chunk_start:chunk_end]
with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
chunk_results = list(executor.map(self._process_single_frame, chunk_data))
# Store results in correct order
for frame_idx, result in chunk_results:
results[frame_idx] = result
return results
# Alias for backward compatibility
class PoseBodyFaceVideoAnnotator(OptimizedPoseBodyFaceVideoAnnotator):
"""Backward compatible class name"""
def __init__(self, cfg, num_workers=2, chunk_size=8):
# Use optimized version with conservative settings
super().__init__(cfg, num_workers=num_workers, chunk_size=chunk_size)
# Keep the existing utility functions
import imageio
def save_one_video(file_path, videos, fps=8, quality=8, macro_block_size=None):
@ -161,11 +317,7 @@ def save_one_video(file_path, videos, fps=8, quality=8, macro_block_size=None):
def get_frames(video_path):
frames = []
# Opens the Video file with CV2
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
print("video fps: " + str(fps))
i = 0
@ -175,9 +327,6 @@ def get_frames(video_path):
break
frames.append(frame)
i += 1
cap.release()
cv2.destroyAllWindows()
return frames, fps

58
preprocessing/flow.py Normal file
View File

@ -0,0 +1,58 @@
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import numpy as np
import argparse
from PIL import Image
def convert_to_numpy(image):
if isinstance(image, Image.Image):
image = np.array(image)
elif isinstance(image, torch.Tensor):
image = image.detach().cpu().numpy()
elif isinstance(image, np.ndarray):
image = image.copy()
else:
raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.'
return image
class FlowAnnotator:
def __init__(self, cfg, device=None):
from .raft.raft import RAFT
from .raft.utils.utils import InputPadder
from .raft.utils import flow_viz
params = {
"small": False,
"mixed_precision": False,
"alternate_corr": False
}
params = argparse.Namespace(**params)
pretrained_model = cfg['PRETRAINED_MODEL']
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
self.model = RAFT(params)
self.model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(pretrained_model, map_location="cpu", weights_only=True).items()})
self.model = self.model.to(self.device).eval()
self.InputPadder = InputPadder
self.flow_viz = flow_viz
def forward(self, frames):
# frames / RGB
frames = [torch.from_numpy(convert_to_numpy(frame).astype(np.uint8)).permute(2, 0, 1).float()[None].to(self.device) for frame in frames]
flow_up_list, flow_up_vis_list = [], []
with torch.no_grad():
for i, (image1, image2) in enumerate(zip(frames[:-1], frames[1:])):
padder = self.InputPadder(image1.shape)
image1, image2 = padder.pad(image1, image2)
flow_low, flow_up = self.model(image1, image2, iters=20, test_mode=True)
flow_up = flow_up[0].permute(1, 2, 0).cpu().numpy()
flow_up_vis = self.flow_viz.flow_to_image(flow_up)
flow_up_list.append(flow_up)
flow_up_vis_list.append(flow_up_vis)
return flow_up_list, flow_up_vis_list # RGB
class FlowVisAnnotator(FlowAnnotator):
def forward(self, frames):
flow_up_list, flow_up_vis_list = super().forward(frames)
return flow_up_vis_list[:1] + flow_up_vis_list

View File

View File

@ -0,0 +1,91 @@
import torch
import torch.nn.functional as F
from .utils.utils import bilinear_sampler, coords_grid
try:
import alt_cuda_corr
except:
# alt_cuda_corr is not compiled
pass
class CorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.corr_pyramid = []
# all pairs correlation
corr = CorrBlock.corr(fmap1, fmap2)
batch, h1, w1, dim, h2, w2 = corr.shape
corr = corr.reshape(batch*h1*w1, dim, h2, w2)
self.corr_pyramid.append(corr)
for i in range(self.num_levels-1):
corr = F.avg_pool2d(corr, 2, stride=2)
self.corr_pyramid.append(corr)
def __call__(self, coords):
r = self.radius
coords = coords.permute(0, 2, 3, 1)
batch, h1, w1, _ = coords.shape
out_pyramid = []
for i in range(self.num_levels):
corr = self.corr_pyramid[i]
dx = torch.linspace(-r, r, 2*r+1)
dy = torch.linspace(-r, r, 2*r+1)
delta = torch.stack(torch.meshgrid(dy, dx, indexing='ij'), axis=-1).to(coords.device)
centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
coords_lvl = centroid_lvl + delta_lvl
corr = bilinear_sampler(corr, coords_lvl)
corr = corr.view(batch, h1, w1, -1)
out_pyramid.append(corr)
out = torch.cat(out_pyramid, dim=-1)
return out.permute(0, 3, 1, 2).contiguous().float()
@staticmethod
def corr(fmap1, fmap2):
batch, dim, ht, wd = fmap1.shape
fmap1 = fmap1.view(batch, dim, ht*wd)
fmap2 = fmap2.view(batch, dim, ht*wd)
corr = torch.matmul(fmap1.transpose(1,2), fmap2)
corr = corr.view(batch, ht, wd, 1, ht, wd)
return corr / torch.sqrt(torch.tensor(dim).float())
class AlternateCorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.pyramid = [(fmap1, fmap2)]
for i in range(self.num_levels):
fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
self.pyramid.append((fmap1, fmap2))
def __call__(self, coords):
coords = coords.permute(0, 2, 3, 1)
B, H, W, _ = coords.shape
dim = self.pyramid[0][0].shape[1]
corr_list = []
for i in range(self.num_levels):
r = self.radius
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
corr_list.append(corr.squeeze(1))
corr = torch.stack(corr_list, dim=1)
corr = corr.reshape(B, -1, H, W)
return corr / torch.sqrt(torch.tensor(dim).float())

View File

@ -0,0 +1,235 @@
# Data loading based on https://github.com/NVIDIA/flownet2-pytorch
import numpy as np
import torch
import torch.utils.data as data
import torch.nn.functional as F
import os
import math
import random
from glob import glob
import os.path as osp
from .utils import frame_utils
from .utils.augmentor import FlowAugmentor, SparseFlowAugmentor
class FlowDataset(data.Dataset):
def __init__(self, aug_params=None, sparse=False):
self.augmentor = None
self.sparse = sparse
if aug_params is not None:
if sparse:
self.augmentor = SparseFlowAugmentor(**aug_params)
else:
self.augmentor = FlowAugmentor(**aug_params)
self.is_test = False
self.init_seed = False
self.flow_list = []
self.image_list = []
self.extra_info = []
def __getitem__(self, index):
if self.is_test:
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
img1 = np.array(img1).astype(np.uint8)[..., :3]
img2 = np.array(img2).astype(np.uint8)[..., :3]
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
return img1, img2, self.extra_info[index]
if not self.init_seed:
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
torch.manual_seed(worker_info.id)
np.random.seed(worker_info.id)
random.seed(worker_info.id)
self.init_seed = True
index = index % len(self.image_list)
valid = None
if self.sparse:
flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
else:
flow = frame_utils.read_gen(self.flow_list[index])
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
flow = np.array(flow).astype(np.float32)
img1 = np.array(img1).astype(np.uint8)
img2 = np.array(img2).astype(np.uint8)
# grayscale images
if len(img1.shape) == 2:
img1 = np.tile(img1[...,None], (1, 1, 3))
img2 = np.tile(img2[...,None], (1, 1, 3))
else:
img1 = img1[..., :3]
img2 = img2[..., :3]
if self.augmentor is not None:
if self.sparse:
img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
else:
img1, img2, flow = self.augmentor(img1, img2, flow)
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
flow = torch.from_numpy(flow).permute(2, 0, 1).float()
if valid is not None:
valid = torch.from_numpy(valid)
else:
valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
return img1, img2, flow, valid.float()
def __rmul__(self, v):
self.flow_list = v * self.flow_list
self.image_list = v * self.image_list
return self
def __len__(self):
return len(self.image_list)
class MpiSintel(FlowDataset):
def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):
super(MpiSintel, self).__init__(aug_params)
flow_root = osp.join(root, split, 'flow')
image_root = osp.join(root, split, dstype)
if split == 'test':
self.is_test = True
for scene in os.listdir(image_root):
image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
for i in range(len(image_list)-1):
self.image_list += [ [image_list[i], image_list[i+1]] ]
self.extra_info += [ (scene, i) ] # scene and frame_id
if split != 'test':
self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
class FlyingChairs(FlowDataset):
def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
super(FlyingChairs, self).__init__(aug_params)
images = sorted(glob(osp.join(root, '*.ppm')))
flows = sorted(glob(osp.join(root, '*.flo')))
assert (len(images)//2 == len(flows))
split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
for i in range(len(flows)):
xid = split_list[i]
if (split=='training' and xid==1) or (split=='validation' and xid==2):
self.flow_list += [ flows[i] ]
self.image_list += [ [images[2*i], images[2*i+1]] ]
class FlyingThings3D(FlowDataset):
def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'):
super(FlyingThings3D, self).__init__(aug_params)
for cam in ['left']:
for direction in ['into_future', 'into_past']:
image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))
image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
for idir, fdir in zip(image_dirs, flow_dirs):
images = sorted(glob(osp.join(idir, '*.png')) )
flows = sorted(glob(osp.join(fdir, '*.pfm')) )
for i in range(len(flows)-1):
if direction == 'into_future':
self.image_list += [ [images[i], images[i+1]] ]
self.flow_list += [ flows[i] ]
elif direction == 'into_past':
self.image_list += [ [images[i+1], images[i]] ]
self.flow_list += [ flows[i+1] ]
class KITTI(FlowDataset):
def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
super(KITTI, self).__init__(aug_params, sparse=True)
if split == 'testing':
self.is_test = True
root = osp.join(root, split)
images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
for img1, img2 in zip(images1, images2):
frame_id = img1.split('/')[-1]
self.extra_info += [ [frame_id] ]
self.image_list += [ [img1, img2] ]
if split == 'training':
self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
class HD1K(FlowDataset):
def __init__(self, aug_params=None, root='datasets/HD1k'):
super(HD1K, self).__init__(aug_params, sparse=True)
seq_ix = 0
while 1:
flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
if len(flows) == 0:
break
for i in range(len(flows)-1):
self.flow_list += [flows[i]]
self.image_list += [ [images[i], images[i+1]] ]
seq_ix += 1
def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
""" Create the data loader for the corresponding trainign set """
if args.stage == 'chairs':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
train_dataset = FlyingChairs(aug_params, split='training')
elif args.stage == 'things':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
train_dataset = clean_dataset + final_dataset
elif args.stage == 'sintel':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
sintel_final = MpiSintel(aug_params, split='training', dstype='final')
if TRAIN_DS == 'C+T+K+S+H':
kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
elif TRAIN_DS == 'C+T+K/S':
train_dataset = 100*sintel_clean + 100*sintel_final + things
elif args.stage == 'kitti':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
train_dataset = KITTI(aug_params, split='training')
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
print('Training with %d image pairs' % len(train_dataset))
return train_loader

View File

@ -0,0 +1,267 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResidualBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(planes)
self.norm2 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm3 = nn.BatchNorm2d(planes)
elif norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(planes)
self.norm2 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm3 = nn.InstanceNorm2d(planes)
elif norm_fn == 'none':
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
if not stride == 1:
self.norm3 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x+y)
class BottleneckBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
super(BottleneckBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(planes//4)
self.norm2 = nn.BatchNorm2d(planes//4)
self.norm3 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm4 = nn.BatchNorm2d(planes)
elif norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(planes//4)
self.norm2 = nn.InstanceNorm2d(planes//4)
self.norm3 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm4 = nn.InstanceNorm2d(planes)
elif norm_fn == 'none':
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
self.norm3 = nn.Sequential()
if not stride == 1:
self.norm4 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
y = self.relu(self.norm3(self.conv3(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x+y)
class BasicEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
super(BasicEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
elif self.norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(64)
elif self.norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(64)
elif self.norm_fn == 'none':
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 64
self.layer1 = self._make_layer(64, stride=1)
self.layer2 = self._make_layer(96, stride=2)
self.layer3 = self._make_layer(128, stride=2)
# output convolution
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
if self.training and self.dropout is not None:
x = self.dropout(x)
if is_list:
x = torch.split(x, [batch_dim, batch_dim], dim=0)
return x
class SmallEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
super(SmallEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
elif self.norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(32)
elif self.norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(32)
elif self.norm_fn == 'none':
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 32
self.layer1 = self._make_layer(32, stride=1)
self.layer2 = self._make_layer(64, stride=2)
self.layer3 = self._make_layer(96, stride=2)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
if self.training and self.dropout is not None:
x = self.dropout(x)
if is_list:
x = torch.split(x, [batch_dim, batch_dim], dim=0)
return x

144
preprocessing/raft/raft.py Normal file
View File

@ -0,0 +1,144 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .update import BasicUpdateBlock, SmallUpdateBlock
from .extractor import BasicEncoder, SmallEncoder
from .corr import CorrBlock, AlternateCorrBlock
from .utils.utils import bilinear_sampler, coords_grid, upflow8
try:
autocast = torch.amp.autocast
except:
# dummy autocast for PyTorch < 1.6
class autocast:
def __init__(self, enabled):
pass
def __enter__(self):
pass
def __exit__(self, *args):
pass
class RAFT(nn.Module):
def __init__(self, args):
super(RAFT, self).__init__()
self.args = args
if args.small:
self.hidden_dim = hdim = 96
self.context_dim = cdim = 64
args.corr_levels = 4
args.corr_radius = 3
else:
self.hidden_dim = hdim = 128
self.context_dim = cdim = 128
args.corr_levels = 4
args.corr_radius = 4
if 'dropout' not in self.args:
self.args.dropout = 0
if 'alternate_corr' not in self.args:
self.args.alternate_corr = False
# feature network, context network, and update block
if args.small:
self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
else:
self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
def initialize_flow(self, img):
""" Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
N, C, H, W = img.shape
coords0 = coords_grid(N, H//8, W//8).to(img.device)
coords1 = coords_grid(N, H//8, W//8).to(img.device)
# optical flow computed as difference: flow = coords1 - coords0
return coords0, coords1
def upsample_flow(self, flow, mask):
""" Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
N, _, H, W = flow.shape
mask = mask.view(N, 1, 9, 8, 8, H, W)
mask = torch.softmax(mask, dim=2)
up_flow = F.unfold(8 * flow, [3,3], padding=1)
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
up_flow = torch.sum(mask * up_flow, dim=2)
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
return up_flow.reshape(N, 2, 8*H, 8*W)
def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
""" Estimate optical flow between pair of frames """
image1 = 2 * (image1 / 255.0) - 1.0
image2 = 2 * (image2 / 255.0) - 1.0
image1 = image1.contiguous()
image2 = image2.contiguous()
hdim = self.hidden_dim
cdim = self.context_dim
# run the feature network
with autocast('cuda', enabled=self.args.mixed_precision):
fmap1, fmap2 = self.fnet([image1, image2])
fmap1 = fmap1.float()
fmap2 = fmap2.float()
if self.args.alternate_corr:
corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
else:
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
# run the context network
with autocast('cuda', enabled=self.args.mixed_precision):
cnet = self.cnet(image1)
net, inp = torch.split(cnet, [hdim, cdim], dim=1)
net = torch.tanh(net)
inp = torch.relu(inp)
coords0, coords1 = self.initialize_flow(image1)
if flow_init is not None:
coords1 = coords1 + flow_init
flow_predictions = []
for itr in range(iters):
coords1 = coords1.detach()
corr = corr_fn(coords1) # index correlation volume
flow = coords1 - coords0
with autocast('cuda', enabled=self.args.mixed_precision):
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
# F(t+1) = F(t) + \Delta(t)
coords1 = coords1 + delta_flow
# upsample predictions
if up_mask is None:
flow_up = upflow8(coords1 - coords0)
else:
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
flow_predictions.append(flow_up)
if test_mode:
return coords1 - coords0, flow_up
return flow_predictions

View File

@ -0,0 +1,139 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class FlowHead(nn.Module):
def __init__(self, input_dim=128, hidden_dim=256):
super(FlowHead, self).__init__()
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.conv2(self.relu(self.conv1(x)))
class ConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192+128):
super(ConvGRU, self).__init__()
self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
def forward(self, h, x):
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz(hx))
r = torch.sigmoid(self.convr(hx))
q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
h = (1-z) * h + z * q
return h
class SepConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192+128):
super(SepConvGRU, self).__init__()
self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
def forward(self, h, x):
# horizontal
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz1(hx))
r = torch.sigmoid(self.convr1(hx))
q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
h = (1-z) * h + z * q
# vertical
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz2(hx))
r = torch.sigmoid(self.convr2(hx))
q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
h = (1-z) * h + z * q
return h
class SmallMotionEncoder(nn.Module):
def __init__(self, args):
super(SmallMotionEncoder, self).__init__()
cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
self.conv = nn.Conv2d(128, 80, 3, padding=1)
def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
class BasicMotionEncoder(nn.Module):
def __init__(self, args):
super(BasicMotionEncoder, self).__init__()
cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
cor = F.relu(self.convc2(cor))
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
class SmallUpdateBlock(nn.Module):
def __init__(self, args, hidden_dim=96):
super(SmallUpdateBlock, self).__init__()
self.encoder = SmallMotionEncoder(args)
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
def forward(self, net, inp, corr, flow):
motion_features = self.encoder(flow, corr)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
return net, None, delta_flow
class BasicUpdateBlock(nn.Module):
def __init__(self, args, hidden_dim=128, input_dim=128):
super(BasicUpdateBlock, self).__init__()
self.args = args
self.encoder = BasicMotionEncoder(args)
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
self.mask = nn.Sequential(
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 64*9, 1, padding=0))
def forward(self, net, inp, corr, flow, upsample=True):
motion_features = self.encoder(flow, corr)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
# scale mask to balence gradients
mask = .25 * self.mask(net)
return net, mask, delta_flow

View File

View File

@ -0,0 +1,246 @@
import numpy as np
import random
import math
from PIL import Image
import cv2
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)
import torch
from torchvision.transforms import ColorJitter
import torch.nn.functional as F
class FlowAugmentor:
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
# spatial augmentation params
self.crop_size = crop_size
self.min_scale = min_scale
self.max_scale = max_scale
self.spatial_aug_prob = 0.8
self.stretch_prob = 0.8
self.max_stretch = 0.2
# flip augmentation params
self.do_flip = do_flip
self.h_flip_prob = 0.5
self.v_flip_prob = 0.1
# photometric augmentation params
self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
self.asymmetric_color_aug_prob = 0.2
self.eraser_aug_prob = 0.5
def color_transform(self, img1, img2):
""" Photometric augmentation """
# asymmetric
if np.random.rand() < self.asymmetric_color_aug_prob:
img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
# symmetric
else:
image_stack = np.concatenate([img1, img2], axis=0)
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
img1, img2 = np.split(image_stack, 2, axis=0)
return img1, img2
def eraser_transform(self, img1, img2, bounds=[50, 100]):
""" Occlusion augmentation """
ht, wd = img1.shape[:2]
if np.random.rand() < self.eraser_aug_prob:
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
for _ in range(np.random.randint(1, 3)):
x0 = np.random.randint(0, wd)
y0 = np.random.randint(0, ht)
dx = np.random.randint(bounds[0], bounds[1])
dy = np.random.randint(bounds[0], bounds[1])
img2[y0:y0+dy, x0:x0+dx, :] = mean_color
return img1, img2
def spatial_transform(self, img1, img2, flow):
# randomly sample scale
ht, wd = img1.shape[:2]
min_scale = np.maximum(
(self.crop_size[0] + 8) / float(ht),
(self.crop_size[1] + 8) / float(wd))
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
scale_x = scale
scale_y = scale
if np.random.rand() < self.stretch_prob:
scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
scale_x = np.clip(scale_x, min_scale, None)
scale_y = np.clip(scale_y, min_scale, None)
if np.random.rand() < self.spatial_aug_prob:
# rescale the images
img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
flow = flow * [scale_x, scale_y]
if self.do_flip:
if np.random.rand() < self.h_flip_prob: # h-flip
img1 = img1[:, ::-1]
img2 = img2[:, ::-1]
flow = flow[:, ::-1] * [-1.0, 1.0]
if np.random.rand() < self.v_flip_prob: # v-flip
img1 = img1[::-1, :]
img2 = img2[::-1, :]
flow = flow[::-1, :] * [1.0, -1.0]
y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
return img1, img2, flow
def __call__(self, img1, img2, flow):
img1, img2 = self.color_transform(img1, img2)
img1, img2 = self.eraser_transform(img1, img2)
img1, img2, flow = self.spatial_transform(img1, img2, flow)
img1 = np.ascontiguousarray(img1)
img2 = np.ascontiguousarray(img2)
flow = np.ascontiguousarray(flow)
return img1, img2, flow
class SparseFlowAugmentor:
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
# spatial augmentation params
self.crop_size = crop_size
self.min_scale = min_scale
self.max_scale = max_scale
self.spatial_aug_prob = 0.8
self.stretch_prob = 0.8
self.max_stretch = 0.2
# flip augmentation params
self.do_flip = do_flip
self.h_flip_prob = 0.5
self.v_flip_prob = 0.1
# photometric augmentation params
self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
self.asymmetric_color_aug_prob = 0.2
self.eraser_aug_prob = 0.5
def color_transform(self, img1, img2):
image_stack = np.concatenate([img1, img2], axis=0)
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
img1, img2 = np.split(image_stack, 2, axis=0)
return img1, img2
def eraser_transform(self, img1, img2):
ht, wd = img1.shape[:2]
if np.random.rand() < self.eraser_aug_prob:
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
for _ in range(np.random.randint(1, 3)):
x0 = np.random.randint(0, wd)
y0 = np.random.randint(0, ht)
dx = np.random.randint(50, 100)
dy = np.random.randint(50, 100)
img2[y0:y0+dy, x0:x0+dx, :] = mean_color
return img1, img2
def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
ht, wd = flow.shape[:2]
coords = np.meshgrid(np.arange(wd), np.arange(ht), indexing='ij')
coords = np.stack(coords, axis=-1)
coords = coords.reshape(-1, 2).astype(np.float32)
flow = flow.reshape(-1, 2).astype(np.float32)
valid = valid.reshape(-1).astype(np.float32)
coords0 = coords[valid>=1]
flow0 = flow[valid>=1]
ht1 = int(round(ht * fy))
wd1 = int(round(wd * fx))
coords1 = coords0 * [fx, fy]
flow1 = flow0 * [fx, fy]
xx = np.round(coords1[:,0]).astype(np.int32)
yy = np.round(coords1[:,1]).astype(np.int32)
v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
xx = xx[v]
yy = yy[v]
flow1 = flow1[v]
flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
valid_img = np.zeros([ht1, wd1], dtype=np.int32)
flow_img[yy, xx] = flow1
valid_img[yy, xx] = 1
return flow_img, valid_img
def spatial_transform(self, img1, img2, flow, valid):
# randomly sample scale
ht, wd = img1.shape[:2]
min_scale = np.maximum(
(self.crop_size[0] + 1) / float(ht),
(self.crop_size[1] + 1) / float(wd))
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
scale_x = np.clip(scale, min_scale, None)
scale_y = np.clip(scale, min_scale, None)
if np.random.rand() < self.spatial_aug_prob:
# rescale the images
img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
if self.do_flip:
if np.random.rand() < 0.5: # h-flip
img1 = img1[:, ::-1]
img2 = img2[:, ::-1]
flow = flow[:, ::-1] * [-1.0, 1.0]
valid = valid[:, ::-1]
margin_y = 20
margin_x = 50
y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
return img1, img2, flow, valid
def __call__(self, img1, img2, flow, valid):
img1, img2 = self.color_transform(img1, img2)
img1, img2 = self.eraser_transform(img1, img2)
img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
img1 = np.ascontiguousarray(img1)
img2 = np.ascontiguousarray(img2)
flow = np.ascontiguousarray(flow)
valid = np.ascontiguousarray(valid)
return img1, img2, flow, valid

View File

@ -0,0 +1,132 @@
# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
# MIT License
#
# Copyright (c) 2018 Tom Runia
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to conditions.
#
# Author: Tom Runia
# Date Created: 2018-08-03
import numpy as np
def make_colorwheel():
"""
Generates a color wheel for optical flow visualization as presented in:
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
Code follows the original C++ source code of Daniel Scharstein.
Code follows the the Matlab source code of Deqing Sun.
Returns:
np.ndarray: Color wheel
"""
RY = 15
YG = 6
GC = 4
CB = 11
BM = 13
MR = 6
ncols = RY + YG + GC + CB + BM + MR
colorwheel = np.zeros((ncols, 3))
col = 0
# RY
colorwheel[0:RY, 0] = 255
colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
col = col+RY
# YG
colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
colorwheel[col:col+YG, 1] = 255
col = col+YG
# GC
colorwheel[col:col+GC, 1] = 255
colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
col = col+GC
# CB
colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
colorwheel[col:col+CB, 2] = 255
col = col+CB
# BM
colorwheel[col:col+BM, 2] = 255
colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
col = col+BM
# MR
colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
colorwheel[col:col+MR, 0] = 255
return colorwheel
def flow_uv_to_colors(u, v, convert_to_bgr=False):
"""
Applies the flow color wheel to (possibly clipped) flow components u and v.
According to the C++ source code of Daniel Scharstein
According to the Matlab source code of Deqing Sun
Args:
u (np.ndarray): Input horizontal flow of shape [H,W]
v (np.ndarray): Input vertical flow of shape [H,W]
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
Returns:
np.ndarray: Flow visualization image of shape [H,W,3]
"""
flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
colorwheel = make_colorwheel() # shape [55x3]
ncols = colorwheel.shape[0]
rad = np.sqrt(np.square(u) + np.square(v))
a = np.arctan2(-v, -u)/np.pi
fk = (a+1) / 2*(ncols-1)
k0 = np.floor(fk).astype(np.int32)
k1 = k0 + 1
k1[k1 == ncols] = 0
f = fk - k0
for i in range(colorwheel.shape[1]):
tmp = colorwheel[:,i]
col0 = tmp[k0] / 255.0
col1 = tmp[k1] / 255.0
col = (1-f)*col0 + f*col1
idx = (rad <= 1)
col[idx] = 1 - rad[idx] * (1-col[idx])
col[~idx] = col[~idx] * 0.75 # out of range
# Note the 2-i => BGR instead of RGB
ch_idx = 2-i if convert_to_bgr else i
flow_image[:,:,ch_idx] = np.floor(255 * col)
return flow_image
def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
"""
Expects a two dimensional flow image of shape.
Args:
flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
Returns:
np.ndarray: Flow visualization image of shape [H,W,3]
"""
assert flow_uv.ndim == 3, 'input flow must have three dimensions'
assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
if clip_flow is not None:
flow_uv = np.clip(flow_uv, 0, clip_flow)
u = flow_uv[:,:,0]
v = flow_uv[:,:,1]
rad = np.sqrt(np.square(u) + np.square(v))
rad_max = np.max(rad)
epsilon = 1e-5
u = u / (rad_max + epsilon)
v = v / (rad_max + epsilon)
return flow_uv_to_colors(u, v, convert_to_bgr)

View File

@ -0,0 +1,137 @@
import numpy as np
from PIL import Image
from os.path import *
import re
import cv2
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)
TAG_CHAR = np.array([202021.25], np.float32)
def readFlow(fn):
""" Read .flo file in Middlebury format"""
# Code adapted from:
# http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
# WARNING: this will work on little-endian architectures (eg Intel x86) only!
# print 'fn = %s'%(fn)
with open(fn, 'rb') as f:
magic = np.fromfile(f, np.float32, count=1)
if 202021.25 != magic:
print('Magic number incorrect. Invalid .flo file')
return None
else:
w = np.fromfile(f, np.int32, count=1)
h = np.fromfile(f, np.int32, count=1)
# print 'Reading %d x %d flo file\n' % (w, h)
data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
# Reshape data into 3D array (columns, rows, bands)
# The reshape here is for visualization, the original code is (w,h,2)
return np.resize(data, (int(h), int(w), 2))
def readPFM(file):
file = open(file, 'rb')
color = None
width = None
height = None
scale = None
endian = None
header = file.readline().rstrip()
if header == b'PF':
color = True
elif header == b'Pf':
color = False
else:
raise Exception('Not a PFM file.')
dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
if dim_match:
width, height = map(int, dim_match.groups())
else:
raise Exception('Malformed PFM header.')
scale = float(file.readline().rstrip())
if scale < 0: # little-endian
endian = '<'
scale = -scale
else:
endian = '>' # big-endian
data = np.fromfile(file, endian + 'f')
shape = (height, width, 3) if color else (height, width)
data = np.reshape(data, shape)
data = np.flipud(data)
return data
def writeFlow(filename,uv,v=None):
""" Write optical flow to file.
If v is None, uv is assumed to contain both u and v channels,
stacked in depth.
Original code by Deqing Sun, adapted from Daniel Scharstein.
"""
nBands = 2
if v is None:
assert(uv.ndim == 3)
assert(uv.shape[2] == 2)
u = uv[:,:,0]
v = uv[:,:,1]
else:
u = uv
assert(u.shape == v.shape)
height,width = u.shape
f = open(filename,'wb')
# write the header
f.write(TAG_CHAR)
np.array(width).astype(np.int32).tofile(f)
np.array(height).astype(np.int32).tofile(f)
# arrange into matrix form
tmp = np.zeros((height, width*nBands))
tmp[:,np.arange(width)*2] = u
tmp[:,np.arange(width)*2 + 1] = v
tmp.astype(np.float32).tofile(f)
f.close()
def readFlowKITTI(filename):
flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
flow = flow[:,:,::-1].astype(np.float32)
flow, valid = flow[:, :, :2], flow[:, :, 2]
flow = (flow - 2**15) / 64.0
return flow, valid
def readDispKITTI(filename):
disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
valid = disp > 0.0
flow = np.stack([-disp, np.zeros_like(disp)], -1)
return flow, valid
def writeFlowKITTI(filename, uv):
uv = 64.0 * uv + 2**15
valid = np.ones([uv.shape[0], uv.shape[1], 1])
uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
cv2.imwrite(filename, uv[..., ::-1])
def read_gen(file_name, pil=False):
ext = splitext(file_name)[-1]
if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
return Image.open(file_name)
elif ext == '.bin' or ext == '.raw':
return np.load(file_name)
elif ext == '.flo':
return readFlow(file_name).astype(np.float32)
elif ext == '.pfm':
flow = readPFM(file_name).astype(np.float32)
if len(flow.shape) == 2:
return flow
else:
return flow[:, :, :-1]
return []

View File

@ -0,0 +1,82 @@
import torch
import torch.nn.functional as F
import numpy as np
from scipy import interpolate
class InputPadder:
""" Pads images such that dimensions are divisible by 8 """
def __init__(self, dims, mode='sintel'):
self.ht, self.wd = dims[-2:]
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
if mode == 'sintel':
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
else:
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
def pad(self, *inputs):
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
def unpad(self,x):
ht, wd = x.shape[-2:]
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
return x[..., c[0]:c[1], c[2]:c[3]]
def forward_interpolate(flow):
flow = flow.detach().cpu().numpy()
dx, dy = flow[0], flow[1]
ht, wd = dx.shape
x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht), indexing='ij')
x1 = x0 + dx
y1 = y0 + dy
x1 = x1.reshape(-1)
y1 = y1.reshape(-1)
dx = dx.reshape(-1)
dy = dy.reshape(-1)
valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
x1 = x1[valid]
y1 = y1[valid]
dx = dx[valid]
dy = dy[valid]
flow_x = interpolate.griddata(
(x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
flow_y = interpolate.griddata(
(x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
flow = np.stack([flow_x, flow_y], axis=0)
return torch.from_numpy(flow).float()
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
""" Wrapper for grid_sample, uses pixel coordinates """
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1,1], dim=-1)
xgrid = 2*xgrid/(W-1) - 1
ygrid = 2*ygrid/(H-1) - 1
grid = torch.cat([xgrid, ygrid], dim=-1)
img = F.grid_sample(img, grid, align_corners=True)
if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
return img, mask.float()
return img
def coords_grid(batch, ht, wd):
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd), indexing='ij')
coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch, 1, 1, 1)
def upflow8(flow, mode='bilinear'):
new_size = (8 * flow.shape[2], 8 * flow.shape[3])
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)

View File

@ -417,7 +417,7 @@ class WanAttentionBlock(nn.Module):
freqs,
context,
hints= None,
context_scale=1.0,
context_scale=[1.0],
cam_emb= None,
block_mask = None,
audio_proj= None,
@ -431,7 +431,7 @@ class WanAttentionBlock(nn.Module):
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
hint = None
hints_processed = None
attention_dtype = self.self_attn.q.weight.dtype
dtype = x.dtype
@ -442,10 +442,13 @@ class WanAttentionBlock(nn.Module):
"context" : context,
"e" : e,
}
if self.block_id == 0:
hint = self.vace(hints, x, **kwargs)
hints_processed= []
for scale, hint in zip(context_scale, hints):
if scale == 0:
hints_processed.append(None)
else:
hint = self.vace(hints, None, **kwargs)
hints_processed.append(self.vace(hint, x, **kwargs) if self.block_id == 0 else self.vace(hint, None, **kwargs))
latent_frames = e.shape[0]
e = (self.modulation + e).chunk(6, dim=1)
# self-attention
@ -506,11 +509,13 @@ class WanAttentionBlock(nn.Module):
x.addcmul_(y, e[5])
x, y = reshape_latent(x , 1), reshape_latent(y , 1)
if hint is not None:
if context_scale == 1:
if hints_processed is not None:
for hint, scale in zip(hints_processed, context_scale):
if scale != 0:
if scale == 1:
x.add_(hint)
else:
x.add_(hint, alpha= context_scale)
x.add_(hint, alpha= scale)
return x
@ -605,8 +610,8 @@ class MLPProj(torch.nn.Module):
class WanModel(ModelMixin, ConfigMixin):
def setup_chipmunk(self):
from chipmunk.util import LayerCounter
from chipmunk.modules import SparseDiffMlp, SparseDiffAttn
# from chipmunk.util import LayerCounter
# from chipmunk.modules import SparseDiffMlp, SparseDiffAttn
seq_shape = (21, 45, 80)
chipmunk_layers =[]
for i in range(self.num_layers):
@ -941,7 +946,7 @@ class WanModel(ModelMixin, ConfigMixin):
t,
context,
vace_context = None,
vace_context_scale=1.0,
vace_context_scale=[1.0],
clip_fea=None,
y=None,
freqs = None,
@ -972,7 +977,7 @@ class WanModel(ModelMixin, ConfigMixin):
chipmunk = offload.shared_state.get("_chipmunk", False)
if chipmunk:
from chipmunk.ops.voxel import voxel_chunk_no_padding, reverse_voxel_chunk_no_padding
# from chipmunk.ops.voxel import voxel_chunk_no_padding, reverse_voxel_chunk_no_padding
voxel_shape = (4, 6, 8)
x_list = x
@ -1065,10 +1070,10 @@ class WanModel(ModelMixin, ConfigMixin):
# Vace embeddings
c = [self.vace_patch_embedding(u.to(self.vace_patch_embedding.weight.dtype).unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c]
c = c[0]
# c = c[0]
c = [ [sub_c] for sub_c in c]
kwargs['context_scale'] = vace_context_scale
hints_list = [ [c] for _ in range(len(x_list)) ]
hints_list = [ c ]* len(x_list)
del c
should_calc = True

View File

@ -132,7 +132,7 @@ class WanT2V:
inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
inactive = self.vae.encode(inactive, tile_size = tile_size)
self.toto = inactive[0].clone()
if overlapped_latents != None :
# inactive[0][:, 0:1] = self.vae.encode([frames[0][:, 0:1]], tile_size = tile_size)[0] # redundant
inactive[0][:, 1:overlapped_latents.shape[1] + 1] = overlapped_latents
@ -220,7 +220,7 @@ class WanT2V:
ref_img = white_canvas
return ref_img.to(device)
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, original_video = False, keep_frames= [], start_frame = 0, fit_into_canvas = None, pre_src_video = None, inject_frames = [], outpainting_dims = None):
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, keep_frames= [], start_frame = 0, fit_into_canvas = None, pre_src_video = None, inject_frames = [], outpainting_dims = None):
image_sizes = []
trim_video = len(keep_frames)
def conv_tensor(t, device):
@ -254,7 +254,7 @@ class WanT2V:
image_sizes.append(image_size)
else:
src_video[i] = conv_tensor(sub_src_video[:num_frames], device)
src_mask[i] = torch.zeros_like(src_video[i], device=device) if original_video else torch.ones_like(src_video[i], device=device)
src_mask[i] = torch.ones_like(src_video[i], device=device)
if prepend_count > 0:
src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1)
@ -278,15 +278,15 @@ class WanT2V:
if ref_images is not None:
image_size = image_sizes[i]
for j, ref_img in enumerate(ref_images):
if ref_img is not None:
if ref_img is not None and not torch.is_tensor(ref_img):
src_ref_images[i][j] = self.fit_image_into_canvas(ref_img, image_size, 1, device)
return src_video, src_mask, src_ref_images
def decode_latent(self, zs, ref_images=None, tile_size= 0 ):
if ref_images is None:
ref_images = [None] * len(zs)
else:
assert len(zs) == len(ref_images)
# else:
# assert len(zs) == len(ref_images)
trimed_zs = []
for z, refs in zip(zs, ref_images):
@ -312,7 +312,7 @@ class WanT2V:
input_ref_images = None,
input_video=None,
target_camera=None,
context_scale=1.0,
context_scale=None,
width = 1280,
height = 720,
fit_into_canvas = True,
@ -478,6 +478,7 @@ class WanT2V:
if vace:
ref_images_count = len(input_ref_images[0]) if input_ref_images != None and input_ref_images[0] != None else 0
context_scale = context_scale if context_scale != None else [1.0] * len(z)
kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale})
if overlapped_latents != None :
overlapped_latents_size = overlapped_latents.shape[1] + 1

267
wgp.py
View File

@ -45,7 +45,7 @@ AUTOSAVE_FILENAME = "queue.zip"
PROMPT_VARS_MAX = 10
target_mmgp_version = "3.4.9"
WanGP_version = "6.21"
WanGP_version = "6.3"
settings_version = 2
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
@ -429,8 +429,8 @@ def process_prompt_and_add_tasks(state, model_choice):
return update_queue_data(queue)
def get_preview_images(inputs):
inputs_to_query = ["image_start", "image_end", "video_guide", "image_refs","video_mask", "video_source"]
labels = ["Start Image", "End Image", "Video Guide", "Image Reference","Video Mask", "Video Source"]
inputs_to_query = ["image_start", "image_end", "video_source", "video_guide", "video_mask", "image_refs" ]
labels = ["Start Image", "End Image", "Video Source", "Video Guide", "Video Mask", "Image Reference"]
start_image_data = None
start_image_labels = []
end_image_data = None
@ -1768,7 +1768,7 @@ def fix_settings(model_type, ui_defaults):
return
video_prompt_type = ui_defaults.get("video_prompt_type", "")
if model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar", "phantom", "phantom_1.3B"]:
if model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar", "phantom_14B", "phantom_1.3B"]:
if not "I" in video_prompt_type: # workaround for settings corruption
video_prompt_type += "I"
if model_type in ["hunyuan"]:
@ -2081,8 +2081,8 @@ def download_models(model_filename, model_type):
shared_def = {
"repoId" : "DeepBeepMeep/Wan2.1",
"sourceFolderList" : [ "pose", "scribble", "depth", "mask", "wav2vec", "" ],
"fileList" : [ ["dw-ll_ucoco_384.onnx", "yolox_l.onnx"],["netG_A_latest.pth"], ["depth_anything_v2_vitl.pth"], ["sam_vit_h_4b8939_fp16.safetensors"], ["config.json", "feature_extractor_config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"],
"sourceFolderList" : [ "pose", "scribble", "flow", "depth", "mask", "wav2vec", "" ],
"fileList" : [ ["dw-ll_ucoco_384.onnx", "yolox_l.onnx"],["netG_A_latest.pth"], ["raft-things.pth"], ["depth_anything_v2_vitl.pth","depth_anything_v2_vitb.pth"], ["sam_vit_h_4b8939_fp16.safetensors"], ["config.json", "feature_extractor_config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"],
[ "flownet.pkl" ] ]
}
process_files_def(**shared_def)
@ -2512,6 +2512,7 @@ def apply_changes( state,
enhancer_enabled_choice = 0,
fit_canvas_choice = 0,
preload_in_VRAM_choice = 0,
depth_anything_v2_variant_choice = "vitl",
notification_sound_enabled_choice = 1,
notification_sound_volume_choice = 50
):
@ -2540,6 +2541,7 @@ def apply_changes( state,
"fit_canvas": fit_canvas_choice,
"enhancer_enabled" : enhancer_enabled_choice,
"preload_in_VRAM" : preload_in_VRAM_choice,
"depth_anything_v2_variant": depth_anything_v2_variant_choice,
"notification_sound_enabled" : notification_sound_enabled_choice,
"notification_sound_volume" : notification_sound_volume_choice
}
@ -2577,7 +2579,7 @@ def apply_changes( state,
transformer_types = server_config["transformer_types"]
model_filename = get_model_filename(transformer_type, transformer_quantization, transformer_dtype_policy)
state["model_filename"] = model_filename
if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list", "fit_canvas", "notification_sound_enabled", "notification_sound_volume"] for change in changes ):
if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list", "fit_canvas", "depth_anything_v2_variant", "notification_sound_enabled", "notification_sound_volume"] for change in changes ):
model_choice = gr.Dropdown()
else:
reload_needed = True
@ -2818,7 +2820,7 @@ def get_preprocessor(process_type, inpaint_color):
"POSE_MODEL": "ckpts/pose/dw-ll_ucoco_384.onnx",
"RESIZE_SIZE": 1024
}
anno_ins = lambda img: PoseBodyFaceVideoAnnotator(cfg_dict).forward(img)[0]
anno_ins = lambda img: PoseBodyFaceVideoAnnotator(cfg_dict).forward(img)
elif process_type=="depth":
# from preprocessing.midas.depth import DepthVideoAnnotator
# cfg_dict = {
@ -2827,30 +2829,72 @@ def get_preprocessor(process_type, inpaint_color):
# anno_ins = lambda img: DepthVideoAnnotator(cfg_dict).forward(img)[0]
from preprocessing.depth_anything_v2.depth import DepthV2VideoAnnotator
cfg_dict = {
"PRETRAINED_MODEL": "ckpts/depth/depth_anything_v2_vitl.pth"
# "PRETRAINED_MODEL": "ckpts/depth/depth_anything_vitb14.pth"
if server_config.get("depth_anything_v2_variant", "vitl") == "vitl":
cfg_dict = {
"PRETRAINED_MODEL": "ckpts/depth/depth_anything_v2_vitl.pth",
'MODEL_VARIANT': 'vitl'
}
anno_ins = lambda img: DepthV2VideoAnnotator(cfg_dict).forward(img)[0]
else:
cfg_dict = {
"PRETRAINED_MODEL": "ckpts/depth/depth_anything_v2_vitb.pth",
'MODEL_VARIANT': 'vitb',
}
anno_ins = lambda img: DepthV2VideoAnnotator(cfg_dict).forward(img)
elif process_type=="gray":
from preprocessing.gray import GrayVideoAnnotator
cfg_dict = {}
anno_ins = lambda img: GrayVideoAnnotator(cfg_dict).forward(img)[0]
anno_ins = lambda img: GrayVideoAnnotator(cfg_dict).forward(img)
elif process_type=="scribble":
from preprocessing.scribble import ScribbleVideoAnnotator
cfg_dict = {
"PRETRAINED_MODEL": "ckpts/scribble/netG_A_latest.pth"
}
anno_ins = lambda img: ScribbleVideoAnnotator(cfg_dict).forward(img)[0]
anno_ins = lambda img: ScribbleVideoAnnotator(cfg_dict).forward(img)
elif process_type=="flow":
from preprocessing.flow import FlowVisAnnotator
cfg_dict = {
"PRETRAINED_MODEL": "ckpts/flow/raft-things.pth"
}
anno_ins = lambda img: FlowVisAnnotator(cfg_dict).forward(img)
elif process_type=="inpaint":
anno_ins = lambda img : inpaint_color
# anno_ins = lambda img : np.full_like(img, inpaint_color)
anno_ins = lambda img : len(img) * [inpaint_color]
elif process_type == None or process_type in ["vace", "identity"]:
anno_ins = lambda img : img
else:
anno_ins = lambda img : img[0]
raise Exception(f"process type '{process_type}' non supported")
return anno_ins
def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None):
def process_images_multithread(image_processor, items, process_type, wrap_in_list = True, max_workers: int = os.cpu_count()/ 2) :
if not items:
return []
import concurrent.futures
start_time = time.time()
print(f"Preprocessus:{process_type} started")
if process_type in ["prephase", "upsample"]:
if wrap_in_list :
items = [ [img] for img in items]
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {executor.submit(image_processor, img): idx for idx, img in enumerate(items)}
results = [None] * len(items)
for future in concurrent.futures.as_completed(futures):
idx = futures[future]
results[idx] = future.result()
if wrap_in_list:
results = [ img[0] for img in results]
else:
results= image_processor(items)
end_time = time.time()
print(f"duration:{end_time-start_time:.1f}")
return results
def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", process_type2 = None, to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None, proc_no = 1):
from wan.utils.utils import calculate_new_dimensions, get_outpainting_frame_location, get_outpainting_full_area_dimensions
def mask_to_xyxy_box(mask):
@ -2871,14 +2915,21 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
return None, None
any_mask = input_mask_path != None
pose_special = "pose" in process_type
if process_type == "pose_depth": process_type = "pose"
any_identity_mask = False
if process_type == "identity":
any_identity_mask = True
negate_mask = False
process_outside_mask = None
preproc = get_preprocessor(process_type, inpaint_color)
preproc2 = get_preprocessor(process_outside_mask, inpaint_color) if process_type != process_outside_mask else preproc
preproc2 = None
if process_type2 != None:
preproc2 = get_preprocessor(process_type2, inpaint_color) if process_type != process_type2 else preproc
if process_outside_mask == process_type :
preproc_outside = preproc
elif preproc2 != None and process_outside_mask == process_type2 :
preproc_outside = preproc2
else:
preproc_outside = get_preprocessor(process_outside_mask, inpaint_color)
video = get_resampled_video(input_video_path, start_frame, max_frames, target_fps)
if any_mask:
mask_video = get_resampled_video(input_mask_path, start_frame, max_frames, target_fps)
@ -2905,12 +2956,16 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
num_frames = min(len(video), len(mask_video))
else:
num_frames = len(video)
masked_frames = []
masks = []
if any_identity_mask:
any_mask = True
for frame_idx in range(num_frames):
proc_list =[]
proc_list_outside =[]
proc_mask = []
# for frame_idx in range(num_frames):
def prep_prephase(frame_idx):
frame = Image.fromarray(video[frame_idx].cpu().numpy()) #.asnumpy()
frame = frame.resize((width, height), resample=Image.Resampling.LANCZOS)
frame = np.array(frame)
@ -2946,17 +3001,33 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
else:
target_frame = frame
processed_img = preproc([target_frame])
if any_mask:
if preproc2 != None:
frame = preproc2([frame])
masked_frame = np.where(mask[..., None], processed_img, frame)
return (target_frame, frame, mask)
else:
masked_frame = processed_img
return (target_frame, None, None)
proc_lists = process_images_multithread(prep_prephase, [frame_idx for frame_idx in range(num_frames)], "prephase", wrap_in_list= False)
proc_list, proc_list_outside, proc_mask = [None] * len(proc_lists), [None] * len(proc_lists), [None] * len(proc_lists)
for frame_idx, frame_group in enumerate(proc_lists):
proc_list[frame_idx], proc_list_outside[frame_idx], proc_mask[frame_idx] = frame_group
prep_prephase = None
video = None
mask_video = None
if preproc2 != None:
proc_list2 = process_images_multithread(preproc2, proc_list, process_type2)
#### to be finished ...or not
proc_list = process_images_multithread(preproc, proc_list, process_type)
if any_mask:
proc_list_outside = process_images_multithread(preproc_outside, proc_list_outside, process_outside_mask)
else:
proc_list_outside = proc_mask = len(proc_list) * [None]
masked_frames = []
masks = []
for frame_no, (processed_img, processed_img_outside, mask) in enumerate(zip(proc_list, proc_list_outside, proc_mask)):
if any_mask :
masked_frame = np.where(mask[..., None], processed_img, processed_img_outside)
if process_outside_mask != None:
mask = np.full_like(mask, 255)
mask = torch.from_numpy(mask)
@ -2967,6 +3038,8 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
full_frame[margin_top:margin_top+height, margin_left:margin_left+width] = mask
mask = full_frame
masks.append(mask)
else:
masked_frame = processed_img
if isinstance(masked_frame, int):
masked_frame= np.full( (height, width, 3), inpaint_color, dtype= np.uint8)
@ -2981,15 +3054,18 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
masked_frame = full_frame
masked_frames.append(masked_frame)
proc_list[frame_no] = proc_list_outside[frame_no] = proc_mask[frame_no] = None
if args.save_masks:
from preprocessing.dwpose.pose import save_one_video
saved_masked_frames = [mask.cpu().numpy() for mask in masked_frames ]
save_one_video("masked_frames.mp4", saved_masked_frames, fps=target_fps, quality=8, macro_block_size=None)
save_one_video(f"masked_frames{'' if proc_no==1 else str(proc_no)}.mp4", saved_masked_frames, fps=target_fps, quality=8, macro_block_size=None)
if any_mask:
saved_masks = [mask.cpu().numpy() for mask in masks ]
save_one_video("masks.mp4", saved_masks, fps=target_fps, quality=8, macro_block_size=None)
preproc = None
preproc2 = None
preproc_outside = None
gc.collect()
torch.cuda.empty_cache()
@ -3066,11 +3142,11 @@ def parse_keep_frames_video_guide(keep_frames, video_length):
for i in range(start_range, end_range + 1):
frames[i] = True
else:
if not is_integer(section):
if not is_integer(section) or int(section) == 0:
error =f"Invalid integer {section}"
break
index = absolute(int(section))
frames[index] = True
frames[index-1] = True
if len(error ) > 0:
return [], error
@ -3113,6 +3189,8 @@ def generate_video(
keep_frames_video_guide,
video_guide_outpainting,
video_mask,
control_net_weight,
control_net_weight2,
mask_expand,
audio_guide,
sliding_window_size,
@ -3397,6 +3475,7 @@ def generate_video(
frames_already_processed = None
pre_video_guide = None
overlapped_latents = None
context_scale = None
window_no = 0
extra_windows = 0
guide_start_frame = 0
@ -3496,60 +3575,48 @@ def generate_video(
if len(error) > 0:
raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
keep_frames_parsed = keep_frames_parsed[guide_start_frame: guide_start_frame + current_video_length]
context_scale = [ control_net_weight]
if "V" in video_prompt_type:
extra_label = ""
if "X" in video_prompt_type:
process_outside_mask = "inpaint"
elif "Y" in video_prompt_type:
process_outside_mask = "depth"
extra_label = " and Depth"
elif "W" in video_prompt_type:
process_outside_mask = "scribble"
extra_label = " and Shapes"
process_map = { "Y" : "depth", "W": "scribble", "X": "inpaint", "Z": "flow"}
process_outside_mask = process_map.get(filter_letters(video_prompt_type, "YWX"), None)
preprocess_type2 = preprocess_type = None
process_map = { "D" : "depth", "P": "pose", "S": "scribble", "F": "flow", "C": "gray", "M": "inpaint", "U": "identity"}
for process_num, process_letter in enumerate( filter_letters(video_prompt_type, "PDSFCMU")):
if process_num == 0:
preprocess_type = process_map.get(process_letter, "vace")
else:
process_outside_mask = None
preprocess_type = None
if "P" in video_prompt_type :
progress_args = [0, get_latest_status(state,f"Extracting Open Pose{extra_label} Information")]
preprocess_type = "pose"
elif "D" in video_prompt_type :
progress_args = [0, get_latest_status(state,"Extracting Depth Information")]
preprocess_type = "depth"
elif "S" in video_prompt_type :
progress_args = [0, get_latest_status(state,"Extracting Shapes Information")]
preprocess_type = "scribble"
elif "C" in video_prompt_type :
progress_args = [0, get_latest_status(state,f"Extracting Gray Level{extra_label} Information")]
preprocess_type = "gray"
elif "M" in video_prompt_type :
progress_args = [0, get_latest_status(state,f"Creating Inpainting{extra_label} Mask")]
preprocess_type = "inpaint"
elif "U" in video_prompt_type :
progress_args = [0, get_latest_status(state,f"Creating Identity{extra_label} Mask")]
preprocess_type = "identity"
else:
progress_args = [0, get_latest_status(state,f"Creating Vace Generic{extra_label} Mask")]
preprocess_type = "vace"
send_cmd("progress", progress_args)
video_guide_copy, video_mask_copy = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) if guide_start_frame == 0 else len(keep_frames_parsed) - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims )
preprocess_type2 = process_map.get(process_letter, None)
process_names = { "pose": "Open Pose", "depth": "Depth Mask", "scribble" : "Shapes", "flow" : "Flow Map", "gray" : "Gray Levels", "inpaint" : "Inpaint Mask", "U": "Identity Mask", "vace" : "Vace Data"}
status_info = "Extracting " + process_names[preprocess_type]
extra_process_list = ([] if preprocess_type2==None else [preprocess_type2]) + ([] if process_outside_mask==None or process_outside_mask == preprocess_type else [process_outside_mask])
if len(extra_process_list) == 1:
status_info += " and " + process_names[extra_process_list[0]]
elif len(extra_process_list) == 2:
status_info += ", " + process_names[extra_process_list[0]] + " and " + process_names[extra_process_list[1]]
send_cmd("progress", [0, get_latest_status(state, status_info)])
video_guide_copy, video_mask_copy = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) if guide_start_frame == 0 else len(keep_frames_parsed) - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1 )
video_guide_copy2 = video_mask_copy2 = None
if preprocess_type2 != None:
video_guide_copy2, video_mask_copy2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) if guide_start_frame == 0 else len(keep_frames_parsed) - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2 )
if video_guide_copy != None:
if sample_fit_canvas != None:
image_size = video_guide_copy.shape[-3: -1]
sample_fit_canvas = None
refresh_preview["video_guide"] = Image.fromarray(video_guide_copy[0].cpu().numpy())
if video_guide_copy2 != None:
refresh_preview["video_guide"] = [refresh_preview["video_guide"], Image.fromarray(video_guide_copy2[0].cpu().numpy())]
if video_mask_copy != None:
refresh_preview["video_mask"] = Image.fromarray(video_mask_copy[0].cpu().numpy())
frames_to_inject_parsed = frames_to_inject[guide_start_frame: guide_start_frame + current_video_length]
src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_copy],
[video_mask_copy],
[image_refs_copy],
src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_copy] if video_guide_copy2 == None else [video_guide_copy, video_guide_copy2],
[video_mask_copy] if video_guide_copy2 == None else [video_mask_copy, video_mask_copy2],
[image_refs_copy] if video_guide_copy2 == None else [image_refs_copy, image_refs_copy],
current_video_length, image_size = image_size, device ="cpu",
original_video= "O" in video_prompt_type,
keep_frames=keep_frames_parsed,
start_frame = guide_start_frame,
pre_src_video = [pre_video_guide],
pre_src_video = [pre_video_guide] if video_guide_copy2 == None else [pre_video_guide, pre_video_guide],
fit_into_canvas = sample_fit_canvas,
inject_frames= frames_to_inject_parsed,
outpainting_dims = outpainting_dims,
@ -3632,6 +3699,7 @@ def generate_video(
audio_proj= audio_proj_split,
audio_scale= audio_scale,
audio_context_lens= audio_context_lens,
context_scale = context_scale,
ar_step = model_mode, #5
causal_block_size = 5,
causal_attention = True,
@ -3768,15 +3836,12 @@ def generate_video(
w *= scale
h = int(h)
w = int(w)
new_frames =[]
for i in range( sample.shape[1] ):
frame = sample[:, i]
frame = resize_lanczos(frame, h, w)
frame = frame.unsqueeze(1)
new_frames.append(frame)
sample = torch.cat(new_frames, dim=1)
new_frames = None
sample = sample * 2 - 1
frames_to_upsample = [sample[:, i] for i in range( sample.shape[1]) ]
def upsample_frames(frame):
return resize_lanczos(frame, h, w).unsqueeze(1)
sample = torch.cat(process_images_multithread(upsample_frames, frames_to_upsample, "upsample", wrap_in_list = False), dim=1)
frames_to_upsample = None
sample.mul_(2).sub_(1)
if sliding_window :
if frames_already_processed == None:
@ -4764,6 +4829,8 @@ def save_inputs(
video_guide,
keep_frames_video_guide,
video_mask,
control_net_weight,
control_net_weight2,
mask_expand,
audio_guide,
sliding_window_size,
@ -4955,13 +5022,13 @@ def refresh_video_prompt_type_image_refs(state, video_prompt_type, video_prompt_
return video_prompt_type, gr.update(visible = visible),gr.update(visible = visible), gr.update(visible = visible and "F" in video_prompt_type_image_refs), gr.update(visible= ("F" in video_prompt_type_image_refs or "V" in video_prompt_type) and vace )
def refresh_video_prompt_type_video_mask(video_prompt_type, video_prompt_type_video_mask):
video_prompt_type = del_in_sequence(video_prompt_type, "XWYNA")
video_prompt_type = del_in_sequence(video_prompt_type, "XYZWNA")
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_mask)
visible= "A" in video_prompt_type
return video_prompt_type, gr.update(visible= visible), gr.update(visible= visible )
def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide):
video_prompt_type = del_in_sequence(video_prompt_type, "DSPCMUV")
video_prompt_type = del_in_sequence(video_prompt_type, "PDSFCMUV")
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide)
visible = "V" in video_prompt_type
mask_visible = visible and "A" in video_prompt_type and not "U" in video_prompt_type
@ -5242,12 +5309,19 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
("Transfer Human Motion", "PV"),
("Transfer Depth", "DV"),
("Transfer Shapes", "SV"),
("Transfer Flow", "FV"),
("Recolorize", "CV"),
("Inpainting", "MV"),
("Vace raw format", "V"),
("Perform Inpainting", "MV"),
("Use Vace raw format", "V"),
("Keep Unchanged", "UV"),
("Transfer Human Motion & Depth", "PDV"),
("Transfer Human Motion & Shape", "PSV"),
("Transfer Human Motion & Flow", "PFV"),
("Transfer Depth & Shape", "DSV"),
("Transfer Depth & Flow", "DFV"),
("Transfer Shapes & Flow", "SFV"),
],
value=filter_letters(video_prompt_type_value, "DSPCMUV"),
value=filter_letters(video_prompt_type_value, "PDSFCMUV"),
label="Control Video Process", scale = 2, visible= True
)
elif hunyuan_video_custom_edit:
@ -5256,7 +5330,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
("Inpaint Control Video", "MV"),
("Transfer Human Motion", "PMV"),
],
value=filter_letters(video_prompt_type_value, "DSPCMUV"),
value=filter_letters(video_prompt_type_value, "PDSFCMUV"),
label="Video to Video", scale = 3, visible= True
)
else:
@ -5286,8 +5360,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
("Non Masked Area, rest Depth", "YNA"),
("Masked Area, rest Shapes", "WA"),
("Non Masked Area, rest Shapes", "WNA"),
("Masked Area, rest Flow", "ZA"),
("Non Masked Area, rest Flow", "ZNA"),
],
value= filter_letters(video_prompt_type_value, "XYWNA"),
value= filter_letters(video_prompt_type_value, "XYZWNA"),
visible= "V" in video_prompt_type_value and not "U" in video_prompt_type_value and not hunyuan_video_custom,
label="Area Processed", scale = 2
)
@ -5466,6 +5542,9 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
audio_guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("audio_guidance_scale",5), step=0.5, label="Audio Guidance", visible=fantasy)
embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=(hunyuan_t2v or hunyuan_i2v))
flow_shift = gr.Slider(0.0, 25.0, value=ui_defaults.get("flow_shift",3), step=0.1, label="Shift Scale")
with gr.Row(visible = vace):
control_net_weight = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight",1), step=0.1, label="Control Net Weight #1", visible=vace)
control_net_weight2 = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight2",1), step=0.1, label="Control Net Weight #2", visible=vace)
with gr.Row():
negative_prompt = gr.Textbox(label="Negative Prompt", value=ui_defaults.get("negative_prompt", "") )
with gr.Tab("Loras"):
@ -6144,6 +6223,15 @@ def generate_configuration_tab(state, blocks, header, model_choice, prompt_enhan
interactive= not lock_ui_compile
)
depth_anything_v2_variant_choice = gr.Dropdown(
choices=[
("Large (more precise but 2x slower)", "vitl"),
("Big (less precise, less VRAM needed but faster)", "vitb"),
],
value= server_config.get("depth_anything_v2_variant", "vitl"),
label="Depth Anything v2 Vace Preprocessor Model type",
)
vae_config_choice = gr.Dropdown(
choices=[
("Auto", 0),
@ -6224,6 +6312,7 @@ def generate_configuration_tab(state, blocks, header, model_choice, prompt_enhan
enhancer_enabled_choice,
fit_canvas_choice,
preload_in_VRAM_choice,
depth_anything_v2_variant_choice,
notification_sound_enabled_choice,
notification_sound_volume_choice
],