mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	Multi Vace controlnets and multithreaded preprocessing
This commit is contained in:
		
							parent
							
								
									a790dd8432
								
							
						
					
					
						commit
						28fe3528e8
					
				
							
								
								
									
										10
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								README.md
									
									
									
									
									
								
							@ -20,6 +20,16 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
 | 
				
			|||||||
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
 | 
					**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## 🔥 Latest Updates
 | 
					## 🔥 Latest Updates
 | 
				
			||||||
 | 
					### June 23 2025: WanGP v6.3, Vace Unleashed. Thought we couldnt squeeze Vace even more ?
 | 
				
			||||||
 | 
					- Multithreaded preprocessing when possible for faster generations
 | 
				
			||||||
 | 
					- Multithreaded frames Lanczos Upsampling as a bonus
 | 
				
			||||||
 | 
					- A new Vace preprocessor : *Flow* to extract fluid motion
 | 
				
			||||||
 | 
					- Multi Vace Controlnets: you can now transfer several properties at the same time. This opens new possibilities to explore, for instance if you transfer *Human Movement* and *Shapes* at the same time for some reasons the lighting of your character will take into account much more the environment of your character.
 | 
				
			||||||
 | 
					- Injected Frames Outpainting, in case you missed it in WanGP 6.21
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Don't know how to use all of the Vace features ? Check the Vace Guide embedded in WanGP as it has also been updated.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### June 19 2025: WanGP v6.2, Vace even more Powercharged
 | 
					### June 19 2025: WanGP v6.2, Vace even more Powercharged
 | 
				
			||||||
👋 Have I told you that I am a big fan of Vace ? Here are more goodies to unleash its power: 
 | 
					👋 Have I told you that I am a big fan of Vace ? Here are more goodies to unleash its power: 
 | 
				
			||||||
- If you ever wanted to watch Star Wars in 4:3, just use the new *Outpainting* feature and it will add the missing bits of image at the top and the bottom of the screen. The best thing is *Outpainting* can be combined with all the other Vace modifications, for instance you can change the main character of your favorite movie at the same time  
 | 
					- If you ever wanted to watch Star Wars in 4:3, just use the new *Outpainting* feature and it will add the missing bits of image at the top and the bottom of the screen. The best thing is *Outpainting* can be combined with all the other Vace modifications, for instance you can change the main character of your favorite movie at the same time  
 | 
				
			||||||
 | 
				
			|||||||
@ -20,9 +20,33 @@ def convert_to_numpy(image):
 | 
				
			|||||||
class DepthV2Annotator:
 | 
					class DepthV2Annotator:
 | 
				
			||||||
    def __init__(self, cfg, device=None):
 | 
					    def __init__(self, cfg, device=None):
 | 
				
			||||||
        from .dpt import DepthAnythingV2
 | 
					        from .dpt import DepthAnythingV2
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # Model configurations for different variants
 | 
				
			||||||
 | 
					        self.model_configs = {
 | 
				
			||||||
 | 
					            'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
 | 
				
			||||||
 | 
					            'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
 | 
				
			||||||
 | 
					            'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
 | 
				
			||||||
 | 
					            'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # Get model variant from config, default to 'vitl' if not specified
 | 
				
			||||||
 | 
					        model_variant = cfg.get('MODEL_VARIANT', 'vitl')
 | 
				
			||||||
 | 
					        if model_variant not in self.model_configs:
 | 
				
			||||||
 | 
					            raise ValueError(f"Invalid model variant '{model_variant}'. Must be one of: {list(self.model_configs.keys())}")
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
        pretrained_model = cfg['PRETRAINED_MODEL']
 | 
					        pretrained_model = cfg['PRETRAINED_MODEL']
 | 
				
			||||||
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
 | 
					        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
 | 
				
			||||||
        self.model = DepthAnythingV2(encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024]).to(self.device)
 | 
					        
 | 
				
			||||||
 | 
					        # Get configuration for the selected model variant
 | 
				
			||||||
 | 
					        config = self.model_configs[model_variant]
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # Initialize model with the appropriate configuration
 | 
				
			||||||
 | 
					        self.model = DepthAnythingV2(
 | 
				
			||||||
 | 
					            encoder=config['encoder'], 
 | 
				
			||||||
 | 
					            features=config['features'], 
 | 
				
			||||||
 | 
					            out_channels=config['out_channels']
 | 
				
			||||||
 | 
					        ).to(self.device)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
        self.model.load_state_dict(
 | 
					        self.model.load_state_dict(
 | 
				
			||||||
            torch.load(
 | 
					            torch.load(
 | 
				
			||||||
                pretrained_model,
 | 
					                pretrained_model,
 | 
				
			||||||
@ -54,4 +78,4 @@ class DepthV2VideoAnnotator(DepthV2Annotator):
 | 
				
			|||||||
        for frame in frames:
 | 
					        for frame in frames:
 | 
				
			||||||
            anno_frame = super().forward(np.array(frame))
 | 
					            anno_frame = super().forward(np.array(frame))
 | 
				
			||||||
            ret_frames.append(anno_frame)
 | 
					            ret_frames.append(anno_frame)
 | 
				
			||||||
        return ret_frames
 | 
					        return ret_frames
 | 
				
			||||||
@ -2,13 +2,15 @@
 | 
				
			|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
 | 
					# Copyright (c) Alibaba, Inc. and its affiliates.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
 | 
					 | 
				
			||||||
import cv2
 | 
					import cv2
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
from . import util
 | 
					from . import util
 | 
				
			||||||
from .wholebody import Wholebody, HWC3, resize_image
 | 
					from .wholebody import Wholebody, HWC3, resize_image
 | 
				
			||||||
from PIL import Image
 | 
					from PIL import Image
 | 
				
			||||||
 | 
					import onnxruntime as ort
 | 
				
			||||||
 | 
					from concurrent.futures import ThreadPoolExecutor
 | 
				
			||||||
 | 
					import threading
 | 
				
			||||||
 | 
					
 | 
				
			||||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
 | 
					os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -23,8 +25,6 @@ def convert_to_numpy(image):
 | 
				
			|||||||
        raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.'
 | 
					        raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.'
 | 
				
			||||||
    return image
 | 
					    return image
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def draw_pose(pose, H, W, use_hand=False, use_body=False, use_face=False):
 | 
					def draw_pose(pose, H, W, use_hand=False, use_body=False, use_face=False):
 | 
				
			||||||
    bodies = pose['bodies']
 | 
					    bodies = pose['bodies']
 | 
				
			||||||
    faces = pose['faces']
 | 
					    faces = pose['faces']
 | 
				
			||||||
@ -43,6 +43,56 @@ def draw_pose(pose, H, W, use_hand=False, use_body=False, use_face=False):
 | 
				
			|||||||
    return canvas
 | 
					    return canvas
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class OptimizedWholebody:
 | 
				
			||||||
 | 
					    """Optimized version of Wholebody for faster serial processing"""
 | 
				
			||||||
 | 
					    def __init__(self, onnx_det, onnx_pose, device='cuda:0'):
 | 
				
			||||||
 | 
					        providers = ['CPUExecutionProvider'] if device == 'cpu' else ['CUDAExecutionProvider']
 | 
				
			||||||
 | 
					        self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
 | 
				
			||||||
 | 
					        self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
 | 
				
			||||||
 | 
					        self.device = device
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # Pre-allocate session options for better performance
 | 
				
			||||||
 | 
					        self.session_det.set_providers(providers)
 | 
				
			||||||
 | 
					        self.session_pose.set_providers(providers)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # Get input names once to avoid repeated lookups
 | 
				
			||||||
 | 
					        self.det_input_name = self.session_det.get_inputs()[0].name
 | 
				
			||||||
 | 
					        self.pose_input_name = self.session_pose.get_inputs()[0].name
 | 
				
			||||||
 | 
					        self.pose_output_names = [out.name for out in self.session_pose.get_outputs()]
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    def __call__(self, ori_img):
 | 
				
			||||||
 | 
					        from .onnxdet import inference_detector
 | 
				
			||||||
 | 
					        from .onnxpose import inference_pose
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        det_result = inference_detector(self.session_det, ori_img)
 | 
				
			||||||
 | 
					        keypoints, scores = inference_pose(self.session_pose, det_result, ori_img)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        keypoints_info = np.concatenate(
 | 
				
			||||||
 | 
					            (keypoints, scores[..., None]), axis=-1)
 | 
				
			||||||
 | 
					        # compute neck joint
 | 
				
			||||||
 | 
					        neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
 | 
				
			||||||
 | 
					        # neck score when visualizing pred
 | 
				
			||||||
 | 
					        neck[:, 2:4] = np.logical_and(
 | 
				
			||||||
 | 
					            keypoints_info[:, 5, 2:4] > 0.3,
 | 
				
			||||||
 | 
					            keypoints_info[:, 6, 2:4] > 0.3).astype(int)
 | 
				
			||||||
 | 
					        new_keypoints_info = np.insert(
 | 
				
			||||||
 | 
					            keypoints_info, 17, neck, axis=1)
 | 
				
			||||||
 | 
					        mmpose_idx = [
 | 
				
			||||||
 | 
					            17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					        openpose_idx = [
 | 
				
			||||||
 | 
					            1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					        new_keypoints_info[:, openpose_idx] = \
 | 
				
			||||||
 | 
					            new_keypoints_info[:, mmpose_idx]
 | 
				
			||||||
 | 
					        keypoints_info = new_keypoints_info
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        keypoints, scores = keypoints_info[
 | 
				
			||||||
 | 
					            ..., :2], keypoints_info[..., 2]
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        return keypoints, scores, det_result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class PoseAnnotator:
 | 
					class PoseAnnotator:
 | 
				
			||||||
    def __init__(self, cfg, device=None):
 | 
					    def __init__(self, cfg, device=None):
 | 
				
			||||||
        onnx_det = cfg['DETECTION_MODEL']
 | 
					        onnx_det = cfg['DETECTION_MODEL']
 | 
				
			||||||
@ -84,9 +134,7 @@ class PoseAnnotator:
 | 
				
			|||||||
            candidate[un_visible] = -1
 | 
					            candidate[un_visible] = -1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            foot = candidate[:, 18:24]
 | 
					            foot = candidate[:, 18:24]
 | 
				
			||||||
 | 
					 | 
				
			||||||
            faces = candidate[:, 24:92]
 | 
					            faces = candidate[:, 24:92]
 | 
				
			||||||
 | 
					 | 
				
			||||||
            hands = candidate[:, 92:113]
 | 
					            hands = candidate[:, 92:113]
 | 
				
			||||||
            hands = np.vstack([hands, candidate[:, 113:]])
 | 
					            hands = np.vstack([hands, candidate[:, 113:]])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -127,10 +175,24 @@ class PoseAnnotator:
 | 
				
			|||||||
            return ret_data, det_result
 | 
					            return ret_data, det_result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class OptimizedPoseAnnotator(PoseAnnotator):
 | 
				
			||||||
 | 
					    """Optimized version using improved Wholebody class"""
 | 
				
			||||||
 | 
					    def __init__(self, cfg, device=None):
 | 
				
			||||||
 | 
					        onnx_det = cfg['DETECTION_MODEL']
 | 
				
			||||||
 | 
					        onnx_pose = cfg['POSE_MODEL']
 | 
				
			||||||
 | 
					        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
 | 
				
			||||||
 | 
					        self.pose_estimation = OptimizedWholebody(onnx_det, onnx_pose, device=self.device)
 | 
				
			||||||
 | 
					        self.resize_size = cfg.get("RESIZE_SIZE", 1024)
 | 
				
			||||||
 | 
					        self.use_body = cfg.get('USE_BODY', True)
 | 
				
			||||||
 | 
					        self.use_face = cfg.get('USE_FACE', True)
 | 
				
			||||||
 | 
					        self.use_hand = cfg.get('USE_HAND', True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class PoseBodyFaceAnnotator(PoseAnnotator):
 | 
					class PoseBodyFaceAnnotator(PoseAnnotator):
 | 
				
			||||||
    def __init__(self, cfg):
 | 
					    def __init__(self, cfg):
 | 
				
			||||||
        super().__init__(cfg)
 | 
					        super().__init__(cfg)
 | 
				
			||||||
        self.use_body, self.use_face, self.use_hand = True, True, False
 | 
					        self.use_body, self.use_face, self.use_hand = True, True, False
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
    @torch.no_grad()
 | 
					    @torch.no_grad()
 | 
				
			||||||
    @torch.inference_mode
 | 
					    @torch.inference_mode
 | 
				
			||||||
    def forward(self, image):
 | 
					    def forward(self, image):
 | 
				
			||||||
@ -138,14 +200,108 @@ class PoseBodyFaceAnnotator(PoseAnnotator):
 | 
				
			|||||||
        return ret_data['detected_map_bodyface']
 | 
					        return ret_data['detected_map_bodyface']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class PoseBodyFaceVideoAnnotator(PoseBodyFaceAnnotator):
 | 
					class OptimizedPoseBodyFaceVideoAnnotator:
 | 
				
			||||||
 | 
					    """Optimized video annotator with multiple optimization strategies"""
 | 
				
			||||||
 | 
					    def __init__(self, cfg, num_workers=5, chunk_size=8):
 | 
				
			||||||
 | 
					        self.cfg = cfg
 | 
				
			||||||
 | 
					        self.num_workers = num_workers
 | 
				
			||||||
 | 
					        self.chunk_size = chunk_size
 | 
				
			||||||
 | 
					        self.use_body, self.use_face, self.use_hand = True, True, False
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # Initialize one annotator per worker to avoid ONNX session conflicts
 | 
				
			||||||
 | 
					        self.annotators = []
 | 
				
			||||||
 | 
					        for _ in range(num_workers):
 | 
				
			||||||
 | 
					            annotator = OptimizedPoseAnnotator(cfg)
 | 
				
			||||||
 | 
					            annotator.use_body, annotator.use_face, annotator.use_hand = True, True, False
 | 
				
			||||||
 | 
					            self.annotators.append(annotator)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        self._current_worker = 0
 | 
				
			||||||
 | 
					        self._worker_lock = threading.Lock()
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    def _get_annotator(self):
 | 
				
			||||||
 | 
					        """Get next available annotator in round-robin fashion"""
 | 
				
			||||||
 | 
					        with self._worker_lock:
 | 
				
			||||||
 | 
					            annotator = self.annotators[self._current_worker]
 | 
				
			||||||
 | 
					            self._current_worker = (self._current_worker + 1) % len(self.annotators)
 | 
				
			||||||
 | 
					            return annotator
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    def _process_single_frame(self, frame_data):
 | 
				
			||||||
 | 
					        """Process a single frame with error handling"""
 | 
				
			||||||
 | 
					        frame, frame_idx = frame_data
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            annotator = self._get_annotator()
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            # Convert frame
 | 
				
			||||||
 | 
					            frame = convert_to_numpy(frame)
 | 
				
			||||||
 | 
					            input_image = HWC3(frame[..., ::-1])
 | 
				
			||||||
 | 
					            resized_image = resize_image(input_image, annotator.resize_size)
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            # Process
 | 
				
			||||||
 | 
					            ret_data, _ = annotator.process(resized_image, frame.shape[:2])
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            if 'detected_map_bodyface' in ret_data:
 | 
				
			||||||
 | 
					                return frame_idx, ret_data['detected_map_bodyface']
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                # Create empty frame if no detection
 | 
				
			||||||
 | 
					                h, w = frame.shape[:2]
 | 
				
			||||||
 | 
					                return frame_idx, np.zeros((h, w, 3), dtype=np.uint8)
 | 
				
			||||||
 | 
					                
 | 
				
			||||||
 | 
					        except Exception as e:
 | 
				
			||||||
 | 
					            print(f"Error processing frame {frame_idx}: {e}")
 | 
				
			||||||
 | 
					            # Return empty frame on error
 | 
				
			||||||
 | 
					            h, w = frame.shape[:2] if hasattr(frame, 'shape') else (480, 640)
 | 
				
			||||||
 | 
					            return frame_idx, np.zeros((h, w, 3), dtype=np.uint8)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
    def forward(self, frames):
 | 
					    def forward(self, frames):
 | 
				
			||||||
        ret_frames = []
 | 
					        """Process video frames with optimizations"""
 | 
				
			||||||
        for frame in frames:
 | 
					        if len(frames) == 0:
 | 
				
			||||||
            anno_frame = super().forward(np.array(frame))
 | 
					            return []
 | 
				
			||||||
            ret_frames.append(anno_frame)
 | 
					        
 | 
				
			||||||
        return ret_frames
 | 
					        # For small number of frames, use serial processing to avoid threading overhead
 | 
				
			||||||
 | 
					        if len(frames) <= 4:
 | 
				
			||||||
 | 
					            annotator = self.annotators[0]
 | 
				
			||||||
 | 
					            ret_frames = []
 | 
				
			||||||
 | 
					            for frame in frames:
 | 
				
			||||||
 | 
					                frame = convert_to_numpy(frame)
 | 
				
			||||||
 | 
					                input_image = HWC3(frame[..., ::-1])
 | 
				
			||||||
 | 
					                resized_image = resize_image(input_image, annotator.resize_size)
 | 
				
			||||||
 | 
					                ret_data, _ = annotator.process(resized_image, frame.shape[:2])
 | 
				
			||||||
 | 
					                
 | 
				
			||||||
 | 
					                if 'detected_map_bodyface' in ret_data:
 | 
				
			||||||
 | 
					                    ret_frames.append(ret_data['detected_map_bodyface'])
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    h, w = frame.shape[:2]
 | 
				
			||||||
 | 
					                    ret_frames.append(np.zeros((h, w, 3), dtype=np.uint8))
 | 
				
			||||||
 | 
					            return ret_frames
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # For larger videos, use parallel processing
 | 
				
			||||||
 | 
					        frame_data = [(frame, idx) for idx, frame in enumerate(frames)]
 | 
				
			||||||
 | 
					        results = [None] * len(frames)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # Process in chunks to manage memory
 | 
				
			||||||
 | 
					        for chunk_start in range(0, len(frame_data), self.chunk_size * self.num_workers):
 | 
				
			||||||
 | 
					            chunk_end = min(chunk_start + self.chunk_size * self.num_workers, len(frame_data))
 | 
				
			||||||
 | 
					            chunk_data = frame_data[chunk_start:chunk_end]
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
 | 
				
			||||||
 | 
					                chunk_results = list(executor.map(self._process_single_frame, chunk_data))
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            # Store results in correct order
 | 
				
			||||||
 | 
					            for frame_idx, result in chunk_results:
 | 
				
			||||||
 | 
					                results[frame_idx] = result
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        return results
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Alias for backward compatibility
 | 
				
			||||||
 | 
					class PoseBodyFaceVideoAnnotator(OptimizedPoseBodyFaceVideoAnnotator):
 | 
				
			||||||
 | 
					    """Backward compatible class name"""
 | 
				
			||||||
 | 
					    def __init__(self, cfg, num_workers=2, chunk_size=8):
 | 
				
			||||||
 | 
					        # Use optimized version with conservative settings
 | 
				
			||||||
 | 
					        super().__init__(cfg, num_workers=num_workers, chunk_size=chunk_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Keep the existing utility functions
 | 
				
			||||||
import imageio
 | 
					import imageio
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def save_one_video(file_path, videos, fps=8, quality=8, macro_block_size=None):
 | 
					def save_one_video(file_path, videos, fps=8, quality=8, macro_block_size=None):
 | 
				
			||||||
@ -161,11 +317,7 @@ def save_one_video(file_path, videos, fps=8, quality=8, macro_block_size=None):
 | 
				
			|||||||
    
 | 
					    
 | 
				
			||||||
def get_frames(video_path):
 | 
					def get_frames(video_path):
 | 
				
			||||||
    frames = []
 | 
					    frames = []
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Opens the Video file with CV2
 | 
					 | 
				
			||||||
    cap = cv2.VideoCapture(video_path)
 | 
					    cap = cv2.VideoCapture(video_path)
 | 
				
			||||||
 | 
					 | 
				
			||||||
    fps = cap.get(cv2.CAP_PROP_FPS)
 | 
					    fps = cap.get(cv2.CAP_PROP_FPS)
 | 
				
			||||||
    print("video fps: " + str(fps))
 | 
					    print("video fps: " + str(fps))
 | 
				
			||||||
    i = 0
 | 
					    i = 0
 | 
				
			||||||
@ -175,9 +327,6 @@ def get_frames(video_path):
 | 
				
			|||||||
            break
 | 
					            break
 | 
				
			||||||
        frames.append(frame)
 | 
					        frames.append(frame)
 | 
				
			||||||
        i += 1
 | 
					        i += 1
 | 
				
			||||||
 | 
					 | 
				
			||||||
    cap.release()
 | 
					    cap.release()
 | 
				
			||||||
    cv2.destroyAllWindows()
 | 
					    cv2.destroyAllWindows()
 | 
				
			||||||
 | 
					    return frames, fps
 | 
				
			||||||
    return frames, fps
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
							
								
								
									
										58
									
								
								preprocessing/flow.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								preprocessing/flow.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,58 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					# Copyright (c) Alibaba, Inc. and its affiliates.
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					from PIL import Image
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def convert_to_numpy(image):
 | 
				
			||||||
 | 
					    if isinstance(image, Image.Image):
 | 
				
			||||||
 | 
					        image = np.array(image)
 | 
				
			||||||
 | 
					    elif isinstance(image, torch.Tensor):
 | 
				
			||||||
 | 
					        image = image.detach().cpu().numpy()
 | 
				
			||||||
 | 
					    elif isinstance(image, np.ndarray):
 | 
				
			||||||
 | 
					        image = image.copy()
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.'
 | 
				
			||||||
 | 
					    return image
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FlowAnnotator:
 | 
				
			||||||
 | 
					    def __init__(self, cfg, device=None):
 | 
				
			||||||
 | 
					        from .raft.raft import RAFT
 | 
				
			||||||
 | 
					        from .raft.utils.utils import InputPadder
 | 
				
			||||||
 | 
					        from .raft.utils import flow_viz
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        params = {
 | 
				
			||||||
 | 
					            "small": False,
 | 
				
			||||||
 | 
					            "mixed_precision": False,
 | 
				
			||||||
 | 
					            "alternate_corr": False
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        params = argparse.Namespace(**params)
 | 
				
			||||||
 | 
					        pretrained_model = cfg['PRETRAINED_MODEL']
 | 
				
			||||||
 | 
					        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
 | 
				
			||||||
 | 
					        self.model = RAFT(params)
 | 
				
			||||||
 | 
					        self.model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(pretrained_model, map_location="cpu", weights_only=True).items()})
 | 
				
			||||||
 | 
					        self.model = self.model.to(self.device).eval()
 | 
				
			||||||
 | 
					        self.InputPadder = InputPadder
 | 
				
			||||||
 | 
					        self.flow_viz = flow_viz
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, frames):
 | 
				
			||||||
 | 
					        # frames / RGB
 | 
				
			||||||
 | 
					        frames = [torch.from_numpy(convert_to_numpy(frame).astype(np.uint8)).permute(2, 0, 1).float()[None].to(self.device) for frame in frames]
 | 
				
			||||||
 | 
					        flow_up_list, flow_up_vis_list = [], []
 | 
				
			||||||
 | 
					        with torch.no_grad():
 | 
				
			||||||
 | 
					            for i, (image1, image2) in enumerate(zip(frames[:-1], frames[1:])):
 | 
				
			||||||
 | 
					                padder = self.InputPadder(image1.shape)
 | 
				
			||||||
 | 
					                image1, image2 = padder.pad(image1, image2)
 | 
				
			||||||
 | 
					                flow_low, flow_up = self.model(image1, image2, iters=20, test_mode=True)
 | 
				
			||||||
 | 
					                flow_up = flow_up[0].permute(1, 2, 0).cpu().numpy()
 | 
				
			||||||
 | 
					                flow_up_vis = self.flow_viz.flow_to_image(flow_up)
 | 
				
			||||||
 | 
					                flow_up_list.append(flow_up)
 | 
				
			||||||
 | 
					                flow_up_vis_list.append(flow_up_vis)
 | 
				
			||||||
 | 
					        return flow_up_list, flow_up_vis_list  # RGB
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FlowVisAnnotator(FlowAnnotator):
 | 
				
			||||||
 | 
					    def forward(self, frames):
 | 
				
			||||||
 | 
					        flow_up_list, flow_up_vis_list = super().forward(frames)
 | 
				
			||||||
 | 
					        return flow_up_vis_list[:1] + flow_up_vis_list
 | 
				
			||||||
							
								
								
									
										0
									
								
								preprocessing/raft/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								preprocessing/raft/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										91
									
								
								preprocessing/raft/corr.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										91
									
								
								preprocessing/raft/corr.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,91 @@
 | 
				
			|||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.nn.functional as F
 | 
				
			||||||
 | 
					from .utils.utils import bilinear_sampler, coords_grid
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					try:
 | 
				
			||||||
 | 
					    import alt_cuda_corr
 | 
				
			||||||
 | 
					except:
 | 
				
			||||||
 | 
					    # alt_cuda_corr is not compiled
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class CorrBlock:
 | 
				
			||||||
 | 
					    def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
 | 
				
			||||||
 | 
					        self.num_levels = num_levels
 | 
				
			||||||
 | 
					        self.radius = radius
 | 
				
			||||||
 | 
					        self.corr_pyramid = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # all pairs correlation
 | 
				
			||||||
 | 
					        corr = CorrBlock.corr(fmap1, fmap2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        batch, h1, w1, dim, h2, w2 = corr.shape
 | 
				
			||||||
 | 
					        corr = corr.reshape(batch*h1*w1, dim, h2, w2)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        self.corr_pyramid.append(corr)
 | 
				
			||||||
 | 
					        for i in range(self.num_levels-1):
 | 
				
			||||||
 | 
					            corr = F.avg_pool2d(corr, 2, stride=2)
 | 
				
			||||||
 | 
					            self.corr_pyramid.append(corr)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(self, coords):
 | 
				
			||||||
 | 
					        r = self.radius
 | 
				
			||||||
 | 
					        coords = coords.permute(0, 2, 3, 1)
 | 
				
			||||||
 | 
					        batch, h1, w1, _ = coords.shape
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        out_pyramid = []
 | 
				
			||||||
 | 
					        for i in range(self.num_levels):
 | 
				
			||||||
 | 
					            corr = self.corr_pyramid[i]
 | 
				
			||||||
 | 
					            dx = torch.linspace(-r, r, 2*r+1)
 | 
				
			||||||
 | 
					            dy = torch.linspace(-r, r, 2*r+1)
 | 
				
			||||||
 | 
					            delta = torch.stack(torch.meshgrid(dy, dx, indexing='ij'), axis=-1).to(coords.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
 | 
				
			||||||
 | 
					            delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
 | 
				
			||||||
 | 
					            coords_lvl = centroid_lvl + delta_lvl
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            corr = bilinear_sampler(corr, coords_lvl)
 | 
				
			||||||
 | 
					            corr = corr.view(batch, h1, w1, -1)
 | 
				
			||||||
 | 
					            out_pyramid.append(corr)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        out = torch.cat(out_pyramid, dim=-1)
 | 
				
			||||||
 | 
					        return out.permute(0, 3, 1, 2).contiguous().float()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def corr(fmap1, fmap2):
 | 
				
			||||||
 | 
					        batch, dim, ht, wd = fmap1.shape
 | 
				
			||||||
 | 
					        fmap1 = fmap1.view(batch, dim, ht*wd)
 | 
				
			||||||
 | 
					        fmap2 = fmap2.view(batch, dim, ht*wd) 
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        corr = torch.matmul(fmap1.transpose(1,2), fmap2)
 | 
				
			||||||
 | 
					        corr = corr.view(batch, ht, wd, 1, ht, wd)
 | 
				
			||||||
 | 
					        return corr  / torch.sqrt(torch.tensor(dim).float())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AlternateCorrBlock:
 | 
				
			||||||
 | 
					    def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
 | 
				
			||||||
 | 
					        self.num_levels = num_levels
 | 
				
			||||||
 | 
					        self.radius = radius
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.pyramid = [(fmap1, fmap2)]
 | 
				
			||||||
 | 
					        for i in range(self.num_levels):
 | 
				
			||||||
 | 
					            fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
 | 
				
			||||||
 | 
					            fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
 | 
				
			||||||
 | 
					            self.pyramid.append((fmap1, fmap2))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(self, coords):
 | 
				
			||||||
 | 
					        coords = coords.permute(0, 2, 3, 1)
 | 
				
			||||||
 | 
					        B, H, W, _ = coords.shape
 | 
				
			||||||
 | 
					        dim = self.pyramid[0][0].shape[1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        corr_list = []
 | 
				
			||||||
 | 
					        for i in range(self.num_levels):
 | 
				
			||||||
 | 
					            r = self.radius
 | 
				
			||||||
 | 
					            fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
 | 
				
			||||||
 | 
					            fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
 | 
				
			||||||
 | 
					            corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
 | 
				
			||||||
 | 
					            corr_list.append(corr.squeeze(1))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        corr = torch.stack(corr_list, dim=1)
 | 
				
			||||||
 | 
					        corr = corr.reshape(B, -1, H, W)
 | 
				
			||||||
 | 
					        return corr / torch.sqrt(torch.tensor(dim).float())
 | 
				
			||||||
							
								
								
									
										235
									
								
								preprocessing/raft/datasets.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										235
									
								
								preprocessing/raft/datasets.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,235 @@
 | 
				
			|||||||
 | 
					# Data loading based on https://github.com/NVIDIA/flownet2-pytorch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.utils.data as data
 | 
				
			||||||
 | 
					import torch.nn.functional as F
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import math
 | 
				
			||||||
 | 
					import random
 | 
				
			||||||
 | 
					from glob import glob
 | 
				
			||||||
 | 
					import os.path as osp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .utils import frame_utils
 | 
				
			||||||
 | 
					from .utils.augmentor import FlowAugmentor, SparseFlowAugmentor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FlowDataset(data.Dataset):
 | 
				
			||||||
 | 
					    def __init__(self, aug_params=None, sparse=False):
 | 
				
			||||||
 | 
					        self.augmentor = None
 | 
				
			||||||
 | 
					        self.sparse = sparse
 | 
				
			||||||
 | 
					        if aug_params is not None:
 | 
				
			||||||
 | 
					            if sparse:
 | 
				
			||||||
 | 
					                self.augmentor = SparseFlowAugmentor(**aug_params)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                self.augmentor = FlowAugmentor(**aug_params)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.is_test = False
 | 
				
			||||||
 | 
					        self.init_seed = False
 | 
				
			||||||
 | 
					        self.flow_list = []
 | 
				
			||||||
 | 
					        self.image_list = []
 | 
				
			||||||
 | 
					        self.extra_info = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __getitem__(self, index):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.is_test:
 | 
				
			||||||
 | 
					            img1 = frame_utils.read_gen(self.image_list[index][0])
 | 
				
			||||||
 | 
					            img2 = frame_utils.read_gen(self.image_list[index][1])
 | 
				
			||||||
 | 
					            img1 = np.array(img1).astype(np.uint8)[..., :3]
 | 
				
			||||||
 | 
					            img2 = np.array(img2).astype(np.uint8)[..., :3]
 | 
				
			||||||
 | 
					            img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
 | 
				
			||||||
 | 
					            img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
 | 
				
			||||||
 | 
					            return img1, img2, self.extra_info[index]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if not self.init_seed:
 | 
				
			||||||
 | 
					            worker_info = torch.utils.data.get_worker_info()
 | 
				
			||||||
 | 
					            if worker_info is not None:
 | 
				
			||||||
 | 
					                torch.manual_seed(worker_info.id)
 | 
				
			||||||
 | 
					                np.random.seed(worker_info.id)
 | 
				
			||||||
 | 
					                random.seed(worker_info.id)
 | 
				
			||||||
 | 
					                self.init_seed = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        index = index % len(self.image_list)
 | 
				
			||||||
 | 
					        valid = None
 | 
				
			||||||
 | 
					        if self.sparse:
 | 
				
			||||||
 | 
					            flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            flow = frame_utils.read_gen(self.flow_list[index])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        img1 = frame_utils.read_gen(self.image_list[index][0])
 | 
				
			||||||
 | 
					        img2 = frame_utils.read_gen(self.image_list[index][1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        flow = np.array(flow).astype(np.float32)
 | 
				
			||||||
 | 
					        img1 = np.array(img1).astype(np.uint8)
 | 
				
			||||||
 | 
					        img2 = np.array(img2).astype(np.uint8)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # grayscale images
 | 
				
			||||||
 | 
					        if len(img1.shape) == 2:
 | 
				
			||||||
 | 
					            img1 = np.tile(img1[...,None], (1, 1, 3))
 | 
				
			||||||
 | 
					            img2 = np.tile(img2[...,None], (1, 1, 3))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            img1 = img1[..., :3]
 | 
				
			||||||
 | 
					            img2 = img2[..., :3]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.augmentor is not None:
 | 
				
			||||||
 | 
					            if self.sparse:
 | 
				
			||||||
 | 
					                img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                img1, img2, flow = self.augmentor(img1, img2, flow)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
 | 
				
			||||||
 | 
					        img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
 | 
				
			||||||
 | 
					        flow = torch.from_numpy(flow).permute(2, 0, 1).float()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if valid is not None:
 | 
				
			||||||
 | 
					            valid = torch.from_numpy(valid)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return img1, img2, flow, valid.float()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __rmul__(self, v):
 | 
				
			||||||
 | 
					        self.flow_list = v * self.flow_list
 | 
				
			||||||
 | 
					        self.image_list = v * self.image_list
 | 
				
			||||||
 | 
					        return self
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					    def __len__(self):
 | 
				
			||||||
 | 
					        return len(self.image_list)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MpiSintel(FlowDataset):
 | 
				
			||||||
 | 
					    def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):
 | 
				
			||||||
 | 
					        super(MpiSintel, self).__init__(aug_params)
 | 
				
			||||||
 | 
					        flow_root = osp.join(root, split, 'flow')
 | 
				
			||||||
 | 
					        image_root = osp.join(root, split, dstype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if split == 'test':
 | 
				
			||||||
 | 
					            self.is_test = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for scene in os.listdir(image_root):
 | 
				
			||||||
 | 
					            image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
 | 
				
			||||||
 | 
					            for i in range(len(image_list)-1):
 | 
				
			||||||
 | 
					                self.image_list += [ [image_list[i], image_list[i+1]] ]
 | 
				
			||||||
 | 
					                self.extra_info += [ (scene, i) ] # scene and frame_id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if split != 'test':
 | 
				
			||||||
 | 
					                self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FlyingChairs(FlowDataset):
 | 
				
			||||||
 | 
					    def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
 | 
				
			||||||
 | 
					        super(FlyingChairs, self).__init__(aug_params)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        images = sorted(glob(osp.join(root, '*.ppm')))
 | 
				
			||||||
 | 
					        flows = sorted(glob(osp.join(root, '*.flo')))
 | 
				
			||||||
 | 
					        assert (len(images)//2 == len(flows))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
 | 
				
			||||||
 | 
					        for i in range(len(flows)):
 | 
				
			||||||
 | 
					            xid = split_list[i]
 | 
				
			||||||
 | 
					            if (split=='training' and xid==1) or (split=='validation' and xid==2):
 | 
				
			||||||
 | 
					                self.flow_list += [ flows[i] ]
 | 
				
			||||||
 | 
					                self.image_list += [ [images[2*i], images[2*i+1]] ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FlyingThings3D(FlowDataset):
 | 
				
			||||||
 | 
					    def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'):
 | 
				
			||||||
 | 
					        super(FlyingThings3D, self).__init__(aug_params)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for cam in ['left']:
 | 
				
			||||||
 | 
					            for direction in ['into_future', 'into_past']:
 | 
				
			||||||
 | 
					                image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))
 | 
				
			||||||
 | 
					                image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
 | 
				
			||||||
 | 
					                flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                for idir, fdir in zip(image_dirs, flow_dirs):
 | 
				
			||||||
 | 
					                    images = sorted(glob(osp.join(idir, '*.png')) )
 | 
				
			||||||
 | 
					                    flows = sorted(glob(osp.join(fdir, '*.pfm')) )
 | 
				
			||||||
 | 
					                    for i in range(len(flows)-1):
 | 
				
			||||||
 | 
					                        if direction == 'into_future':
 | 
				
			||||||
 | 
					                            self.image_list += [ [images[i], images[i+1]] ]
 | 
				
			||||||
 | 
					                            self.flow_list += [ flows[i] ]
 | 
				
			||||||
 | 
					                        elif direction == 'into_past':
 | 
				
			||||||
 | 
					                            self.image_list += [ [images[i+1], images[i]] ]
 | 
				
			||||||
 | 
					                            self.flow_list += [ flows[i+1] ]
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class KITTI(FlowDataset):
 | 
				
			||||||
 | 
					    def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
 | 
				
			||||||
 | 
					        super(KITTI, self).__init__(aug_params, sparse=True)
 | 
				
			||||||
 | 
					        if split == 'testing':
 | 
				
			||||||
 | 
					            self.is_test = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        root = osp.join(root, split)
 | 
				
			||||||
 | 
					        images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
 | 
				
			||||||
 | 
					        images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for img1, img2 in zip(images1, images2):
 | 
				
			||||||
 | 
					            frame_id = img1.split('/')[-1]
 | 
				
			||||||
 | 
					            self.extra_info += [ [frame_id] ]
 | 
				
			||||||
 | 
					            self.image_list += [ [img1, img2] ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if split == 'training':
 | 
				
			||||||
 | 
					            self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class HD1K(FlowDataset):
 | 
				
			||||||
 | 
					    def __init__(self, aug_params=None, root='datasets/HD1k'):
 | 
				
			||||||
 | 
					        super(HD1K, self).__init__(aug_params, sparse=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        seq_ix = 0
 | 
				
			||||||
 | 
					        while 1:
 | 
				
			||||||
 | 
					            flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
 | 
				
			||||||
 | 
					            images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if len(flows) == 0:
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            for i in range(len(flows)-1):
 | 
				
			||||||
 | 
					                self.flow_list += [flows[i]]
 | 
				
			||||||
 | 
					                self.image_list += [ [images[i], images[i+1]] ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            seq_ix += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
 | 
				
			||||||
 | 
					    """ Create the data loader for the corresponding trainign set """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if args.stage == 'chairs':
 | 
				
			||||||
 | 
					        aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
 | 
				
			||||||
 | 
					        train_dataset = FlyingChairs(aug_params, split='training')
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    elif args.stage == 'things':
 | 
				
			||||||
 | 
					        aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
 | 
				
			||||||
 | 
					        clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
 | 
				
			||||||
 | 
					        final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
 | 
				
			||||||
 | 
					        train_dataset = clean_dataset + final_dataset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    elif args.stage == 'sintel':
 | 
				
			||||||
 | 
					        aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
 | 
				
			||||||
 | 
					        things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
 | 
				
			||||||
 | 
					        sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
 | 
				
			||||||
 | 
					        sintel_final = MpiSintel(aug_params, split='training', dstype='final')        
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if TRAIN_DS == 'C+T+K+S+H':
 | 
				
			||||||
 | 
					            kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
 | 
				
			||||||
 | 
					            hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
 | 
				
			||||||
 | 
					            train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        elif TRAIN_DS == 'C+T+K/S':
 | 
				
			||||||
 | 
					            train_dataset = 100*sintel_clean + 100*sintel_final + things
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    elif args.stage == 'kitti':
 | 
				
			||||||
 | 
					        aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
 | 
				
			||||||
 | 
					        train_dataset = KITTI(aug_params, split='training')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, 
 | 
				
			||||||
 | 
					        pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    print('Training with %d image pairs' % len(train_dataset))
 | 
				
			||||||
 | 
					    return train_loader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
							
								
								
									
										267
									
								
								preprocessing/raft/extractor.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										267
									
								
								preprocessing/raft/extractor.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,267 @@
 | 
				
			|||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.nn as nn
 | 
				
			||||||
 | 
					import torch.nn.functional as F
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ResidualBlock(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, in_planes, planes, norm_fn='group', stride=1):
 | 
				
			||||||
 | 
					        super(ResidualBlock, self).__init__()
 | 
				
			||||||
 | 
					  
 | 
				
			||||||
 | 
					        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
 | 
				
			||||||
 | 
					        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
 | 
				
			||||||
 | 
					        self.relu = nn.ReLU(inplace=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        num_groups = planes // 8
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if norm_fn == 'group':
 | 
				
			||||||
 | 
					            self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
 | 
				
			||||||
 | 
					            self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
 | 
				
			||||||
 | 
					            if not stride == 1:
 | 
				
			||||||
 | 
					                self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        elif norm_fn == 'batch':
 | 
				
			||||||
 | 
					            self.norm1 = nn.BatchNorm2d(planes)
 | 
				
			||||||
 | 
					            self.norm2 = nn.BatchNorm2d(planes)
 | 
				
			||||||
 | 
					            if not stride == 1:
 | 
				
			||||||
 | 
					                self.norm3 = nn.BatchNorm2d(planes)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        elif norm_fn == 'instance':
 | 
				
			||||||
 | 
					            self.norm1 = nn.InstanceNorm2d(planes)
 | 
				
			||||||
 | 
					            self.norm2 = nn.InstanceNorm2d(planes)
 | 
				
			||||||
 | 
					            if not stride == 1:
 | 
				
			||||||
 | 
					                self.norm3 = nn.InstanceNorm2d(planes)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        elif norm_fn == 'none':
 | 
				
			||||||
 | 
					            self.norm1 = nn.Sequential()
 | 
				
			||||||
 | 
					            self.norm2 = nn.Sequential()
 | 
				
			||||||
 | 
					            if not stride == 1:
 | 
				
			||||||
 | 
					                self.norm3 = nn.Sequential()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if stride == 1:
 | 
				
			||||||
 | 
					            self.downsample = None
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        else:    
 | 
				
			||||||
 | 
					            self.downsample = nn.Sequential(
 | 
				
			||||||
 | 
					                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					        y = x
 | 
				
			||||||
 | 
					        y = self.relu(self.norm1(self.conv1(y)))
 | 
				
			||||||
 | 
					        y = self.relu(self.norm2(self.conv2(y)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.downsample is not None:
 | 
				
			||||||
 | 
					            x = self.downsample(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return self.relu(x+y)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BottleneckBlock(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, in_planes, planes, norm_fn='group', stride=1):
 | 
				
			||||||
 | 
					        super(BottleneckBlock, self).__init__()
 | 
				
			||||||
 | 
					  
 | 
				
			||||||
 | 
					        self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
 | 
				
			||||||
 | 
					        self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
 | 
				
			||||||
 | 
					        self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
 | 
				
			||||||
 | 
					        self.relu = nn.ReLU(inplace=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        num_groups = planes // 8
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if norm_fn == 'group':
 | 
				
			||||||
 | 
					            self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
 | 
				
			||||||
 | 
					            self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
 | 
				
			||||||
 | 
					            self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
 | 
				
			||||||
 | 
					            if not stride == 1:
 | 
				
			||||||
 | 
					                self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        elif norm_fn == 'batch':
 | 
				
			||||||
 | 
					            self.norm1 = nn.BatchNorm2d(planes//4)
 | 
				
			||||||
 | 
					            self.norm2 = nn.BatchNorm2d(planes//4)
 | 
				
			||||||
 | 
					            self.norm3 = nn.BatchNorm2d(planes)
 | 
				
			||||||
 | 
					            if not stride == 1:
 | 
				
			||||||
 | 
					                self.norm4 = nn.BatchNorm2d(planes)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        elif norm_fn == 'instance':
 | 
				
			||||||
 | 
					            self.norm1 = nn.InstanceNorm2d(planes//4)
 | 
				
			||||||
 | 
					            self.norm2 = nn.InstanceNorm2d(planes//4)
 | 
				
			||||||
 | 
					            self.norm3 = nn.InstanceNorm2d(planes)
 | 
				
			||||||
 | 
					            if not stride == 1:
 | 
				
			||||||
 | 
					                self.norm4 = nn.InstanceNorm2d(planes)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        elif norm_fn == 'none':
 | 
				
			||||||
 | 
					            self.norm1 = nn.Sequential()
 | 
				
			||||||
 | 
					            self.norm2 = nn.Sequential()
 | 
				
			||||||
 | 
					            self.norm3 = nn.Sequential()
 | 
				
			||||||
 | 
					            if not stride == 1:
 | 
				
			||||||
 | 
					                self.norm4 = nn.Sequential()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if stride == 1:
 | 
				
			||||||
 | 
					            self.downsample = None
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        else:    
 | 
				
			||||||
 | 
					            self.downsample = nn.Sequential(
 | 
				
			||||||
 | 
					                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					        y = x
 | 
				
			||||||
 | 
					        y = self.relu(self.norm1(self.conv1(y)))
 | 
				
			||||||
 | 
					        y = self.relu(self.norm2(self.conv2(y)))
 | 
				
			||||||
 | 
					        y = self.relu(self.norm3(self.conv3(y)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.downsample is not None:
 | 
				
			||||||
 | 
					            x = self.downsample(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return self.relu(x+y)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BasicEncoder(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
 | 
				
			||||||
 | 
					        super(BasicEncoder, self).__init__()
 | 
				
			||||||
 | 
					        self.norm_fn = norm_fn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.norm_fn == 'group':
 | 
				
			||||||
 | 
					            self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					        elif self.norm_fn == 'batch':
 | 
				
			||||||
 | 
					            self.norm1 = nn.BatchNorm2d(64)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        elif self.norm_fn == 'instance':
 | 
				
			||||||
 | 
					            self.norm1 = nn.InstanceNorm2d(64)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        elif self.norm_fn == 'none':
 | 
				
			||||||
 | 
					            self.norm1 = nn.Sequential()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
 | 
				
			||||||
 | 
					        self.relu1 = nn.ReLU(inplace=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.in_planes = 64
 | 
				
			||||||
 | 
					        self.layer1 = self._make_layer(64,  stride=1)
 | 
				
			||||||
 | 
					        self.layer2 = self._make_layer(96, stride=2)
 | 
				
			||||||
 | 
					        self.layer3 = self._make_layer(128, stride=2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # output convolution
 | 
				
			||||||
 | 
					        self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.dropout = None
 | 
				
			||||||
 | 
					        if dropout > 0:
 | 
				
			||||||
 | 
					            self.dropout = nn.Dropout2d(p=dropout)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for m in self.modules():
 | 
				
			||||||
 | 
					            if isinstance(m, nn.Conv2d):
 | 
				
			||||||
 | 
					                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
 | 
				
			||||||
 | 
					            elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
 | 
				
			||||||
 | 
					                if m.weight is not None:
 | 
				
			||||||
 | 
					                    nn.init.constant_(m.weight, 1)
 | 
				
			||||||
 | 
					                if m.bias is not None:
 | 
				
			||||||
 | 
					                    nn.init.constant_(m.bias, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _make_layer(self, dim, stride=1):
 | 
				
			||||||
 | 
					        layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
 | 
				
			||||||
 | 
					        layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
 | 
				
			||||||
 | 
					        layers = (layer1, layer2)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        self.in_planes = dim
 | 
				
			||||||
 | 
					        return nn.Sequential(*layers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # if input is list, combine batch dimension
 | 
				
			||||||
 | 
					        is_list = isinstance(x, tuple) or isinstance(x, list)
 | 
				
			||||||
 | 
					        if is_list:
 | 
				
			||||||
 | 
					            batch_dim = x[0].shape[0]
 | 
				
			||||||
 | 
					            x = torch.cat(x, dim=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        x = self.conv1(x)
 | 
				
			||||||
 | 
					        x = self.norm1(x)
 | 
				
			||||||
 | 
					        x = self.relu1(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        x = self.layer1(x)
 | 
				
			||||||
 | 
					        x = self.layer2(x)
 | 
				
			||||||
 | 
					        x = self.layer3(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        x = self.conv2(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.training and self.dropout is not None:
 | 
				
			||||||
 | 
					            x = self.dropout(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if is_list:
 | 
				
			||||||
 | 
					            x = torch.split(x, [batch_dim, batch_dim], dim=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class SmallEncoder(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
 | 
				
			||||||
 | 
					        super(SmallEncoder, self).__init__()
 | 
				
			||||||
 | 
					        self.norm_fn = norm_fn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.norm_fn == 'group':
 | 
				
			||||||
 | 
					            self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					        elif self.norm_fn == 'batch':
 | 
				
			||||||
 | 
					            self.norm1 = nn.BatchNorm2d(32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        elif self.norm_fn == 'instance':
 | 
				
			||||||
 | 
					            self.norm1 = nn.InstanceNorm2d(32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        elif self.norm_fn == 'none':
 | 
				
			||||||
 | 
					            self.norm1 = nn.Sequential()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
 | 
				
			||||||
 | 
					        self.relu1 = nn.ReLU(inplace=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.in_planes = 32
 | 
				
			||||||
 | 
					        self.layer1 = self._make_layer(32,  stride=1)
 | 
				
			||||||
 | 
					        self.layer2 = self._make_layer(64, stride=2)
 | 
				
			||||||
 | 
					        self.layer3 = self._make_layer(96, stride=2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.dropout = None
 | 
				
			||||||
 | 
					        if dropout > 0:
 | 
				
			||||||
 | 
					            self.dropout = nn.Dropout2d(p=dropout)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for m in self.modules():
 | 
				
			||||||
 | 
					            if isinstance(m, nn.Conv2d):
 | 
				
			||||||
 | 
					                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
 | 
				
			||||||
 | 
					            elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
 | 
				
			||||||
 | 
					                if m.weight is not None:
 | 
				
			||||||
 | 
					                    nn.init.constant_(m.weight, 1)
 | 
				
			||||||
 | 
					                if m.bias is not None:
 | 
				
			||||||
 | 
					                    nn.init.constant_(m.bias, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _make_layer(self, dim, stride=1):
 | 
				
			||||||
 | 
					        layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
 | 
				
			||||||
 | 
					        layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
 | 
				
			||||||
 | 
					        layers = (layer1, layer2)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					        self.in_planes = dim
 | 
				
			||||||
 | 
					        return nn.Sequential(*layers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # if input is list, combine batch dimension
 | 
				
			||||||
 | 
					        is_list = isinstance(x, tuple) or isinstance(x, list)
 | 
				
			||||||
 | 
					        if is_list:
 | 
				
			||||||
 | 
					            batch_dim = x[0].shape[0]
 | 
				
			||||||
 | 
					            x = torch.cat(x, dim=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        x = self.conv1(x)
 | 
				
			||||||
 | 
					        x = self.norm1(x)
 | 
				
			||||||
 | 
					        x = self.relu1(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        x = self.layer1(x)
 | 
				
			||||||
 | 
					        x = self.layer2(x)
 | 
				
			||||||
 | 
					        x = self.layer3(x)
 | 
				
			||||||
 | 
					        x = self.conv2(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.training and self.dropout is not None:
 | 
				
			||||||
 | 
					            x = self.dropout(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if is_list:
 | 
				
			||||||
 | 
					            x = torch.split(x, [batch_dim, batch_dim], dim=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return x
 | 
				
			||||||
							
								
								
									
										144
									
								
								preprocessing/raft/raft.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										144
									
								
								preprocessing/raft/raft.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,144 @@
 | 
				
			|||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.nn as nn
 | 
				
			||||||
 | 
					import torch.nn.functional as F
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .update import BasicUpdateBlock, SmallUpdateBlock
 | 
				
			||||||
 | 
					from .extractor import BasicEncoder, SmallEncoder
 | 
				
			||||||
 | 
					from .corr import CorrBlock, AlternateCorrBlock
 | 
				
			||||||
 | 
					from .utils.utils import bilinear_sampler, coords_grid, upflow8
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					try:
 | 
				
			||||||
 | 
					    autocast = torch.amp.autocast
 | 
				
			||||||
 | 
					except:
 | 
				
			||||||
 | 
					    # dummy autocast for PyTorch < 1.6
 | 
				
			||||||
 | 
					    class autocast:
 | 
				
			||||||
 | 
					        def __init__(self, enabled):
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					        def __enter__(self):
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					        def __exit__(self, *args):
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class RAFT(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, args):
 | 
				
			||||||
 | 
					        super(RAFT, self).__init__()
 | 
				
			||||||
 | 
					        self.args = args
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if args.small:
 | 
				
			||||||
 | 
					            self.hidden_dim = hdim = 96
 | 
				
			||||||
 | 
					            self.context_dim = cdim = 64
 | 
				
			||||||
 | 
					            args.corr_levels = 4
 | 
				
			||||||
 | 
					            args.corr_radius = 3
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.hidden_dim = hdim = 128
 | 
				
			||||||
 | 
					            self.context_dim = cdim = 128
 | 
				
			||||||
 | 
					            args.corr_levels = 4
 | 
				
			||||||
 | 
					            args.corr_radius = 4
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if 'dropout' not in self.args:
 | 
				
			||||||
 | 
					            self.args.dropout = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if 'alternate_corr' not in self.args:
 | 
				
			||||||
 | 
					            self.args.alternate_corr = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # feature network, context network, and update block
 | 
				
			||||||
 | 
					        if args.small:
 | 
				
			||||||
 | 
					            self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)        
 | 
				
			||||||
 | 
					            self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
 | 
				
			||||||
 | 
					            self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)        
 | 
				
			||||||
 | 
					            self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
 | 
				
			||||||
 | 
					            self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def freeze_bn(self):
 | 
				
			||||||
 | 
					        for m in self.modules():
 | 
				
			||||||
 | 
					            if isinstance(m, nn.BatchNorm2d):
 | 
				
			||||||
 | 
					                m.eval()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def initialize_flow(self, img):
 | 
				
			||||||
 | 
					        """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
 | 
				
			||||||
 | 
					        N, C, H, W = img.shape
 | 
				
			||||||
 | 
					        coords0 = coords_grid(N, H//8, W//8).to(img.device)
 | 
				
			||||||
 | 
					        coords1 = coords_grid(N, H//8, W//8).to(img.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # optical flow computed as difference: flow = coords1 - coords0
 | 
				
			||||||
 | 
					        return coords0, coords1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def upsample_flow(self, flow, mask):
 | 
				
			||||||
 | 
					        """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
 | 
				
			||||||
 | 
					        N, _, H, W = flow.shape
 | 
				
			||||||
 | 
					        mask = mask.view(N, 1, 9, 8, 8, H, W)
 | 
				
			||||||
 | 
					        mask = torch.softmax(mask, dim=2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        up_flow = F.unfold(8 * flow, [3,3], padding=1)
 | 
				
			||||||
 | 
					        up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        up_flow = torch.sum(mask * up_flow, dim=2)
 | 
				
			||||||
 | 
					        up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
 | 
				
			||||||
 | 
					        return up_flow.reshape(N, 2, 8*H, 8*W)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
 | 
				
			||||||
 | 
					        """ Estimate optical flow between pair of frames """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        image1 = 2 * (image1 / 255.0) - 1.0
 | 
				
			||||||
 | 
					        image2 = 2 * (image2 / 255.0) - 1.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        image1 = image1.contiguous()
 | 
				
			||||||
 | 
					        image2 = image2.contiguous()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        hdim = self.hidden_dim
 | 
				
			||||||
 | 
					        cdim = self.context_dim
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # run the feature network
 | 
				
			||||||
 | 
					        with autocast('cuda', enabled=self.args.mixed_precision):
 | 
				
			||||||
 | 
					            fmap1, fmap2 = self.fnet([image1, image2])        
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        fmap1 = fmap1.float()
 | 
				
			||||||
 | 
					        fmap2 = fmap2.float()
 | 
				
			||||||
 | 
					        if self.args.alternate_corr:
 | 
				
			||||||
 | 
					            corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # run the context network
 | 
				
			||||||
 | 
					        with autocast('cuda', enabled=self.args.mixed_precision):
 | 
				
			||||||
 | 
					            cnet = self.cnet(image1)
 | 
				
			||||||
 | 
					            net, inp = torch.split(cnet, [hdim, cdim], dim=1)
 | 
				
			||||||
 | 
					            net = torch.tanh(net)
 | 
				
			||||||
 | 
					            inp = torch.relu(inp)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        coords0, coords1 = self.initialize_flow(image1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if flow_init is not None:
 | 
				
			||||||
 | 
					            coords1 = coords1 + flow_init
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        flow_predictions = []
 | 
				
			||||||
 | 
					        for itr in range(iters):
 | 
				
			||||||
 | 
					            coords1 = coords1.detach()
 | 
				
			||||||
 | 
					            corr = corr_fn(coords1) # index correlation volume
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            flow = coords1 - coords0
 | 
				
			||||||
 | 
					            with autocast('cuda', enabled=self.args.mixed_precision):
 | 
				
			||||||
 | 
					                net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # F(t+1) = F(t) + \Delta(t)
 | 
				
			||||||
 | 
					            coords1 = coords1 + delta_flow
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # upsample predictions
 | 
				
			||||||
 | 
					            if up_mask is None:
 | 
				
			||||||
 | 
					                flow_up = upflow8(coords1 - coords0)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                flow_up = self.upsample_flow(coords1 - coords0, up_mask)
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            flow_predictions.append(flow_up)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if test_mode:
 | 
				
			||||||
 | 
					            return coords1 - coords0, flow_up
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					        return flow_predictions
 | 
				
			||||||
							
								
								
									
										139
									
								
								preprocessing/raft/update.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										139
									
								
								preprocessing/raft/update.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,139 @@
 | 
				
			|||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.nn as nn
 | 
				
			||||||
 | 
					import torch.nn.functional as F
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FlowHead(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, input_dim=128, hidden_dim=256):
 | 
				
			||||||
 | 
					        super(FlowHead, self).__init__()
 | 
				
			||||||
 | 
					        self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
 | 
				
			||||||
 | 
					        self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
 | 
				
			||||||
 | 
					        self.relu = nn.ReLU(inplace=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					        return self.conv2(self.relu(self.conv1(x)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ConvGRU(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, hidden_dim=128, input_dim=192+128):
 | 
				
			||||||
 | 
					        super(ConvGRU, self).__init__()
 | 
				
			||||||
 | 
					        self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
 | 
				
			||||||
 | 
					        self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
 | 
				
			||||||
 | 
					        self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, h, x):
 | 
				
			||||||
 | 
					        hx = torch.cat([h, x], dim=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        z = torch.sigmoid(self.convz(hx))
 | 
				
			||||||
 | 
					        r = torch.sigmoid(self.convr(hx))
 | 
				
			||||||
 | 
					        q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        h = (1-z) * h + z * q
 | 
				
			||||||
 | 
					        return h
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class SepConvGRU(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, hidden_dim=128, input_dim=192+128):
 | 
				
			||||||
 | 
					        super(SepConvGRU, self).__init__()
 | 
				
			||||||
 | 
					        self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
 | 
				
			||||||
 | 
					        self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
 | 
				
			||||||
 | 
					        self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
 | 
				
			||||||
 | 
					        self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
 | 
				
			||||||
 | 
					        self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, h, x):
 | 
				
			||||||
 | 
					        # horizontal
 | 
				
			||||||
 | 
					        hx = torch.cat([h, x], dim=1)
 | 
				
			||||||
 | 
					        z = torch.sigmoid(self.convz1(hx))
 | 
				
			||||||
 | 
					        r = torch.sigmoid(self.convr1(hx))
 | 
				
			||||||
 | 
					        q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))        
 | 
				
			||||||
 | 
					        h = (1-z) * h + z * q
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # vertical
 | 
				
			||||||
 | 
					        hx = torch.cat([h, x], dim=1)
 | 
				
			||||||
 | 
					        z = torch.sigmoid(self.convz2(hx))
 | 
				
			||||||
 | 
					        r = torch.sigmoid(self.convr2(hx))
 | 
				
			||||||
 | 
					        q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))       
 | 
				
			||||||
 | 
					        h = (1-z) * h + z * q
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return h
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class SmallMotionEncoder(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, args):
 | 
				
			||||||
 | 
					        super(SmallMotionEncoder, self).__init__()
 | 
				
			||||||
 | 
					        cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
 | 
				
			||||||
 | 
					        self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
 | 
				
			||||||
 | 
					        self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
 | 
				
			||||||
 | 
					        self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
 | 
				
			||||||
 | 
					        self.conv = nn.Conv2d(128, 80, 3, padding=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, flow, corr):
 | 
				
			||||||
 | 
					        cor = F.relu(self.convc1(corr))
 | 
				
			||||||
 | 
					        flo = F.relu(self.convf1(flow))
 | 
				
			||||||
 | 
					        flo = F.relu(self.convf2(flo))
 | 
				
			||||||
 | 
					        cor_flo = torch.cat([cor, flo], dim=1)
 | 
				
			||||||
 | 
					        out = F.relu(self.conv(cor_flo))
 | 
				
			||||||
 | 
					        return torch.cat([out, flow], dim=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BasicMotionEncoder(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, args):
 | 
				
			||||||
 | 
					        super(BasicMotionEncoder, self).__init__()
 | 
				
			||||||
 | 
					        cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
 | 
				
			||||||
 | 
					        self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
 | 
				
			||||||
 | 
					        self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
 | 
				
			||||||
 | 
					        self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
 | 
				
			||||||
 | 
					        self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
 | 
				
			||||||
 | 
					        self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, flow, corr):
 | 
				
			||||||
 | 
					        cor = F.relu(self.convc1(corr))
 | 
				
			||||||
 | 
					        cor = F.relu(self.convc2(cor))
 | 
				
			||||||
 | 
					        flo = F.relu(self.convf1(flow))
 | 
				
			||||||
 | 
					        flo = F.relu(self.convf2(flo))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cor_flo = torch.cat([cor, flo], dim=1)
 | 
				
			||||||
 | 
					        out = F.relu(self.conv(cor_flo))
 | 
				
			||||||
 | 
					        return torch.cat([out, flow], dim=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class SmallUpdateBlock(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, args, hidden_dim=96):
 | 
				
			||||||
 | 
					        super(SmallUpdateBlock, self).__init__()
 | 
				
			||||||
 | 
					        self.encoder = SmallMotionEncoder(args)
 | 
				
			||||||
 | 
					        self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
 | 
				
			||||||
 | 
					        self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, net, inp, corr, flow):
 | 
				
			||||||
 | 
					        motion_features = self.encoder(flow, corr)
 | 
				
			||||||
 | 
					        inp = torch.cat([inp, motion_features], dim=1)
 | 
				
			||||||
 | 
					        net = self.gru(net, inp)
 | 
				
			||||||
 | 
					        delta_flow = self.flow_head(net)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return net, None, delta_flow
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BasicUpdateBlock(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, args, hidden_dim=128, input_dim=128):
 | 
				
			||||||
 | 
					        super(BasicUpdateBlock, self).__init__()
 | 
				
			||||||
 | 
					        self.args = args
 | 
				
			||||||
 | 
					        self.encoder = BasicMotionEncoder(args)
 | 
				
			||||||
 | 
					        self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
 | 
				
			||||||
 | 
					        self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.mask = nn.Sequential(
 | 
				
			||||||
 | 
					            nn.Conv2d(128, 256, 3, padding=1),
 | 
				
			||||||
 | 
					            nn.ReLU(inplace=True),
 | 
				
			||||||
 | 
					            nn.Conv2d(256, 64*9, 1, padding=0))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, net, inp, corr, flow, upsample=True):
 | 
				
			||||||
 | 
					        motion_features = self.encoder(flow, corr)
 | 
				
			||||||
 | 
					        inp = torch.cat([inp, motion_features], dim=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        net = self.gru(net, inp)
 | 
				
			||||||
 | 
					        delta_flow = self.flow_head(net)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # scale mask to balence gradients
 | 
				
			||||||
 | 
					        mask = .25 * self.mask(net)
 | 
				
			||||||
 | 
					        return net, mask, delta_flow
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
							
								
								
									
										0
									
								
								preprocessing/raft/utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								preprocessing/raft/utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										246
									
								
								preprocessing/raft/utils/augmentor.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										246
									
								
								preprocessing/raft/utils/augmentor.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,246 @@
 | 
				
			|||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import random
 | 
				
			||||||
 | 
					import math
 | 
				
			||||||
 | 
					from PIL import Image
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import cv2
 | 
				
			||||||
 | 
					cv2.setNumThreads(0)
 | 
				
			||||||
 | 
					cv2.ocl.setUseOpenCL(False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torchvision.transforms import ColorJitter
 | 
				
			||||||
 | 
					import torch.nn.functional as F
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FlowAugmentor:
 | 
				
			||||||
 | 
					    def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # spatial augmentation params
 | 
				
			||||||
 | 
					        self.crop_size = crop_size
 | 
				
			||||||
 | 
					        self.min_scale = min_scale
 | 
				
			||||||
 | 
					        self.max_scale = max_scale
 | 
				
			||||||
 | 
					        self.spatial_aug_prob = 0.8
 | 
				
			||||||
 | 
					        self.stretch_prob = 0.8
 | 
				
			||||||
 | 
					        self.max_stretch = 0.2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # flip augmentation params
 | 
				
			||||||
 | 
					        self.do_flip = do_flip
 | 
				
			||||||
 | 
					        self.h_flip_prob = 0.5
 | 
				
			||||||
 | 
					        self.v_flip_prob = 0.1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # photometric augmentation params
 | 
				
			||||||
 | 
					        self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
 | 
				
			||||||
 | 
					        self.asymmetric_color_aug_prob = 0.2
 | 
				
			||||||
 | 
					        self.eraser_aug_prob = 0.5
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def color_transform(self, img1, img2):
 | 
				
			||||||
 | 
					        """ Photometric augmentation """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # asymmetric
 | 
				
			||||||
 | 
					        if np.random.rand() < self.asymmetric_color_aug_prob:
 | 
				
			||||||
 | 
					            img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
 | 
				
			||||||
 | 
					            img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # symmetric
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            image_stack = np.concatenate([img1, img2], axis=0)
 | 
				
			||||||
 | 
					            image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
 | 
				
			||||||
 | 
					            img1, img2 = np.split(image_stack, 2, axis=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return img1, img2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def eraser_transform(self, img1, img2, bounds=[50, 100]):
 | 
				
			||||||
 | 
					        """ Occlusion augmentation """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ht, wd = img1.shape[:2]
 | 
				
			||||||
 | 
					        if np.random.rand() < self.eraser_aug_prob:
 | 
				
			||||||
 | 
					            mean_color = np.mean(img2.reshape(-1, 3), axis=0)
 | 
				
			||||||
 | 
					            for _ in range(np.random.randint(1, 3)):
 | 
				
			||||||
 | 
					                x0 = np.random.randint(0, wd)
 | 
				
			||||||
 | 
					                y0 = np.random.randint(0, ht)
 | 
				
			||||||
 | 
					                dx = np.random.randint(bounds[0], bounds[1])
 | 
				
			||||||
 | 
					                dy = np.random.randint(bounds[0], bounds[1])
 | 
				
			||||||
 | 
					                img2[y0:y0+dy, x0:x0+dx, :] = mean_color
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return img1, img2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def spatial_transform(self, img1, img2, flow):
 | 
				
			||||||
 | 
					        # randomly sample scale
 | 
				
			||||||
 | 
					        ht, wd = img1.shape[:2]
 | 
				
			||||||
 | 
					        min_scale = np.maximum(
 | 
				
			||||||
 | 
					            (self.crop_size[0] + 8) / float(ht), 
 | 
				
			||||||
 | 
					            (self.crop_size[1] + 8) / float(wd))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
 | 
				
			||||||
 | 
					        scale_x = scale
 | 
				
			||||||
 | 
					        scale_y = scale
 | 
				
			||||||
 | 
					        if np.random.rand() < self.stretch_prob:
 | 
				
			||||||
 | 
					            scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
 | 
				
			||||||
 | 
					            scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        scale_x = np.clip(scale_x, min_scale, None)
 | 
				
			||||||
 | 
					        scale_y = np.clip(scale_y, min_scale, None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if np.random.rand() < self.spatial_aug_prob:
 | 
				
			||||||
 | 
					            # rescale the images
 | 
				
			||||||
 | 
					            img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
 | 
				
			||||||
 | 
					            img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
 | 
				
			||||||
 | 
					            flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
 | 
				
			||||||
 | 
					            flow = flow * [scale_x, scale_y]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.do_flip:
 | 
				
			||||||
 | 
					            if np.random.rand() < self.h_flip_prob: # h-flip
 | 
				
			||||||
 | 
					                img1 = img1[:, ::-1]
 | 
				
			||||||
 | 
					                img2 = img2[:, ::-1]
 | 
				
			||||||
 | 
					                flow = flow[:, ::-1] * [-1.0, 1.0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if np.random.rand() < self.v_flip_prob: # v-flip
 | 
				
			||||||
 | 
					                img1 = img1[::-1, :]
 | 
				
			||||||
 | 
					                img2 = img2[::-1, :]
 | 
				
			||||||
 | 
					                flow = flow[::-1, :] * [1.0, -1.0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
 | 
				
			||||||
 | 
					        x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
 | 
				
			||||||
 | 
					        img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
 | 
				
			||||||
 | 
					        flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return img1, img2, flow
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(self, img1, img2, flow):
 | 
				
			||||||
 | 
					        img1, img2 = self.color_transform(img1, img2)
 | 
				
			||||||
 | 
					        img1, img2 = self.eraser_transform(img1, img2)
 | 
				
			||||||
 | 
					        img1, img2, flow = self.spatial_transform(img1, img2, flow)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        img1 = np.ascontiguousarray(img1)
 | 
				
			||||||
 | 
					        img2 = np.ascontiguousarray(img2)
 | 
				
			||||||
 | 
					        flow = np.ascontiguousarray(flow)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return img1, img2, flow
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class SparseFlowAugmentor:
 | 
				
			||||||
 | 
					    def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
 | 
				
			||||||
 | 
					        # spatial augmentation params
 | 
				
			||||||
 | 
					        self.crop_size = crop_size
 | 
				
			||||||
 | 
					        self.min_scale = min_scale
 | 
				
			||||||
 | 
					        self.max_scale = max_scale
 | 
				
			||||||
 | 
					        self.spatial_aug_prob = 0.8
 | 
				
			||||||
 | 
					        self.stretch_prob = 0.8
 | 
				
			||||||
 | 
					        self.max_stretch = 0.2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # flip augmentation params
 | 
				
			||||||
 | 
					        self.do_flip = do_flip
 | 
				
			||||||
 | 
					        self.h_flip_prob = 0.5
 | 
				
			||||||
 | 
					        self.v_flip_prob = 0.1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # photometric augmentation params
 | 
				
			||||||
 | 
					        self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
 | 
				
			||||||
 | 
					        self.asymmetric_color_aug_prob = 0.2
 | 
				
			||||||
 | 
					        self.eraser_aug_prob = 0.5
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					    def color_transform(self, img1, img2):
 | 
				
			||||||
 | 
					        image_stack = np.concatenate([img1, img2], axis=0)
 | 
				
			||||||
 | 
					        image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
 | 
				
			||||||
 | 
					        img1, img2 = np.split(image_stack, 2, axis=0)
 | 
				
			||||||
 | 
					        return img1, img2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def eraser_transform(self, img1, img2):
 | 
				
			||||||
 | 
					        ht, wd = img1.shape[:2]
 | 
				
			||||||
 | 
					        if np.random.rand() < self.eraser_aug_prob:
 | 
				
			||||||
 | 
					            mean_color = np.mean(img2.reshape(-1, 3), axis=0)
 | 
				
			||||||
 | 
					            for _ in range(np.random.randint(1, 3)):
 | 
				
			||||||
 | 
					                x0 = np.random.randint(0, wd)
 | 
				
			||||||
 | 
					                y0 = np.random.randint(0, ht)
 | 
				
			||||||
 | 
					                dx = np.random.randint(50, 100)
 | 
				
			||||||
 | 
					                dy = np.random.randint(50, 100)
 | 
				
			||||||
 | 
					                img2[y0:y0+dy, x0:x0+dx, :] = mean_color
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return img1, img2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
 | 
				
			||||||
 | 
					        ht, wd = flow.shape[:2]
 | 
				
			||||||
 | 
					        coords = np.meshgrid(np.arange(wd), np.arange(ht), indexing='ij')
 | 
				
			||||||
 | 
					        coords = np.stack(coords, axis=-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        coords = coords.reshape(-1, 2).astype(np.float32)
 | 
				
			||||||
 | 
					        flow = flow.reshape(-1, 2).astype(np.float32)
 | 
				
			||||||
 | 
					        valid = valid.reshape(-1).astype(np.float32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        coords0 = coords[valid>=1]
 | 
				
			||||||
 | 
					        flow0 = flow[valid>=1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ht1 = int(round(ht * fy))
 | 
				
			||||||
 | 
					        wd1 = int(round(wd * fx))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        coords1 = coords0 * [fx, fy]
 | 
				
			||||||
 | 
					        flow1 = flow0 * [fx, fy]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        xx = np.round(coords1[:,0]).astype(np.int32)
 | 
				
			||||||
 | 
					        yy = np.round(coords1[:,1]).astype(np.int32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
 | 
				
			||||||
 | 
					        xx = xx[v]
 | 
				
			||||||
 | 
					        yy = yy[v]
 | 
				
			||||||
 | 
					        flow1 = flow1[v]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
 | 
				
			||||||
 | 
					        valid_img = np.zeros([ht1, wd1], dtype=np.int32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        flow_img[yy, xx] = flow1
 | 
				
			||||||
 | 
					        valid_img[yy, xx] = 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return flow_img, valid_img
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def spatial_transform(self, img1, img2, flow, valid):
 | 
				
			||||||
 | 
					        # randomly sample scale
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ht, wd = img1.shape[:2]
 | 
				
			||||||
 | 
					        min_scale = np.maximum(
 | 
				
			||||||
 | 
					            (self.crop_size[0] + 1) / float(ht), 
 | 
				
			||||||
 | 
					            (self.crop_size[1] + 1) / float(wd))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
 | 
				
			||||||
 | 
					        scale_x = np.clip(scale, min_scale, None)
 | 
				
			||||||
 | 
					        scale_y = np.clip(scale, min_scale, None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if np.random.rand() < self.spatial_aug_prob:
 | 
				
			||||||
 | 
					            # rescale the images
 | 
				
			||||||
 | 
					            img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
 | 
				
			||||||
 | 
					            img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
 | 
				
			||||||
 | 
					            flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.do_flip:
 | 
				
			||||||
 | 
					            if np.random.rand() < 0.5: # h-flip
 | 
				
			||||||
 | 
					                img1 = img1[:, ::-1]
 | 
				
			||||||
 | 
					                img2 = img2[:, ::-1]
 | 
				
			||||||
 | 
					                flow = flow[:, ::-1] * [-1.0, 1.0]
 | 
				
			||||||
 | 
					                valid = valid[:, ::-1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        margin_y = 20
 | 
				
			||||||
 | 
					        margin_x = 50
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
 | 
				
			||||||
 | 
					        x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
 | 
				
			||||||
 | 
					        x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
 | 
				
			||||||
 | 
					        img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
 | 
				
			||||||
 | 
					        flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
 | 
				
			||||||
 | 
					        valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
 | 
				
			||||||
 | 
					        return img1, img2, flow, valid
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(self, img1, img2, flow, valid):
 | 
				
			||||||
 | 
					        img1, img2 = self.color_transform(img1, img2)
 | 
				
			||||||
 | 
					        img1, img2 = self.eraser_transform(img1, img2)
 | 
				
			||||||
 | 
					        img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        img1 = np.ascontiguousarray(img1)
 | 
				
			||||||
 | 
					        img2 = np.ascontiguousarray(img2)
 | 
				
			||||||
 | 
					        flow = np.ascontiguousarray(flow)
 | 
				
			||||||
 | 
					        valid = np.ascontiguousarray(valid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return img1, img2, flow, valid
 | 
				
			||||||
							
								
								
									
										132
									
								
								preprocessing/raft/utils/flow_viz.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										132
									
								
								preprocessing/raft/utils/flow_viz.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,132 @@
 | 
				
			|||||||
 | 
					# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# MIT License
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright (c) 2018 Tom Runia
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Permission is hereby granted, free of charge, to any person obtaining a copy
 | 
				
			||||||
 | 
					# of this software and associated documentation files (the "Software"), to deal
 | 
				
			||||||
 | 
					# in the Software without restriction, including without limitation the rights
 | 
				
			||||||
 | 
					# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 | 
				
			||||||
 | 
					# copies of the Software, and to permit persons to whom the Software is
 | 
				
			||||||
 | 
					# furnished to do so, subject to conditions.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Author: Tom Runia
 | 
				
			||||||
 | 
					# Date Created: 2018-08-03
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def make_colorwheel():
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Generates a color wheel for optical flow visualization as presented in:
 | 
				
			||||||
 | 
					        Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
 | 
				
			||||||
 | 
					        URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Code follows the original C++ source code of Daniel Scharstein.
 | 
				
			||||||
 | 
					    Code follows the the Matlab source code of Deqing Sun.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					        np.ndarray: Color wheel
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    RY = 15
 | 
				
			||||||
 | 
					    YG = 6
 | 
				
			||||||
 | 
					    GC = 4
 | 
				
			||||||
 | 
					    CB = 11
 | 
				
			||||||
 | 
					    BM = 13
 | 
				
			||||||
 | 
					    MR = 6
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ncols = RY + YG + GC + CB + BM + MR
 | 
				
			||||||
 | 
					    colorwheel = np.zeros((ncols, 3))
 | 
				
			||||||
 | 
					    col = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # RY
 | 
				
			||||||
 | 
					    colorwheel[0:RY, 0] = 255
 | 
				
			||||||
 | 
					    colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
 | 
				
			||||||
 | 
					    col = col+RY
 | 
				
			||||||
 | 
					    # YG
 | 
				
			||||||
 | 
					    colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
 | 
				
			||||||
 | 
					    colorwheel[col:col+YG, 1] = 255
 | 
				
			||||||
 | 
					    col = col+YG
 | 
				
			||||||
 | 
					    # GC
 | 
				
			||||||
 | 
					    colorwheel[col:col+GC, 1] = 255
 | 
				
			||||||
 | 
					    colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
 | 
				
			||||||
 | 
					    col = col+GC
 | 
				
			||||||
 | 
					    # CB
 | 
				
			||||||
 | 
					    colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
 | 
				
			||||||
 | 
					    colorwheel[col:col+CB, 2] = 255
 | 
				
			||||||
 | 
					    col = col+CB
 | 
				
			||||||
 | 
					    # BM
 | 
				
			||||||
 | 
					    colorwheel[col:col+BM, 2] = 255
 | 
				
			||||||
 | 
					    colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
 | 
				
			||||||
 | 
					    col = col+BM
 | 
				
			||||||
 | 
					    # MR
 | 
				
			||||||
 | 
					    colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
 | 
				
			||||||
 | 
					    colorwheel[col:col+MR, 0] = 255
 | 
				
			||||||
 | 
					    return colorwheel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def flow_uv_to_colors(u, v, convert_to_bgr=False):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Applies the flow color wheel to (possibly clipped) flow components u and v.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    According to the C++ source code of Daniel Scharstein
 | 
				
			||||||
 | 
					    According to the Matlab source code of Deqing Sun
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        u (np.ndarray): Input horizontal flow of shape [H,W]
 | 
				
			||||||
 | 
					        v (np.ndarray): Input vertical flow of shape [H,W]
 | 
				
			||||||
 | 
					        convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					        np.ndarray: Flow visualization image of shape [H,W,3]
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
 | 
				
			||||||
 | 
					    colorwheel = make_colorwheel()  # shape [55x3]
 | 
				
			||||||
 | 
					    ncols = colorwheel.shape[0]
 | 
				
			||||||
 | 
					    rad = np.sqrt(np.square(u) + np.square(v))
 | 
				
			||||||
 | 
					    a = np.arctan2(-v, -u)/np.pi
 | 
				
			||||||
 | 
					    fk = (a+1) / 2*(ncols-1)
 | 
				
			||||||
 | 
					    k0 = np.floor(fk).astype(np.int32)
 | 
				
			||||||
 | 
					    k1 = k0 + 1
 | 
				
			||||||
 | 
					    k1[k1 == ncols] = 0
 | 
				
			||||||
 | 
					    f = fk - k0
 | 
				
			||||||
 | 
					    for i in range(colorwheel.shape[1]):
 | 
				
			||||||
 | 
					        tmp = colorwheel[:,i]
 | 
				
			||||||
 | 
					        col0 = tmp[k0] / 255.0
 | 
				
			||||||
 | 
					        col1 = tmp[k1] / 255.0
 | 
				
			||||||
 | 
					        col = (1-f)*col0 + f*col1
 | 
				
			||||||
 | 
					        idx = (rad <= 1)
 | 
				
			||||||
 | 
					        col[idx]  = 1 - rad[idx] * (1-col[idx])
 | 
				
			||||||
 | 
					        col[~idx] = col[~idx] * 0.75   # out of range
 | 
				
			||||||
 | 
					        # Note the 2-i => BGR instead of RGB
 | 
				
			||||||
 | 
					        ch_idx = 2-i if convert_to_bgr else i
 | 
				
			||||||
 | 
					        flow_image[:,:,ch_idx] = np.floor(255 * col)
 | 
				
			||||||
 | 
					    return flow_image
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Expects a two dimensional flow image of shape.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
 | 
				
			||||||
 | 
					        clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
 | 
				
			||||||
 | 
					        convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					        np.ndarray: Flow visualization image of shape [H,W,3]
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    assert flow_uv.ndim == 3, 'input flow must have three dimensions'
 | 
				
			||||||
 | 
					    assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
 | 
				
			||||||
 | 
					    if clip_flow is not None:
 | 
				
			||||||
 | 
					        flow_uv = np.clip(flow_uv, 0, clip_flow)
 | 
				
			||||||
 | 
					    u = flow_uv[:,:,0]
 | 
				
			||||||
 | 
					    v = flow_uv[:,:,1]
 | 
				
			||||||
 | 
					    rad = np.sqrt(np.square(u) + np.square(v))
 | 
				
			||||||
 | 
					    rad_max = np.max(rad)
 | 
				
			||||||
 | 
					    epsilon = 1e-5
 | 
				
			||||||
 | 
					    u = u / (rad_max + epsilon)
 | 
				
			||||||
 | 
					    v = v / (rad_max + epsilon)
 | 
				
			||||||
 | 
					    return flow_uv_to_colors(u, v, convert_to_bgr)
 | 
				
			||||||
							
								
								
									
										137
									
								
								preprocessing/raft/utils/frame_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										137
									
								
								preprocessing/raft/utils/frame_utils.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,137 @@
 | 
				
			|||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					from PIL import Image
 | 
				
			||||||
 | 
					from os.path import *
 | 
				
			||||||
 | 
					import re
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import cv2
 | 
				
			||||||
 | 
					cv2.setNumThreads(0)
 | 
				
			||||||
 | 
					cv2.ocl.setUseOpenCL(False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TAG_CHAR = np.array([202021.25], np.float32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def readFlow(fn):
 | 
				
			||||||
 | 
					    """ Read .flo file in Middlebury format"""
 | 
				
			||||||
 | 
					    # Code adapted from:
 | 
				
			||||||
 | 
					    # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # WARNING: this will work on little-endian architectures (eg Intel x86) only!
 | 
				
			||||||
 | 
					    # print 'fn = %s'%(fn)
 | 
				
			||||||
 | 
					    with open(fn, 'rb') as f:
 | 
				
			||||||
 | 
					        magic = np.fromfile(f, np.float32, count=1)
 | 
				
			||||||
 | 
					        if 202021.25 != magic:
 | 
				
			||||||
 | 
					            print('Magic number incorrect. Invalid .flo file')
 | 
				
			||||||
 | 
					            return None
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            w = np.fromfile(f, np.int32, count=1)
 | 
				
			||||||
 | 
					            h = np.fromfile(f, np.int32, count=1)
 | 
				
			||||||
 | 
					            # print 'Reading %d x %d flo file\n' % (w, h)
 | 
				
			||||||
 | 
					            data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
 | 
				
			||||||
 | 
					            # Reshape data into 3D array (columns, rows, bands)
 | 
				
			||||||
 | 
					            # The reshape here is for visualization, the original code is (w,h,2)
 | 
				
			||||||
 | 
					            return np.resize(data, (int(h), int(w), 2))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def readPFM(file):
 | 
				
			||||||
 | 
					    file = open(file, 'rb')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    color = None
 | 
				
			||||||
 | 
					    width = None
 | 
				
			||||||
 | 
					    height = None
 | 
				
			||||||
 | 
					    scale = None
 | 
				
			||||||
 | 
					    endian = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    header = file.readline().rstrip()
 | 
				
			||||||
 | 
					    if header == b'PF':
 | 
				
			||||||
 | 
					        color = True
 | 
				
			||||||
 | 
					    elif header == b'Pf':
 | 
				
			||||||
 | 
					        color = False
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        raise Exception('Not a PFM file.')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
 | 
				
			||||||
 | 
					    if dim_match:
 | 
				
			||||||
 | 
					        width, height = map(int, dim_match.groups())
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        raise Exception('Malformed PFM header.')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    scale = float(file.readline().rstrip())
 | 
				
			||||||
 | 
					    if scale < 0: # little-endian
 | 
				
			||||||
 | 
					        endian = '<'
 | 
				
			||||||
 | 
					        scale = -scale
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        endian = '>' # big-endian
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    data = np.fromfile(file, endian + 'f')
 | 
				
			||||||
 | 
					    shape = (height, width, 3) if color else (height, width)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    data = np.reshape(data, shape)
 | 
				
			||||||
 | 
					    data = np.flipud(data)
 | 
				
			||||||
 | 
					    return data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def writeFlow(filename,uv,v=None):
 | 
				
			||||||
 | 
					    """ Write optical flow to file.
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    If v is None, uv is assumed to contain both u and v channels,
 | 
				
			||||||
 | 
					    stacked in depth.
 | 
				
			||||||
 | 
					    Original code by Deqing Sun, adapted from Daniel Scharstein.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    nBands = 2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if v is None:
 | 
				
			||||||
 | 
					        assert(uv.ndim == 3)
 | 
				
			||||||
 | 
					        assert(uv.shape[2] == 2)
 | 
				
			||||||
 | 
					        u = uv[:,:,0]
 | 
				
			||||||
 | 
					        v = uv[:,:,1]
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        u = uv
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert(u.shape == v.shape)
 | 
				
			||||||
 | 
					    height,width = u.shape
 | 
				
			||||||
 | 
					    f = open(filename,'wb')
 | 
				
			||||||
 | 
					    # write the header
 | 
				
			||||||
 | 
					    f.write(TAG_CHAR)
 | 
				
			||||||
 | 
					    np.array(width).astype(np.int32).tofile(f)
 | 
				
			||||||
 | 
					    np.array(height).astype(np.int32).tofile(f)
 | 
				
			||||||
 | 
					    # arrange into matrix form
 | 
				
			||||||
 | 
					    tmp = np.zeros((height, width*nBands))
 | 
				
			||||||
 | 
					    tmp[:,np.arange(width)*2] = u
 | 
				
			||||||
 | 
					    tmp[:,np.arange(width)*2 + 1] = v
 | 
				
			||||||
 | 
					    tmp.astype(np.float32).tofile(f)
 | 
				
			||||||
 | 
					    f.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def readFlowKITTI(filename):
 | 
				
			||||||
 | 
					    flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
 | 
				
			||||||
 | 
					    flow = flow[:,:,::-1].astype(np.float32)
 | 
				
			||||||
 | 
					    flow, valid = flow[:, :, :2], flow[:, :, 2]
 | 
				
			||||||
 | 
					    flow = (flow - 2**15) / 64.0
 | 
				
			||||||
 | 
					    return flow, valid
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def readDispKITTI(filename):
 | 
				
			||||||
 | 
					    disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
 | 
				
			||||||
 | 
					    valid = disp > 0.0
 | 
				
			||||||
 | 
					    flow = np.stack([-disp, np.zeros_like(disp)], -1)
 | 
				
			||||||
 | 
					    return flow, valid
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def writeFlowKITTI(filename, uv):
 | 
				
			||||||
 | 
					    uv = 64.0 * uv + 2**15
 | 
				
			||||||
 | 
					    valid = np.ones([uv.shape[0], uv.shape[1], 1])
 | 
				
			||||||
 | 
					    uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
 | 
				
			||||||
 | 
					    cv2.imwrite(filename, uv[..., ::-1])
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def read_gen(file_name, pil=False):
 | 
				
			||||||
 | 
					    ext = splitext(file_name)[-1]
 | 
				
			||||||
 | 
					    if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
 | 
				
			||||||
 | 
					        return Image.open(file_name)
 | 
				
			||||||
 | 
					    elif ext == '.bin' or ext == '.raw':
 | 
				
			||||||
 | 
					        return np.load(file_name)
 | 
				
			||||||
 | 
					    elif ext == '.flo':
 | 
				
			||||||
 | 
					        return readFlow(file_name).astype(np.float32)
 | 
				
			||||||
 | 
					    elif ext == '.pfm':
 | 
				
			||||||
 | 
					        flow = readPFM(file_name).astype(np.float32)
 | 
				
			||||||
 | 
					        if len(flow.shape) == 2:
 | 
				
			||||||
 | 
					            return flow
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return flow[:, :, :-1]
 | 
				
			||||||
 | 
					    return []
 | 
				
			||||||
							
								
								
									
										82
									
								
								preprocessing/raft/utils/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										82
									
								
								preprocessing/raft/utils/utils.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,82 @@
 | 
				
			|||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.nn.functional as F
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					from scipy import interpolate
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class InputPadder:
 | 
				
			||||||
 | 
					    """ Pads images such that dimensions are divisible by 8 """
 | 
				
			||||||
 | 
					    def __init__(self, dims, mode='sintel'):
 | 
				
			||||||
 | 
					        self.ht, self.wd = dims[-2:]
 | 
				
			||||||
 | 
					        pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
 | 
				
			||||||
 | 
					        pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
 | 
				
			||||||
 | 
					        if mode == 'sintel':
 | 
				
			||||||
 | 
					            self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def pad(self, *inputs):
 | 
				
			||||||
 | 
					        return [F.pad(x, self._pad, mode='replicate') for x in inputs]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def unpad(self,x):
 | 
				
			||||||
 | 
					        ht, wd = x.shape[-2:]
 | 
				
			||||||
 | 
					        c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
 | 
				
			||||||
 | 
					        return x[..., c[0]:c[1], c[2]:c[3]]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def forward_interpolate(flow):
 | 
				
			||||||
 | 
					    flow = flow.detach().cpu().numpy()
 | 
				
			||||||
 | 
					    dx, dy = flow[0], flow[1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ht, wd = dx.shape
 | 
				
			||||||
 | 
					    x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht), indexing='ij')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    x1 = x0 + dx
 | 
				
			||||||
 | 
					    y1 = y0 + dy
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    x1 = x1.reshape(-1)
 | 
				
			||||||
 | 
					    y1 = y1.reshape(-1)
 | 
				
			||||||
 | 
					    dx = dx.reshape(-1)
 | 
				
			||||||
 | 
					    dy = dy.reshape(-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
 | 
				
			||||||
 | 
					    x1 = x1[valid]
 | 
				
			||||||
 | 
					    y1 = y1[valid]
 | 
				
			||||||
 | 
					    dx = dx[valid]
 | 
				
			||||||
 | 
					    dy = dy[valid]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    flow_x = interpolate.griddata(
 | 
				
			||||||
 | 
					        (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    flow_y = interpolate.griddata(
 | 
				
			||||||
 | 
					        (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    flow = np.stack([flow_x, flow_y], axis=0)
 | 
				
			||||||
 | 
					    return torch.from_numpy(flow).float()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def bilinear_sampler(img, coords, mode='bilinear', mask=False):
 | 
				
			||||||
 | 
					    """ Wrapper for grid_sample, uses pixel coordinates """
 | 
				
			||||||
 | 
					    H, W = img.shape[-2:]
 | 
				
			||||||
 | 
					    xgrid, ygrid = coords.split([1,1], dim=-1)
 | 
				
			||||||
 | 
					    xgrid = 2*xgrid/(W-1) - 1
 | 
				
			||||||
 | 
					    ygrid = 2*ygrid/(H-1) - 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    grid = torch.cat([xgrid, ygrid], dim=-1)
 | 
				
			||||||
 | 
					    img = F.grid_sample(img, grid, align_corners=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if mask:
 | 
				
			||||||
 | 
					        mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
 | 
				
			||||||
 | 
					        return img, mask.float()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return img
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def coords_grid(batch, ht, wd):
 | 
				
			||||||
 | 
					    coords = torch.meshgrid(torch.arange(ht), torch.arange(wd), indexing='ij')
 | 
				
			||||||
 | 
					    coords = torch.stack(coords[::-1], dim=0).float()
 | 
				
			||||||
 | 
					    return coords[None].repeat(batch, 1, 1, 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def upflow8(flow, mode='bilinear'):
 | 
				
			||||||
 | 
					    new_size = (8 * flow.shape[2], 8 * flow.shape[3])
 | 
				
			||||||
 | 
					    return  8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
 | 
				
			||||||
@ -417,7 +417,7 @@ class WanAttentionBlock(nn.Module):
 | 
				
			|||||||
        freqs,
 | 
					        freqs,
 | 
				
			||||||
        context,
 | 
					        context,
 | 
				
			||||||
        hints= None, 
 | 
					        hints= None, 
 | 
				
			||||||
        context_scale=1.0,
 | 
					        context_scale=[1.0],
 | 
				
			||||||
        cam_emb= None,
 | 
					        cam_emb= None,
 | 
				
			||||||
        block_mask = None,
 | 
					        block_mask = None,
 | 
				
			||||||
        audio_proj= None,
 | 
					        audio_proj= None,
 | 
				
			||||||
@ -431,7 +431,7 @@ class WanAttentionBlock(nn.Module):
 | 
				
			|||||||
            grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
 | 
					            grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
 | 
				
			||||||
            freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
 | 
					            freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        hint = None
 | 
					        hints_processed = None
 | 
				
			||||||
        attention_dtype =  self.self_attn.q.weight.dtype 
 | 
					        attention_dtype =  self.self_attn.q.weight.dtype 
 | 
				
			||||||
        dtype = x.dtype
 | 
					        dtype = x.dtype
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -442,10 +442,13 @@ class WanAttentionBlock(nn.Module):
 | 
				
			|||||||
                "context" : context,
 | 
					                "context" : context,
 | 
				
			||||||
                "e" : e,
 | 
					                "e" : e,
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
            if self.block_id == 0:
 | 
					            hints_processed= []
 | 
				
			||||||
                hint = self.vace(hints, x, **kwargs)
 | 
					            for scale, hint in zip(context_scale, hints):
 | 
				
			||||||
            else:
 | 
					                if scale == 0:
 | 
				
			||||||
                hint = self.vace(hints, None, **kwargs)
 | 
					                    hints_processed.append(None)
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    hints_processed.append(self.vace(hint, x, **kwargs) if self.block_id == 0 else self.vace(hint, None, **kwargs))
 | 
				
			||||||
 | 
					                     
 | 
				
			||||||
        latent_frames = e.shape[0]
 | 
					        latent_frames = e.shape[0]
 | 
				
			||||||
        e = (self.modulation + e).chunk(6, dim=1)
 | 
					        e = (self.modulation + e).chunk(6, dim=1)
 | 
				
			||||||
        # self-attention
 | 
					        # self-attention
 | 
				
			||||||
@ -506,11 +509,13 @@ class WanAttentionBlock(nn.Module):
 | 
				
			|||||||
        x.addcmul_(y, e[5])
 | 
					        x.addcmul_(y, e[5])
 | 
				
			||||||
        x, y = reshape_latent(x , 1), reshape_latent(y , 1)
 | 
					        x, y = reshape_latent(x , 1), reshape_latent(y , 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if hint is not None:
 | 
					        if hints_processed is not None:
 | 
				
			||||||
            if context_scale == 1:
 | 
					            for hint, scale in zip(hints_processed, context_scale):
 | 
				
			||||||
                x.add_(hint)
 | 
					                if scale != 0:
 | 
				
			||||||
            else:
 | 
					                    if scale == 1:
 | 
				
			||||||
                x.add_(hint, alpha= context_scale)
 | 
					                        x.add_(hint)
 | 
				
			||||||
 | 
					                    else:
 | 
				
			||||||
 | 
					                        x.add_(hint, alpha= scale)
 | 
				
			||||||
        return x 
 | 
					        return x 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -605,8 +610,8 @@ class MLPProj(torch.nn.Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
class WanModel(ModelMixin, ConfigMixin):
 | 
					class WanModel(ModelMixin, ConfigMixin):
 | 
				
			||||||
    def setup_chipmunk(self):
 | 
					    def setup_chipmunk(self):
 | 
				
			||||||
        from chipmunk.util import LayerCounter
 | 
					        # from chipmunk.util import LayerCounter
 | 
				
			||||||
        from chipmunk.modules import SparseDiffMlp, SparseDiffAttn
 | 
					        # from chipmunk.modules import SparseDiffMlp, SparseDiffAttn
 | 
				
			||||||
        seq_shape = (21, 45, 80)
 | 
					        seq_shape = (21, 45, 80)
 | 
				
			||||||
        chipmunk_layers =[]
 | 
					        chipmunk_layers =[]
 | 
				
			||||||
        for i in range(self.num_layers):
 | 
					        for i in range(self.num_layers):
 | 
				
			||||||
@ -941,7 +946,7 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
				
			|||||||
        t,
 | 
					        t,
 | 
				
			||||||
        context,
 | 
					        context,
 | 
				
			||||||
        vace_context = None,
 | 
					        vace_context = None,
 | 
				
			||||||
        vace_context_scale=1.0,        
 | 
					        vace_context_scale=[1.0],        
 | 
				
			||||||
        clip_fea=None,
 | 
					        clip_fea=None,
 | 
				
			||||||
        y=None,
 | 
					        y=None,
 | 
				
			||||||
        freqs = None,
 | 
					        freqs = None,
 | 
				
			||||||
@ -972,7 +977,7 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        chipmunk = offload.shared_state.get("_chipmunk", False) 
 | 
					        chipmunk = offload.shared_state.get("_chipmunk", False) 
 | 
				
			||||||
        if chipmunk:
 | 
					        if chipmunk:
 | 
				
			||||||
            from chipmunk.ops.voxel import voxel_chunk_no_padding, reverse_voxel_chunk_no_padding
 | 
					            # from chipmunk.ops.voxel import voxel_chunk_no_padding, reverse_voxel_chunk_no_padding
 | 
				
			||||||
            voxel_shape = (4, 6, 8)
 | 
					            voxel_shape = (4, 6, 8)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        x_list = x
 | 
					        x_list = x
 | 
				
			||||||
@ -1065,10 +1070,10 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
				
			|||||||
            # Vace embeddings
 | 
					            # Vace embeddings
 | 
				
			||||||
            c = [self.vace_patch_embedding(u.to(self.vace_patch_embedding.weight.dtype).unsqueeze(0)) for u in vace_context]
 | 
					            c = [self.vace_patch_embedding(u.to(self.vace_patch_embedding.weight.dtype).unsqueeze(0)) for u in vace_context]
 | 
				
			||||||
            c = [u.flatten(2).transpose(1, 2) for u in c]
 | 
					            c = [u.flatten(2).transpose(1, 2) for u in c]
 | 
				
			||||||
            c = c[0]
 | 
					            # c = c[0]
 | 
				
			||||||
 
 | 
					            c = [ [sub_c] for sub_c in c]
 | 
				
			||||||
            kwargs['context_scale'] = vace_context_scale
 | 
					            kwargs['context_scale'] = vace_context_scale
 | 
				
			||||||
            hints_list = [ [c] for _ in range(len(x_list)) ] 
 | 
					            hints_list = [ c ]* len(x_list) 
 | 
				
			||||||
            del c
 | 
					            del c
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        should_calc = True
 | 
					        should_calc = True
 | 
				
			||||||
 | 
				
			|||||||
@ -132,7 +132,7 @@ class WanT2V:
 | 
				
			|||||||
            inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
 | 
					            inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
 | 
				
			||||||
            reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
 | 
					            reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
 | 
				
			||||||
            inactive = self.vae.encode(inactive, tile_size = tile_size)
 | 
					            inactive = self.vae.encode(inactive, tile_size = tile_size)
 | 
				
			||||||
            self.toto = inactive[0].clone() 
 | 
					
 | 
				
			||||||
            if overlapped_latents  != None  : 
 | 
					            if overlapped_latents  != None  : 
 | 
				
			||||||
                # inactive[0][:, 0:1] = self.vae.encode([frames[0][:, 0:1]], tile_size = tile_size)[0] # redundant
 | 
					                # inactive[0][:, 0:1] = self.vae.encode([frames[0][:, 0:1]], tile_size = tile_size)[0] # redundant
 | 
				
			||||||
                inactive[0][:, 1:overlapped_latents.shape[1] + 1] = overlapped_latents
 | 
					                inactive[0][:, 1:overlapped_latents.shape[1] + 1] = overlapped_latents
 | 
				
			||||||
@ -220,7 +220,7 @@ class WanT2V:
 | 
				
			|||||||
            ref_img = white_canvas
 | 
					            ref_img = white_canvas
 | 
				
			||||||
        return ref_img.to(device)
 | 
					        return ref_img.to(device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size,  device, original_video = False, keep_frames= [], start_frame = 0,  fit_into_canvas = None, pre_src_video = None, inject_frames = [], outpainting_dims = None):
 | 
					    def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size,  device, keep_frames= [], start_frame = 0,  fit_into_canvas = None, pre_src_video = None, inject_frames = [], outpainting_dims = None):
 | 
				
			||||||
        image_sizes = []
 | 
					        image_sizes = []
 | 
				
			||||||
        trim_video = len(keep_frames)
 | 
					        trim_video = len(keep_frames)
 | 
				
			||||||
        def conv_tensor(t, device):
 | 
					        def conv_tensor(t, device):
 | 
				
			||||||
@ -254,7 +254,7 @@ class WanT2V:
 | 
				
			|||||||
                image_sizes.append(image_size)
 | 
					                image_sizes.append(image_size)
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                src_video[i] = conv_tensor(sub_src_video[:num_frames], device)
 | 
					                src_video[i] = conv_tensor(sub_src_video[:num_frames], device)
 | 
				
			||||||
                src_mask[i] = torch.zeros_like(src_video[i], device=device) if original_video else torch.ones_like(src_video[i], device=device)
 | 
					                src_mask[i] = torch.ones_like(src_video[i], device=device)
 | 
				
			||||||
                if prepend_count > 0:
 | 
					                if prepend_count > 0:
 | 
				
			||||||
                    src_video[i] =  torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
 | 
					                    src_video[i] =  torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
 | 
				
			||||||
                    src_mask[i] =  torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1)
 | 
					                    src_mask[i] =  torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1)
 | 
				
			||||||
@ -278,15 +278,15 @@ class WanT2V:
 | 
				
			|||||||
            if ref_images is not None:
 | 
					            if ref_images is not None:
 | 
				
			||||||
                image_size = image_sizes[i]
 | 
					                image_size = image_sizes[i]
 | 
				
			||||||
                for j, ref_img in enumerate(ref_images):
 | 
					                for j, ref_img in enumerate(ref_images):
 | 
				
			||||||
                    if ref_img is not None:
 | 
					                    if ref_img is not None and not torch.is_tensor(ref_img):
 | 
				
			||||||
                        src_ref_images[i][j] = self.fit_image_into_canvas(ref_img, image_size, 1, device)
 | 
					                        src_ref_images[i][j] = self.fit_image_into_canvas(ref_img, image_size, 1, device)
 | 
				
			||||||
        return src_video, src_mask, src_ref_images
 | 
					        return src_video, src_mask, src_ref_images
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def decode_latent(self, zs, ref_images=None, tile_size= 0 ):
 | 
					    def decode_latent(self, zs, ref_images=None, tile_size= 0 ):
 | 
				
			||||||
        if ref_images is None:
 | 
					        if ref_images is None:
 | 
				
			||||||
            ref_images = [None] * len(zs)
 | 
					            ref_images = [None] * len(zs)
 | 
				
			||||||
        else:
 | 
					        # else:
 | 
				
			||||||
            assert len(zs) == len(ref_images)
 | 
					        #     assert len(zs) == len(ref_images)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        trimed_zs = []
 | 
					        trimed_zs = []
 | 
				
			||||||
        for z, refs in zip(zs, ref_images):
 | 
					        for z, refs in zip(zs, ref_images):
 | 
				
			||||||
@ -312,7 +312,7 @@ class WanT2V:
 | 
				
			|||||||
                input_ref_images = None,      
 | 
					                input_ref_images = None,      
 | 
				
			||||||
                input_video=None,
 | 
					                input_video=None,
 | 
				
			||||||
                target_camera=None,                  
 | 
					                target_camera=None,                  
 | 
				
			||||||
                context_scale=1.0,
 | 
					                context_scale=None,
 | 
				
			||||||
                width = 1280,
 | 
					                width = 1280,
 | 
				
			||||||
                height = 720,
 | 
					                height = 720,
 | 
				
			||||||
                fit_into_canvas = True,
 | 
					                fit_into_canvas = True,
 | 
				
			||||||
@ -477,7 +477,8 @@ class WanT2V:
 | 
				
			|||||||
            kwargs.update({'cam_emb': cam_emb})
 | 
					            kwargs.update({'cam_emb': cam_emb})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if vace:
 | 
					        if vace:
 | 
				
			||||||
            ref_images_count = len(input_ref_images[0]) if input_ref_images != None and input_ref_images[0] != None else 0 
 | 
					            ref_images_count = len(input_ref_images[0]) if input_ref_images != None and input_ref_images[0] != None else 0
 | 
				
			||||||
 | 
					            context_scale = context_scale if context_scale != None else [1.0] * len(z)
 | 
				
			||||||
            kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale})
 | 
					            kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale})
 | 
				
			||||||
            if overlapped_latents != None :
 | 
					            if overlapped_latents != None :
 | 
				
			||||||
                overlapped_latents_size = overlapped_latents.shape[1] + 1
 | 
					                overlapped_latents_size = overlapped_latents.shape[1] + 1
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										279
									
								
								wgp.py
									
									
									
									
									
								
							
							
						
						
									
										279
									
								
								wgp.py
									
									
									
									
									
								
							@ -45,7 +45,7 @@ AUTOSAVE_FILENAME = "queue.zip"
 | 
				
			|||||||
PROMPT_VARS_MAX = 10
 | 
					PROMPT_VARS_MAX = 10
 | 
				
			||||||
 | 
					
 | 
				
			||||||
target_mmgp_version = "3.4.9"
 | 
					target_mmgp_version = "3.4.9"
 | 
				
			||||||
WanGP_version = "6.21"
 | 
					WanGP_version = "6.3"
 | 
				
			||||||
settings_version = 2
 | 
					settings_version = 2
 | 
				
			||||||
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
 | 
					prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -429,8 +429,8 @@ def process_prompt_and_add_tasks(state, model_choice):
 | 
				
			|||||||
    return update_queue_data(queue)
 | 
					    return update_queue_data(queue)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_preview_images(inputs):
 | 
					def get_preview_images(inputs):
 | 
				
			||||||
    inputs_to_query = ["image_start", "image_end", "video_guide", "image_refs","video_mask", "video_source"]
 | 
					    inputs_to_query = ["image_start", "image_end", "video_source", "video_guide", "video_mask", "image_refs" ]
 | 
				
			||||||
    labels = ["Start Image", "End Image", "Video Guide", "Image Reference","Video Mask", "Video Source"]
 | 
					    labels = ["Start Image", "End Image", "Video Source", "Video Guide", "Video Mask", "Image Reference"]
 | 
				
			||||||
    start_image_data = None
 | 
					    start_image_data = None
 | 
				
			||||||
    start_image_labels = []
 | 
					    start_image_labels = []
 | 
				
			||||||
    end_image_data = None
 | 
					    end_image_data = None
 | 
				
			||||||
@ -1768,7 +1768,7 @@ def fix_settings(model_type, ui_defaults):
 | 
				
			|||||||
        return
 | 
					        return
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    video_prompt_type = ui_defaults.get("video_prompt_type", "")
 | 
					    video_prompt_type = ui_defaults.get("video_prompt_type", "")
 | 
				
			||||||
    if model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar", "phantom", "phantom_1.3B"]:
 | 
					    if model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar", "phantom_14B", "phantom_1.3B"]:
 | 
				
			||||||
        if not "I" in video_prompt_type:  # workaround for settings corruption
 | 
					        if not "I" in video_prompt_type:  # workaround for settings corruption
 | 
				
			||||||
            video_prompt_type += "I" 
 | 
					            video_prompt_type += "I" 
 | 
				
			||||||
    if model_type in ["hunyuan"]:
 | 
					    if model_type in ["hunyuan"]:
 | 
				
			||||||
@ -2081,8 +2081,8 @@ def download_models(model_filename, model_type):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    shared_def = {
 | 
					    shared_def = {
 | 
				
			||||||
        "repoId" : "DeepBeepMeep/Wan2.1",
 | 
					        "repoId" : "DeepBeepMeep/Wan2.1",
 | 
				
			||||||
        "sourceFolderList" : [ "pose", "scribble", "depth", "mask", "wav2vec", ""  ],
 | 
					        "sourceFolderList" : [ "pose", "scribble", "flow", "depth", "mask", "wav2vec", ""  ],
 | 
				
			||||||
        "fileList" : [ ["dw-ll_ucoco_384.onnx", "yolox_l.onnx"],["netG_A_latest.pth"], ["depth_anything_v2_vitl.pth"], ["sam_vit_h_4b8939_fp16.safetensors"], ["config.json", "feature_extractor_config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"],
 | 
					        "fileList" : [ ["dw-ll_ucoco_384.onnx", "yolox_l.onnx"],["netG_A_latest.pth"], ["raft-things.pth"], ["depth_anything_v2_vitl.pth","depth_anything_v2_vitb.pth"], ["sam_vit_h_4b8939_fp16.safetensors"], ["config.json", "feature_extractor_config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"],
 | 
				
			||||||
                [ "flownet.pkl"  ] ]
 | 
					                [ "flownet.pkl"  ] ]
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    process_files_def(**shared_def)
 | 
					    process_files_def(**shared_def)
 | 
				
			||||||
@ -2512,6 +2512,7 @@ def apply_changes(  state,
 | 
				
			|||||||
                    enhancer_enabled_choice = 0,
 | 
					                    enhancer_enabled_choice = 0,
 | 
				
			||||||
                    fit_canvas_choice = 0,
 | 
					                    fit_canvas_choice = 0,
 | 
				
			||||||
                    preload_in_VRAM_choice = 0,
 | 
					                    preload_in_VRAM_choice = 0,
 | 
				
			||||||
 | 
					                    depth_anything_v2_variant_choice = "vitl",
 | 
				
			||||||
                    notification_sound_enabled_choice = 1,
 | 
					                    notification_sound_enabled_choice = 1,
 | 
				
			||||||
                    notification_sound_volume_choice = 50
 | 
					                    notification_sound_volume_choice = 50
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
@ -2540,6 +2541,7 @@ def apply_changes(  state,
 | 
				
			|||||||
        "fit_canvas": fit_canvas_choice,
 | 
					        "fit_canvas": fit_canvas_choice,
 | 
				
			||||||
        "enhancer_enabled" : enhancer_enabled_choice,
 | 
					        "enhancer_enabled" : enhancer_enabled_choice,
 | 
				
			||||||
        "preload_in_VRAM" : preload_in_VRAM_choice,
 | 
					        "preload_in_VRAM" : preload_in_VRAM_choice,
 | 
				
			||||||
 | 
					        "depth_anything_v2_variant": depth_anything_v2_variant_choice,
 | 
				
			||||||
        "notification_sound_enabled" : notification_sound_enabled_choice,
 | 
					        "notification_sound_enabled" : notification_sound_enabled_choice,
 | 
				
			||||||
        "notification_sound_volume" : notification_sound_volume_choice
 | 
					        "notification_sound_volume" : notification_sound_volume_choice
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -2577,7 +2579,7 @@ def apply_changes(  state,
 | 
				
			|||||||
    transformer_types = server_config["transformer_types"]
 | 
					    transformer_types = server_config["transformer_types"]
 | 
				
			||||||
    model_filename = get_model_filename(transformer_type, transformer_quantization, transformer_dtype_policy)
 | 
					    model_filename = get_model_filename(transformer_type, transformer_quantization, transformer_dtype_policy)
 | 
				
			||||||
    state["model_filename"] = model_filename
 | 
					    state["model_filename"] = model_filename
 | 
				
			||||||
    if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list", "fit_canvas", "notification_sound_enabled", "notification_sound_volume"] for change in changes ):
 | 
					    if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list", "fit_canvas", "depth_anything_v2_variant", "notification_sound_enabled", "notification_sound_volume"] for change in changes ):
 | 
				
			||||||
        model_choice = gr.Dropdown()
 | 
					        model_choice = gr.Dropdown()
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        reload_needed = True
 | 
					        reload_needed = True
 | 
				
			||||||
@ -2818,7 +2820,7 @@ def get_preprocessor(process_type, inpaint_color):
 | 
				
			|||||||
            "POSE_MODEL": "ckpts/pose/dw-ll_ucoco_384.onnx",
 | 
					            "POSE_MODEL": "ckpts/pose/dw-ll_ucoco_384.onnx",
 | 
				
			||||||
            "RESIZE_SIZE": 1024
 | 
					            "RESIZE_SIZE": 1024
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        anno_ins = lambda img: PoseBodyFaceVideoAnnotator(cfg_dict).forward(img)[0]
 | 
					        anno_ins = lambda img: PoseBodyFaceVideoAnnotator(cfg_dict).forward(img)
 | 
				
			||||||
    elif process_type=="depth":
 | 
					    elif process_type=="depth":
 | 
				
			||||||
        # from preprocessing.midas.depth import DepthVideoAnnotator
 | 
					        # from preprocessing.midas.depth import DepthVideoAnnotator
 | 
				
			||||||
        # cfg_dict = {
 | 
					        # cfg_dict = {
 | 
				
			||||||
@ -2827,30 +2829,72 @@ def get_preprocessor(process_type, inpaint_color):
 | 
				
			|||||||
        # anno_ins = lambda img: DepthVideoAnnotator(cfg_dict).forward(img)[0]
 | 
					        # anno_ins = lambda img: DepthVideoAnnotator(cfg_dict).forward(img)[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        from preprocessing.depth_anything_v2.depth import DepthV2VideoAnnotator
 | 
					        from preprocessing.depth_anything_v2.depth import DepthV2VideoAnnotator
 | 
				
			||||||
        cfg_dict = {
 | 
					
 | 
				
			||||||
            "PRETRAINED_MODEL": "ckpts/depth/depth_anything_v2_vitl.pth"
 | 
					        if server_config.get("depth_anything_v2_variant", "vitl") == "vitl":
 | 
				
			||||||
            # "PRETRAINED_MODEL": "ckpts/depth/depth_anything_vitb14.pth"
 | 
					            cfg_dict = {
 | 
				
			||||||
            
 | 
					                "PRETRAINED_MODEL": "ckpts/depth/depth_anything_v2_vitl.pth",
 | 
				
			||||||
        }
 | 
					                'MODEL_VARIANT': 'vitl'
 | 
				
			||||||
        anno_ins = lambda img: DepthV2VideoAnnotator(cfg_dict).forward(img)[0]
 | 
					            }
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            cfg_dict = {
 | 
				
			||||||
 | 
					                "PRETRAINED_MODEL": "ckpts/depth/depth_anything_v2_vitb.pth",
 | 
				
			||||||
 | 
					                'MODEL_VARIANT': 'vitb',
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        anno_ins = lambda img: DepthV2VideoAnnotator(cfg_dict).forward(img)
 | 
				
			||||||
    elif process_type=="gray":
 | 
					    elif process_type=="gray":
 | 
				
			||||||
        from preprocessing.gray import GrayVideoAnnotator
 | 
					        from preprocessing.gray import GrayVideoAnnotator
 | 
				
			||||||
        cfg_dict = {}
 | 
					        cfg_dict = {}
 | 
				
			||||||
        anno_ins = lambda img: GrayVideoAnnotator(cfg_dict).forward(img)[0]
 | 
					        anno_ins = lambda img: GrayVideoAnnotator(cfg_dict).forward(img)
 | 
				
			||||||
    elif process_type=="scribble":
 | 
					    elif process_type=="scribble":
 | 
				
			||||||
        from preprocessing.scribble import ScribbleVideoAnnotator
 | 
					        from preprocessing.scribble import ScribbleVideoAnnotator
 | 
				
			||||||
        cfg_dict = {
 | 
					        cfg_dict = {
 | 
				
			||||||
                "PRETRAINED_MODEL": "ckpts/scribble/netG_A_latest.pth"
 | 
					                "PRETRAINED_MODEL": "ckpts/scribble/netG_A_latest.pth"
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        anno_ins = lambda img: ScribbleVideoAnnotator(cfg_dict).forward(img)[0]
 | 
					        anno_ins = lambda img: ScribbleVideoAnnotator(cfg_dict).forward(img)
 | 
				
			||||||
 | 
					    elif process_type=="flow":
 | 
				
			||||||
 | 
					        from preprocessing.flow import FlowVisAnnotator
 | 
				
			||||||
 | 
					        cfg_dict = {
 | 
				
			||||||
 | 
					                "PRETRAINED_MODEL": "ckpts/flow/raft-things.pth"
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        anno_ins = lambda img: FlowVisAnnotator(cfg_dict).forward(img)
 | 
				
			||||||
    elif process_type=="inpaint":
 | 
					    elif process_type=="inpaint":
 | 
				
			||||||
        anno_ins = lambda img : inpaint_color
 | 
					        anno_ins = lambda img :  len(img) * [inpaint_color]
 | 
				
			||||||
        # anno_ins = lambda img : np.full_like(img, inpaint_color)
 | 
					    elif process_type == None or process_type in ["vace", "identity"]:
 | 
				
			||||||
 | 
					        anno_ins = lambda img : img
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        anno_ins = lambda img : img[0]
 | 
					        raise Exception(f"process type '{process_type}' non supported")
 | 
				
			||||||
    return anno_ins
 | 
					    return anno_ins
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,  max_frames, start_frame=0, fit_canvas = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None):
 | 
					
 | 
				
			||||||
 | 
					def process_images_multithread(image_processor, items, process_type, wrap_in_list = True, max_workers: int = os.cpu_count()/ 2) :
 | 
				
			||||||
 | 
					    if not items:
 | 
				
			||||||
 | 
					       return []    
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    import concurrent.futures
 | 
				
			||||||
 | 
					    start_time = time.time()
 | 
				
			||||||
 | 
					    print(f"Preprocessus:{process_type} started")
 | 
				
			||||||
 | 
					    if process_type in ["prephase", "upsample"]: 
 | 
				
			||||||
 | 
					        if wrap_in_list :
 | 
				
			||||||
 | 
					            items = [ [img] for img in items]
 | 
				
			||||||
 | 
					        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
 | 
				
			||||||
 | 
					            futures = {executor.submit(image_processor, img): idx for idx, img in enumerate(items)}
 | 
				
			||||||
 | 
					            results = [None] * len(items)
 | 
				
			||||||
 | 
					            for future in concurrent.futures.as_completed(futures):
 | 
				
			||||||
 | 
					                idx = futures[future]
 | 
				
			||||||
 | 
					                results[idx] = future.result()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if wrap_in_list: 
 | 
				
			||||||
 | 
					            results = [ img[0] for img in results]
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        results=  image_processor(items) 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    end_time = time.time()
 | 
				
			||||||
 | 
					    print(f"duration:{end_time-start_time:.1f}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return results  
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,  max_frames, start_frame=0, fit_canvas = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", process_type2 = None, to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None, proc_no = 1):
 | 
				
			||||||
    from wan.utils.utils import calculate_new_dimensions, get_outpainting_frame_location, get_outpainting_full_area_dimensions
 | 
					    from wan.utils.utils import calculate_new_dimensions, get_outpainting_frame_location, get_outpainting_full_area_dimensions
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def mask_to_xyxy_box(mask):
 | 
					    def mask_to_xyxy_box(mask):
 | 
				
			||||||
@ -2871,14 +2915,21 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
 | 
				
			|||||||
        return None, None
 | 
					        return None, None
 | 
				
			||||||
    any_mask = input_mask_path != None
 | 
					    any_mask = input_mask_path != None
 | 
				
			||||||
    pose_special = "pose" in process_type
 | 
					    pose_special = "pose" in process_type
 | 
				
			||||||
    if process_type == "pose_depth": process_type = "pose"
 | 
					 | 
				
			||||||
    any_identity_mask = False
 | 
					    any_identity_mask = False
 | 
				
			||||||
    if process_type == "identity":
 | 
					    if process_type == "identity":
 | 
				
			||||||
        any_identity_mask = True
 | 
					        any_identity_mask = True
 | 
				
			||||||
        negate_mask = False
 | 
					        negate_mask = False
 | 
				
			||||||
        process_outside_mask = None
 | 
					        process_outside_mask = None
 | 
				
			||||||
    preproc = get_preprocessor(process_type, inpaint_color)
 | 
					    preproc = get_preprocessor(process_type, inpaint_color)
 | 
				
			||||||
    preproc2 = get_preprocessor(process_outside_mask, inpaint_color) if process_type != process_outside_mask else preproc
 | 
					    preproc2 = None
 | 
				
			||||||
 | 
					    if process_type2 != None:
 | 
				
			||||||
 | 
					        preproc2 = get_preprocessor(process_type2, inpaint_color) if process_type != process_type2 else preproc
 | 
				
			||||||
 | 
					    if process_outside_mask == process_type :
 | 
				
			||||||
 | 
					        preproc_outside = preproc
 | 
				
			||||||
 | 
					    elif preproc2 != None and process_outside_mask == process_type2 :
 | 
				
			||||||
 | 
					        preproc_outside = preproc2
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        preproc_outside = get_preprocessor(process_outside_mask, inpaint_color)
 | 
				
			||||||
    video = get_resampled_video(input_video_path, start_frame, max_frames, target_fps)
 | 
					    video = get_resampled_video(input_video_path, start_frame, max_frames, target_fps)
 | 
				
			||||||
    if any_mask:
 | 
					    if any_mask:
 | 
				
			||||||
        mask_video = get_resampled_video(input_mask_path, start_frame, max_frames, target_fps)
 | 
					        mask_video = get_resampled_video(input_mask_path, start_frame, max_frames, target_fps)
 | 
				
			||||||
@ -2905,12 +2956,16 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
 | 
				
			|||||||
        num_frames = min(len(video), len(mask_video))
 | 
					        num_frames = min(len(video), len(mask_video))
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        num_frames = len(video)
 | 
					        num_frames = len(video)
 | 
				
			||||||
    masked_frames = []
 | 
					
 | 
				
			||||||
    masks = []
 | 
					 | 
				
			||||||
    if any_identity_mask:
 | 
					    if any_identity_mask:
 | 
				
			||||||
        any_mask = True
 | 
					        any_mask = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for frame_idx in range(num_frames):
 | 
					    proc_list =[]
 | 
				
			||||||
 | 
					    proc_list_outside =[]
 | 
				
			||||||
 | 
					    proc_mask = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # for frame_idx in range(num_frames):
 | 
				
			||||||
 | 
					    def prep_prephase(frame_idx):
 | 
				
			||||||
        frame = Image.fromarray(video[frame_idx].cpu().numpy()) #.asnumpy()
 | 
					        frame = Image.fromarray(video[frame_idx].cpu().numpy()) #.asnumpy()
 | 
				
			||||||
        frame = frame.resize((width, height), resample=Image.Resampling.LANCZOS) 
 | 
					        frame = frame.resize((width, height), resample=Image.Resampling.LANCZOS) 
 | 
				
			||||||
        frame = np.array(frame) 
 | 
					        frame = np.array(frame) 
 | 
				
			||||||
@ -2946,17 +3001,33 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
 | 
				
			|||||||
        else:
 | 
					        else:
 | 
				
			||||||
            target_frame = frame 
 | 
					            target_frame = frame 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        processed_img = preproc([target_frame])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if any_mask:
 | 
					        if any_mask:
 | 
				
			||||||
            if preproc2 != None:
 | 
					            return (target_frame, frame, mask) 
 | 
				
			||||||
                frame = preproc2([frame])
 | 
					 | 
				
			||||||
            masked_frame = np.where(mask[..., None], processed_img, frame)
 | 
					 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            masked_frame = processed_img
 | 
					            return (target_frame, None, None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            
 | 
					    proc_lists = process_images_multithread(prep_prephase, [frame_idx for frame_idx in range(num_frames)], "prephase", wrap_in_list= False)
 | 
				
			||||||
 | 
					    proc_list, proc_list_outside, proc_mask = [None] * len(proc_lists), [None] * len(proc_lists), [None] * len(proc_lists)
 | 
				
			||||||
 | 
					    for frame_idx, frame_group in enumerate(proc_lists): 
 | 
				
			||||||
 | 
					        proc_list[frame_idx], proc_list_outside[frame_idx], proc_mask[frame_idx] = frame_group
 | 
				
			||||||
 | 
					    prep_prephase = None
 | 
				
			||||||
 | 
					    video = None
 | 
				
			||||||
 | 
					    mask_video = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if preproc2 != None:
 | 
				
			||||||
 | 
					        proc_list2 = process_images_multithread(preproc2, proc_list, process_type2)
 | 
				
			||||||
 | 
					        #### to be finished ...or not
 | 
				
			||||||
 | 
					    proc_list = process_images_multithread(preproc, proc_list, process_type)
 | 
				
			||||||
 | 
					    if any_mask:
 | 
				
			||||||
 | 
					        proc_list_outside = process_images_multithread(preproc_outside, proc_list_outside, process_outside_mask)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        proc_list_outside = proc_mask = len(proc_list) * [None]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    masked_frames = []
 | 
				
			||||||
 | 
					    masks = []
 | 
				
			||||||
 | 
					    for frame_no, (processed_img, processed_img_outside, mask) in enumerate(zip(proc_list, proc_list_outside, proc_mask)):
 | 
				
			||||||
        if any_mask :
 | 
					        if any_mask :
 | 
				
			||||||
 | 
					            masked_frame = np.where(mask[..., None], processed_img, processed_img_outside)
 | 
				
			||||||
            if process_outside_mask != None:
 | 
					            if process_outside_mask != None:
 | 
				
			||||||
                mask = np.full_like(mask, 255)
 | 
					                mask = np.full_like(mask, 255)
 | 
				
			||||||
            mask = torch.from_numpy(mask)
 | 
					            mask = torch.from_numpy(mask)
 | 
				
			||||||
@ -2967,6 +3038,8 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
 | 
				
			|||||||
                full_frame[margin_top:margin_top+height, margin_left:margin_left+width] = mask
 | 
					                full_frame[margin_top:margin_top+height, margin_left:margin_left+width] = mask
 | 
				
			||||||
                mask = full_frame 
 | 
					                mask = full_frame 
 | 
				
			||||||
            masks.append(mask)
 | 
					            masks.append(mask)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            masked_frame = processed_img
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if isinstance(masked_frame, int):
 | 
					        if isinstance(masked_frame, int):
 | 
				
			||||||
            masked_frame= np.full( (height, width, 3), inpaint_color, dtype= np.uint8)
 | 
					            masked_frame= np.full( (height, width, 3), inpaint_color, dtype= np.uint8)
 | 
				
			||||||
@ -2981,15 +3054,18 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
 | 
				
			|||||||
            masked_frame = full_frame 
 | 
					            masked_frame = full_frame 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        masked_frames.append(masked_frame)
 | 
					        masked_frames.append(masked_frame)
 | 
				
			||||||
 | 
					        proc_list[frame_no] = proc_list_outside[frame_no] = proc_mask[frame_no] = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if args.save_masks:
 | 
					    if args.save_masks:
 | 
				
			||||||
        from preprocessing.dwpose.pose import save_one_video
 | 
					        from preprocessing.dwpose.pose import save_one_video
 | 
				
			||||||
        saved_masked_frames = [mask.cpu().numpy() for mask in masked_frames ]
 | 
					        saved_masked_frames = [mask.cpu().numpy() for mask in masked_frames ]
 | 
				
			||||||
        save_one_video("masked_frames.mp4", saved_masked_frames, fps=target_fps, quality=8, macro_block_size=None)
 | 
					        save_one_video(f"masked_frames{'' if proc_no==1 else str(proc_no)}.mp4", saved_masked_frames, fps=target_fps, quality=8, macro_block_size=None)
 | 
				
			||||||
        if any_mask:
 | 
					        if any_mask:
 | 
				
			||||||
            saved_masks = [mask.cpu().numpy() for mask in masks ]
 | 
					            saved_masks = [mask.cpu().numpy() for mask in masks ]
 | 
				
			||||||
            save_one_video("masks.mp4", saved_masks, fps=target_fps, quality=8, macro_block_size=None)
 | 
					            save_one_video("masks.mp4", saved_masks, fps=target_fps, quality=8, macro_block_size=None)
 | 
				
			||||||
    preproc = None
 | 
					    preproc = None
 | 
				
			||||||
    preproc2 = None
 | 
					    preproc_outside = None
 | 
				
			||||||
    gc.collect()
 | 
					    gc.collect()
 | 
				
			||||||
    torch.cuda.empty_cache()
 | 
					    torch.cuda.empty_cache()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -3066,11 +3142,11 @@ def parse_keep_frames_video_guide(keep_frames, video_length):
 | 
				
			|||||||
            for i in range(start_range, end_range + 1):
 | 
					            for i in range(start_range, end_range + 1):
 | 
				
			||||||
                frames[i] = True
 | 
					                frames[i] = True
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            if not is_integer(section):
 | 
					            if not is_integer(section) or int(section) == 0:
 | 
				
			||||||
                error =f"Invalid integer {section}"
 | 
					                error =f"Invalid integer {section}"
 | 
				
			||||||
                break
 | 
					                break
 | 
				
			||||||
            index = absolute(int(section))
 | 
					            index = absolute(int(section))
 | 
				
			||||||
            frames[index] = True
 | 
					            frames[index-1] = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if len(error ) > 0:
 | 
					    if len(error ) > 0:
 | 
				
			||||||
        return [], error
 | 
					        return [], error
 | 
				
			||||||
@ -3113,6 +3189,8 @@ def generate_video(
 | 
				
			|||||||
    keep_frames_video_guide,
 | 
					    keep_frames_video_guide,
 | 
				
			||||||
    video_guide_outpainting,
 | 
					    video_guide_outpainting,
 | 
				
			||||||
    video_mask,
 | 
					    video_mask,
 | 
				
			||||||
 | 
					    control_net_weight,
 | 
				
			||||||
 | 
					    control_net_weight2,
 | 
				
			||||||
    mask_expand,
 | 
					    mask_expand,
 | 
				
			||||||
    audio_guide,
 | 
					    audio_guide,
 | 
				
			||||||
    sliding_window_size,
 | 
					    sliding_window_size,
 | 
				
			||||||
@ -3397,6 +3475,7 @@ def generate_video(
 | 
				
			|||||||
        frames_already_processed = None
 | 
					        frames_already_processed = None
 | 
				
			||||||
        pre_video_guide = None
 | 
					        pre_video_guide = None
 | 
				
			||||||
        overlapped_latents = None
 | 
					        overlapped_latents = None
 | 
				
			||||||
 | 
					        context_scale = None
 | 
				
			||||||
        window_no = 0
 | 
					        window_no = 0
 | 
				
			||||||
        extra_windows = 0
 | 
					        extra_windows = 0
 | 
				
			||||||
        guide_start_frame = 0
 | 
					        guide_start_frame = 0
 | 
				
			||||||
@ -3496,60 +3575,48 @@ def generate_video(
 | 
				
			|||||||
                if len(error) > 0:
 | 
					                if len(error) > 0:
 | 
				
			||||||
                    raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
 | 
					                    raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
 | 
				
			||||||
                keep_frames_parsed = keep_frames_parsed[guide_start_frame: guide_start_frame + current_video_length]
 | 
					                keep_frames_parsed = keep_frames_parsed[guide_start_frame: guide_start_frame + current_video_length]
 | 
				
			||||||
 | 
					                context_scale = [ control_net_weight]
 | 
				
			||||||
                if "V" in video_prompt_type:
 | 
					                if "V" in video_prompt_type:
 | 
				
			||||||
                    extra_label = ""
 | 
					                    process_map = { "Y" : "depth", "W": "scribble", "X": "inpaint", "Z": "flow"}
 | 
				
			||||||
                    if "X" in video_prompt_type:
 | 
					                    process_outside_mask = process_map.get(filter_letters(video_prompt_type, "YWX"), None)
 | 
				
			||||||
                        process_outside_mask = "inpaint"
 | 
					                    preprocess_type2 = preprocess_type = None 
 | 
				
			||||||
                    elif "Y" in video_prompt_type:
 | 
					                    process_map = { "D" : "depth", "P": "pose", "S": "scribble", "F": "flow", "C": "gray", "M": "inpaint", "U": "identity"}
 | 
				
			||||||
                        process_outside_mask = "depth"
 | 
					                    for process_num, process_letter in  enumerate( filter_letters(video_prompt_type, "PDSFCMU")):
 | 
				
			||||||
                        extra_label = " and Depth"
 | 
					                        if process_num == 0:
 | 
				
			||||||
                    elif "W" in video_prompt_type:
 | 
					                            preprocess_type = process_map.get(process_letter, "vace")
 | 
				
			||||||
                        process_outside_mask = "scribble"
 | 
					                        else:
 | 
				
			||||||
                        extra_label = " and Shapes"
 | 
					                            preprocess_type2 = process_map.get(process_letter, None)
 | 
				
			||||||
                    else:
 | 
					                    process_names = { "pose": "Open Pose", "depth": "Depth Mask", "scribble" : "Shapes", "flow" : "Flow Map", "gray" : "Gray Levels", "inpaint" : "Inpaint Mask", "U": "Identity Mask", "vace" : "Vace Data"}
 | 
				
			||||||
                        process_outside_mask = None
 | 
					                    status_info = "Extracting " + process_names[preprocess_type]
 | 
				
			||||||
                    preprocess_type = None
 | 
					                    extra_process_list = ([] if preprocess_type2==None else [preprocess_type2]) + ([] if process_outside_mask==None or process_outside_mask == preprocess_type else [process_outside_mask])
 | 
				
			||||||
                    if "P" in video_prompt_type :
 | 
					                    if len(extra_process_list) == 1:
 | 
				
			||||||
                        progress_args = [0, get_latest_status(state,f"Extracting Open Pose{extra_label} Information")]
 | 
					                        status_info += " and " + process_names[extra_process_list[0]]
 | 
				
			||||||
                        preprocess_type = "pose"
 | 
					                    elif len(extra_process_list) == 2:
 | 
				
			||||||
                    elif "D" in video_prompt_type :
 | 
					                        status_info +=  ", " + process_names[extra_process_list[0]] + " and " + process_names[extra_process_list[1]]                    
 | 
				
			||||||
                        progress_args = [0, get_latest_status(state,"Extracting Depth Information")]
 | 
					                    send_cmd("progress", [0, get_latest_status(state, status_info)])
 | 
				
			||||||
                        preprocess_type = "depth"
 | 
					                    video_guide_copy, video_mask_copy = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) if guide_start_frame == 0 else len(keep_frames_parsed) - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps,  process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1 )
 | 
				
			||||||
                    elif "S" in video_prompt_type :
 | 
					                    video_guide_copy2 = video_mask_copy2 = None
 | 
				
			||||||
                        progress_args = [0, get_latest_status(state,"Extracting Shapes Information")]
 | 
					                    if preprocess_type2 != None:
 | 
				
			||||||
                        preprocess_type = "scribble"
 | 
					                        video_guide_copy2, video_mask_copy2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) if guide_start_frame == 0 else len(keep_frames_parsed) - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps,  process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2 )
 | 
				
			||||||
                    elif "C" in video_prompt_type :
 | 
					
 | 
				
			||||||
                        progress_args = [0, get_latest_status(state,f"Extracting Gray Level{extra_label} Information")]
 | 
					 | 
				
			||||||
                        preprocess_type = "gray"
 | 
					 | 
				
			||||||
                    elif "M" in video_prompt_type :
 | 
					 | 
				
			||||||
                        progress_args = [0, get_latest_status(state,f"Creating Inpainting{extra_label} Mask")]
 | 
					 | 
				
			||||||
                        preprocess_type = "inpaint"
 | 
					 | 
				
			||||||
                    elif "U" in video_prompt_type :
 | 
					 | 
				
			||||||
                        progress_args = [0, get_latest_status(state,f"Creating Identity{extra_label} Mask")]
 | 
					 | 
				
			||||||
                        preprocess_type = "identity"
 | 
					 | 
				
			||||||
                    else:
 | 
					 | 
				
			||||||
                        progress_args = [0, get_latest_status(state,f"Creating Vace Generic{extra_label} Mask")]
 | 
					 | 
				
			||||||
                        preprocess_type = "vace"
 | 
					 | 
				
			||||||
                    send_cmd("progress", progress_args)
 | 
					 | 
				
			||||||
                    video_guide_copy, video_mask_copy = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) if guide_start_frame == 0 else len(keep_frames_parsed) - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps,  process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims )
 | 
					 | 
				
			||||||
                    if video_guide_copy != None:
 | 
					                    if video_guide_copy != None:
 | 
				
			||||||
                        if sample_fit_canvas != None:
 | 
					                        if sample_fit_canvas != None:
 | 
				
			||||||
                            image_size = video_guide_copy.shape[-3: -1]
 | 
					                            image_size = video_guide_copy.shape[-3: -1]
 | 
				
			||||||
                            sample_fit_canvas = None
 | 
					                            sample_fit_canvas = None
 | 
				
			||||||
                        refresh_preview["video_guide"] = Image.fromarray(video_guide_copy[0].cpu().numpy())
 | 
					                        refresh_preview["video_guide"] = Image.fromarray(video_guide_copy[0].cpu().numpy())
 | 
				
			||||||
 | 
					                        if video_guide_copy2 != None:
 | 
				
			||||||
 | 
					                            refresh_preview["video_guide"] = [refresh_preview["video_guide"], Image.fromarray(video_guide_copy2[0].cpu().numpy())] 
 | 
				
			||||||
                        if video_mask_copy != None:                        
 | 
					                        if video_mask_copy != None:                        
 | 
				
			||||||
                            refresh_preview["video_mask"] = Image.fromarray(video_mask_copy[0].cpu().numpy())
 | 
					                            refresh_preview["video_mask"] = Image.fromarray(video_mask_copy[0].cpu().numpy())
 | 
				
			||||||
                frames_to_inject_parsed = frames_to_inject[guide_start_frame: guide_start_frame + current_video_length]
 | 
					                frames_to_inject_parsed = frames_to_inject[guide_start_frame: guide_start_frame + current_video_length]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_copy],
 | 
					                src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_copy] if video_guide_copy2 == None else [video_guide_copy, video_guide_copy2],
 | 
				
			||||||
                                                                        [video_mask_copy],
 | 
					                                                                        [video_mask_copy] if video_guide_copy2 == None else [video_mask_copy, video_mask_copy2],
 | 
				
			||||||
                                                                        [image_refs_copy], 
 | 
					                                                                        [image_refs_copy] if video_guide_copy2 == None else [image_refs_copy, image_refs_copy], 
 | 
				
			||||||
                                                                        current_video_length, image_size = image_size, device ="cpu",
 | 
					                                                                        current_video_length, image_size = image_size, device ="cpu",
 | 
				
			||||||
                                                                        original_video= "O" in video_prompt_type,
 | 
					 | 
				
			||||||
                                                                        keep_frames=keep_frames_parsed,
 | 
					                                                                        keep_frames=keep_frames_parsed,
 | 
				
			||||||
                                                                        start_frame = guide_start_frame,
 | 
					                                                                        start_frame = guide_start_frame,
 | 
				
			||||||
                                                                        pre_src_video = [pre_video_guide],
 | 
					                                                                        pre_src_video = [pre_video_guide] if video_guide_copy2 == None else [pre_video_guide, pre_video_guide],
 | 
				
			||||||
                                                                        fit_into_canvas = sample_fit_canvas,
 | 
					                                                                        fit_into_canvas = sample_fit_canvas,
 | 
				
			||||||
                                                                        inject_frames= frames_to_inject_parsed,
 | 
					                                                                        inject_frames= frames_to_inject_parsed,
 | 
				
			||||||
                                                                        outpainting_dims = outpainting_dims,
 | 
					                                                                        outpainting_dims = outpainting_dims,
 | 
				
			||||||
@ -3567,7 +3634,7 @@ def generate_video(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
                send_cmd("progress", progress_args)
 | 
					                send_cmd("progress", progress_args)
 | 
				
			||||||
                src_video, src_mask = preprocess_video_with_mask(video_guide,  video_mask, height=height, width = width, max_frames= current_video_length if window_no == 1 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type= "pose" if "P" in video_prompt_type else "inpaint", negate_mask = "N" in video_prompt_type, inpaint_color =0)
 | 
					                src_video, src_mask = preprocess_video_with_mask(video_guide,  video_mask, height=height, width = width, max_frames= current_video_length if window_no == 1 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type= "pose" if "P" in video_prompt_type else "inpaint", negate_mask = "N" in video_prompt_type, inpaint_color =0)
 | 
				
			||||||
                refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy())
 | 
					                refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy()) 
 | 
				
			||||||
                if src_mask != None:                        
 | 
					                if src_mask != None:                        
 | 
				
			||||||
                    refresh_preview["video_mask"] = Image.fromarray(src_mask[0].cpu().numpy())
 | 
					                    refresh_preview["video_mask"] = Image.fromarray(src_mask[0].cpu().numpy())
 | 
				
			||||||
            if len(refresh_preview) > 0:
 | 
					            if len(refresh_preview) > 0:
 | 
				
			||||||
@ -3632,6 +3699,7 @@ def generate_video(
 | 
				
			|||||||
                    audio_proj= audio_proj_split,
 | 
					                    audio_proj= audio_proj_split,
 | 
				
			||||||
                    audio_scale= audio_scale,
 | 
					                    audio_scale= audio_scale,
 | 
				
			||||||
                    audio_context_lens= audio_context_lens,
 | 
					                    audio_context_lens= audio_context_lens,
 | 
				
			||||||
 | 
					                    context_scale = context_scale,
 | 
				
			||||||
                    ar_step = model_mode, #5
 | 
					                    ar_step = model_mode, #5
 | 
				
			||||||
                    causal_block_size = 5,
 | 
					                    causal_block_size = 5,
 | 
				
			||||||
                    causal_attention = True,
 | 
					                    causal_attention = True,
 | 
				
			||||||
@ -3768,15 +3836,12 @@ def generate_video(
 | 
				
			|||||||
                    w *= scale
 | 
					                    w *= scale
 | 
				
			||||||
                    h = int(h)
 | 
					                    h = int(h)
 | 
				
			||||||
                    w = int(w)
 | 
					                    w = int(w)
 | 
				
			||||||
                    new_frames =[]
 | 
					                    frames_to_upsample = [sample[:, i] for i in range( sample.shape[1]) ] 
 | 
				
			||||||
                    for i in range( sample.shape[1] ):
 | 
					                    def upsample_frames(frame):
 | 
				
			||||||
                        frame = sample[:, i]
 | 
					                        return resize_lanczos(frame, h, w).unsqueeze(1)
 | 
				
			||||||
                        frame = resize_lanczos(frame, h, w)
 | 
					                    sample = torch.cat(process_images_multithread(upsample_frames, frames_to_upsample, "upsample", wrap_in_list = False), dim=1)
 | 
				
			||||||
                        frame = frame.unsqueeze(1)
 | 
					                    frames_to_upsample = None
 | 
				
			||||||
                        new_frames.append(frame)
 | 
					                    sample.mul_(2).sub_(1) 
 | 
				
			||||||
                    sample = torch.cat(new_frames, dim=1)
 | 
					 | 
				
			||||||
                    new_frames = None
 | 
					 | 
				
			||||||
                    sample = sample * 2 - 1
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                if sliding_window :
 | 
					                if sliding_window :
 | 
				
			||||||
                    if frames_already_processed == None:
 | 
					                    if frames_already_processed == None:
 | 
				
			||||||
@ -4764,6 +4829,8 @@ def save_inputs(
 | 
				
			|||||||
            video_guide,
 | 
					            video_guide,
 | 
				
			||||||
            keep_frames_video_guide,
 | 
					            keep_frames_video_guide,
 | 
				
			||||||
            video_mask,
 | 
					            video_mask,
 | 
				
			||||||
 | 
					            control_net_weight,
 | 
				
			||||||
 | 
					            control_net_weight2,
 | 
				
			||||||
            mask_expand,
 | 
					            mask_expand,
 | 
				
			||||||
            audio_guide,
 | 
					            audio_guide,
 | 
				
			||||||
            sliding_window_size,
 | 
					            sliding_window_size,
 | 
				
			||||||
@ -4955,13 +5022,13 @@ def refresh_video_prompt_type_image_refs(state, video_prompt_type, video_prompt_
 | 
				
			|||||||
    return video_prompt_type, gr.update(visible = visible),gr.update(visible = visible), gr.update(visible = visible and "F" in video_prompt_type_image_refs), gr.update(visible= ("F" in video_prompt_type_image_refs or "V" in video_prompt_type) and vace )
 | 
					    return video_prompt_type, gr.update(visible = visible),gr.update(visible = visible), gr.update(visible = visible and "F" in video_prompt_type_image_refs), gr.update(visible= ("F" in video_prompt_type_image_refs or "V" in video_prompt_type) and vace )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def refresh_video_prompt_type_video_mask(video_prompt_type, video_prompt_type_video_mask):
 | 
					def refresh_video_prompt_type_video_mask(video_prompt_type, video_prompt_type_video_mask):
 | 
				
			||||||
    video_prompt_type = del_in_sequence(video_prompt_type, "XWYNA")
 | 
					    video_prompt_type = del_in_sequence(video_prompt_type, "XYZWNA")
 | 
				
			||||||
    video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_mask)
 | 
					    video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_mask)
 | 
				
			||||||
    visible= "A" in video_prompt_type     
 | 
					    visible= "A" in video_prompt_type     
 | 
				
			||||||
    return video_prompt_type, gr.update(visible= visible), gr.update(visible= visible )
 | 
					    return video_prompt_type, gr.update(visible= visible), gr.update(visible= visible )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide):
 | 
					def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide):
 | 
				
			||||||
    video_prompt_type = del_in_sequence(video_prompt_type, "DSPCMUV")
 | 
					    video_prompt_type = del_in_sequence(video_prompt_type, "PDSFCMUV")
 | 
				
			||||||
    video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide)
 | 
					    video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide)
 | 
				
			||||||
    visible = "V" in video_prompt_type
 | 
					    visible = "V" in video_prompt_type
 | 
				
			||||||
    mask_visible = visible and "A" in video_prompt_type and not "U" in video_prompt_type
 | 
					    mask_visible = visible and "A" in video_prompt_type and not "U" in video_prompt_type
 | 
				
			||||||
@ -5242,12 +5309,19 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
 | 
				
			|||||||
                                ("Transfer Human Motion", "PV"),
 | 
					                                ("Transfer Human Motion", "PV"),
 | 
				
			||||||
                                ("Transfer Depth", "DV"),
 | 
					                                ("Transfer Depth", "DV"),
 | 
				
			||||||
                                ("Transfer Shapes", "SV"),
 | 
					                                ("Transfer Shapes", "SV"),
 | 
				
			||||||
 | 
					                                ("Transfer Flow", "FV"),
 | 
				
			||||||
                                ("Recolorize", "CV"),
 | 
					                                ("Recolorize", "CV"),
 | 
				
			||||||
                                ("Inpainting", "MV"),
 | 
					                                ("Perform Inpainting", "MV"),
 | 
				
			||||||
                                ("Vace raw format", "V"),
 | 
					                                ("Use Vace raw format", "V"),
 | 
				
			||||||
                                ("Keep Unchanged", "UV"),
 | 
					                                ("Keep Unchanged", "UV"),
 | 
				
			||||||
                            ],
 | 
					                                ("Transfer Human Motion & Depth", "PDV"),
 | 
				
			||||||
                            value=filter_letters(video_prompt_type_value, "DSPCMUV"),
 | 
					                                ("Transfer Human Motion & Shape", "PSV"),
 | 
				
			||||||
 | 
					                                ("Transfer Human Motion & Flow", "PFV"),
 | 
				
			||||||
 | 
					                                ("Transfer Depth & Shape", "DSV"),
 | 
				
			||||||
 | 
					                                ("Transfer Depth & Flow", "DFV"),
 | 
				
			||||||
 | 
					                                ("Transfer Shapes & Flow", "SFV"),
 | 
				
			||||||
 | 
					                           ],
 | 
				
			||||||
 | 
					                            value=filter_letters(video_prompt_type_value, "PDSFCMUV"),
 | 
				
			||||||
                            label="Control Video Process", scale = 2, visible= True
 | 
					                            label="Control Video Process", scale = 2, visible= True
 | 
				
			||||||
                        )
 | 
					                        )
 | 
				
			||||||
                    elif hunyuan_video_custom_edit:
 | 
					                    elif hunyuan_video_custom_edit:
 | 
				
			||||||
@ -5256,7 +5330,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
 | 
				
			|||||||
                                ("Inpaint Control Video", "MV"),
 | 
					                                ("Inpaint Control Video", "MV"),
 | 
				
			||||||
                                ("Transfer Human Motion", "PMV"),
 | 
					                                ("Transfer Human Motion", "PMV"),
 | 
				
			||||||
                            ],
 | 
					                            ],
 | 
				
			||||||
                            value=filter_letters(video_prompt_type_value, "DSPCMUV"),
 | 
					                            value=filter_letters(video_prompt_type_value, "PDSFCMUV"),
 | 
				
			||||||
                            label="Video to Video", scale = 3, visible= True
 | 
					                            label="Video to Video", scale = 3, visible= True
 | 
				
			||||||
                        )
 | 
					                        )
 | 
				
			||||||
                    else:
 | 
					                    else:
 | 
				
			||||||
@ -5286,8 +5360,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
 | 
				
			|||||||
                                ("Non Masked Area, rest Depth", "YNA"),
 | 
					                                ("Non Masked Area, rest Depth", "YNA"),
 | 
				
			||||||
                                ("Masked Area, rest Shapes", "WA"),
 | 
					                                ("Masked Area, rest Shapes", "WA"),
 | 
				
			||||||
                                ("Non Masked Area, rest Shapes", "WNA"),
 | 
					                                ("Non Masked Area, rest Shapes", "WNA"),
 | 
				
			||||||
 | 
					                                ("Masked Area, rest Flow", "ZA"),
 | 
				
			||||||
 | 
					                                ("Non Masked Area, rest Flow", "ZNA"),
 | 
				
			||||||
                            ],
 | 
					                            ],
 | 
				
			||||||
                            value= filter_letters(video_prompt_type_value, "XYWNA"),
 | 
					                            value= filter_letters(video_prompt_type_value, "XYZWNA"),
 | 
				
			||||||
                            visible=  "V" in video_prompt_type_value and not "U" in video_prompt_type_value and not hunyuan_video_custom,
 | 
					                            visible=  "V" in video_prompt_type_value and not "U" in video_prompt_type_value and not hunyuan_video_custom,
 | 
				
			||||||
                            label="Area Processed", scale = 2
 | 
					                            label="Area Processed", scale = 2
 | 
				
			||||||
                        )
 | 
					                        )
 | 
				
			||||||
@ -5466,6 +5542,9 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
 | 
				
			|||||||
                            audio_guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("audio_guidance_scale",5), step=0.5, label="Audio Guidance", visible=fantasy)
 | 
					                            audio_guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("audio_guidance_scale",5), step=0.5, label="Audio Guidance", visible=fantasy)
 | 
				
			||||||
                            embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=(hunyuan_t2v or hunyuan_i2v))
 | 
					                            embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=(hunyuan_t2v or hunyuan_i2v))
 | 
				
			||||||
                            flow_shift = gr.Slider(0.0, 25.0, value=ui_defaults.get("flow_shift",3), step=0.1, label="Shift Scale") 
 | 
					                            flow_shift = gr.Slider(0.0, 25.0, value=ui_defaults.get("flow_shift",3), step=0.1, label="Shift Scale") 
 | 
				
			||||||
 | 
					                        with gr.Row(visible = vace):
 | 
				
			||||||
 | 
					                            control_net_weight = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight",1), step=0.1, label="Control Net Weight #1", visible=vace)
 | 
				
			||||||
 | 
					                            control_net_weight2 = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight2",1), step=0.1, label="Control Net Weight #2", visible=vace)
 | 
				
			||||||
                        with gr.Row():
 | 
					                        with gr.Row():
 | 
				
			||||||
                            negative_prompt = gr.Textbox(label="Negative Prompt", value=ui_defaults.get("negative_prompt", "") )
 | 
					                            negative_prompt = gr.Textbox(label="Negative Prompt", value=ui_defaults.get("negative_prompt", "") )
 | 
				
			||||||
                with gr.Tab("Loras"):
 | 
					                with gr.Tab("Loras"):
 | 
				
			||||||
@ -6144,6 +6223,15 @@ def generate_configuration_tab(state, blocks, header, model_choice, prompt_enhan
 | 
				
			|||||||
                    interactive= not lock_ui_compile
 | 
					                    interactive= not lock_ui_compile
 | 
				
			||||||
                )              
 | 
					                )              
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                depth_anything_v2_variant_choice = gr.Dropdown(
 | 
				
			||||||
 | 
					                    choices=[
 | 
				
			||||||
 | 
					                        ("Large (more precise but 2x slower)", "vitl"),
 | 
				
			||||||
 | 
					                        ("Big (less precise, less VRAM needed but faster)", "vitb"),
 | 
				
			||||||
 | 
					                    ],
 | 
				
			||||||
 | 
					                    value= server_config.get("depth_anything_v2_variant", "vitl"),
 | 
				
			||||||
 | 
					                    label="Depth Anything v2 Vace Preprocessor Model type",
 | 
				
			||||||
 | 
					                )              
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                vae_config_choice = gr.Dropdown(
 | 
					                vae_config_choice = gr.Dropdown(
 | 
				
			||||||
                    choices=[
 | 
					                    choices=[
 | 
				
			||||||
                ("Auto", 0),
 | 
					                ("Auto", 0),
 | 
				
			||||||
@ -6224,6 +6312,7 @@ def generate_configuration_tab(state, blocks, header, model_choice, prompt_enhan
 | 
				
			|||||||
                    enhancer_enabled_choice,
 | 
					                    enhancer_enabled_choice,
 | 
				
			||||||
                    fit_canvas_choice,
 | 
					                    fit_canvas_choice,
 | 
				
			||||||
                    preload_in_VRAM_choice,
 | 
					                    preload_in_VRAM_choice,
 | 
				
			||||||
 | 
					                    depth_anything_v2_variant_choice,
 | 
				
			||||||
                    notification_sound_enabled_choice,
 | 
					                    notification_sound_enabled_choice,
 | 
				
			||||||
                    notification_sound_volume_choice
 | 
					                    notification_sound_volume_choice
 | 
				
			||||||
                ],
 | 
					                ],
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user