import sys import os import json import time import psutil # import ffmpeg import imageio from PIL import Image import cv2 import torch import torch.nn.functional as F import numpy as np import gradio as gr from .tools.painter import mask_painter from .tools.interact_tools import SamControler from .tools.misc import get_device from .tools.download_util import load_file_from_url from segment_anything.modeling.image_encoder import window_partition, window_unpartition, get_rel_pos, Block as image_encoder_block from .utils.get_default_model import get_matanyone_model from .matanyone.inference.inference_core import InferenceCore from .matanyone_wrapper import matanyone arg_device = "cuda" arg_sam_model_type="vit_h" arg_mask_save = False model_loaded = False model = None matanyone_model = None # SAM generator class MaskGenerator(): def __init__(self, sam_checkpoint, device): global args_device args_device = device self.samcontroler = SamControler(sam_checkpoint, arg_sam_model_type, arg_device) def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True): mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask) return mask, logit, painted_image # convert points input to prompt state def get_prompt(click_state, click_input): inputs = json.loads(click_input) points = click_state[0] labels = click_state[1] for input in inputs: points.append(input[:2]) labels.append(input[2]) click_state[0] = points click_state[1] = labels prompt = { "prompt_type":["click"], "input_point":click_state[0], "input_label":click_state[1], "multimask_output":"True", } return prompt def get_frames_from_image(image_input, image_state): """ Args: video_path:str timestamp:float64 Return [[0:nearest_frame], [nearest_frame:], nearest_frame] """ load_sam() user_name = time.time() frames = [image_input] * 2 # hardcode: mimic a video with 2 frames image_size = (frames[0].shape[0],frames[0].shape[1]) # initialize video_state image_state = { "user_name": user_name, "image_name": "output.png", "origin_images": frames, "painted_images": frames.copy(), "masks": [np.zeros((frames[0].shape[0],frames[0].shape[1]), np.uint8)]*len(frames), "logits": [None]*len(frames), "select_frame_number": 0, "last_frame_numer": 0, "fps": None } image_info = "Image Name: N/A,\nFPS: N/A,\nTotal Frames: {},\nImage Size:{}".format(len(frames), image_size) set_image_encoder_patch() torch.cuda.empty_cache() model.samcontroler.sam_controler.reset_image() model.samcontroler.sam_controler.set_image(image_state["origin_images"][0]) torch.cuda.empty_cache() return image_state, image_info, image_state["origin_images"][0], \ gr.update(visible=True, maximum=10, value=10), gr.update(visible=False, maximum=len(frames), value=len(frames)), \ gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True), gr.update(visible=True),\ gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True), gr.update(value="", visible=True), gr.update(visible=False), \ gr.update(visible=False), gr.update(visible=True), \ gr.update(visible=True) # extract frames from upload video def get_frames_from_video(video_input, video_state): """ Args: video_path:str timestamp:float64 Return [[0:nearest_frame], [nearest_frame:], nearest_frame] """ load_sam() while model == None: time.sleep(1) video_path = video_input frames = [] user_name = time.time() # extract Audio # try: # audio_path = video_input.replace(".mp4", "_audio.wav") # ffmpeg.input(video_path).output(audio_path, format='wav', acodec='pcm_s16le', ac=2, ar='44100').run(overwrite_output=True, quiet=True) # except Exception as e: # print(f"Audio extraction error: {str(e)}") # audio_path = "" # Set to "" if extraction fails # print(f'audio_path: {audio_path}') audio_path = "" # extract frames try: cap = cv2.VideoCapture(video_path) fps = cap.get(cv2.CAP_PROP_FPS) while cap.isOpened(): ret, frame = cap.read() if ret == True: current_memory_usage = psutil.virtual_memory().percent frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) if current_memory_usage > 90: break else: break except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e: print("read_frame_source:{} error. {}\n".format(video_path, str(e))) image_size = (frames[0].shape[0],frames[0].shape[1]) # resize if resolution too big if image_size[0]>=1280 and image_size[0]>=1280: scale = 1080 / min(image_size) new_w = int(image_size[1] * scale) new_h = int(image_size[0] * scale) # update frames frames = [cv2.resize(f, (new_w, new_h), interpolation=cv2.INTER_AREA) for f in frames] # update image_size image_size = (frames[0].shape[0],frames[0].shape[1]) # initialize video_state video_state = { "user_name": user_name, "video_name": os.path.split(video_path)[-1], "origin_images": frames, "painted_images": frames.copy(), "masks": [np.zeros((frames[0].shape[0],frames[0].shape[1]), np.uint8)]*len(frames), "logits": [None]*len(frames), "select_frame_number": 0, "last_frame_number": 0, "fps": fps, "audio": audio_path } video_info = "Video Name: {},\nFPS: {},\nTotal Frames: {},\nImage Size:{}".format(video_state["video_name"], round(video_state["fps"], 0), len(frames), image_size) set_image_encoder_patch() torch.cuda.empty_cache() model.samcontroler.sam_controler.reset_image() model.samcontroler.sam_controler.set_image(video_state["origin_images"][0]) torch.cuda.empty_cache() return video_state, video_info, video_state["origin_images"][0], \ gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), gr.update(visible=False, maximum=len(frames), value=len(frames)), \ gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True), gr.update(visible=True),\ gr.update(visible=True), gr.update(visible=False), \ gr.update(visible=False), gr.update(visible=False), \ gr.update(visible=False), gr.update(visible=True), \ gr.update(visible=True) # get the select frame from gradio slider def select_video_template(image_selection_slider, video_state, interactive_state): image_selection_slider -= 1 video_state["select_frame_number"] = image_selection_slider # once select a new template frame, set the image in sam model.samcontroler.sam_controler.reset_image() model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider]) return video_state["painted_images"][image_selection_slider], video_state, interactive_state def select_image_template(image_selection_slider, video_state, interactive_state): image_selection_slider = 0 # fixed for image video_state["select_frame_number"] = image_selection_slider # once select a new template frame, set the image in sam model.samcontroler.sam_controler.reset_image() model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider]) return video_state["painted_images"][image_selection_slider], video_state, interactive_state # set the tracking end frame def get_end_number(track_pause_number_slider, video_state, interactive_state): interactive_state["track_end_number"] = track_pause_number_slider return video_state["painted_images"][track_pause_number_slider],interactive_state def patched_forward(self, x: torch.Tensor) -> torch.Tensor: def split_mlp(mlp, x, divide = 4): x_shape = x.shape x = x.view(-1, x.shape[-1]) chunk_size = int(x.shape[0]/divide) x_chunks = torch.split(x, chunk_size) for i, x_chunk in enumerate(x_chunks): mlp_chunk = mlp.lin1(x_chunk) mlp_chunk = mlp.act(mlp_chunk) x_chunk[...] = mlp.lin2(mlp_chunk) return x.reshape(x_shape) def get_decomposed_rel_pos( q, rel_pos_h, rel_pos_w, q_size, k_size) -> torch.Tensor: q_h, q_w = q_size k_h, k_w = k_size Rh = get_rel_pos(q_h, k_h, rel_pos_h) Rw = get_rel_pos(q_w, k_w, rel_pos_w) B, _, dim = q.shape r_q = q.reshape(B, q_h, q_w, dim) rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) attn = torch.zeros(B, q_h, q_w, k_h, k_w, dtype=q.dtype, device=q.device) attn += rel_h[:, :, :, :, None] attn += rel_w[:, :, :, None, :] return attn.view(B, q_h * q_w, k_h * k_w) def pay_attention(self, x: torch.Tensor) -> torch.Tensor: B, H, W, _ = x.shape # qkv with shape (3, B, nHead, H * W, C) qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # q, k, v with shape (B * nHead, H * W, C) q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) attn_mask = None if self.use_rel_pos: attn_mask = get_decomposed_rel_pos(q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=self.scale) del q, k, v, attn_mask x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) return self.proj(x) shortcut = x x = self.norm1(x) # Window partition if self.window_size > 0: H, W = x.shape[1], x.shape[2] x, pad_hw = window_partition(x, self.window_size) x = pay_attention(self.attn,x) # Reverse window partition if self.window_size > 0: x = window_unpartition(x, self.window_size, pad_hw, (H, W)) x += shortcut shortcut[...] = self.norm2(x) # x += self.mlp(shortcut) x += split_mlp(self.mlp, shortcut) return x def set_image_encoder_patch(): if not hasattr(image_encoder_block, "patched"): image_encoder_block.forward = patched_forward image_encoder_block.patched = True # use sam to get the mask def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData ): # """ Args: template_frame: PIL.Image point_prompt: flag for positive or negative button click click_state: [[points], [labels]] """ if point_prompt == "Positive": coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1]) interactive_state["positive_click_times"] += 1 else: coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1]) interactive_state["negative_click_times"] += 1 torch.cuda.empty_cache() # prompt for sam model set_image_encoder_patch() model.samcontroler.sam_controler.reset_image() model.samcontroler.sam_controler.set_image(video_state["origin_images"][video_state["select_frame_number"]]) prompt = get_prompt(click_state=click_state, click_input=coordinate) mask, logit, painted_image = model.first_frame_click( image=video_state["origin_images"][video_state["select_frame_number"]], points=np.array(prompt["input_point"]), labels=np.array(prompt["input_label"]), multimask=prompt["multimask_output"], ) video_state["masks"][video_state["select_frame_number"]] = mask video_state["logits"][video_state["select_frame_number"]] = logit video_state["painted_images"][video_state["select_frame_number"]] = painted_image torch.cuda.empty_cache() return painted_image, video_state, interactive_state def add_multi_mask(video_state, interactive_state, mask_dropdown): mask = video_state["masks"][video_state["select_frame_number"]] interactive_state["multi_mask"]["masks"].append(mask) interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"]))) mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"]))) select_frame = show_mask(video_state, interactive_state, mask_dropdown) return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]] def clear_click(video_state, click_state): click_state = [[],[]] template_frame = video_state["origin_images"][video_state["select_frame_number"]] return template_frame, click_state def remove_multi_mask(interactive_state, mask_dropdown): interactive_state["multi_mask"]["mask_names"]= [] interactive_state["multi_mask"]["masks"] = [] return interactive_state, gr.update(choices=[],value=[]) def show_mask(video_state, interactive_state, mask_dropdown): mask_dropdown.sort() if video_state["origin_images"]: select_frame = video_state["origin_images"][video_state["select_frame_number"]] for i in range(len(mask_dropdown)): mask_number = int(mask_dropdown[i].split("_")[1]) - 1 mask = interactive_state["multi_mask"]["masks"][mask_number] select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2) return select_frame def save_video(frames, output_path, fps): writer = imageio.get_writer( output_path, fps=fps, codec='libx264', quality=8) for frame in frames: writer.append_data(frame) writer.close() return output_path def mask_to_xyxy_box(mask): rows, cols = np.where(mask == 255) xmin = min(cols) xmax = max(cols) + 1 ymin = min(rows) ymax = max(rows) + 1 xmin = max(xmin, 0) ymin = max(ymin, 0) xmax = min(xmax, mask.shape[1]) ymax = min(ymax, mask.shape[0]) box = [xmin, ymin, xmax, ymax] box = [int(x) for x in box] return box # image matting def image_matting(video_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size, refine_iter): matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg) if interactive_state["track_end_number"]: following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] else: following_frames = video_state["origin_images"][video_state["select_frame_number"]:] if interactive_state["multi_mask"]["masks"]: if len(mask_dropdown) == 0: mask_dropdown = ["mask_001"] mask_dropdown.sort() template_mask = interactive_state["multi_mask"]["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1])) for i in range(1,len(mask_dropdown)): mask_number = int(mask_dropdown[i].split("_")[1]) - 1 template_mask = np.clip(template_mask+interactive_state["multi_mask"]["masks"][mask_number]*(mask_number+1), 0, mask_number+1) video_state["masks"][video_state["select_frame_number"]]= template_mask else: template_mask = video_state["masks"][video_state["select_frame_number"]] # operation error if len(np.unique(template_mask))==1: template_mask[0][0]=1 torch.cuda.empty_cache() foreground, alpha = matanyone(matanyone_processor, following_frames, template_mask*255, r_erode=erode_kernel_size, r_dilate=dilate_kernel_size, n_warmup=refine_iter) torch.cuda.empty_cache() foreground_mat = False output_frames = [] for frame_origin, frame_alpha in zip(following_frames, alpha): if foreground_mat: frame_alpha[frame_alpha > 127] = 255 frame_alpha[frame_alpha <= 127] = 0 else: frame_temp = frame_alpha.copy() frame_alpha[frame_temp > 127] = 0 frame_alpha[frame_temp <= 127] = 255 output_frame = np.bitwise_and(frame_origin, 255-frame_alpha) frame_grey = frame_alpha.copy() frame_grey[frame_alpha == 255] = 255 output_frame += frame_grey output_frames.append(output_frame) foreground = output_frames foreground_output = Image.fromarray(foreground[-1]) alpha_output = alpha[-1][:,:,0] frame_temp = alpha_output.copy() alpha_output[frame_temp > 127] = 0 alpha_output[frame_temp <= 127] = 255 bbox_info = mask_to_xyxy_box(alpha_output) h = alpha_output.shape[0] w = alpha_output.shape[1] bbox_info = [str(int(bbox_info[0]/ w * 100 )), str(int(bbox_info[1]/ h * 100 )), str(int(bbox_info[2]/ w * 100 )), str(int(bbox_info[3]/ h * 100 )) ] bbox_info = ":".join(bbox_info) alpha_output = Image.fromarray(alpha_output) return foreground_output, alpha_output, bbox_info, gr.update(visible=True), gr.update(visible=True) # video matting def video_matting(video_state, end_slider, matting_type, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size): matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg) # if interactive_state["track_end_number"]: # following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] # else: end_slider = max(video_state["select_frame_number"] +1, end_slider) following_frames = video_state["origin_images"][video_state["select_frame_number"]: end_slider] if interactive_state["multi_mask"]["masks"]: if len(mask_dropdown) == 0: mask_dropdown = ["mask_001"] mask_dropdown.sort() template_mask = interactive_state["multi_mask"]["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1])) for i in range(1,len(mask_dropdown)): mask_number = int(mask_dropdown[i].split("_")[1]) - 1 template_mask = np.clip(template_mask+interactive_state["multi_mask"]["masks"][mask_number]*(mask_number+1), 0, mask_number+1) video_state["masks"][video_state["select_frame_number"]]= template_mask else: template_mask = video_state["masks"][video_state["select_frame_number"]] fps = video_state["fps"] audio_path = video_state["audio"] # operation error if len(np.unique(template_mask))==1: template_mask[0][0]=1 torch.cuda.empty_cache() foreground, alpha = matanyone(matanyone_processor, following_frames, template_mask*255, r_erode=erode_kernel_size, r_dilate=dilate_kernel_size) torch.cuda.empty_cache() output_frames = [] foreground_mat = matting_type == "Foreground" if not foreground_mat: new_alpha = [] for frame_alpha in alpha: frame_temp = frame_alpha.copy() frame_alpha[frame_temp > 127] = 0 frame_alpha[frame_temp <= 127] = 255 new_alpha.append(frame_alpha) alpha = new_alpha # for frame_origin, frame_alpha in zip(following_frames, alpha): # if foreground_mat: # frame_alpha[frame_alpha > 127] = 255 # frame_alpha[frame_alpha <= 127] = 0 # else: # frame_temp = frame_alpha.copy() # frame_alpha[frame_temp > 127] = 0 # frame_alpha[frame_temp <= 127] = 255 # output_frame = np.bitwise_and(frame_origin, 255-frame_alpha) # frame_grey = frame_alpha.copy() # frame_grey[frame_alpha == 255] = 127 # output_frame += frame_grey # output_frames.append(output_frame) foreground = following_frames if not os.path.exists("mask_outputs"): os.makedirs("mask_outputs") file_name= video_state["video_name"] file_name = ".".join(file_name.split(".")[:-1]) foreground_output = save_video(foreground, output_path="./mask_outputs/{}_fg.mp4".format(file_name), fps=fps) # foreground_output = generate_video_from_frames(foreground, output_path="./results/{}_fg.mp4".format(video_state["video_name"]), fps=fps, audio_path=audio_path) # import video_input to name the output video alpha_output = save_video(alpha, output_path="./mask_outputs/{}_alpha.mp4".format(file_name), fps=fps) # alpha_output = generate_video_from_frames(alpha, output_path="./results/{}_alpha.mp4".format(video_state["video_name"]), fps=fps, gray2rgb=True, audio_path=audio_path) # import video_input to name the output video return foreground_output, alpha_output, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True) def show_outputs(): return gr.update(visible=True), gr.update(visible=True) def add_audio_to_video(video_path, audio_path, output_path): pass # try: # video_input = ffmpeg.input(video_path) # audio_input = ffmpeg.input(audio_path) # _ = ( # ffmpeg # .output(video_input, audio_input, output_path, vcodec="copy", acodec="aac") # .run(overwrite_output=True, capture_stdout=True, capture_stderr=True) # ) # return output_path # except ffmpeg.Error as e: # print(f"FFmpeg error:\n{e.stderr.decode()}") # return None def generate_video_from_frames(frames, output_path, fps=30, gray2rgb=False, audio_path=""): """ Generates a video from a list of frames. Args: frames (list of numpy arrays): The frames to include in the video. output_path (str): The path to save the generated video. fps (int, optional): The frame rate of the output video. Defaults to 30. """ frames = torch.from_numpy(np.asarray(frames)) _, h, w, _ = frames.shape if gray2rgb: frames = np.repeat(frames, 3, axis=3) if not os.path.exists(os.path.dirname(output_path)): os.makedirs(os.path.dirname(output_path)) video_temp_path = output_path.replace(".mp4", "_temp.mp4") # resize back to ensure input resolution imageio.mimwrite(video_temp_path, frames, fps=fps, quality=7, codec='libx264', ffmpeg_params=["-vf", f"scale={w}:{h}"]) # add audio to video if audio path exists if audio_path != "" and os.path.exists(audio_path): output_path = add_audio_to_video(video_temp_path, audio_path, output_path) os.remove(video_temp_path) return output_path else: return video_temp_path # reset all states for a new input def restart(): return { "user_name": "", "video_name": "", "origin_images": None, "painted_images": None, "masks": None, "inpaint_masks": None, "logits": None, "select_frame_number": 0, "fps": 30 }, { "inference_times": 0, "negative_click_times" : 0, "positive_click_times": 0, "mask_save": False, "multi_mask": { "mask_names": [], "masks": [] }, "track_end_number": None, }, [[],[]], None, None, \ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),\ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ gr.update(visible=False), gr.update(visible=False, choices=[], value=[]), "", gr.update(visible=False) def load_sam(): global model_loaded global model global matanyone_model model.samcontroler.sam_controler.model.to(arg_device) matanyone_model.to(arg_device) def load_unload_models(selected): global model_loaded global model global matanyone_model if selected: # print("Matanyone Tab Selected") if model_loaded: load_sam() else: # args, defined in track_anything.py sam_checkpoint_url_dict = { 'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", 'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", 'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" } # os.path.join('.') from mmgp import offload # sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[arg_sam_model_type], ".") sam_checkpoint = None transfer_stream = torch.cuda.Stream() with torch.cuda.stream(transfer_stream): # initialize sams model = MaskGenerator(sam_checkpoint, arg_device) from .matanyone.model.matanyone import MatAnyone matanyone_model = MatAnyone.from_pretrained("PeiqingYang/MatAnyone") # pipe ={"mat" : matanyone_model, "sam" :model.samcontroler.sam_controler.model } # offload.profile(pipe) matanyone_model = matanyone_model.to(arg_device).eval() matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg) model_loaded = True else: # print("Matanyone Tab UnSelected") import gc model.samcontroler.sam_controler.model.to("cpu") matanyone_model.to("cpu") gc.collect() torch.cuda.empty_cache() def get_vmc_event_handler(): return load_unload_models def export_to_vace_video_input(foreground_video_output): gr.Info("Masked Video Input transferred to Vace For Inpainting") return "V#" + str(time.time()), foreground_video_output def export_image(image_refs, image_output): gr.Info("Masked Image transferred to Current Video") if image_refs == None: image_refs =[] image_refs.append( image_output) return image_refs def export_image_mask(image_input, image_mask): gr.Info("Input Image & Mask transferred to Current Video") return Image.fromarray(image_input), image_mask def export_to_current_video_engine(model_type, foreground_video_output, alpha_video_output): gr.Info("Original Video and Full Mask have been transferred") # return "MV#" + str(time.time()), foreground_video_output, alpha_video_output if "custom_edit" in model_type and False: return gr.update(), alpha_video_output else: return foreground_video_output, alpha_video_output def teleport_to_video_tab(tab_state): from wgp import set_new_tab set_new_tab(tab_state, 0) return gr.Tabs(selected="video_gen") def display(tabs, tab_state, model_choice, vace_video_input, vace_image_input, vace_video_mask, vace_image_mask, vace_image_refs): # my_tab.select(fn=load_unload_models, inputs=[], outputs=[]) media_url = "https://github.com/pq-yang/MatAnyone/releases/download/media/" # download assets gr.Markdown("Mast Edition is provided by MatAnyone") gr.Markdown("If you have some trouble creating the perfect mask, be aware of these tips:") gr.Markdown("- Using the Matanyone Settings you can also define Negative Point Prompts to remove parts of the current selection.") gr.Markdown("- Sometime it is very hard to fit everything you want in a single mask, it may be much easier to combine multiple independent sub Masks before producing the Matting : each sub Mask is created by selecting an area of an image and by clicking the Add Mask button. Sub masks can then be enabled / disabled in the Matanyone settings.") with gr.Column( visible=True): with gr.Row(): with gr.Accordion("Video Tutorial (click to expand)", open=False, elem_classes="custom-bg"): with gr.Row(): with gr.Column(): gr.Markdown("### Case 1: Single Target") gr.Video(value="preprocessing/matanyone/tutorial_single_target.mp4", elem_classes="video") with gr.Column(): gr.Markdown("### Case 2: Multiple Targets") gr.Video(value="preprocessing/matanyone/tutorial_multi_targets.mp4", elem_classes="video") with gr.Tabs(): with gr.TabItem("Video"): click_state = gr.State([[],[]]) interactive_state = gr.State({ "inference_times": 0, "negative_click_times" : 0, "positive_click_times": 0, "mask_save": arg_mask_save, "multi_mask": { "mask_names": [], "masks": [] }, "track_end_number": None, } ) video_state = gr.State( { "user_name": "", "video_name": "", "origin_images": None, "painted_images": None, "masks": None, "inpaint_masks": None, "logits": None, "select_frame_number": 0, "fps": 16, "audio": "", } ) with gr.Column( visible=True): with gr.Row(): with gr.Accordion('MatAnyone Settings (click to expand)', open=False): with gr.Row(): erode_kernel_size = gr.Slider(label='Erode Kernel Size', minimum=0, maximum=30, step=1, value=10, info="Erosion on the added mask", interactive=True) dilate_kernel_size = gr.Slider(label='Dilate Kernel Size', minimum=0, maximum=30, step=1, value=10, info="Dilation on the added mask", interactive=True) with gr.Row(): image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Start Frame", info="Choose the start frame for target assignment and video matting", visible=False) end_selection_slider = gr.Slider(minimum=1, maximum=300, step=1, value=81, label="Last Frame to Process", info="Last Frame to Process", visible=False) track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="End frame", visible=False) with gr.Row(): point_prompt = gr.Radio( choices=["Positive", "Negative"], value="Positive", label="Point Prompt", info="Click to add positive or negative point for target mask", interactive=True, visible=False, min_width=100, scale=1) matting_type = gr.Radio( choices=["Foreground", "Background"], value="Foreground", label="Matting Type", info="Type of Video Matting to Generate", interactive=True, visible=False, min_width=100, scale=1) mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask Selection", info="Choose 1~all mask(s) added in Step 2", visible=False, scale=2) # input video with gr.Row(equal_height=True): with gr.Column(scale=2): gr.Markdown("## Step1: Upload video") with gr.Column(scale=2): step2_title = gr.Markdown("## Step2: Add masks (Several clicks then **`Add Mask`** one by one)", visible=False) with gr.Row(equal_height=True): with gr.Column(scale=2): video_input = gr.Video(label="Input Video", elem_classes="video") extract_frames_button = gr.Button(value="Load Video", interactive=True, elem_classes="new_button") with gr.Column(scale=2): video_info = gr.Textbox(label="Video Info", visible=False) template_frame = gr.Image(label="Start Frame", type="pil",interactive=True, elem_id="template_frame", visible=False, elem_classes="image") with gr.Row(): clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False, min_width=100) add_mask_button = gr.Button(value="Set Mask", interactive=True, visible=False, min_width=100) remove_mask_button = gr.Button(value="Remove Mask", interactive=True, visible=False, min_width=100) # no use matting_button = gr.Button(value="Generate Video Matting", interactive=True, visible=False, min_width=100) with gr.Row(): gr.Markdown("") # output video with gr.Column() as output_row: #equal_height=True with gr.Row(): with gr.Column(scale=2): foreground_video_output = gr.Video(label="Original Video Input", visible=False, elem_classes="video") foreground_output_button = gr.Button(value="Black & White Video Output", visible=False, elem_classes="new_button") with gr.Column(scale=2): alpha_video_output = gr.Video(label="B & W Mask Video Output", visible=False, elem_classes="video") export_image_mask_btn = gr.Button(value="Alpha Mask Output", visible=False, elem_classes="new_button") with gr.Row(): with gr.Row(visible= False): export_to_vace_video_14B_btn = gr.Button("Export to current Video Input Video For Inpainting", visible= False) with gr.Row(visible= True): export_to_current_video_engine_btn = gr.Button("Export to Control Video Input and Video Mask Input", visible= False) export_to_current_video_engine_btn.click( fn=export_to_current_video_engine, inputs= [model_choice, foreground_video_output, alpha_video_output], outputs= [vace_video_input, vace_video_mask]).then( #video_prompt_video_guide_trigger, fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs]) # first step: get the video information extract_frames_button.click( fn=get_frames_from_video, inputs=[ video_input, video_state ], outputs=[video_state, video_info, template_frame, image_selection_slider, end_selection_slider, track_pause_number_slider, point_prompt, matting_type, clear_button_click, add_mask_button, matting_button, template_frame, foreground_video_output, alpha_video_output, foreground_output_button, export_image_mask_btn, mask_dropdown, step2_title] ) # second step: select images from slider image_selection_slider.release(fn=select_video_template, inputs=[image_selection_slider, video_state, interactive_state], outputs=[template_frame, video_state, interactive_state], api_name="select_image") track_pause_number_slider.release(fn=get_end_number, inputs=[track_pause_number_slider, video_state, interactive_state], outputs=[template_frame, interactive_state], api_name="end_image") # click select image to get mask using sam template_frame.select( fn=sam_refine, inputs=[video_state, point_prompt, click_state, interactive_state], outputs=[template_frame, video_state, interactive_state] ) # add different mask add_mask_button.click( fn=add_multi_mask, inputs=[video_state, interactive_state, mask_dropdown], outputs=[interactive_state, mask_dropdown, template_frame, click_state] ) remove_mask_button.click( fn=remove_multi_mask, inputs=[interactive_state, mask_dropdown], outputs=[interactive_state, mask_dropdown] ) # video matting matting_button.click( fn=show_outputs, inputs=[], outputs=[foreground_video_output, alpha_video_output]).then( fn=video_matting, inputs=[video_state, end_selection_slider, matting_type, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size], outputs=[foreground_video_output, alpha_video_output,foreground_video_output, alpha_video_output, export_to_vace_video_14B_btn, export_to_current_video_engine_btn] ) # click to get mask mask_dropdown.change( fn=show_mask, inputs=[video_state, interactive_state, mask_dropdown], outputs=[template_frame] ) # clear input video_input.change( fn=restart, inputs=[], outputs=[ video_state, interactive_state, click_state, foreground_video_output, alpha_video_output, template_frame, image_selection_slider, end_selection_slider, track_pause_number_slider,point_prompt, export_to_vace_video_14B_btn, export_to_current_video_engine_btn, matting_type, clear_button_click, add_mask_button, matting_button, template_frame, foreground_video_output, alpha_video_output, remove_mask_button, foreground_output_button, export_image_mask_btn, mask_dropdown, video_info, step2_title ], queue=False, show_progress=False) video_input.clear( fn=restart, inputs=[], outputs=[ video_state, interactive_state, click_state, foreground_video_output, alpha_video_output, template_frame, image_selection_slider , end_selection_slider, track_pause_number_slider,point_prompt, export_to_vace_video_14B_btn, export_to_current_video_engine_btn, matting_type, clear_button_click, add_mask_button, matting_button, template_frame, foreground_video_output, alpha_video_output, remove_mask_button, foreground_output_button, export_image_mask_btn, mask_dropdown, video_info, step2_title ], queue=False, show_progress=False) # points clear clear_button_click.click( fn = clear_click, inputs = [video_state, click_state,], outputs = [template_frame,click_state], ) with gr.TabItem("Image"): click_state = gr.State([[],[]]) interactive_state = gr.State({ "inference_times": 0, "negative_click_times" : 0, "positive_click_times": 0, "mask_save": False, "multi_mask": { "mask_names": [], "masks": [] }, "track_end_number": None, } ) image_state = gr.State( { "user_name": "", "image_name": "", "origin_images": None, "painted_images": None, "masks": None, "inpaint_masks": None, "logits": None, "select_frame_number": 0, "fps": 30 } ) with gr.Group(elem_classes="gr-monochrome-group", visible=True): with gr.Row(): with gr.Accordion('MatAnyone Settings (click to expand)', open=False): with gr.Row(): erode_kernel_size = gr.Slider(label='Erode Kernel Size', minimum=0, maximum=30, step=1, value=10, info="Erosion on the added mask", interactive=True) dilate_kernel_size = gr.Slider(label='Dilate Kernel Size', minimum=0, maximum=30, step=1, value=10, info="Dilation on the added mask", interactive=True) with gr.Row(): image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Num of Refinement Iterations", info="More iterations → More details & More time", visible=False) track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frame", visible=False) with gr.Row(): point_prompt = gr.Radio( choices=["Positive", "Negative"], value="Positive", label="Point Prompt", info="Click to add positive or negative point for target mask", interactive=True, visible=False, min_width=100, scale=1) mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask Selection", info="Choose 1~all mask(s) added in Step 2", visible=False) with gr.Column(): # input image with gr.Row(equal_height=True): with gr.Column(scale=2): gr.Markdown("## Step1: Upload image") with gr.Column(scale=2): step2_title = gr.Markdown("## Step2: Add masks (Several clicks then **`Add Mask`** one by one)", visible=False) with gr.Row(equal_height=True): with gr.Column(scale=2): image_input = gr.Image(label="Input Image", elem_classes="image") extract_frames_button = gr.Button(value="Load Image", interactive=True, elem_classes="new_button") with gr.Column(scale=2): image_info = gr.Textbox(label="Image Info", visible=False) template_frame = gr.Image(type="pil", label="Start Frame", interactive=True, elem_id="template_frame", visible=False, elem_classes="image") with gr.Row(equal_height=True, elem_classes="mask_button_group"): clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False, elem_classes="new_button", min_width=100) add_mask_button = gr.Button(value="Add Mask", interactive=True, visible=False, elem_classes="new_button", min_width=100) remove_mask_button = gr.Button(value="Remove Mask", interactive=True, visible=False, elem_classes="new_button", min_width=100) matting_button = gr.Button(value="Image Matting", interactive=True, visible=False, elem_classes="green_button", min_width=100) # output image with gr.Row(equal_height=True): foreground_image_output = gr.Image(type="pil", label="Foreground Output", visible=False, elem_classes="image") alpha_image_output = gr.Image(type="pil", label="Mask", visible=False, elem_classes="image") with gr.Row(equal_height=True): bbox_info = gr.Text(label ="Mask BBox Info (Left:Top:Right:Bottom)", interactive= False) with gr.Row(): # with gr.Row(): export_image_btn = gr.Button(value="Add to current Reference Images", visible=False, elem_classes="new_button") # with gr.Column(scale=2, visible= True): export_image_mask_btn = gr.Button(value="Set to Control Image & Mask", visible=False, elem_classes="new_button") export_image_btn.click( fn=export_image, inputs= [vace_image_refs, foreground_image_output], outputs= [vace_image_refs]).then( #video_prompt_video_guide_trigger, fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs]) export_image_mask_btn.click( fn=export_image_mask, inputs= [image_input, alpha_image_output], outputs= [vace_image_input, vace_image_mask]).then( #video_prompt_video_guide_trigger, fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs]) # first step: get the image information extract_frames_button.click( fn=get_frames_from_image, inputs=[ image_input, image_state ], outputs=[image_state, image_info, template_frame, image_selection_slider, track_pause_number_slider,point_prompt, clear_button_click, add_mask_button, matting_button, template_frame, foreground_image_output, alpha_image_output, bbox_info, export_image_btn, export_image_mask_btn, mask_dropdown, step2_title] ) # points clear clear_button_click.click( fn = clear_click, inputs = [image_state, click_state,], outputs = [template_frame,click_state], ) # second step: select images from slider image_selection_slider.release(fn=select_image_template, inputs=[image_selection_slider, image_state, interactive_state], outputs=[template_frame, image_state, interactive_state], api_name="select_image") track_pause_number_slider.release(fn=get_end_number, inputs=[track_pause_number_slider, image_state, interactive_state], outputs=[template_frame, interactive_state], api_name="end_image") # click select image to get mask using sam template_frame.select( fn=sam_refine, inputs=[image_state, point_prompt, click_state, interactive_state], outputs=[template_frame, image_state, interactive_state] ) # add different mask add_mask_button.click( fn=add_multi_mask, inputs=[image_state, interactive_state, mask_dropdown], outputs=[interactive_state, mask_dropdown, template_frame, click_state] ) remove_mask_button.click( fn=remove_multi_mask, inputs=[interactive_state, mask_dropdown], outputs=[interactive_state, mask_dropdown] ) # image matting matting_button.click( fn=image_matting, inputs=[image_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size, image_selection_slider], outputs=[foreground_image_output, alpha_image_output,bbox_info, export_image_btn, export_image_mask_btn] )