import argparse import os import os.path as osp import torchvision.transforms.functional as TF import torch.nn.functional as F import cv2 import tempfile import imageio import torch import decord from PIL import Image import numpy as np from rembg import remove, new_session import random import ffmpeg import os import tempfile import subprocess import json from functools import lru_cache os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg") from PIL import Image video_info_cache = [] def seed_everything(seed: int): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) if torch.backends.mps.is_available(): torch.mps.manual_seed(seed) def resample(video_fps, video_frames_count, max_target_frames_count, target_fps, start_target_frame ): import math video_frame_duration = 1 /video_fps target_frame_duration = 1 / target_fps target_time = start_target_frame * target_frame_duration frame_no = math.ceil(target_time / video_frame_duration) cur_time = frame_no * video_frame_duration frame_ids =[] while True: if max_target_frames_count != 0 and len(frame_ids) >= max_target_frames_count : break diff = round( (target_time -cur_time) / video_frame_duration , 5) add_frames_count = math.ceil( diff) frame_no += add_frames_count if frame_no >= video_frames_count: break frame_ids.append(frame_no) cur_time += add_frames_count * video_frame_duration target_time += target_frame_duration frame_ids = frame_ids[:max_target_frames_count] return frame_ids import os from datetime import datetime def get_file_creation_date(file_path): # On Windows if os.name == 'nt': return datetime.fromtimestamp(os.path.getctime(file_path)) # On Unix/Linux/Mac (gets last status change, not creation) else: stat = os.stat(file_path) return datetime.fromtimestamp(stat.st_birthtime if hasattr(stat, 'st_birthtime') else stat.st_mtime) def truncate_for_filesystem(s, max_bytes=255): if len(s.encode('utf-8')) <= max_bytes: return s l, r = 0, len(s) while l < r: m = (l + r + 1) // 2 if len(s[:m].encode('utf-8')) <= max_bytes: l = m else: r = m - 1 return s[:l] @lru_cache(maxsize=100) def get_video_info(video_path): global video_info_cache import cv2 cap = cv2.VideoCapture(video_path) # Get FPS fps = round(cap.get(cv2.CAP_PROP_FPS)) # Get resolution width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) cap.release() return fps, width, height, frame_count def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool = False, return_PIL = True) -> torch.Tensor: """Extract nth frame from video as PyTorch tensor normalized to [-1, 1].""" cap = cv2.VideoCapture(file_name) if not cap.isOpened(): raise ValueError(f"Cannot open video: {file_name}") total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # Handle out of bounds if frame_no >= total_frames or frame_no < 0: if return_last_if_missing: frame_no = total_frames - 1 else: cap.release() raise IndexError(f"Frame {frame_no} out of bounds (0-{total_frames-1})") # Get frame cap.set(cv2.CAP_PROP_POS_FRAMES, frame_no) ret, frame = cap.read() cap.release() if not ret: raise ValueError(f"Failed to read frame {frame_no}") # Convert BGR->RGB, reshape to (C,H,W), normalize to [-1,1] frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) if return_PIL: return Image.fromarray(frame) else: return (torch.from_numpy(frame).permute(2, 0, 1).float() / 127.5) - 1.0 # def get_video_frame(file_name, frame_no): # decord.bridge.set_bridge('torch') # reader = decord.VideoReader(file_name) # frame = reader.get_batch([frame_no]).squeeze(0) # img = Image.fromarray(frame.numpy().astype(np.uint8)) # return img def convert_image_to_video(image): if image is None: return None # Convert PIL/numpy image to OpenCV format if needed if isinstance(image, np.ndarray): # Gradio images are typically RGB, OpenCV expects BGR img_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) else: # Handle PIL Image img_array = np.array(image) img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR) height, width = img_bgr.shape[:2] # Create temporary video file (auto-cleaned by Gradio) with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as temp_video: fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(temp_video.name, fourcc, 30.0, (width, height)) out.write(img_bgr) out.release() return temp_video.name def resize_lanczos(img, h, w): img = (img + 1).float().mul_(127.5) img = Image.fromarray(np.clip(img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) img = img.resize((w,h), resample=Image.Resampling.LANCZOS) img = torch.from_numpy(np.array(img).astype(np.float32)).movedim(-1, 0) img = img.div(127.5).sub_(1) return img def remove_background(img, session=None): if session ==None: session = new_session() img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) img = remove(img, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0) def convert_image_to_tensor(image): return torch.from_numpy(np.array(image).astype(np.float32)).div_(127.5).sub_(1.).movedim(-1, 0) def convert_tensor_to_image(t, frame_no = 0): if len(t.shape) == 4: t = t[:, frame_no] return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy()) def save_image(tensor_image, name, frame_no = -1): convert_tensor_to_image(tensor_image, frame_no).save(name) def get_outpainting_full_area_dimensions(frame_height,frame_width, outpainting_dims): outpainting_top, outpainting_bottom, outpainting_left, outpainting_right= outpainting_dims frame_height = int(frame_height * (100 + outpainting_top + outpainting_bottom) / 100) frame_width = int(frame_width * (100 + outpainting_left + outpainting_right) / 100) return frame_height, frame_width def rgb_bw_to_rgba_mask(img, thresh=127): a = img.convert('L').point(lambda p: 255 if p > thresh else 0) # alpha out = Image.new('RGBA', img.size, (255, 255, 255, 0)) # white, transparent out.putalpha(a) # white where alpha=255 return out def get_outpainting_frame_location(final_height, final_width, outpainting_dims, block_size = 8): outpainting_top, outpainting_bottom, outpainting_left, outpainting_right= outpainting_dims raw_height = int(final_height / ((100 + outpainting_top + outpainting_bottom) / 100)) height = int(raw_height / block_size) * block_size extra_height = raw_height - height raw_width = int(final_width / ((100 + outpainting_left + outpainting_right) / 100)) width = int(raw_width / block_size) * block_size extra_width = raw_width - width margin_top = int(outpainting_top/(100 + outpainting_top + outpainting_bottom) * final_height) if extra_height != 0 and (outpainting_top + outpainting_bottom) != 0: margin_top += int(outpainting_top / (outpainting_top + outpainting_bottom) * extra_height) if (margin_top + height) > final_height or outpainting_bottom == 0: margin_top = final_height - height margin_left = int(outpainting_left/(100 + outpainting_left + outpainting_right) * final_width) if extra_width != 0 and (outpainting_left + outpainting_right) != 0: margin_left += int(outpainting_left / (outpainting_left + outpainting_right) * extra_height) if (margin_left + width) > final_width or outpainting_right == 0: margin_left = final_width - width return height, width, margin_top, margin_left def rescale_and_crop(img, w, h): ow, oh = img.size target_ratio = w / h orig_ratio = ow / oh if orig_ratio > target_ratio: # Crop width first nw = int(oh * target_ratio) img = img.crop(((ow - nw) // 2, 0, (ow + nw) // 2, oh)) else: # Crop height first nh = int(ow / target_ratio) img = img.crop((0, (oh - nh) // 2, ow, (oh + nh) // 2)) return img.resize((w, h), Image.LANCZOS) def calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = 16): if fit_into_canvas == None or fit_into_canvas == 2: # return image_height, image_width return canvas_height, canvas_width if fit_into_canvas == 1: scale1 = min(canvas_height / image_height, canvas_width / image_width) scale2 = min(canvas_width / image_height, canvas_height / image_width) scale = max(scale1, scale2) else: #0 or #2 (crop) scale = (canvas_height * canvas_width / (image_height * image_width))**(1/2) new_height = round( image_height * scale / block_size) * block_size new_width = round( image_width * scale / block_size) * block_size return new_height, new_width def calculate_dimensions_and_resize_image(image, canvas_height, canvas_width, fit_into_canvas, fit_crop, block_size = 16): if fit_crop: image = rescale_and_crop(image, canvas_width, canvas_height) new_width, new_height = image.size else: image_width, image_height = image.size new_height, new_width = calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = block_size ) image = image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) return image, new_height, new_width def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None, background_ref_outpainted = True, inpaint_color = 127.5 ): if rm_background: session = new_session() output_list =[] output_mask_list =[] for i, img in enumerate(img_list): width, height = img.size resized_mask = None if fit_into_canvas == None or any_background_ref == 1 and i==0 or any_background_ref == 2: if outpainting_dims is not None and background_ref_outpainted: resized_image, resized_mask = fit_image_into_canvas(img, (budget_height, budget_width), inpaint_color, full_frame = True, outpainting_dims = outpainting_dims, return_mask= True, return_image= True) elif img.size != (budget_width, budget_height): resized_image= img.resize((budget_width, budget_height), resample=Image.Resampling.LANCZOS) else: resized_image =img elif fit_into_canvas == 1: white_canvas = np.ones((budget_height, budget_width, 3), dtype=np.uint8) * 255 scale = min(budget_height / height, budget_width / width) new_height = int(height * scale) new_width = int(width * scale) resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS) top = (budget_height - new_height) // 2 left = (budget_width - new_width) // 2 white_canvas[top:top + new_height, left:left + new_width] = np.array(resized_image) resized_image = Image.fromarray(white_canvas) else: scale = (budget_height * budget_width / (height * width))**(1/2) new_height = int( round(height * scale / block_size) * block_size) new_width = int( round(width * scale / block_size) * block_size) resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS) if rm_background and not (any_background_ref and i==0 or any_background_ref == 2) : # resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200, output_mask_list.append(resized_mask) return output_list, output_mask_list def fit_image_into_canvas(ref_img, image_size, canvas_tf_bg =127.5, device ="cpu", full_frame = False, outpainting_dims = None, return_mask = False, return_image = False): from shared.utils.utils import save_image inpaint_color = canvas_tf_bg / 127.5 - 1 ref_width, ref_height = ref_img.size if (ref_height, ref_width) == image_size and outpainting_dims == None: ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) canvas = torch.zeros_like(ref_img) if return_mask else None else: if outpainting_dims != None: final_height, final_width = image_size canvas_height, canvas_width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 1) else: canvas_height, canvas_width = image_size if full_frame: new_height = canvas_height new_width = canvas_width top = left = 0 else: # if fill_max and (canvas_height - new_height) < 16: # new_height = canvas_height # if fill_max and (canvas_width - new_width) < 16: # new_width = canvas_width scale = min(canvas_height / ref_height, canvas_width / ref_width) new_height = int(ref_height * scale) new_width = int(ref_width * scale) top = (canvas_height - new_height) // 2 left = (canvas_width - new_width) // 2 ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) if outpainting_dims != None: canvas = torch.full((3, 1, final_height, final_width), inpaint_color, dtype= torch.float, device=device) # [-1, 1] canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = ref_img else: canvas = torch.full((3, 1, canvas_height, canvas_width), inpaint_color, dtype= torch.float, device=device) # [-1, 1] canvas[:, :, top:top + new_height, left:left + new_width] = ref_img ref_img = canvas canvas = None if return_mask: if outpainting_dims != None: canvas = torch.ones((1, 1, final_height, final_width), dtype= torch.float, device=device) # [-1, 1] canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = 0 else: canvas = torch.ones((1, 1, canvas_height, canvas_width), dtype= torch.float, device=device) # [-1, 1] canvas[:, :, top:top + new_height, left:left + new_width] = 0 canvas = canvas.to(device) if return_image: return convert_tensor_to_image(ref_img), canvas return ref_img.to(device), canvas def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, image_size, current_video_length = 81, latent_size = 4, any_mask = False, any_guide_padding = False, guide_inpaint_color = 127.5, extract_guide_from_window_start = False, keep_video_guide_frames = [], inject_frames = [], outpainting_dims = None ): src_videos, src_masks = [], [] inpaint_color = guide_inpaint_color/127.5 - 1 prepend_count = pre_video_guide.shape[1] if not extract_guide_from_window_start and pre_video_guide is not None else 0 for guide_no, (cur_video_guide, cur_video_mask) in enumerate(zip(video_guides, video_masks)): src_video = src_mask = None if cur_video_guide is not None: src_video = cur_video_guide.permute(3, 0, 1, 2).float().div_(127.5).sub_(1.) # c, f, h, w if cur_video_mask is not None and any_mask: src_mask = cur_video_mask.permute(3, 0, 1, 2).float().div_(255)[0:1] # c, f, h, w if pre_video_guide is not None and not extract_guide_from_window_start: src_video = pre_video_guide if src_video is None else torch.cat( [pre_video_guide, src_video], dim=1) if any_mask: src_mask = torch.zeros_like(pre_video_guide[0:1]) if src_mask is None else torch.cat( [torch.zeros_like(pre_video_guide[0:1]), src_mask], dim=1) if src_video is None: if any_guide_padding: src_video = torch.full( (3, current_video_length, *image_size ), inpaint_color, dtype = torch.float, device= "cpu") if any_mask: src_mask = torch.zeros_like(src_video[0:1]) elif src_video.shape[1] < current_video_length: if any_guide_padding: src_video = torch.cat([src_video, torch.full( (3, current_video_length - src_video.shape[1], *src_video.shape[-2:] ), inpaint_color, dtype = src_video.dtype, device= src_video.device) ], dim=1) if cur_video_mask is not None and any_mask: src_mask = torch.cat([src_mask, torch.full( (1, current_video_length - src_mask.shape[1], *src_mask.shape[-2:] ), 1, dtype = src_video.dtype, device= src_video.device) ], dim=1) else: new_num_frames = (src_video.shape[1] - 1) // latent_size * latent_size + 1 src_video = src_video[:, :new_num_frames] if any_mask: src_mask = src_mask[:, :new_num_frames] for k, keep in enumerate(keep_video_guide_frames): if not keep: pos = prepend_count + k src_video[:, pos:pos+1] = inpaint_color src_mask[:, pos:pos+1] = 1 for k, frame in enumerate(inject_frames): if frame != None: pos = prepend_count + k src_video[:, pos:pos+1], src_mask[:, pos:pos+1] = fit_image_into_canvas(frame, image_size, inpaint_color, device, True, outpainting_dims, return_mask= True) src_videos.append(src_video) src_masks.append(src_mask) return src_videos, src_masks