From 28fe3528e83e2d0e914ca915b577918c624adbe1 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Mon, 23 Jun 2025 19:48:19 +0200 Subject: [PATCH] Multi Vace controlnets and multithreaded preprocessing --- README.md | 10 + preprocessing/depth_anything_v2/depth.py | 28 ++- preprocessing/dwpose/pose.py | 187 +++++++++++++-- preprocessing/flow.py | 58 +++++ preprocessing/raft/__init__.py | 0 preprocessing/raft/corr.py | 91 ++++++++ preprocessing/raft/datasets.py | 235 +++++++++++++++++++ preprocessing/raft/extractor.py | 267 ++++++++++++++++++++++ preprocessing/raft/raft.py | 144 ++++++++++++ preprocessing/raft/update.py | 139 +++++++++++ preprocessing/raft/utils/__init__.py | 0 preprocessing/raft/utils/augmentor.py | 246 ++++++++++++++++++++ preprocessing/raft/utils/flow_viz.py | 132 +++++++++++ preprocessing/raft/utils/frame_utils.py | 137 +++++++++++ preprocessing/raft/utils/utils.py | 82 +++++++ wan/modules/model.py | 41 ++-- wan/text2video.py | 17 +- wgp.py | 279 +++++++++++++++-------- 18 files changed, 1951 insertions(+), 142 deletions(-) create mode 100644 preprocessing/flow.py create mode 100644 preprocessing/raft/__init__.py create mode 100644 preprocessing/raft/corr.py create mode 100644 preprocessing/raft/datasets.py create mode 100644 preprocessing/raft/extractor.py create mode 100644 preprocessing/raft/raft.py create mode 100644 preprocessing/raft/update.py create mode 100644 preprocessing/raft/utils/__init__.py create mode 100644 preprocessing/raft/utils/augmentor.py create mode 100644 preprocessing/raft/utils/flow_viz.py create mode 100644 preprocessing/raft/utils/frame_utils.py create mode 100644 preprocessing/raft/utils/utils.py diff --git a/README.md b/README.md index 58e59a8..46a7dd1 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,16 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep ## 🔥 Latest Updates +### June 23 2025: WanGP v6.3, Vace Unleashed. Thought we couldnt squeeze Vace even more ? +- Multithreaded preprocessing when possible for faster generations +- Multithreaded frames Lanczos Upsampling as a bonus +- A new Vace preprocessor : *Flow* to extract fluid motion +- Multi Vace Controlnets: you can now transfer several properties at the same time. This opens new possibilities to explore, for instance if you transfer *Human Movement* and *Shapes* at the same time for some reasons the lighting of your character will take into account much more the environment of your character. +- Injected Frames Outpainting, in case you missed it in WanGP 6.21 + +Don't know how to use all of the Vace features ? Check the Vace Guide embedded in WanGP as it has also been updated. + + ### June 19 2025: WanGP v6.2, Vace even more Powercharged 👋 Have I told you that I am a big fan of Vace ? Here are more goodies to unleash its power: - If you ever wanted to watch Star Wars in 4:3, just use the new *Outpainting* feature and it will add the missing bits of image at the top and the bottom of the screen. The best thing is *Outpainting* can be combined with all the other Vace modifications, for instance you can change the main character of your favorite movie at the same time diff --git a/preprocessing/depth_anything_v2/depth.py b/preprocessing/depth_anything_v2/depth.py index fd4fcdd..3f66161 100644 --- a/preprocessing/depth_anything_v2/depth.py +++ b/preprocessing/depth_anything_v2/depth.py @@ -20,9 +20,33 @@ def convert_to_numpy(image): class DepthV2Annotator: def __init__(self, cfg, device=None): from .dpt import DepthAnythingV2 + + # Model configurations for different variants + self.model_configs = { + 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, + 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, + 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, + 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]} + } + + # Get model variant from config, default to 'vitl' if not specified + model_variant = cfg.get('MODEL_VARIANT', 'vitl') + if model_variant not in self.model_configs: + raise ValueError(f"Invalid model variant '{model_variant}'. Must be one of: {list(self.model_configs.keys())}") + pretrained_model = cfg['PRETRAINED_MODEL'] self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device - self.model = DepthAnythingV2(encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024]).to(self.device) + + # Get configuration for the selected model variant + config = self.model_configs[model_variant] + + # Initialize model with the appropriate configuration + self.model = DepthAnythingV2( + encoder=config['encoder'], + features=config['features'], + out_channels=config['out_channels'] + ).to(self.device) + self.model.load_state_dict( torch.load( pretrained_model, @@ -54,4 +78,4 @@ class DepthV2VideoAnnotator(DepthV2Annotator): for frame in frames: anno_frame = super().forward(np.array(frame)) ret_frames.append(anno_frame) - return ret_frames + return ret_frames \ No newline at end of file diff --git a/preprocessing/dwpose/pose.py b/preprocessing/dwpose/pose.py index ff7b0ca..f5e67cf 100644 --- a/preprocessing/dwpose/pose.py +++ b/preprocessing/dwpose/pose.py @@ -2,13 +2,15 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os - import cv2 import torch import numpy as np from . import util from .wholebody import Wholebody, HWC3, resize_image from PIL import Image +import onnxruntime as ort +from concurrent.futures import ThreadPoolExecutor +import threading os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" @@ -23,8 +25,6 @@ def convert_to_numpy(image): raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.' return image - - def draw_pose(pose, H, W, use_hand=False, use_body=False, use_face=False): bodies = pose['bodies'] faces = pose['faces'] @@ -43,6 +43,56 @@ def draw_pose(pose, H, W, use_hand=False, use_body=False, use_face=False): return canvas +class OptimizedWholebody: + """Optimized version of Wholebody for faster serial processing""" + def __init__(self, onnx_det, onnx_pose, device='cuda:0'): + providers = ['CPUExecutionProvider'] if device == 'cpu' else ['CUDAExecutionProvider'] + self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers) + self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers) + self.device = device + + # Pre-allocate session options for better performance + self.session_det.set_providers(providers) + self.session_pose.set_providers(providers) + + # Get input names once to avoid repeated lookups + self.det_input_name = self.session_det.get_inputs()[0].name + self.pose_input_name = self.session_pose.get_inputs()[0].name + self.pose_output_names = [out.name for out in self.session_pose.get_outputs()] + + def __call__(self, ori_img): + from .onnxdet import inference_detector + from .onnxpose import inference_pose + + det_result = inference_detector(self.session_det, ori_img) + keypoints, scores = inference_pose(self.session_pose, det_result, ori_img) + + keypoints_info = np.concatenate( + (keypoints, scores[..., None]), axis=-1) + # compute neck joint + neck = np.mean(keypoints_info[:, [5, 6]], axis=1) + # neck score when visualizing pred + neck[:, 2:4] = np.logical_and( + keypoints_info[:, 5, 2:4] > 0.3, + keypoints_info[:, 6, 2:4] > 0.3).astype(int) + new_keypoints_info = np.insert( + keypoints_info, 17, neck, axis=1) + mmpose_idx = [ + 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3 + ] + openpose_idx = [ + 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17 + ] + new_keypoints_info[:, openpose_idx] = \ + new_keypoints_info[:, mmpose_idx] + keypoints_info = new_keypoints_info + + keypoints, scores = keypoints_info[ + ..., :2], keypoints_info[..., 2] + + return keypoints, scores, det_result + + class PoseAnnotator: def __init__(self, cfg, device=None): onnx_det = cfg['DETECTION_MODEL'] @@ -84,9 +134,7 @@ class PoseAnnotator: candidate[un_visible] = -1 foot = candidate[:, 18:24] - faces = candidate[:, 24:92] - hands = candidate[:, 92:113] hands = np.vstack([hands, candidate[:, 113:]]) @@ -127,10 +175,24 @@ class PoseAnnotator: return ret_data, det_result +class OptimizedPoseAnnotator(PoseAnnotator): + """Optimized version using improved Wholebody class""" + def __init__(self, cfg, device=None): + onnx_det = cfg['DETECTION_MODEL'] + onnx_pose = cfg['POSE_MODEL'] + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device + self.pose_estimation = OptimizedWholebody(onnx_det, onnx_pose, device=self.device) + self.resize_size = cfg.get("RESIZE_SIZE", 1024) + self.use_body = cfg.get('USE_BODY', True) + self.use_face = cfg.get('USE_FACE', True) + self.use_hand = cfg.get('USE_HAND', True) + + class PoseBodyFaceAnnotator(PoseAnnotator): def __init__(self, cfg): super().__init__(cfg) self.use_body, self.use_face, self.use_hand = True, True, False + @torch.no_grad() @torch.inference_mode def forward(self, image): @@ -138,14 +200,108 @@ class PoseBodyFaceAnnotator(PoseAnnotator): return ret_data['detected_map_bodyface'] -class PoseBodyFaceVideoAnnotator(PoseBodyFaceAnnotator): +class OptimizedPoseBodyFaceVideoAnnotator: + """Optimized video annotator with multiple optimization strategies""" + def __init__(self, cfg, num_workers=5, chunk_size=8): + self.cfg = cfg + self.num_workers = num_workers + self.chunk_size = chunk_size + self.use_body, self.use_face, self.use_hand = True, True, False + + # Initialize one annotator per worker to avoid ONNX session conflicts + self.annotators = [] + for _ in range(num_workers): + annotator = OptimizedPoseAnnotator(cfg) + annotator.use_body, annotator.use_face, annotator.use_hand = True, True, False + self.annotators.append(annotator) + + self._current_worker = 0 + self._worker_lock = threading.Lock() + + def _get_annotator(self): + """Get next available annotator in round-robin fashion""" + with self._worker_lock: + annotator = self.annotators[self._current_worker] + self._current_worker = (self._current_worker + 1) % len(self.annotators) + return annotator + + def _process_single_frame(self, frame_data): + """Process a single frame with error handling""" + frame, frame_idx = frame_data + try: + annotator = self._get_annotator() + + # Convert frame + frame = convert_to_numpy(frame) + input_image = HWC3(frame[..., ::-1]) + resized_image = resize_image(input_image, annotator.resize_size) + + # Process + ret_data, _ = annotator.process(resized_image, frame.shape[:2]) + + if 'detected_map_bodyface' in ret_data: + return frame_idx, ret_data['detected_map_bodyface'] + else: + # Create empty frame if no detection + h, w = frame.shape[:2] + return frame_idx, np.zeros((h, w, 3), dtype=np.uint8) + + except Exception as e: + print(f"Error processing frame {frame_idx}: {e}") + # Return empty frame on error + h, w = frame.shape[:2] if hasattr(frame, 'shape') else (480, 640) + return frame_idx, np.zeros((h, w, 3), dtype=np.uint8) + def forward(self, frames): - ret_frames = [] - for frame in frames: - anno_frame = super().forward(np.array(frame)) - ret_frames.append(anno_frame) - return ret_frames + """Process video frames with optimizations""" + if len(frames) == 0: + return [] + + # For small number of frames, use serial processing to avoid threading overhead + if len(frames) <= 4: + annotator = self.annotators[0] + ret_frames = [] + for frame in frames: + frame = convert_to_numpy(frame) + input_image = HWC3(frame[..., ::-1]) + resized_image = resize_image(input_image, annotator.resize_size) + ret_data, _ = annotator.process(resized_image, frame.shape[:2]) + + if 'detected_map_bodyface' in ret_data: + ret_frames.append(ret_data['detected_map_bodyface']) + else: + h, w = frame.shape[:2] + ret_frames.append(np.zeros((h, w, 3), dtype=np.uint8)) + return ret_frames + + # For larger videos, use parallel processing + frame_data = [(frame, idx) for idx, frame in enumerate(frames)] + results = [None] * len(frames) + + # Process in chunks to manage memory + for chunk_start in range(0, len(frame_data), self.chunk_size * self.num_workers): + chunk_end = min(chunk_start + self.chunk_size * self.num_workers, len(frame_data)) + chunk_data = frame_data[chunk_start:chunk_end] + + with ThreadPoolExecutor(max_workers=self.num_workers) as executor: + chunk_results = list(executor.map(self._process_single_frame, chunk_data)) + + # Store results in correct order + for frame_idx, result in chunk_results: + results[frame_idx] = result + + return results + +# Alias for backward compatibility +class PoseBodyFaceVideoAnnotator(OptimizedPoseBodyFaceVideoAnnotator): + """Backward compatible class name""" + def __init__(self, cfg, num_workers=2, chunk_size=8): + # Use optimized version with conservative settings + super().__init__(cfg, num_workers=num_workers, chunk_size=chunk_size) + + +# Keep the existing utility functions import imageio def save_one_video(file_path, videos, fps=8, quality=8, macro_block_size=None): @@ -161,11 +317,7 @@ def save_one_video(file_path, videos, fps=8, quality=8, macro_block_size=None): def get_frames(video_path): frames = [] - - - # Opens the Video file with CV2 cap = cv2.VideoCapture(video_path) - fps = cap.get(cv2.CAP_PROP_FPS) print("video fps: " + str(fps)) i = 0 @@ -175,9 +327,6 @@ def get_frames(video_path): break frames.append(frame) i += 1 - cap.release() cv2.destroyAllWindows() - - return frames, fps - + return frames, fps \ No newline at end of file diff --git a/preprocessing/flow.py b/preprocessing/flow.py new file mode 100644 index 0000000..0b5c39d --- /dev/null +++ b/preprocessing/flow.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch +import numpy as np +import argparse +from PIL import Image + +def convert_to_numpy(image): + if isinstance(image, Image.Image): + image = np.array(image) + elif isinstance(image, torch.Tensor): + image = image.detach().cpu().numpy() + elif isinstance(image, np.ndarray): + image = image.copy() + else: + raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.' + return image + +class FlowAnnotator: + def __init__(self, cfg, device=None): + from .raft.raft import RAFT + from .raft.utils.utils import InputPadder + from .raft.utils import flow_viz + + params = { + "small": False, + "mixed_precision": False, + "alternate_corr": False + } + params = argparse.Namespace(**params) + pretrained_model = cfg['PRETRAINED_MODEL'] + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device + self.model = RAFT(params) + self.model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(pretrained_model, map_location="cpu", weights_only=True).items()}) + self.model = self.model.to(self.device).eval() + self.InputPadder = InputPadder + self.flow_viz = flow_viz + + def forward(self, frames): + # frames / RGB + frames = [torch.from_numpy(convert_to_numpy(frame).astype(np.uint8)).permute(2, 0, 1).float()[None].to(self.device) for frame in frames] + flow_up_list, flow_up_vis_list = [], [] + with torch.no_grad(): + for i, (image1, image2) in enumerate(zip(frames[:-1], frames[1:])): + padder = self.InputPadder(image1.shape) + image1, image2 = padder.pad(image1, image2) + flow_low, flow_up = self.model(image1, image2, iters=20, test_mode=True) + flow_up = flow_up[0].permute(1, 2, 0).cpu().numpy() + flow_up_vis = self.flow_viz.flow_to_image(flow_up) + flow_up_list.append(flow_up) + flow_up_vis_list.append(flow_up_vis) + return flow_up_list, flow_up_vis_list # RGB + + +class FlowVisAnnotator(FlowAnnotator): + def forward(self, frames): + flow_up_list, flow_up_vis_list = super().forward(frames) + return flow_up_vis_list[:1] + flow_up_vis_list \ No newline at end of file diff --git a/preprocessing/raft/__init__.py b/preprocessing/raft/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/preprocessing/raft/corr.py b/preprocessing/raft/corr.py new file mode 100644 index 0000000..23a6ffd --- /dev/null +++ b/preprocessing/raft/corr.py @@ -0,0 +1,91 @@ +import torch +import torch.nn.functional as F +from .utils.utils import bilinear_sampler, coords_grid + +try: + import alt_cuda_corr +except: + # alt_cuda_corr is not compiled + pass + + +class CorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + self.corr_pyramid = [] + + # all pairs correlation + corr = CorrBlock.corr(fmap1, fmap2) + + batch, h1, w1, dim, h2, w2 = corr.shape + corr = corr.reshape(batch*h1*w1, dim, h2, w2) + + self.corr_pyramid.append(corr) + for i in range(self.num_levels-1): + corr = F.avg_pool2d(corr, 2, stride=2) + self.corr_pyramid.append(corr) + + def __call__(self, coords): + r = self.radius + coords = coords.permute(0, 2, 3, 1) + batch, h1, w1, _ = coords.shape + + out_pyramid = [] + for i in range(self.num_levels): + corr = self.corr_pyramid[i] + dx = torch.linspace(-r, r, 2*r+1) + dy = torch.linspace(-r, r, 2*r+1) + delta = torch.stack(torch.meshgrid(dy, dx, indexing='ij'), axis=-1).to(coords.device) + + centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i + delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) + coords_lvl = centroid_lvl + delta_lvl + + corr = bilinear_sampler(corr, coords_lvl) + corr = corr.view(batch, h1, w1, -1) + out_pyramid.append(corr) + + out = torch.cat(out_pyramid, dim=-1) + return out.permute(0, 3, 1, 2).contiguous().float() + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht*wd) + fmap2 = fmap2.view(batch, dim, ht*wd) + + corr = torch.matmul(fmap1.transpose(1,2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr / torch.sqrt(torch.tensor(dim).float()) + + +class AlternateCorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + + self.pyramid = [(fmap1, fmap2)] + for i in range(self.num_levels): + fmap1 = F.avg_pool2d(fmap1, 2, stride=2) + fmap2 = F.avg_pool2d(fmap2, 2, stride=2) + self.pyramid.append((fmap1, fmap2)) + + def __call__(self, coords): + coords = coords.permute(0, 2, 3, 1) + B, H, W, _ = coords.shape + dim = self.pyramid[0][0].shape[1] + + corr_list = [] + for i in range(self.num_levels): + r = self.radius + fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() + fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() + + coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() + corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) + corr_list.append(corr.squeeze(1)) + + corr = torch.stack(corr_list, dim=1) + corr = corr.reshape(B, -1, H, W) + return corr / torch.sqrt(torch.tensor(dim).float()) diff --git a/preprocessing/raft/datasets.py b/preprocessing/raft/datasets.py new file mode 100644 index 0000000..e456954 --- /dev/null +++ b/preprocessing/raft/datasets.py @@ -0,0 +1,235 @@ +# Data loading based on https://github.com/NVIDIA/flownet2-pytorch + +import numpy as np +import torch +import torch.utils.data as data +import torch.nn.functional as F + +import os +import math +import random +from glob import glob +import os.path as osp + +from .utils import frame_utils +from .utils.augmentor import FlowAugmentor, SparseFlowAugmentor + + +class FlowDataset(data.Dataset): + def __init__(self, aug_params=None, sparse=False): + self.augmentor = None + self.sparse = sparse + if aug_params is not None: + if sparse: + self.augmentor = SparseFlowAugmentor(**aug_params) + else: + self.augmentor = FlowAugmentor(**aug_params) + + self.is_test = False + self.init_seed = False + self.flow_list = [] + self.image_list = [] + self.extra_info = [] + + def __getitem__(self, index): + + if self.is_test: + img1 = frame_utils.read_gen(self.image_list[index][0]) + img2 = frame_utils.read_gen(self.image_list[index][1]) + img1 = np.array(img1).astype(np.uint8)[..., :3] + img2 = np.array(img2).astype(np.uint8)[..., :3] + img1 = torch.from_numpy(img1).permute(2, 0, 1).float() + img2 = torch.from_numpy(img2).permute(2, 0, 1).float() + return img1, img2, self.extra_info[index] + + if not self.init_seed: + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + torch.manual_seed(worker_info.id) + np.random.seed(worker_info.id) + random.seed(worker_info.id) + self.init_seed = True + + index = index % len(self.image_list) + valid = None + if self.sparse: + flow, valid = frame_utils.readFlowKITTI(self.flow_list[index]) + else: + flow = frame_utils.read_gen(self.flow_list[index]) + + img1 = frame_utils.read_gen(self.image_list[index][0]) + img2 = frame_utils.read_gen(self.image_list[index][1]) + + flow = np.array(flow).astype(np.float32) + img1 = np.array(img1).astype(np.uint8) + img2 = np.array(img2).astype(np.uint8) + + # grayscale images + if len(img1.shape) == 2: + img1 = np.tile(img1[...,None], (1, 1, 3)) + img2 = np.tile(img2[...,None], (1, 1, 3)) + else: + img1 = img1[..., :3] + img2 = img2[..., :3] + + if self.augmentor is not None: + if self.sparse: + img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid) + else: + img1, img2, flow = self.augmentor(img1, img2, flow) + + img1 = torch.from_numpy(img1).permute(2, 0, 1).float() + img2 = torch.from_numpy(img2).permute(2, 0, 1).float() + flow = torch.from_numpy(flow).permute(2, 0, 1).float() + + if valid is not None: + valid = torch.from_numpy(valid) + else: + valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) + + return img1, img2, flow, valid.float() + + + def __rmul__(self, v): + self.flow_list = v * self.flow_list + self.image_list = v * self.image_list + return self + + def __len__(self): + return len(self.image_list) + + +class MpiSintel(FlowDataset): + def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'): + super(MpiSintel, self).__init__(aug_params) + flow_root = osp.join(root, split, 'flow') + image_root = osp.join(root, split, dstype) + + if split == 'test': + self.is_test = True + + for scene in os.listdir(image_root): + image_list = sorted(glob(osp.join(image_root, scene, '*.png'))) + for i in range(len(image_list)-1): + self.image_list += [ [image_list[i], image_list[i+1]] ] + self.extra_info += [ (scene, i) ] # scene and frame_id + + if split != 'test': + self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo'))) + + +class FlyingChairs(FlowDataset): + def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'): + super(FlyingChairs, self).__init__(aug_params) + + images = sorted(glob(osp.join(root, '*.ppm'))) + flows = sorted(glob(osp.join(root, '*.flo'))) + assert (len(images)//2 == len(flows)) + + split_list = np.loadtxt('chairs_split.txt', dtype=np.int32) + for i in range(len(flows)): + xid = split_list[i] + if (split=='training' and xid==1) or (split=='validation' and xid==2): + self.flow_list += [ flows[i] ] + self.image_list += [ [images[2*i], images[2*i+1]] ] + + +class FlyingThings3D(FlowDataset): + def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'): + super(FlyingThings3D, self).__init__(aug_params) + + for cam in ['left']: + for direction in ['into_future', 'into_past']: + image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*'))) + image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) + + flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*'))) + flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs]) + + for idir, fdir in zip(image_dirs, flow_dirs): + images = sorted(glob(osp.join(idir, '*.png')) ) + flows = sorted(glob(osp.join(fdir, '*.pfm')) ) + for i in range(len(flows)-1): + if direction == 'into_future': + self.image_list += [ [images[i], images[i+1]] ] + self.flow_list += [ flows[i] ] + elif direction == 'into_past': + self.image_list += [ [images[i+1], images[i]] ] + self.flow_list += [ flows[i+1] ] + + +class KITTI(FlowDataset): + def __init__(self, aug_params=None, split='training', root='datasets/KITTI'): + super(KITTI, self).__init__(aug_params, sparse=True) + if split == 'testing': + self.is_test = True + + root = osp.join(root, split) + images1 = sorted(glob(osp.join(root, 'image_2/*_10.png'))) + images2 = sorted(glob(osp.join(root, 'image_2/*_11.png'))) + + for img1, img2 in zip(images1, images2): + frame_id = img1.split('/')[-1] + self.extra_info += [ [frame_id] ] + self.image_list += [ [img1, img2] ] + + if split == 'training': + self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png'))) + + +class HD1K(FlowDataset): + def __init__(self, aug_params=None, root='datasets/HD1k'): + super(HD1K, self).__init__(aug_params, sparse=True) + + seq_ix = 0 + while 1: + flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix))) + images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix))) + + if len(flows) == 0: + break + + for i in range(len(flows)-1): + self.flow_list += [flows[i]] + self.image_list += [ [images[i], images[i+1]] ] + + seq_ix += 1 + + +def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): + """ Create the data loader for the corresponding trainign set """ + + if args.stage == 'chairs': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True} + train_dataset = FlyingChairs(aug_params, split='training') + + elif args.stage == 'things': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True} + clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass') + final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass') + train_dataset = clean_dataset + final_dataset + + elif args.stage == 'sintel': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} + things = FlyingThings3D(aug_params, dstype='frames_cleanpass') + sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') + sintel_final = MpiSintel(aug_params, split='training', dstype='final') + + if TRAIN_DS == 'C+T+K+S+H': + kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True}) + hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True}) + train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things + + elif TRAIN_DS == 'C+T+K/S': + train_dataset = 100*sintel_clean + 100*sintel_final + things + + elif args.stage == 'kitti': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} + train_dataset = KITTI(aug_params, split='training') + + train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, + pin_memory=False, shuffle=True, num_workers=4, drop_last=True) + + print('Training with %d image pairs' % len(train_dataset)) + return train_loader + diff --git a/preprocessing/raft/extractor.py b/preprocessing/raft/extractor.py new file mode 100644 index 0000000..9a9c759 --- /dev/null +++ b/preprocessing/raft/extractor.py @@ -0,0 +1,267 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + + +class BottleneckBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) + self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes//4) + self.norm2 = nn.BatchNorm2d(planes//4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes//4) + self.norm2 = nn.InstanceNorm2d(planes//4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + +class BasicEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(BasicEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(96, stride=2) + self.layer3 = self._make_layer(128, stride=2) + + # output convolution + self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + + +class SmallEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(SmallEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(32) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(32) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 32 + self.layer1 = self._make_layer(32, stride=1) + self.layer2 = self._make_layer(64, stride=2) + self.layer3 = self._make_layer(96, stride=2) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x diff --git a/preprocessing/raft/raft.py b/preprocessing/raft/raft.py new file mode 100644 index 0000000..5ffc746 --- /dev/null +++ b/preprocessing/raft/raft.py @@ -0,0 +1,144 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .update import BasicUpdateBlock, SmallUpdateBlock +from .extractor import BasicEncoder, SmallEncoder +from .corr import CorrBlock, AlternateCorrBlock +from .utils.utils import bilinear_sampler, coords_grid, upflow8 + +try: + autocast = torch.amp.autocast +except: + # dummy autocast for PyTorch < 1.6 + class autocast: + def __init__(self, enabled): + pass + def __enter__(self): + pass + def __exit__(self, *args): + pass + + +class RAFT(nn.Module): + def __init__(self, args): + super(RAFT, self).__init__() + self.args = args + + if args.small: + self.hidden_dim = hdim = 96 + self.context_dim = cdim = 64 + args.corr_levels = 4 + args.corr_radius = 3 + + else: + self.hidden_dim = hdim = 128 + self.context_dim = cdim = 128 + args.corr_levels = 4 + args.corr_radius = 4 + + if 'dropout' not in self.args: + self.args.dropout = 0 + + if 'alternate_corr' not in self.args: + self.args.alternate_corr = False + + # feature network, context network, and update block + if args.small: + self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout) + self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout) + self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) + + else: + self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) + self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) + self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) + + def freeze_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + + def initialize_flow(self, img): + """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" + N, C, H, W = img.shape + coords0 = coords_grid(N, H//8, W//8).to(img.device) + coords1 = coords_grid(N, H//8, W//8).to(img.device) + + # optical flow computed as difference: flow = coords1 - coords0 + return coords0, coords1 + + def upsample_flow(self, flow, mask): + """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 8, 8, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(8 * flow, [3,3], padding=1) + up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 2, 8*H, 8*W) + + + def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False): + """ Estimate optical flow between pair of frames """ + + image1 = 2 * (image1 / 255.0) - 1.0 + image2 = 2 * (image2 / 255.0) - 1.0 + + image1 = image1.contiguous() + image2 = image2.contiguous() + + hdim = self.hidden_dim + cdim = self.context_dim + + # run the feature network + with autocast('cuda', enabled=self.args.mixed_precision): + fmap1, fmap2 = self.fnet([image1, image2]) + + fmap1 = fmap1.float() + fmap2 = fmap2.float() + if self.args.alternate_corr: + corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + else: + corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + + # run the context network + with autocast('cuda', enabled=self.args.mixed_precision): + cnet = self.cnet(image1) + net, inp = torch.split(cnet, [hdim, cdim], dim=1) + net = torch.tanh(net) + inp = torch.relu(inp) + + coords0, coords1 = self.initialize_flow(image1) + + if flow_init is not None: + coords1 = coords1 + flow_init + + flow_predictions = [] + for itr in range(iters): + coords1 = coords1.detach() + corr = corr_fn(coords1) # index correlation volume + + flow = coords1 - coords0 + with autocast('cuda', enabled=self.args.mixed_precision): + net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) + + # F(t+1) = F(t) + \Delta(t) + coords1 = coords1 + delta_flow + + # upsample predictions + if up_mask is None: + flow_up = upflow8(coords1 - coords0) + else: + flow_up = self.upsample_flow(coords1 - coords0, up_mask) + + flow_predictions.append(flow_up) + + if test_mode: + return coords1 - coords0, flow_up + + return flow_predictions diff --git a/preprocessing/raft/update.py b/preprocessing/raft/update.py new file mode 100644 index 0000000..f940497 --- /dev/null +++ b/preprocessing/raft/update.py @@ -0,0 +1,139 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FlowHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256): + super(FlowHead, self).__init__() + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192+128): + super(ConvGRU, self).__init__() + self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) + + h = (1-z) * h + z * q + return h + +class SepConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192+128): + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + + self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + return h + +class SmallMotionEncoder(nn.Module): + def __init__(self, args): + super(SmallMotionEncoder, self).__init__() + cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 + self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) + self.convf1 = nn.Conv2d(2, 64, 7, padding=3) + self.convf2 = nn.Conv2d(64, 32, 3, padding=1) + self.conv = nn.Conv2d(128, 80, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + +class BasicMotionEncoder(nn.Module): + def __init__(self, args): + super(BasicMotionEncoder, self).__init__() + cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 + self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 192, 3, padding=1) + self.convf1 = nn.Conv2d(2, 128, 7, padding=3) + self.convf2 = nn.Conv2d(128, 64, 3, padding=1) + self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + +class SmallUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=96): + super(SmallUpdateBlock, self).__init__() + self.encoder = SmallMotionEncoder(args) + self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) + self.flow_head = FlowHead(hidden_dim, hidden_dim=128) + + def forward(self, net, inp, corr, flow): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + return net, None, delta_flow + +class BasicUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128, input_dim=128): + super(BasicUpdateBlock, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64*9, 1, padding=0)) + + def forward(self, net, inp, corr, flow, upsample=True): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = .25 * self.mask(net) + return net, mask, delta_flow + + + diff --git a/preprocessing/raft/utils/__init__.py b/preprocessing/raft/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/preprocessing/raft/utils/augmentor.py b/preprocessing/raft/utils/augmentor.py new file mode 100644 index 0000000..3489a7e --- /dev/null +++ b/preprocessing/raft/utils/augmentor.py @@ -0,0 +1,246 @@ +import numpy as np +import random +import math +from PIL import Image + +import cv2 +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +import torch +from torchvision.transforms import ColorJitter +import torch.nn.functional as F + + +class FlowAugmentor: + def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): + + # spatial augmentation params + self.crop_size = crop_size + self.min_scale = min_scale + self.max_scale = max_scale + self.spatial_aug_prob = 0.8 + self.stretch_prob = 0.8 + self.max_stretch = 0.2 + + # flip augmentation params + self.do_flip = do_flip + self.h_flip_prob = 0.5 + self.v_flip_prob = 0.1 + + # photometric augmentation params + self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14) + self.asymmetric_color_aug_prob = 0.2 + self.eraser_aug_prob = 0.5 + + def color_transform(self, img1, img2): + """ Photometric augmentation """ + + # asymmetric + if np.random.rand() < self.asymmetric_color_aug_prob: + img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) + img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) + + # symmetric + else: + image_stack = np.concatenate([img1, img2], axis=0) + image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) + img1, img2 = np.split(image_stack, 2, axis=0) + + return img1, img2 + + def eraser_transform(self, img1, img2, bounds=[50, 100]): + """ Occlusion augmentation """ + + ht, wd = img1.shape[:2] + if np.random.rand() < self.eraser_aug_prob: + mean_color = np.mean(img2.reshape(-1, 3), axis=0) + for _ in range(np.random.randint(1, 3)): + x0 = np.random.randint(0, wd) + y0 = np.random.randint(0, ht) + dx = np.random.randint(bounds[0], bounds[1]) + dy = np.random.randint(bounds[0], bounds[1]) + img2[y0:y0+dy, x0:x0+dx, :] = mean_color + + return img1, img2 + + def spatial_transform(self, img1, img2, flow): + # randomly sample scale + ht, wd = img1.shape[:2] + min_scale = np.maximum( + (self.crop_size[0] + 8) / float(ht), + (self.crop_size[1] + 8) / float(wd)) + + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = scale + scale_y = scale + if np.random.rand() < self.stretch_prob: + scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + + scale_x = np.clip(scale_x, min_scale, None) + scale_y = np.clip(scale_y, min_scale, None) + + if np.random.rand() < self.spatial_aug_prob: + # rescale the images + img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow = flow * [scale_x, scale_y] + + if self.do_flip: + if np.random.rand() < self.h_flip_prob: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + + if np.random.rand() < self.v_flip_prob: # v-flip + img1 = img1[::-1, :] + img2 = img2[::-1, :] + flow = flow[::-1, :] * [1.0, -1.0] + + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) + x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) + + img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + + return img1, img2, flow + + def __call__(self, img1, img2, flow): + img1, img2 = self.color_transform(img1, img2) + img1, img2 = self.eraser_transform(img1, img2) + img1, img2, flow = self.spatial_transform(img1, img2, flow) + + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + + return img1, img2, flow + +class SparseFlowAugmentor: + def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): + # spatial augmentation params + self.crop_size = crop_size + self.min_scale = min_scale + self.max_scale = max_scale + self.spatial_aug_prob = 0.8 + self.stretch_prob = 0.8 + self.max_stretch = 0.2 + + # flip augmentation params + self.do_flip = do_flip + self.h_flip_prob = 0.5 + self.v_flip_prob = 0.1 + + # photometric augmentation params + self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) + self.asymmetric_color_aug_prob = 0.2 + self.eraser_aug_prob = 0.5 + + def color_transform(self, img1, img2): + image_stack = np.concatenate([img1, img2], axis=0) + image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) + img1, img2 = np.split(image_stack, 2, axis=0) + return img1, img2 + + def eraser_transform(self, img1, img2): + ht, wd = img1.shape[:2] + if np.random.rand() < self.eraser_aug_prob: + mean_color = np.mean(img2.reshape(-1, 3), axis=0) + for _ in range(np.random.randint(1, 3)): + x0 = np.random.randint(0, wd) + y0 = np.random.randint(0, ht) + dx = np.random.randint(50, 100) + dy = np.random.randint(50, 100) + img2[y0:y0+dy, x0:x0+dx, :] = mean_color + + return img1, img2 + + def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): + ht, wd = flow.shape[:2] + coords = np.meshgrid(np.arange(wd), np.arange(ht), indexing='ij') + coords = np.stack(coords, axis=-1) + + coords = coords.reshape(-1, 2).astype(np.float32) + flow = flow.reshape(-1, 2).astype(np.float32) + valid = valid.reshape(-1).astype(np.float32) + + coords0 = coords[valid>=1] + flow0 = flow[valid>=1] + + ht1 = int(round(ht * fy)) + wd1 = int(round(wd * fx)) + + coords1 = coords0 * [fx, fy] + flow1 = flow0 * [fx, fy] + + xx = np.round(coords1[:,0]).astype(np.int32) + yy = np.round(coords1[:,1]).astype(np.int32) + + v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) + xx = xx[v] + yy = yy[v] + flow1 = flow1[v] + + flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) + valid_img = np.zeros([ht1, wd1], dtype=np.int32) + + flow_img[yy, xx] = flow1 + valid_img[yy, xx] = 1 + + return flow_img, valid_img + + def spatial_transform(self, img1, img2, flow, valid): + # randomly sample scale + + ht, wd = img1.shape[:2] + min_scale = np.maximum( + (self.crop_size[0] + 1) / float(ht), + (self.crop_size[1] + 1) / float(wd)) + + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = np.clip(scale, min_scale, None) + scale_y = np.clip(scale, min_scale, None) + + if np.random.rand() < self.spatial_aug_prob: + # rescale the images + img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) + + if self.do_flip: + if np.random.rand() < 0.5: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + valid = valid[:, ::-1] + + margin_y = 20 + margin_x = 50 + + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) + x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) + + y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) + x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) + + img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + return img1, img2, flow, valid + + + def __call__(self, img1, img2, flow, valid): + img1, img2 = self.color_transform(img1, img2) + img1, img2 = self.eraser_transform(img1, img2) + img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) + + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + valid = np.ascontiguousarray(valid) + + return img1, img2, flow, valid diff --git a/preprocessing/raft/utils/flow_viz.py b/preprocessing/raft/utils/flow_viz.py new file mode 100644 index 0000000..dcee65e --- /dev/null +++ b/preprocessing/raft/utils/flow_viz.py @@ -0,0 +1,132 @@ +# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization + + +# MIT License +# +# Copyright (c) 2018 Tom Runia +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to conditions. +# +# Author: Tom Runia +# Date Created: 2018-08-03 + +import numpy as np + +def make_colorwheel(): + """ + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf + + Code follows the original C++ source code of Daniel Scharstein. + Code follows the the Matlab source code of Deqing Sun. + + Returns: + np.ndarray: Color wheel + """ + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = np.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) + col = col+RY + # YG + colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) + colorwheel[col:col+YG, 1] = 255 + col = col+YG + # GC + colorwheel[col:col+GC, 1] = 255 + colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) + col = col+GC + # CB + colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) + colorwheel[col:col+CB, 2] = 255 + col = col+CB + # BM + colorwheel[col:col+BM, 2] = 255 + colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) + col = col+BM + # MR + colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) + colorwheel[col:col+MR, 0] = 255 + return colorwheel + + +def flow_uv_to_colors(u, v, convert_to_bgr=False): + """ + Applies the flow color wheel to (possibly clipped) flow components u and v. + + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + + Args: + u (np.ndarray): Input horizontal flow of shape [H,W] + v (np.ndarray): Input vertical flow of shape [H,W] + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) + colorwheel = make_colorwheel() # shape [55x3] + ncols = colorwheel.shape[0] + rad = np.sqrt(np.square(u) + np.square(v)) + a = np.arctan2(-v, -u)/np.pi + fk = (a+1) / 2*(ncols-1) + k0 = np.floor(fk).astype(np.int32) + k1 = k0 + 1 + k1[k1 == ncols] = 0 + f = fk - k0 + for i in range(colorwheel.shape[1]): + tmp = colorwheel[:,i] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1-f)*col0 + f*col1 + idx = (rad <= 1) + col[idx] = 1 - rad[idx] * (1-col[idx]) + col[~idx] = col[~idx] * 0.75 # out of range + # Note the 2-i => BGR instead of RGB + ch_idx = 2-i if convert_to_bgr else i + flow_image[:,:,ch_idx] = np.floor(255 * col) + return flow_image + + +def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): + """ + Expects a two dimensional flow image of shape. + + Args: + flow_uv (np.ndarray): Flow UV image of shape [H,W,2] + clip_flow (float, optional): Clip maximum of flow values. Defaults to None. + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + assert flow_uv.ndim == 3, 'input flow must have three dimensions' + assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' + if clip_flow is not None: + flow_uv = np.clip(flow_uv, 0, clip_flow) + u = flow_uv[:,:,0] + v = flow_uv[:,:,1] + rad = np.sqrt(np.square(u) + np.square(v)) + rad_max = np.max(rad) + epsilon = 1e-5 + u = u / (rad_max + epsilon) + v = v / (rad_max + epsilon) + return flow_uv_to_colors(u, v, convert_to_bgr) \ No newline at end of file diff --git a/preprocessing/raft/utils/frame_utils.py b/preprocessing/raft/utils/frame_utils.py new file mode 100644 index 0000000..6c49113 --- /dev/null +++ b/preprocessing/raft/utils/frame_utils.py @@ -0,0 +1,137 @@ +import numpy as np +from PIL import Image +from os.path import * +import re + +import cv2 +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +TAG_CHAR = np.array([202021.25], np.float32) + +def readFlow(fn): + """ Read .flo file in Middlebury format""" + # Code adapted from: + # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy + + # WARNING: this will work on little-endian architectures (eg Intel x86) only! + # print 'fn = %s'%(fn) + with open(fn, 'rb') as f: + magic = np.fromfile(f, np.float32, count=1) + if 202021.25 != magic: + print('Magic number incorrect. Invalid .flo file') + return None + else: + w = np.fromfile(f, np.int32, count=1) + h = np.fromfile(f, np.int32, count=1) + # print 'Reading %d x %d flo file\n' % (w, h) + data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) + # Reshape data into 3D array (columns, rows, bands) + # The reshape here is for visualization, the original code is (w,h,2) + return np.resize(data, (int(h), int(w), 2)) + +def readPFM(file): + file = open(file, 'rb') + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header == b'PF': + color = True + elif header == b'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) + if dim_match: + width, height = map(int, dim_match.groups()) + else: + raise Exception('Malformed PFM header.') + + scale = float(file.readline().rstrip()) + if scale < 0: # little-endian + endian = '<' + scale = -scale + else: + endian = '>' # big-endian + + data = np.fromfile(file, endian + 'f') + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + return data + +def writeFlow(filename,uv,v=None): + """ Write optical flow to file. + + If v is None, uv is assumed to contain both u and v channels, + stacked in depth. + Original code by Deqing Sun, adapted from Daniel Scharstein. + """ + nBands = 2 + + if v is None: + assert(uv.ndim == 3) + assert(uv.shape[2] == 2) + u = uv[:,:,0] + v = uv[:,:,1] + else: + u = uv + + assert(u.shape == v.shape) + height,width = u.shape + f = open(filename,'wb') + # write the header + f.write(TAG_CHAR) + np.array(width).astype(np.int32).tofile(f) + np.array(height).astype(np.int32).tofile(f) + # arrange into matrix form + tmp = np.zeros((height, width*nBands)) + tmp[:,np.arange(width)*2] = u + tmp[:,np.arange(width)*2 + 1] = v + tmp.astype(np.float32).tofile(f) + f.close() + + +def readFlowKITTI(filename): + flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) + flow = flow[:,:,::-1].astype(np.float32) + flow, valid = flow[:, :, :2], flow[:, :, 2] + flow = (flow - 2**15) / 64.0 + return flow, valid + +def readDispKITTI(filename): + disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 + valid = disp > 0.0 + flow = np.stack([-disp, np.zeros_like(disp)], -1) + return flow, valid + + +def writeFlowKITTI(filename, uv): + uv = 64.0 * uv + 2**15 + valid = np.ones([uv.shape[0], uv.shape[1], 1]) + uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) + cv2.imwrite(filename, uv[..., ::-1]) + + +def read_gen(file_name, pil=False): + ext = splitext(file_name)[-1] + if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': + return Image.open(file_name) + elif ext == '.bin' or ext == '.raw': + return np.load(file_name) + elif ext == '.flo': + return readFlow(file_name).astype(np.float32) + elif ext == '.pfm': + flow = readPFM(file_name).astype(np.float32) + if len(flow.shape) == 2: + return flow + else: + return flow[:, :, :-1] + return [] \ No newline at end of file diff --git a/preprocessing/raft/utils/utils.py b/preprocessing/raft/utils/utils.py new file mode 100644 index 0000000..e3144ae --- /dev/null +++ b/preprocessing/raft/utils/utils.py @@ -0,0 +1,82 @@ +import torch +import torch.nn.functional as F +import numpy as np +from scipy import interpolate + + +class InputPadder: + """ Pads images such that dimensions are divisible by 8 """ + def __init__(self, dims, mode='sintel'): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 + pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 + if mode == 'sintel': + self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] + else: + self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] + + def pad(self, *inputs): + return [F.pad(x, self._pad, mode='replicate') for x in inputs] + + def unpad(self,x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] + return x[..., c[0]:c[1], c[2]:c[3]] + +def forward_interpolate(flow): + flow = flow.detach().cpu().numpy() + dx, dy = flow[0], flow[1] + + ht, wd = dx.shape + x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht), indexing='ij') + + x1 = x0 + dx + y1 = y0 + dy + + x1 = x1.reshape(-1) + y1 = y1.reshape(-1) + dx = dx.reshape(-1) + dy = dy.reshape(-1) + + valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) + x1 = x1[valid] + y1 = y1[valid] + dx = dx[valid] + dy = dy[valid] + + flow_x = interpolate.griddata( + (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) + + flow_y = interpolate.griddata( + (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) + + flow = np.stack([flow_x, flow_y], axis=0) + return torch.from_numpy(flow).float() + + +def bilinear_sampler(img, coords, mode='bilinear', mask=False): + """ Wrapper for grid_sample, uses pixel coordinates """ + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1,1], dim=-1) + xgrid = 2*xgrid/(W-1) - 1 + ygrid = 2*ygrid/(H-1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + + +def coords_grid(batch, ht, wd): + coords = torch.meshgrid(torch.arange(ht), torch.arange(wd), indexing='ij') + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +def upflow8(flow, mode='bilinear'): + new_size = (8 * flow.shape[2], 8 * flow.shape[3]) + return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) diff --git a/wan/modules/model.py b/wan/modules/model.py index aba6f98..fba8a67 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -417,7 +417,7 @@ class WanAttentionBlock(nn.Module): freqs, context, hints= None, - context_scale=1.0, + context_scale=[1.0], cam_emb= None, block_mask = None, audio_proj= None, @@ -431,7 +431,7 @@ class WanAttentionBlock(nn.Module): grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] """ - hint = None + hints_processed = None attention_dtype = self.self_attn.q.weight.dtype dtype = x.dtype @@ -442,10 +442,13 @@ class WanAttentionBlock(nn.Module): "context" : context, "e" : e, } - if self.block_id == 0: - hint = self.vace(hints, x, **kwargs) - else: - hint = self.vace(hints, None, **kwargs) + hints_processed= [] + for scale, hint in zip(context_scale, hints): + if scale == 0: + hints_processed.append(None) + else: + hints_processed.append(self.vace(hint, x, **kwargs) if self.block_id == 0 else self.vace(hint, None, **kwargs)) + latent_frames = e.shape[0] e = (self.modulation + e).chunk(6, dim=1) # self-attention @@ -506,11 +509,13 @@ class WanAttentionBlock(nn.Module): x.addcmul_(y, e[5]) x, y = reshape_latent(x , 1), reshape_latent(y , 1) - if hint is not None: - if context_scale == 1: - x.add_(hint) - else: - x.add_(hint, alpha= context_scale) + if hints_processed is not None: + for hint, scale in zip(hints_processed, context_scale): + if scale != 0: + if scale == 1: + x.add_(hint) + else: + x.add_(hint, alpha= scale) return x @@ -605,8 +610,8 @@ class MLPProj(torch.nn.Module): class WanModel(ModelMixin, ConfigMixin): def setup_chipmunk(self): - from chipmunk.util import LayerCounter - from chipmunk.modules import SparseDiffMlp, SparseDiffAttn + # from chipmunk.util import LayerCounter + # from chipmunk.modules import SparseDiffMlp, SparseDiffAttn seq_shape = (21, 45, 80) chipmunk_layers =[] for i in range(self.num_layers): @@ -941,7 +946,7 @@ class WanModel(ModelMixin, ConfigMixin): t, context, vace_context = None, - vace_context_scale=1.0, + vace_context_scale=[1.0], clip_fea=None, y=None, freqs = None, @@ -972,7 +977,7 @@ class WanModel(ModelMixin, ConfigMixin): chipmunk = offload.shared_state.get("_chipmunk", False) if chipmunk: - from chipmunk.ops.voxel import voxel_chunk_no_padding, reverse_voxel_chunk_no_padding + # from chipmunk.ops.voxel import voxel_chunk_no_padding, reverse_voxel_chunk_no_padding voxel_shape = (4, 6, 8) x_list = x @@ -1065,10 +1070,10 @@ class WanModel(ModelMixin, ConfigMixin): # Vace embeddings c = [self.vace_patch_embedding(u.to(self.vace_patch_embedding.weight.dtype).unsqueeze(0)) for u in vace_context] c = [u.flatten(2).transpose(1, 2) for u in c] - c = c[0] - + # c = c[0] + c = [ [sub_c] for sub_c in c] kwargs['context_scale'] = vace_context_scale - hints_list = [ [c] for _ in range(len(x_list)) ] + hints_list = [ c ]* len(x_list) del c should_calc = True diff --git a/wan/text2video.py b/wan/text2video.py index 1b83f13..4020393 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -132,7 +132,7 @@ class WanT2V: inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)] reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)] inactive = self.vae.encode(inactive, tile_size = tile_size) - self.toto = inactive[0].clone() + if overlapped_latents != None : # inactive[0][:, 0:1] = self.vae.encode([frames[0][:, 0:1]], tile_size = tile_size)[0] # redundant inactive[0][:, 1:overlapped_latents.shape[1] + 1] = overlapped_latents @@ -220,7 +220,7 @@ class WanT2V: ref_img = white_canvas return ref_img.to(device) - def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, original_video = False, keep_frames= [], start_frame = 0, fit_into_canvas = None, pre_src_video = None, inject_frames = [], outpainting_dims = None): + def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, keep_frames= [], start_frame = 0, fit_into_canvas = None, pre_src_video = None, inject_frames = [], outpainting_dims = None): image_sizes = [] trim_video = len(keep_frames) def conv_tensor(t, device): @@ -254,7 +254,7 @@ class WanT2V: image_sizes.append(image_size) else: src_video[i] = conv_tensor(sub_src_video[:num_frames], device) - src_mask[i] = torch.zeros_like(src_video[i], device=device) if original_video else torch.ones_like(src_video[i], device=device) + src_mask[i] = torch.ones_like(src_video[i], device=device) if prepend_count > 0: src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1) src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1) @@ -278,15 +278,15 @@ class WanT2V: if ref_images is not None: image_size = image_sizes[i] for j, ref_img in enumerate(ref_images): - if ref_img is not None: + if ref_img is not None and not torch.is_tensor(ref_img): src_ref_images[i][j] = self.fit_image_into_canvas(ref_img, image_size, 1, device) return src_video, src_mask, src_ref_images def decode_latent(self, zs, ref_images=None, tile_size= 0 ): if ref_images is None: ref_images = [None] * len(zs) - else: - assert len(zs) == len(ref_images) + # else: + # assert len(zs) == len(ref_images) trimed_zs = [] for z, refs in zip(zs, ref_images): @@ -312,7 +312,7 @@ class WanT2V: input_ref_images = None, input_video=None, target_camera=None, - context_scale=1.0, + context_scale=None, width = 1280, height = 720, fit_into_canvas = True, @@ -477,7 +477,8 @@ class WanT2V: kwargs.update({'cam_emb': cam_emb}) if vace: - ref_images_count = len(input_ref_images[0]) if input_ref_images != None and input_ref_images[0] != None else 0 + ref_images_count = len(input_ref_images[0]) if input_ref_images != None and input_ref_images[0] != None else 0 + context_scale = context_scale if context_scale != None else [1.0] * len(z) kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale}) if overlapped_latents != None : overlapped_latents_size = overlapped_latents.shape[1] + 1 diff --git a/wgp.py b/wgp.py index 3f40c8a..2a96c12 100644 --- a/wgp.py +++ b/wgp.py @@ -45,7 +45,7 @@ AUTOSAVE_FILENAME = "queue.zip" PROMPT_VARS_MAX = 10 target_mmgp_version = "3.4.9" -WanGP_version = "6.21" +WanGP_version = "6.3" settings_version = 2 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -429,8 +429,8 @@ def process_prompt_and_add_tasks(state, model_choice): return update_queue_data(queue) def get_preview_images(inputs): - inputs_to_query = ["image_start", "image_end", "video_guide", "image_refs","video_mask", "video_source"] - labels = ["Start Image", "End Image", "Video Guide", "Image Reference","Video Mask", "Video Source"] + inputs_to_query = ["image_start", "image_end", "video_source", "video_guide", "video_mask", "image_refs" ] + labels = ["Start Image", "End Image", "Video Source", "Video Guide", "Video Mask", "Image Reference"] start_image_data = None start_image_labels = [] end_image_data = None @@ -1768,7 +1768,7 @@ def fix_settings(model_type, ui_defaults): return video_prompt_type = ui_defaults.get("video_prompt_type", "") - if model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar", "phantom", "phantom_1.3B"]: + if model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar", "phantom_14B", "phantom_1.3B"]: if not "I" in video_prompt_type: # workaround for settings corruption video_prompt_type += "I" if model_type in ["hunyuan"]: @@ -2081,8 +2081,8 @@ def download_models(model_filename, model_type): shared_def = { "repoId" : "DeepBeepMeep/Wan2.1", - "sourceFolderList" : [ "pose", "scribble", "depth", "mask", "wav2vec", "" ], - "fileList" : [ ["dw-ll_ucoco_384.onnx", "yolox_l.onnx"],["netG_A_latest.pth"], ["depth_anything_v2_vitl.pth"], ["sam_vit_h_4b8939_fp16.safetensors"], ["config.json", "feature_extractor_config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"], + "sourceFolderList" : [ "pose", "scribble", "flow", "depth", "mask", "wav2vec", "" ], + "fileList" : [ ["dw-ll_ucoco_384.onnx", "yolox_l.onnx"],["netG_A_latest.pth"], ["raft-things.pth"], ["depth_anything_v2_vitl.pth","depth_anything_v2_vitb.pth"], ["sam_vit_h_4b8939_fp16.safetensors"], ["config.json", "feature_extractor_config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"], [ "flownet.pkl" ] ] } process_files_def(**shared_def) @@ -2512,6 +2512,7 @@ def apply_changes( state, enhancer_enabled_choice = 0, fit_canvas_choice = 0, preload_in_VRAM_choice = 0, + depth_anything_v2_variant_choice = "vitl", notification_sound_enabled_choice = 1, notification_sound_volume_choice = 50 ): @@ -2540,6 +2541,7 @@ def apply_changes( state, "fit_canvas": fit_canvas_choice, "enhancer_enabled" : enhancer_enabled_choice, "preload_in_VRAM" : preload_in_VRAM_choice, + "depth_anything_v2_variant": depth_anything_v2_variant_choice, "notification_sound_enabled" : notification_sound_enabled_choice, "notification_sound_volume" : notification_sound_volume_choice } @@ -2577,7 +2579,7 @@ def apply_changes( state, transformer_types = server_config["transformer_types"] model_filename = get_model_filename(transformer_type, transformer_quantization, transformer_dtype_policy) state["model_filename"] = model_filename - if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list", "fit_canvas", "notification_sound_enabled", "notification_sound_volume"] for change in changes ): + if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list", "fit_canvas", "depth_anything_v2_variant", "notification_sound_enabled", "notification_sound_volume"] for change in changes ): model_choice = gr.Dropdown() else: reload_needed = True @@ -2818,7 +2820,7 @@ def get_preprocessor(process_type, inpaint_color): "POSE_MODEL": "ckpts/pose/dw-ll_ucoco_384.onnx", "RESIZE_SIZE": 1024 } - anno_ins = lambda img: PoseBodyFaceVideoAnnotator(cfg_dict).forward(img)[0] + anno_ins = lambda img: PoseBodyFaceVideoAnnotator(cfg_dict).forward(img) elif process_type=="depth": # from preprocessing.midas.depth import DepthVideoAnnotator # cfg_dict = { @@ -2827,30 +2829,72 @@ def get_preprocessor(process_type, inpaint_color): # anno_ins = lambda img: DepthVideoAnnotator(cfg_dict).forward(img)[0] from preprocessing.depth_anything_v2.depth import DepthV2VideoAnnotator - cfg_dict = { - "PRETRAINED_MODEL": "ckpts/depth/depth_anything_v2_vitl.pth" - # "PRETRAINED_MODEL": "ckpts/depth/depth_anything_vitb14.pth" - - } - anno_ins = lambda img: DepthV2VideoAnnotator(cfg_dict).forward(img)[0] + + if server_config.get("depth_anything_v2_variant", "vitl") == "vitl": + cfg_dict = { + "PRETRAINED_MODEL": "ckpts/depth/depth_anything_v2_vitl.pth", + 'MODEL_VARIANT': 'vitl' + } + else: + cfg_dict = { + "PRETRAINED_MODEL": "ckpts/depth/depth_anything_v2_vitb.pth", + 'MODEL_VARIANT': 'vitb', + } + + anno_ins = lambda img: DepthV2VideoAnnotator(cfg_dict).forward(img) elif process_type=="gray": from preprocessing.gray import GrayVideoAnnotator cfg_dict = {} - anno_ins = lambda img: GrayVideoAnnotator(cfg_dict).forward(img)[0] + anno_ins = lambda img: GrayVideoAnnotator(cfg_dict).forward(img) elif process_type=="scribble": from preprocessing.scribble import ScribbleVideoAnnotator cfg_dict = { "PRETRAINED_MODEL": "ckpts/scribble/netG_A_latest.pth" } - anno_ins = lambda img: ScribbleVideoAnnotator(cfg_dict).forward(img)[0] + anno_ins = lambda img: ScribbleVideoAnnotator(cfg_dict).forward(img) + elif process_type=="flow": + from preprocessing.flow import FlowVisAnnotator + cfg_dict = { + "PRETRAINED_MODEL": "ckpts/flow/raft-things.pth" + } + anno_ins = lambda img: FlowVisAnnotator(cfg_dict).forward(img) elif process_type=="inpaint": - anno_ins = lambda img : inpaint_color - # anno_ins = lambda img : np.full_like(img, inpaint_color) + anno_ins = lambda img : len(img) * [inpaint_color] + elif process_type == None or process_type in ["vace", "identity"]: + anno_ins = lambda img : img else: - anno_ins = lambda img : img[0] + raise Exception(f"process type '{process_type}' non supported") return anno_ins -def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None): + +def process_images_multithread(image_processor, items, process_type, wrap_in_list = True, max_workers: int = os.cpu_count()/ 2) : + if not items: + return [] + + import concurrent.futures + start_time = time.time() + print(f"Preprocessus:{process_type} started") + if process_type in ["prephase", "upsample"]: + if wrap_in_list : + items = [ [img] for img in items] + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = {executor.submit(image_processor, img): idx for idx, img in enumerate(items)} + results = [None] * len(items) + for future in concurrent.futures.as_completed(futures): + idx = futures[future] + results[idx] = future.result() + + if wrap_in_list: + results = [ img[0] for img in results] + else: + results= image_processor(items) + + end_time = time.time() + print(f"duration:{end_time-start_time:.1f}") + + return results + +def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", process_type2 = None, to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None, proc_no = 1): from wan.utils.utils import calculate_new_dimensions, get_outpainting_frame_location, get_outpainting_full_area_dimensions def mask_to_xyxy_box(mask): @@ -2871,14 +2915,21 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, return None, None any_mask = input_mask_path != None pose_special = "pose" in process_type - if process_type == "pose_depth": process_type = "pose" any_identity_mask = False if process_type == "identity": any_identity_mask = True negate_mask = False process_outside_mask = None preproc = get_preprocessor(process_type, inpaint_color) - preproc2 = get_preprocessor(process_outside_mask, inpaint_color) if process_type != process_outside_mask else preproc + preproc2 = None + if process_type2 != None: + preproc2 = get_preprocessor(process_type2, inpaint_color) if process_type != process_type2 else preproc + if process_outside_mask == process_type : + preproc_outside = preproc + elif preproc2 != None and process_outside_mask == process_type2 : + preproc_outside = preproc2 + else: + preproc_outside = get_preprocessor(process_outside_mask, inpaint_color) video = get_resampled_video(input_video_path, start_frame, max_frames, target_fps) if any_mask: mask_video = get_resampled_video(input_mask_path, start_frame, max_frames, target_fps) @@ -2905,12 +2956,16 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, num_frames = min(len(video), len(mask_video)) else: num_frames = len(video) - masked_frames = [] - masks = [] + if any_identity_mask: any_mask = True - for frame_idx in range(num_frames): + proc_list =[] + proc_list_outside =[] + proc_mask = [] + + # for frame_idx in range(num_frames): + def prep_prephase(frame_idx): frame = Image.fromarray(video[frame_idx].cpu().numpy()) #.asnumpy() frame = frame.resize((width, height), resample=Image.Resampling.LANCZOS) frame = np.array(frame) @@ -2946,17 +3001,33 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, else: target_frame = frame - processed_img = preproc([target_frame]) - if any_mask: - if preproc2 != None: - frame = preproc2([frame]) - masked_frame = np.where(mask[..., None], processed_img, frame) + return (target_frame, frame, mask) else: - masked_frame = processed_img + return (target_frame, None, None) - + proc_lists = process_images_multithread(prep_prephase, [frame_idx for frame_idx in range(num_frames)], "prephase", wrap_in_list= False) + proc_list, proc_list_outside, proc_mask = [None] * len(proc_lists), [None] * len(proc_lists), [None] * len(proc_lists) + for frame_idx, frame_group in enumerate(proc_lists): + proc_list[frame_idx], proc_list_outside[frame_idx], proc_mask[frame_idx] = frame_group + prep_prephase = None + video = None + mask_video = None + + if preproc2 != None: + proc_list2 = process_images_multithread(preproc2, proc_list, process_type2) + #### to be finished ...or not + proc_list = process_images_multithread(preproc, proc_list, process_type) + if any_mask: + proc_list_outside = process_images_multithread(preproc_outside, proc_list_outside, process_outside_mask) + else: + proc_list_outside = proc_mask = len(proc_list) * [None] + + masked_frames = [] + masks = [] + for frame_no, (processed_img, processed_img_outside, mask) in enumerate(zip(proc_list, proc_list_outside, proc_mask)): if any_mask : + masked_frame = np.where(mask[..., None], processed_img, processed_img_outside) if process_outside_mask != None: mask = np.full_like(mask, 255) mask = torch.from_numpy(mask) @@ -2967,6 +3038,8 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, full_frame[margin_top:margin_top+height, margin_left:margin_left+width] = mask mask = full_frame masks.append(mask) + else: + masked_frame = processed_img if isinstance(masked_frame, int): masked_frame= np.full( (height, width, 3), inpaint_color, dtype= np.uint8) @@ -2981,15 +3054,18 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, masked_frame = full_frame masked_frames.append(masked_frame) + proc_list[frame_no] = proc_list_outside[frame_no] = proc_mask[frame_no] = None + + if args.save_masks: from preprocessing.dwpose.pose import save_one_video saved_masked_frames = [mask.cpu().numpy() for mask in masked_frames ] - save_one_video("masked_frames.mp4", saved_masked_frames, fps=target_fps, quality=8, macro_block_size=None) + save_one_video(f"masked_frames{'' if proc_no==1 else str(proc_no)}.mp4", saved_masked_frames, fps=target_fps, quality=8, macro_block_size=None) if any_mask: saved_masks = [mask.cpu().numpy() for mask in masks ] save_one_video("masks.mp4", saved_masks, fps=target_fps, quality=8, macro_block_size=None) preproc = None - preproc2 = None + preproc_outside = None gc.collect() torch.cuda.empty_cache() @@ -3066,11 +3142,11 @@ def parse_keep_frames_video_guide(keep_frames, video_length): for i in range(start_range, end_range + 1): frames[i] = True else: - if not is_integer(section): + if not is_integer(section) or int(section) == 0: error =f"Invalid integer {section}" break index = absolute(int(section)) - frames[index] = True + frames[index-1] = True if len(error ) > 0: return [], error @@ -3113,6 +3189,8 @@ def generate_video( keep_frames_video_guide, video_guide_outpainting, video_mask, + control_net_weight, + control_net_weight2, mask_expand, audio_guide, sliding_window_size, @@ -3397,6 +3475,7 @@ def generate_video( frames_already_processed = None pre_video_guide = None overlapped_latents = None + context_scale = None window_no = 0 extra_windows = 0 guide_start_frame = 0 @@ -3496,60 +3575,48 @@ def generate_video( if len(error) > 0: raise gr.Error(f"invalid keep frames {keep_frames_video_guide}") keep_frames_parsed = keep_frames_parsed[guide_start_frame: guide_start_frame + current_video_length] - + context_scale = [ control_net_weight] if "V" in video_prompt_type: - extra_label = "" - if "X" in video_prompt_type: - process_outside_mask = "inpaint" - elif "Y" in video_prompt_type: - process_outside_mask = "depth" - extra_label = " and Depth" - elif "W" in video_prompt_type: - process_outside_mask = "scribble" - extra_label = " and Shapes" - else: - process_outside_mask = None - preprocess_type = None - if "P" in video_prompt_type : - progress_args = [0, get_latest_status(state,f"Extracting Open Pose{extra_label} Information")] - preprocess_type = "pose" - elif "D" in video_prompt_type : - progress_args = [0, get_latest_status(state,"Extracting Depth Information")] - preprocess_type = "depth" - elif "S" in video_prompt_type : - progress_args = [0, get_latest_status(state,"Extracting Shapes Information")] - preprocess_type = "scribble" - elif "C" in video_prompt_type : - progress_args = [0, get_latest_status(state,f"Extracting Gray Level{extra_label} Information")] - preprocess_type = "gray" - elif "M" in video_prompt_type : - progress_args = [0, get_latest_status(state,f"Creating Inpainting{extra_label} Mask")] - preprocess_type = "inpaint" - elif "U" in video_prompt_type : - progress_args = [0, get_latest_status(state,f"Creating Identity{extra_label} Mask")] - preprocess_type = "identity" - else: - progress_args = [0, get_latest_status(state,f"Creating Vace Generic{extra_label} Mask")] - preprocess_type = "vace" - send_cmd("progress", progress_args) - video_guide_copy, video_mask_copy = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) if guide_start_frame == 0 else len(keep_frames_parsed) - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims ) + process_map = { "Y" : "depth", "W": "scribble", "X": "inpaint", "Z": "flow"} + process_outside_mask = process_map.get(filter_letters(video_prompt_type, "YWX"), None) + preprocess_type2 = preprocess_type = None + process_map = { "D" : "depth", "P": "pose", "S": "scribble", "F": "flow", "C": "gray", "M": "inpaint", "U": "identity"} + for process_num, process_letter in enumerate( filter_letters(video_prompt_type, "PDSFCMU")): + if process_num == 0: + preprocess_type = process_map.get(process_letter, "vace") + else: + preprocess_type2 = process_map.get(process_letter, None) + process_names = { "pose": "Open Pose", "depth": "Depth Mask", "scribble" : "Shapes", "flow" : "Flow Map", "gray" : "Gray Levels", "inpaint" : "Inpaint Mask", "U": "Identity Mask", "vace" : "Vace Data"} + status_info = "Extracting " + process_names[preprocess_type] + extra_process_list = ([] if preprocess_type2==None else [preprocess_type2]) + ([] if process_outside_mask==None or process_outside_mask == preprocess_type else [process_outside_mask]) + if len(extra_process_list) == 1: + status_info += " and " + process_names[extra_process_list[0]] + elif len(extra_process_list) == 2: + status_info += ", " + process_names[extra_process_list[0]] + " and " + process_names[extra_process_list[1]] + send_cmd("progress", [0, get_latest_status(state, status_info)]) + video_guide_copy, video_mask_copy = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) if guide_start_frame == 0 else len(keep_frames_parsed) - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1 ) + video_guide_copy2 = video_mask_copy2 = None + if preprocess_type2 != None: + video_guide_copy2, video_mask_copy2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) if guide_start_frame == 0 else len(keep_frames_parsed) - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2 ) + if video_guide_copy != None: if sample_fit_canvas != None: image_size = video_guide_copy.shape[-3: -1] sample_fit_canvas = None refresh_preview["video_guide"] = Image.fromarray(video_guide_copy[0].cpu().numpy()) + if video_guide_copy2 != None: + refresh_preview["video_guide"] = [refresh_preview["video_guide"], Image.fromarray(video_guide_copy2[0].cpu().numpy())] if video_mask_copy != None: refresh_preview["video_mask"] = Image.fromarray(video_mask_copy[0].cpu().numpy()) frames_to_inject_parsed = frames_to_inject[guide_start_frame: guide_start_frame + current_video_length] - src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_copy], - [video_mask_copy], - [image_refs_copy], + src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_copy] if video_guide_copy2 == None else [video_guide_copy, video_guide_copy2], + [video_mask_copy] if video_guide_copy2 == None else [video_mask_copy, video_mask_copy2], + [image_refs_copy] if video_guide_copy2 == None else [image_refs_copy, image_refs_copy], current_video_length, image_size = image_size, device ="cpu", - original_video= "O" in video_prompt_type, keep_frames=keep_frames_parsed, start_frame = guide_start_frame, - pre_src_video = [pre_video_guide], + pre_src_video = [pre_video_guide] if video_guide_copy2 == None else [pre_video_guide, pre_video_guide], fit_into_canvas = sample_fit_canvas, inject_frames= frames_to_inject_parsed, outpainting_dims = outpainting_dims, @@ -3567,7 +3634,7 @@ def generate_video( send_cmd("progress", progress_args) src_video, src_mask = preprocess_video_with_mask(video_guide, video_mask, height=height, width = width, max_frames= current_video_length if window_no == 1 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type= "pose" if "P" in video_prompt_type else "inpaint", negate_mask = "N" in video_prompt_type, inpaint_color =0) - refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy()) + refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy()) if src_mask != None: refresh_preview["video_mask"] = Image.fromarray(src_mask[0].cpu().numpy()) if len(refresh_preview) > 0: @@ -3632,6 +3699,7 @@ def generate_video( audio_proj= audio_proj_split, audio_scale= audio_scale, audio_context_lens= audio_context_lens, + context_scale = context_scale, ar_step = model_mode, #5 causal_block_size = 5, causal_attention = True, @@ -3768,15 +3836,12 @@ def generate_video( w *= scale h = int(h) w = int(w) - new_frames =[] - for i in range( sample.shape[1] ): - frame = sample[:, i] - frame = resize_lanczos(frame, h, w) - frame = frame.unsqueeze(1) - new_frames.append(frame) - sample = torch.cat(new_frames, dim=1) - new_frames = None - sample = sample * 2 - 1 + frames_to_upsample = [sample[:, i] for i in range( sample.shape[1]) ] + def upsample_frames(frame): + return resize_lanczos(frame, h, w).unsqueeze(1) + sample = torch.cat(process_images_multithread(upsample_frames, frames_to_upsample, "upsample", wrap_in_list = False), dim=1) + frames_to_upsample = None + sample.mul_(2).sub_(1) if sliding_window : if frames_already_processed == None: @@ -4764,6 +4829,8 @@ def save_inputs( video_guide, keep_frames_video_guide, video_mask, + control_net_weight, + control_net_weight2, mask_expand, audio_guide, sliding_window_size, @@ -4955,13 +5022,13 @@ def refresh_video_prompt_type_image_refs(state, video_prompt_type, video_prompt_ return video_prompt_type, gr.update(visible = visible),gr.update(visible = visible), gr.update(visible = visible and "F" in video_prompt_type_image_refs), gr.update(visible= ("F" in video_prompt_type_image_refs or "V" in video_prompt_type) and vace ) def refresh_video_prompt_type_video_mask(video_prompt_type, video_prompt_type_video_mask): - video_prompt_type = del_in_sequence(video_prompt_type, "XWYNA") + video_prompt_type = del_in_sequence(video_prompt_type, "XYZWNA") video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_mask) visible= "A" in video_prompt_type return video_prompt_type, gr.update(visible= visible), gr.update(visible= visible ) def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide): - video_prompt_type = del_in_sequence(video_prompt_type, "DSPCMUV") + video_prompt_type = del_in_sequence(video_prompt_type, "PDSFCMUV") video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide) visible = "V" in video_prompt_type mask_visible = visible and "A" in video_prompt_type and not "U" in video_prompt_type @@ -5242,12 +5309,19 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ("Transfer Human Motion", "PV"), ("Transfer Depth", "DV"), ("Transfer Shapes", "SV"), + ("Transfer Flow", "FV"), ("Recolorize", "CV"), - ("Inpainting", "MV"), - ("Vace raw format", "V"), + ("Perform Inpainting", "MV"), + ("Use Vace raw format", "V"), ("Keep Unchanged", "UV"), - ], - value=filter_letters(video_prompt_type_value, "DSPCMUV"), + ("Transfer Human Motion & Depth", "PDV"), + ("Transfer Human Motion & Shape", "PSV"), + ("Transfer Human Motion & Flow", "PFV"), + ("Transfer Depth & Shape", "DSV"), + ("Transfer Depth & Flow", "DFV"), + ("Transfer Shapes & Flow", "SFV"), + ], + value=filter_letters(video_prompt_type_value, "PDSFCMUV"), label="Control Video Process", scale = 2, visible= True ) elif hunyuan_video_custom_edit: @@ -5256,7 +5330,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ("Inpaint Control Video", "MV"), ("Transfer Human Motion", "PMV"), ], - value=filter_letters(video_prompt_type_value, "DSPCMUV"), + value=filter_letters(video_prompt_type_value, "PDSFCMUV"), label="Video to Video", scale = 3, visible= True ) else: @@ -5286,8 +5360,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ("Non Masked Area, rest Depth", "YNA"), ("Masked Area, rest Shapes", "WA"), ("Non Masked Area, rest Shapes", "WNA"), + ("Masked Area, rest Flow", "ZA"), + ("Non Masked Area, rest Flow", "ZNA"), ], - value= filter_letters(video_prompt_type_value, "XYWNA"), + value= filter_letters(video_prompt_type_value, "XYZWNA"), visible= "V" in video_prompt_type_value and not "U" in video_prompt_type_value and not hunyuan_video_custom, label="Area Processed", scale = 2 ) @@ -5466,6 +5542,9 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non audio_guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("audio_guidance_scale",5), step=0.5, label="Audio Guidance", visible=fantasy) embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=(hunyuan_t2v or hunyuan_i2v)) flow_shift = gr.Slider(0.0, 25.0, value=ui_defaults.get("flow_shift",3), step=0.1, label="Shift Scale") + with gr.Row(visible = vace): + control_net_weight = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight",1), step=0.1, label="Control Net Weight #1", visible=vace) + control_net_weight2 = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight2",1), step=0.1, label="Control Net Weight #2", visible=vace) with gr.Row(): negative_prompt = gr.Textbox(label="Negative Prompt", value=ui_defaults.get("negative_prompt", "") ) with gr.Tab("Loras"): @@ -6144,6 +6223,15 @@ def generate_configuration_tab(state, blocks, header, model_choice, prompt_enhan interactive= not lock_ui_compile ) + depth_anything_v2_variant_choice = gr.Dropdown( + choices=[ + ("Large (more precise but 2x slower)", "vitl"), + ("Big (less precise, less VRAM needed but faster)", "vitb"), + ], + value= server_config.get("depth_anything_v2_variant", "vitl"), + label="Depth Anything v2 Vace Preprocessor Model type", + ) + vae_config_choice = gr.Dropdown( choices=[ ("Auto", 0), @@ -6224,6 +6312,7 @@ def generate_configuration_tab(state, blocks, header, model_choice, prompt_enhan enhancer_enabled_choice, fit_canvas_choice, preload_in_VRAM_choice, + depth_anything_v2_variant_choice, notification_sound_enabled_choice, notification_sound_volume_choice ],