diff --git a/README.md b/README.md index fa9d69f..9b7ccbf 100644 --- a/README.md +++ b/README.md @@ -14,10 +14,14 @@ ## 🔥 Latest News!! -* April 9 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you ! +* April 13 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you ! - A new queuing system that lets you stack in a queue as many text2video and imag2video tasks as you want. Each task can rely on complete different generation parameters (different number of frames, steps, loras, ...). - Temporal upsampling (Rife) and spatial upsampling (Lanczos) for a smoother video (32 fps or 64 fps) and to enlarge you video by x2 or x4. Check these new advanced options. - Wan Vace Control Net support : with Vace you can inject in the scene people or objects, animate a person, perform inpainting or outpainting, continue a video, ... I have provided an introduction guide below. + - Integrated *Matanyone* tool directly inside WanGP so that you can create easily inpainting masks + - Sliding Window generation for Vace, create windows that can last dozen of seconds + - A new UI, tabs were replaced by a Dropdown box to easily switch models + * Mar 27 2025: 👋 Added support for the new Wan Fun InP models (image2video). The 14B Fun InP has probably better end image support but unfortunately existing loras do not work so well with it. The great novelty is the Fun InP image2 1.3B model : Image 2 Video is now accessible to even lower hardware configuration. It is not as good as the 14B models but very impressive for its size. You can choose any of those models in the Configuration tab. Many thanks to the VideoX-Fun team (https://github.com/aigc-apps/VideoX-Fun) * Mar 26 2025: 👋 Good news ! Official support for RTX 50xx please check the installation instructions below. * Mar 24 2025: 👋 Wan2.1GP v3.2: diff --git a/preprocessing/matanyone/__init__.py b/preprocessing/matanyone/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/preprocessing/matanyone/app.py b/preprocessing/matanyone/app.py new file mode 100644 index 0000000..a58fa67 --- /dev/null +++ b/preprocessing/matanyone/app.py @@ -0,0 +1,683 @@ +import sys + +import os +import json +import time +import psutil +import ffmpeg +import imageio +from PIL import Image + +import cv2 +import torch +import numpy as np +import gradio as gr +from .tools.painter import mask_painter +from .tools.interact_tools import SamControler +from .tools.misc import get_device +from .tools.download_util import load_file_from_url + +from .utils.get_default_model import get_matanyone_model +from .matanyone.inference.inference_core import InferenceCore +from .matanyone_wrapper import matanyone + +arg_device = "cuda" +arg_sam_model_type="vit_h" +arg_mask_save = False +model = None +matanyone_model = None + +# SAM generator +class MaskGenerator(): + def __init__(self, sam_checkpoint, device): + global args_device + args_device = device + self.samcontroler = SamControler(sam_checkpoint, arg_sam_model_type, arg_device) + + def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True): + mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask) + return mask, logit, painted_image + +# convert points input to prompt state +def get_prompt(click_state, click_input): + inputs = json.loads(click_input) + points = click_state[0] + labels = click_state[1] + for input in inputs: + points.append(input[:2]) + labels.append(input[2]) + click_state[0] = points + click_state[1] = labels + prompt = { + "prompt_type":["click"], + "input_point":click_state[0], + "input_label":click_state[1], + "multimask_output":"True", + } + return prompt + +def get_frames_from_image(image_input, image_state): + """ + Args: + video_path:str + timestamp:float64 + Return + [[0:nearest_frame], [nearest_frame:], nearest_frame] + """ + + user_name = time.time() + frames = [image_input] * 2 # hardcode: mimic a video with 2 frames + image_size = (frames[0].shape[0],frames[0].shape[1]) + # initialize video_state + image_state = { + "user_name": user_name, + "image_name": "output.png", + "origin_images": frames, + "painted_images": frames.copy(), + "masks": [np.zeros((frames[0].shape[0],frames[0].shape[1]), np.uint8)]*len(frames), + "logits": [None]*len(frames), + "select_frame_number": 0, + "last_frame_numer": 0, + "fps": None + } + image_info = "Image Name: N/A,\nFPS: N/A,\nTotal Frames: {},\nImage Size:{}".format(len(frames), image_size) + model.samcontroler.sam_controler.reset_image() + model.samcontroler.sam_controler.set_image(image_state["origin_images"][0]) + return image_state, image_info, image_state["origin_images"][0], \ + gr.update(visible=True, maximum=10, value=10), gr.update(visible=True, maximum=len(frames), value=len(frames)), gr.update(visible=False, maximum=len(frames), value=len(frames)), \ + gr.update(visible=True), gr.update(visible=True), \ + gr.update(visible=True), gr.update(visible=True),\ + gr.update(visible=True), gr.update(visible=True), \ + gr.update(visible=True), gr.update(visible=False), \ + gr.update(visible=False), gr.update(visible=True), \ + gr.update(visible=True) + +# extract frames from upload video +def get_frames_from_video(video_input, video_state): + """ + Args: + video_path:str + timestamp:float64 + Return + [[0:nearest_frame], [nearest_frame:], nearest_frame] + """ + + while model == None: + time.sleep(1) + + video_path = video_input + frames = [] + user_name = time.time() + + # extract Audio + # try: + # audio_path = video_input.replace(".mp4", "_audio.wav") + # ffmpeg.input(video_path).output(audio_path, format='wav', acodec='pcm_s16le', ac=2, ar='44100').run(overwrite_output=True, quiet=True) + # except Exception as e: + # print(f"Audio extraction error: {str(e)}") + # audio_path = "" # Set to "" if extraction fails + # print(f'audio_path: {audio_path}') + audio_path = "" + # extract frames + try: + cap = cv2.VideoCapture(video_path) + fps = cap.get(cv2.CAP_PROP_FPS) + while cap.isOpened(): + ret, frame = cap.read() + if ret == True: + current_memory_usage = psutil.virtual_memory().percent + frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + if current_memory_usage > 90: + break + else: + break + except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e: + print("read_frame_source:{} error. {}\n".format(video_path, str(e))) + image_size = (frames[0].shape[0],frames[0].shape[1]) + + # resize if resolution too big + if image_size[0]>=1280 and image_size[0]>=1280: + scale = 1080 / min(image_size) + new_w = int(image_size[1] * scale) + new_h = int(image_size[0] * scale) + # update frames + frames = [cv2.resize(f, (new_w, new_h), interpolation=cv2.INTER_AREA) for f in frames] + # update image_size + image_size = (frames[0].shape[0],frames[0].shape[1]) + + # initialize video_state + video_state = { + "user_name": user_name, + "video_name": os.path.split(video_path)[-1], + "origin_images": frames, + "painted_images": frames.copy(), + "masks": [np.zeros((frames[0].shape[0],frames[0].shape[1]), np.uint8)]*len(frames), + "logits": [None]*len(frames), + "select_frame_number": 0, + "last_frame_number": 0, + "fps": fps, + "audio": audio_path + } + video_info = "Video Name: {},\nFPS: {},\nTotal Frames: {},\nImage Size:{}".format(video_state["video_name"], round(video_state["fps"], 0), len(frames), image_size) + model.samcontroler.sam_controler.reset_image() + model.samcontroler.sam_controler.set_image(video_state["origin_images"][0]) + return video_state, video_info, video_state["origin_images"][0], \ + gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), gr.update(visible=False, maximum=len(frames), value=len(frames)), \ + gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \ + gr.update(visible=True), gr.update(visible=True),\ + gr.update(visible=True), gr.update(visible=False), \ + gr.update(visible=False), gr.update(visible=False), \ + gr.update(visible=False), gr.update(visible=True), \ + gr.update(visible=True) + +# get the select frame from gradio slider +def select_video_template(image_selection_slider, video_state, interactive_state): + + image_selection_slider -= 1 + video_state["select_frame_number"] = image_selection_slider + + # once select a new template frame, set the image in sam + model.samcontroler.sam_controler.reset_image() + model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider]) + + return video_state["painted_images"][image_selection_slider], video_state, interactive_state + +def select_image_template(image_selection_slider, video_state, interactive_state): + + image_selection_slider = 0 # fixed for image + video_state["select_frame_number"] = image_selection_slider + + # once select a new template frame, set the image in sam + model.samcontroler.sam_controler.reset_image() + model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider]) + + return video_state["painted_images"][image_selection_slider], video_state, interactive_state + +# set the tracking end frame +def get_end_number(track_pause_number_slider, video_state, interactive_state): + interactive_state["track_end_number"] = track_pause_number_slider + + return video_state["painted_images"][track_pause_number_slider],interactive_state + +# use sam to get the mask +def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData ): # + """ + Args: + template_frame: PIL.Image + point_prompt: flag for positive or negative button click + click_state: [[points], [labels]] + """ + if point_prompt == "Positive": + coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1]) + interactive_state["positive_click_times"] += 1 + else: + coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1]) + interactive_state["negative_click_times"] += 1 + + # prompt for sam model + model.samcontroler.sam_controler.reset_image() + model.samcontroler.sam_controler.set_image(video_state["origin_images"][video_state["select_frame_number"]]) + prompt = get_prompt(click_state=click_state, click_input=coordinate) + + mask, logit, painted_image = model.first_frame_click( + image=video_state["origin_images"][video_state["select_frame_number"]], + points=np.array(prompt["input_point"]), + labels=np.array(prompt["input_label"]), + multimask=prompt["multimask_output"], + ) + video_state["masks"][video_state["select_frame_number"]] = mask + video_state["logits"][video_state["select_frame_number"]] = logit + video_state["painted_images"][video_state["select_frame_number"]] = painted_image + + return painted_image, video_state, interactive_state + +def add_multi_mask(video_state, interactive_state, mask_dropdown): + mask = video_state["masks"][video_state["select_frame_number"]] + interactive_state["multi_mask"]["masks"].append(mask) + interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"]))) + mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"]))) + select_frame = show_mask(video_state, interactive_state, mask_dropdown) + + return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]] + +def clear_click(video_state, click_state): + click_state = [[],[]] + template_frame = video_state["origin_images"][video_state["select_frame_number"]] + return template_frame, click_state + +def remove_multi_mask(interactive_state, mask_dropdown): + interactive_state["multi_mask"]["mask_names"]= [] + interactive_state["multi_mask"]["masks"] = [] + + return interactive_state, gr.update(choices=[],value=[]) + +def show_mask(video_state, interactive_state, mask_dropdown): + mask_dropdown.sort() + if video_state["origin_images"]: + select_frame = video_state["origin_images"][video_state["select_frame_number"]] + for i in range(len(mask_dropdown)): + mask_number = int(mask_dropdown[i].split("_")[1]) - 1 + mask = interactive_state["multi_mask"]["masks"][mask_number] + select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2) + + return select_frame + + +def save_video(frames, output_path, fps): + + writer = imageio.get_writer( output_path, fps=fps, codec='libx264', quality=8) + for frame in frames: + writer.append_data(frame) + writer.close() + + return output_path + +# video matting +def video_matting(video_state, end_slider, matting_type, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size): + matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg) + # if interactive_state["track_end_number"]: + # following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] + # else: + end_slider = max(video_state["select_frame_number"] +1, end_slider) + following_frames = video_state["origin_images"][video_state["select_frame_number"]: end_slider] + + if interactive_state["multi_mask"]["masks"]: + if len(mask_dropdown) == 0: + mask_dropdown = ["mask_001"] + mask_dropdown.sort() + template_mask = interactive_state["multi_mask"]["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1])) + for i in range(1,len(mask_dropdown)): + mask_number = int(mask_dropdown[i].split("_")[1]) - 1 + template_mask = np.clip(template_mask+interactive_state["multi_mask"]["masks"][mask_number]*(mask_number+1), 0, mask_number+1) + video_state["masks"][video_state["select_frame_number"]]= template_mask + else: + template_mask = video_state["masks"][video_state["select_frame_number"]] + fps = video_state["fps"] + + audio_path = video_state["audio"] + + # operation error + if len(np.unique(template_mask))==1: + template_mask[0][0]=1 + foreground, alpha = matanyone(matanyone_processor, following_frames, template_mask*255, r_erode=erode_kernel_size, r_dilate=dilate_kernel_size) + output_frames = [] + foreground_mat = matting_type == "Foreground" + for frame_origin, frame_alpha in zip(following_frames, alpha): + if foreground_mat: + frame_alpha[frame_alpha > 127] = 255 + frame_alpha[frame_alpha <= 127] = 0 + else: + frame_temp = frame_alpha.copy() + frame_alpha[frame_temp > 127] = 0 + frame_alpha[frame_temp <= 127] = 255 + + output_frame = np.bitwise_and(frame_origin, 255-frame_alpha) + frame_grey = frame_alpha.copy() + frame_grey[frame_alpha == 255] = 127 + output_frame += frame_grey + output_frames.append(output_frame) + foreground = output_frames + + if not os.path.exists("mask_outputs"): + os.makedirs("mask_outputs") + + file_name= video_state["video_name"] + file_name = ".".join(file_name.split(".")[:-1]) + foreground_output = save_video(foreground, output_path="./mask_outputs/{}_fg.mp4".format(file_name), fps=fps) + # foreground_output = generate_video_from_frames(foreground, output_path="./results/{}_fg.mp4".format(video_state["video_name"]), fps=fps, audio_path=audio_path) # import video_input to name the output video + alpha_output = save_video(alpha, output_path="./mask_outputs/{}_alpha.mp4".format(file_name), fps=fps) + # alpha_output = generate_video_from_frames(alpha, output_path="./results/{}_alpha.mp4".format(video_state["video_name"]), fps=fps, gray2rgb=True, audio_path=audio_path) # import video_input to name the output video + + return foreground_output, alpha_output, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True) + + +def show_outputs(): + return gr.update(visible=True), gr.update(visible=True) + +def add_audio_to_video(video_path, audio_path, output_path): + try: + video_input = ffmpeg.input(video_path) + audio_input = ffmpeg.input(audio_path) + + _ = ( + ffmpeg + .output(video_input, audio_input, output_path, vcodec="copy", acodec="aac") + .run(overwrite_output=True, capture_stdout=True, capture_stderr=True) + ) + return output_path + except ffmpeg.Error as e: + print(f"FFmpeg error:\n{e.stderr.decode()}") + return None + + +def generate_video_from_frames(frames, output_path, fps=30, gray2rgb=False, audio_path=""): + """ + Generates a video from a list of frames. + + Args: + frames (list of numpy arrays): The frames to include in the video. + output_path (str): The path to save the generated video. + fps (int, optional): The frame rate of the output video. Defaults to 30. + """ + frames = torch.from_numpy(np.asarray(frames)) + _, h, w, _ = frames.shape + if gray2rgb: + frames = np.repeat(frames, 3, axis=3) + + if not os.path.exists(os.path.dirname(output_path)): + os.makedirs(os.path.dirname(output_path)) + video_temp_path = output_path.replace(".mp4", "_temp.mp4") + + # resize back to ensure input resolution + imageio.mimwrite(video_temp_path, frames, fps=fps, quality=7, + codec='libx264', ffmpeg_params=["-vf", f"scale={w}:{h}"]) + + # add audio to video if audio path exists + if audio_path != "" and os.path.exists(audio_path): + output_path = add_audio_to_video(video_temp_path, audio_path, output_path) + os.remove(video_temp_path) + return output_path + else: + return video_temp_path + +# reset all states for a new input +def restart(): + return { + "user_name": "", + "video_name": "", + "origin_images": None, + "painted_images": None, + "masks": None, + "inpaint_masks": None, + "logits": None, + "select_frame_number": 0, + "fps": 30 + }, { + "inference_times": 0, + "negative_click_times" : 0, + "positive_click_times": 0, + "mask_save": arg_mask_save, + "multi_mask": { + "mask_names": [], + "masks": [] + }, + "track_end_number": None, + }, [[],[]], None, None, \ + gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),\ + gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ + gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ + gr.update(visible=False), gr.update(visible=False, choices=[], value=[]), "", gr.update(visible=False) + +def load_unload_models(selected): + global model + global matanyone_model + if selected: + # args, defined in track_anything.py + sam_checkpoint_url_dict = { + 'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", + 'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", + 'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" + } + # os.path.join('.') + + from mmgp import offload + + # sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[arg_sam_model_type], ".") + sam_checkpoint = None + + transfer_stream = torch.cuda.Stream() + with torch.cuda.stream(transfer_stream): + # initialize sams + model = MaskGenerator(sam_checkpoint, "cuda") + from .matanyone.model.matanyone import MatAnyone + matanyone_model = MatAnyone.from_pretrained("PeiqingYang/MatAnyone") + # pipe ={"mat" : matanyone_model, "sam" :model.samcontroler.sam_controler.model } + # offload.profile(pipe) + matanyone_model = matanyone_model.to(arg_device).eval() + matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg) + else: + import gc + model = None + matanyone_model = None + gc.collect() + torch.cuda.empty_cache() + + +def get_vmc_event_handler(): + return load_unload_models + +def export_to_vace_video_input(foreground_video_output): + gr.Info("Masked Video Input transferred to Vace For Inpainting") + return "V#" + str(time.time()), foreground_video_output + +def export_to_vace_video_mask(foreground_video_output, alpha_video_output): + gr.Info("Masked Video Input and Full Mask transferred to Vace For Stronger Inpainting") + return "MV#" + str(time.time()), foreground_video_output, alpha_video_output + +def display(vace_video_input, vace_video_mask, video_prompt_video_guide_trigger): + # my_tab.select(fn=load_unload_models, inputs=[], outputs=[]) + + media_url = "https://github.com/pq-yang/MatAnyone/releases/download/media/" + + # download assets + + gr.Markdown("Mast Edition is provided by MatAnyone") + + with gr.Column( visible=True): + with gr.Row(): + with gr.Accordion("Video Tutorial (click to expand)", open=False, elem_classes="custom-bg"): + with gr.Row(): + with gr.Column(): + gr.Markdown("### Case 1: Single Target") + gr.Video(value="preprocessing/matanyone/tutorial_single_target.mp4", elem_classes="video") + + with gr.Column(): + gr.Markdown("### Case 2: Multiple Targets") + gr.Video(value="preprocessing/matanyone/tutorial_multi_targets.mp4", elem_classes="video") + + + click_state = gr.State([[],[]]) + + interactive_state = gr.State({ + "inference_times": 0, + "negative_click_times" : 0, + "positive_click_times": 0, + "mask_save": arg_mask_save, + "multi_mask": { + "mask_names": [], + "masks": [] + }, + "track_end_number": None, + } + ) + + video_state = gr.State( + { + "user_name": "", + "video_name": "", + "origin_images": None, + "painted_images": None, + "masks": None, + "inpaint_masks": None, + "logits": None, + "select_frame_number": 0, + "fps": 16, + "audio": "", + } + ) + + with gr.Column( visible=True): + with gr.Row(): + with gr.Accordion('MatAnyone Settings (click to expand)', open=False): + with gr.Row(): + erode_kernel_size = gr.Slider(label='Erode Kernel Size', + minimum=0, + maximum=30, + step=1, + value=10, + info="Erosion on the added mask", + interactive=True) + dilate_kernel_size = gr.Slider(label='Dilate Kernel Size', + minimum=0, + maximum=30, + step=1, + value=10, + info="Dilation on the added mask", + interactive=True) + + with gr.Row(): + image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Start Frame", info="Choose the start frame for target assignment and video matting", visible=False) + end_selection_slider = gr.Slider(minimum=1, maximum=300, step=1, value=81, label="Last Frame to Process", info="Last Frame to Process", visible=False) + + track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="End frame", visible=False) + with gr.Row(): + point_prompt = gr.Radio( + choices=["Positive", "Negative"], + value="Positive", + label="Point Prompt", + info="Click to add positive or negative point for target mask", + interactive=True, + visible=False, + min_width=100, + scale=1) + matting_type = gr.Radio( + choices=["Foreground", "Background"], + value="Foreground", + label="Matting Type", + info="Type of Video Matting to Generate", + interactive=True, + visible=False, + min_width=100, + scale=1) + mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask Selection", info="Choose 1~all mask(s) added in Step 2", visible=False, scale=2) + + gr.Markdown("---") + + with gr.Column(): + # input video + with gr.Row(equal_height=True): + with gr.Column(scale=2): + gr.Markdown("## Step1: Upload video") + with gr.Column(scale=2): + step2_title = gr.Markdown("## Step2: Add masks (Several clicks then **`Add Mask`** one by one)", visible=False) + with gr.Row(equal_height=True): + with gr.Column(scale=2): + video_input = gr.Video(label="Input Video", elem_classes="video") + extract_frames_button = gr.Button(value="Load Video", interactive=True, elem_classes="new_button") + with gr.Column(scale=2): + video_info = gr.Textbox(label="Video Info", visible=False) + template_frame = gr.Image(label="Start Frame", type="pil",interactive=True, elem_id="template_frame", visible=False, elem_classes="image") + with gr.Row(): + clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False, min_width=100) + add_mask_button = gr.Button(value="Set Mask", interactive=True, visible=False, min_width=100) + remove_mask_button = gr.Button(value="Remove Mask", interactive=True, visible=False, min_width=100) # no use + matting_button = gr.Button(value="Generate Video Matting", interactive=True, visible=False, min_width=100) + with gr.Row(): + gr.Markdown("") + + # output video + with gr.Row(equal_height=True) as output_row: + with gr.Column(scale=2): + foreground_video_output = gr.Video(label="Masked Video Output", visible=False, elem_classes="video") + foreground_output_button = gr.Button(value="Black & White Video Output", visible=False, elem_classes="new_button") + export_to_vace_video_input_btn = gr.Button("Export to Vace Video Input Video For Inpainting", visible= False) + with gr.Column(scale=2): + alpha_video_output = gr.Video(label="B & W Mask Video Output", visible=False, elem_classes="video") + alpha_output_button = gr.Button(value="Alpha Mask Output", visible=False, elem_classes="new_button") + export_to_vace_video_mask_btn = gr.Button("Export to Vace Video Input and Video Mask for stronger Inpainting", visible= False) + + export_to_vace_video_input_btn.click(fn=export_to_vace_video_input, inputs= [foreground_video_output], outputs= [video_prompt_video_guide_trigger, vace_video_input]) + export_to_vace_video_mask_btn.click(fn=export_to_vace_video_mask, inputs= [foreground_video_output, alpha_video_output], outputs= [video_prompt_video_guide_trigger, vace_video_input, vace_video_mask]) + # first step: get the video information + extract_frames_button.click( + fn=get_frames_from_video, + inputs=[ + video_input, video_state + ], + outputs=[video_state, video_info, template_frame, + image_selection_slider, end_selection_slider, track_pause_number_slider, point_prompt, matting_type, clear_button_click, add_mask_button, matting_button, template_frame, + foreground_video_output, alpha_video_output, foreground_output_button, alpha_output_button, mask_dropdown, step2_title] + ) + + # second step: select images from slider + image_selection_slider.release(fn=select_video_template, + inputs=[image_selection_slider, video_state, interactive_state], + outputs=[template_frame, video_state, interactive_state], api_name="select_image") + track_pause_number_slider.release(fn=get_end_number, + inputs=[track_pause_number_slider, video_state, interactive_state], + outputs=[template_frame, interactive_state], api_name="end_image") + + # click select image to get mask using sam + template_frame.select( + fn=sam_refine, + inputs=[video_state, point_prompt, click_state, interactive_state], + outputs=[template_frame, video_state, interactive_state] + ) + + # add different mask + add_mask_button.click( + fn=add_multi_mask, + inputs=[video_state, interactive_state, mask_dropdown], + outputs=[interactive_state, mask_dropdown, template_frame, click_state] + ) + + remove_mask_button.click( + fn=remove_multi_mask, + inputs=[interactive_state, mask_dropdown], + outputs=[interactive_state, mask_dropdown] + ) + + # video matting + matting_button.click( + fn=show_outputs, + inputs=[], + outputs=[foreground_video_output, alpha_video_output]).then( + fn=video_matting, + inputs=[video_state, end_selection_slider, matting_type, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size], + outputs=[foreground_video_output, alpha_video_output,foreground_video_output, alpha_video_output, export_to_vace_video_input_btn, export_to_vace_video_mask_btn] + ) + + # click to get mask + mask_dropdown.change( + fn=show_mask, + inputs=[video_state, interactive_state, mask_dropdown], + outputs=[template_frame] + ) + + # clear input + video_input.change( + fn=restart, + inputs=[], + outputs=[ + video_state, + interactive_state, + click_state, + foreground_video_output, alpha_video_output, + template_frame, + image_selection_slider, end_selection_slider, track_pause_number_slider,point_prompt, export_to_vace_video_input_btn, export_to_vace_video_mask_btn, matting_type, clear_button_click, + add_mask_button, matting_button, template_frame, foreground_video_output, alpha_video_output, remove_mask_button, foreground_output_button, alpha_output_button, mask_dropdown, video_info, step2_title + ], + queue=False, + show_progress=False) + + video_input.clear( + fn=restart, + inputs=[], + outputs=[ + video_state, + interactive_state, + click_state, + foreground_video_output, alpha_video_output, + template_frame, + image_selection_slider , end_selection_slider, track_pause_number_slider,point_prompt, export_to_vace_video_input_btn, export_to_vace_video_mask_btn, matting_type, clear_button_click, + add_mask_button, matting_button, template_frame, foreground_video_output, alpha_video_output, remove_mask_button, foreground_output_button, alpha_output_button, mask_dropdown, video_info, step2_title + ], + queue=False, + show_progress=False) + + # points clear + clear_button_click.click( + fn = clear_click, + inputs = [video_state, click_state,], + outputs = [template_frame,click_state], + ) diff --git a/preprocessing/matanyone/matanyone/config/__init__.py b/preprocessing/matanyone/matanyone/config/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/preprocessing/matanyone/matanyone/config/eval_matanyone_config.yaml b/preprocessing/matanyone/matanyone/config/eval_matanyone_config.yaml new file mode 100644 index 0000000..0c4d34f --- /dev/null +++ b/preprocessing/matanyone/matanyone/config/eval_matanyone_config.yaml @@ -0,0 +1,47 @@ +defaults: + - _self_ + - model: base + - override hydra/job_logging: custom-no-rank.yaml + +hydra: + run: + dir: ../output/${exp_id}/${dataset} + output_subdir: ${now:%Y-%m-%d_%H-%M-%S}-hydra + +amp: False +weights: pretrained_models/matanyone.pth # default (can be modified from outside) +output_dir: null # defaults to run_dir; specify this to override +flip_aug: False + + +# maximum shortest side of the input; -1 means no resizing +# With eval_vos.py, we usually just use the dataset's size (resizing done in dataloader) +# this parameter is added for the sole purpose for the GUI in the current codebase +# InferenceCore will downsize the input and restore the output to the original size if needed +# if you are using this code for some other project, you can also utilize this parameter +max_internal_size: -1 + +# these parameters, when set, override the dataset's default; useful for debugging +save_all: True +use_all_masks: False +use_long_term: False +mem_every: 5 + +# only relevant when long_term is not enabled +max_mem_frames: 5 + +# only relevant when long_term is enabled +long_term: + count_usage: True + max_mem_frames: 10 + min_mem_frames: 5 + num_prototypes: 128 + max_num_tokens: 10000 + buffer_tokens: 2000 + +top_k: 30 +stagger_updates: 5 +chunk_size: -1 # number of objects to process in parallel; -1 means unlimited +save_scores: False +save_aux: False +visualize: False diff --git a/preprocessing/matanyone/matanyone/config/hydra/job_logging/custom-no-rank.yaml b/preprocessing/matanyone/matanyone/config/hydra/job_logging/custom-no-rank.yaml new file mode 100644 index 0000000..0173c68 --- /dev/null +++ b/preprocessing/matanyone/matanyone/config/hydra/job_logging/custom-no-rank.yaml @@ -0,0 +1,22 @@ +# python logging configuration for tasks +version: 1 +formatters: + simple: + format: '[%(asctime)s][%(levelname)s] - %(message)s' + datefmt: '%Y-%m-%d %H:%M:%S' +handlers: + console: + class: logging.StreamHandler + formatter: simple + stream: ext://sys.stdout + file: + class: logging.FileHandler + formatter: simple + # absolute file path + filename: ${hydra.runtime.output_dir}/${now:%Y-%m-%d_%H-%M-%S}-eval.log + mode: w +root: + level: INFO + handlers: [console, file] + +disable_existing_loggers: false \ No newline at end of file diff --git a/preprocessing/matanyone/matanyone/config/hydra/job_logging/custom.yaml b/preprocessing/matanyone/matanyone/config/hydra/job_logging/custom.yaml new file mode 100644 index 0000000..16d4969 --- /dev/null +++ b/preprocessing/matanyone/matanyone/config/hydra/job_logging/custom.yaml @@ -0,0 +1,22 @@ +# python logging configuration for tasks +version: 1 +formatters: + simple: + format: '[%(asctime)s][%(levelname)s][r${oc.env:LOCAL_RANK}] - %(message)s' + datefmt: '%Y-%m-%d %H:%M:%S' +handlers: + console: + class: logging.StreamHandler + formatter: simple + stream: ext://sys.stdout + file: + class: logging.FileHandler + formatter: simple + # absolute file path + filename: ${hydra.runtime.output_dir}/${now:%Y-%m-%d_%H-%M-%S}-rank${oc.env:LOCAL_RANK}.log + mode: w +root: + level: INFO + handlers: [console, file] + +disable_existing_loggers: false \ No newline at end of file diff --git a/preprocessing/matanyone/matanyone/config/model/base.yaml b/preprocessing/matanyone/matanyone/config/model/base.yaml new file mode 100644 index 0000000..3d64dcc --- /dev/null +++ b/preprocessing/matanyone/matanyone/config/model/base.yaml @@ -0,0 +1,58 @@ +pixel_mean: [0.485, 0.456, 0.406] +pixel_std: [0.229, 0.224, 0.225] + +pixel_dim: 256 +key_dim: 64 +value_dim: 256 +sensory_dim: 256 +embed_dim: 256 + +pixel_encoder: + type: resnet50 + ms_dims: [1024, 512, 256, 64, 3] # f16, f8, f4, f2, f1 + +mask_encoder: + type: resnet18 + final_dim: 256 + +pixel_pe_scale: 32 +pixel_pe_temperature: 128 + +object_transformer: + embed_dim: ${model.embed_dim} + ff_dim: 2048 + num_heads: 8 + num_blocks: 3 + num_queries: 16 + read_from_pixel: + input_norm: False + input_add_pe: False + add_pe_to_qkv: [True, True, False] + read_from_past: + add_pe_to_qkv: [True, True, False] + read_from_memory: + add_pe_to_qkv: [True, True, False] + read_from_query: + add_pe_to_qkv: [True, True, False] + output_norm: False + query_self_attention: + add_pe_to_qkv: [True, True, False] + pixel_self_attention: + add_pe_to_qkv: [True, True, False] + +object_summarizer: + embed_dim: ${model.object_transformer.embed_dim} + num_summaries: ${model.object_transformer.num_queries} + add_pe: True + +aux_loss: + sensory: + enabled: True + weight: 0.01 + query: + enabled: True + weight: 0.01 + +mask_decoder: + # first value must equal embed_dim + up_dims: [256, 128, 128, 64, 16] diff --git a/preprocessing/matanyone/matanyone/inference/__init__.py b/preprocessing/matanyone/matanyone/inference/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/preprocessing/matanyone/matanyone/inference/image_feature_store.py b/preprocessing/matanyone/matanyone/inference/image_feature_store.py new file mode 100644 index 0000000..7195b05 --- /dev/null +++ b/preprocessing/matanyone/matanyone/inference/image_feature_store.py @@ -0,0 +1,56 @@ +import warnings +from typing import Iterable +import torch +from ..model.matanyone import MatAnyone + + +class ImageFeatureStore: + """ + A cache for image features. + These features might be reused at different parts of the inference pipeline. + This class provide an interface for reusing these features. + It is the user's responsibility to delete redundant features. + + Feature of a frame should be associated with a unique index -- typically the frame id. + """ + def __init__(self, network: MatAnyone, no_warning: bool = False): + self.network = network + self._store = {} + self.no_warning = no_warning + + def _encode_feature(self, index: int, image: torch.Tensor, last_feats=None) -> None: + ms_features, pix_feat = self.network.encode_image(image, last_feats=last_feats) + key, shrinkage, selection = self.network.transform_key(ms_features[0]) + self._store[index] = (ms_features, pix_feat, key, shrinkage, selection) + + def get_all_features(self, images: torch.Tensor) -> (Iterable[torch.Tensor], torch.Tensor): + seq_length = images.shape[0] + ms_features, pix_feat = self.network.encode_image(images, seq_length) + key, shrinkage, selection = self.network.transform_key(ms_features[0]) + for index in range(seq_length): + self._store[index] = ([f[index].unsqueeze(0) for f in ms_features], pix_feat[index].unsqueeze(0), key[index].unsqueeze(0), shrinkage[index].unsqueeze(0), selection[index].unsqueeze(0)) + + def get_features(self, index: int, + image: torch.Tensor, last_feats=None) -> (Iterable[torch.Tensor], torch.Tensor): + if index not in self._store: + self._encode_feature(index, image, last_feats) + + return self._store[index][:2] + + def get_key(self, index: int, + image: torch.Tensor, last_feats=None) -> (torch.Tensor, torch.Tensor, torch.Tensor): + if index not in self._store: + self._encode_feature(index, image, last_feats) + + return self._store[index][2:] + + def delete(self, index: int) -> None: + if index in self._store: + del self._store[index] + + def __len__(self): + return len(self._store) + + def __del__(self): + if len(self._store) > 0 and not self.no_warning: + warnings.warn(f'Leaking {self._store.keys()} in the image feature store') diff --git a/preprocessing/matanyone/matanyone/inference/inference_core.py b/preprocessing/matanyone/matanyone/inference/inference_core.py new file mode 100644 index 0000000..12a6365 --- /dev/null +++ b/preprocessing/matanyone/matanyone/inference/inference_core.py @@ -0,0 +1,406 @@ +from typing import List, Optional, Iterable +import logging +from omegaconf import DictConfig + +import numpy as np +import torch +import torch.nn.functional as F + +from .memory_manager import MemoryManager +from .object_manager import ObjectManager +from .image_feature_store import ImageFeatureStore +from ..model.matanyone import MatAnyone +from ...utils.tensor_utils import pad_divide_by, unpad, aggregate + +log = logging.getLogger() + + +class InferenceCore: + + def __init__(self, + network: MatAnyone, + cfg: DictConfig, + *, + image_feature_store: ImageFeatureStore = None): + self.network = network + self.cfg = cfg + self.mem_every = cfg.mem_every + stagger_updates = cfg.stagger_updates + self.chunk_size = cfg.chunk_size + self.save_aux = cfg.save_aux + self.max_internal_size = cfg.max_internal_size + self.flip_aug = cfg.flip_aug + + self.curr_ti = -1 + self.last_mem_ti = 0 + # at which time indices should we update the sensory memory + if stagger_updates >= self.mem_every: + self.stagger_ti = set(range(1, self.mem_every + 1)) + else: + self.stagger_ti = set( + np.round(np.linspace(1, self.mem_every, stagger_updates)).astype(int)) + self.object_manager = ObjectManager() + self.memory = MemoryManager(cfg=cfg, object_manager=self.object_manager) + + if image_feature_store is None: + self.image_feature_store = ImageFeatureStore(self.network) + else: + self.image_feature_store = image_feature_store + + self.last_mask = None + self.last_pix_feat = None + self.last_msk_value = None + + def clear_memory(self): + self.curr_ti = -1 + self.last_mem_ti = 0 + self.memory = MemoryManager(cfg=self.cfg, object_manager=self.object_manager) + + def clear_non_permanent_memory(self): + self.curr_ti = -1 + self.last_mem_ti = 0 + self.memory.clear_non_permanent_memory() + + def clear_sensory_memory(self): + self.curr_ti = -1 + self.last_mem_ti = 0 + self.memory.clear_sensory_memory() + + def update_config(self, cfg): + self.mem_every = cfg['mem_every'] + self.memory.update_config(cfg) + + def clear_temp_mem(self): + self.memory.clear_work_mem() + # self.object_manager = ObjectManager() + self.memory.clear_obj_mem() + # self.memory.clear_sensory_memory() + + def _add_memory(self, + image: torch.Tensor, + pix_feat: torch.Tensor, + prob: torch.Tensor, + key: torch.Tensor, + shrinkage: torch.Tensor, + selection: torch.Tensor, + *, + is_deep_update: bool = True, + force_permanent: bool = False) -> None: + """ + Memorize the given segmentation in all memory stores. + + The batch dimension is 1 if flip augmentation is not used. + image: RGB image, (1/2)*3*H*W + pix_feat: from the key encoder, (1/2)*_*H*W + prob: (1/2)*num_objects*H*W, in [0, 1] + key/shrinkage/selection: for anisotropic l2, (1/2)*_*H*W + selection can be None if not using long-term memory + is_deep_update: whether to use deep update (e.g. with the mask encoder) + force_permanent: whether to force the memory to be permanent + """ + if prob.shape[1] == 0: + # nothing to add + log.warn('Trying to add an empty object mask to memory!') + return + + if force_permanent: + as_permanent = 'all' + else: + as_permanent = 'first' + + self.memory.initialize_sensory_if_needed(key, self.object_manager.all_obj_ids) + msk_value, sensory, obj_value, _ = self.network.encode_mask( + image, + pix_feat, + self.memory.get_sensory(self.object_manager.all_obj_ids), + prob, + deep_update=is_deep_update, + chunk_size=self.chunk_size, + need_weights=self.save_aux) + self.memory.add_memory(key, + shrinkage, + msk_value, + obj_value, + self.object_manager.all_obj_ids, + selection=selection, + as_permanent=as_permanent) + self.last_mem_ti = self.curr_ti + if is_deep_update: + self.memory.update_sensory(sensory, self.object_manager.all_obj_ids) + self.last_msk_value = msk_value + + def _segment(self, + key: torch.Tensor, + selection: torch.Tensor, + pix_feat: torch.Tensor, + ms_features: Iterable[torch.Tensor], + update_sensory: bool = True) -> torch.Tensor: + """ + Produce a segmentation using the given features and the memory + + The batch dimension is 1 if flip augmentation is not used. + key/selection: for anisotropic l2: (1/2) * _ * H * W + pix_feat: from the key encoder, (1/2) * _ * H * W + ms_features: an iterable of multiscale features from the encoder, each is (1/2)*_*H*W + with strides 16, 8, and 4 respectively + update_sensory: whether to update the sensory memory + + Returns: (num_objects+1)*H*W normalized probability; the first channel is the background + """ + bs = key.shape[0] + if self.flip_aug: + assert bs == 2 + else: + assert bs == 1 + + if not self.memory.engaged: + log.warn('Trying to segment without any memory!') + return torch.zeros((1, key.shape[-2] * 16, key.shape[-1] * 16), + device=key.device, + dtype=key.dtype) + + uncert_output = None + + if self.curr_ti == 0: # ONLY for the first frame for prediction + memory_readout = self.memory.read_first_frame(self.last_msk_value, pix_feat, self.last_mask, self.network, uncert_output=uncert_output) + else: + memory_readout = self.memory.read(pix_feat, key, selection, self.last_mask, self.network, uncert_output=uncert_output, last_msk_value=self.last_msk_value, ti=self.curr_ti, + last_pix_feat=self.last_pix_feat, last_pred_mask=self.last_mask) + memory_readout = self.object_manager.realize_dict(memory_readout) + + sensory, _, pred_prob_with_bg = self.network.segment(ms_features, + memory_readout, + self.memory.get_sensory( + self.object_manager.all_obj_ids), + chunk_size=self.chunk_size, + update_sensory=update_sensory) + # remove batch dim + if self.flip_aug: + # average predictions of the non-flipped and flipped version + pred_prob_with_bg = (pred_prob_with_bg[0] + + torch.flip(pred_prob_with_bg[1], dims=[-1])) / 2 + else: + pred_prob_with_bg = pred_prob_with_bg[0] + if update_sensory: + self.memory.update_sensory(sensory, self.object_manager.all_obj_ids) + return pred_prob_with_bg + + def pred_all_flow(self, images): + self.total_len = images.shape[0] + images, self.pad = pad_divide_by(images, 16) + images = images.unsqueeze(0) # add the batch dimension: (1,t,c,h,w) + + self.flows_forward, self.flows_backward = self.network.pred_forward_backward_flow(images) + + def encode_all_images(self, images): + images, self.pad = pad_divide_by(images, 16) + self.image_feature_store.get_all_features(images) # t c h w + return images + + def step(self, + image: torch.Tensor, + mask: Optional[torch.Tensor] = None, + objects: Optional[List[int]] = None, + *, + idx_mask: bool = False, + end: bool = False, + delete_buffer: bool = True, + force_permanent: bool = False, + matting: bool = True, + first_frame_pred: bool = False) -> torch.Tensor: + """ + Take a step with a new incoming image. + If there is an incoming mask with new objects, we will memorize them. + If there is no incoming mask, we will segment the image using the memory. + In both cases, we will update the memory and return a segmentation. + + image: 3*H*W + mask: H*W (if idx mask) or len(objects)*H*W or None + objects: list of object ids that are valid in the mask Tensor. + The ids themselves do not need to be consecutive/in order, but they need to be + in the same position in the list as the corresponding mask + in the tensor in non-idx-mask mode. + objects is ignored if the mask is None. + If idx_mask is False and objects is None, we sequentially infer the object ids. + idx_mask: if True, mask is expected to contain an object id at every pixel. + If False, mask should have multiple channels with each channel representing one object. + end: if we are at the end of the sequence, we do not need to update memory + if unsure just set it to False + delete_buffer: whether to delete the image feature buffer after this step + force_permanent: the memory recorded this frame will be added to the permanent memory + """ + if objects is None and mask is not None: + assert not idx_mask + objects = list(range(1, mask.shape[0] + 1)) + + # resize input if needed -- currently only used for the GUI + resize_needed = False + if self.max_internal_size > 0: + h, w = image.shape[-2:] + min_side = min(h, w) + if min_side > self.max_internal_size: + resize_needed = True + new_h = int(h / min_side * self.max_internal_size) + new_w = int(w / min_side * self.max_internal_size) + image = F.interpolate(image.unsqueeze(0), + size=(new_h, new_w), + mode='bilinear', + align_corners=False)[0] + if mask is not None: + if idx_mask: + mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).float(), + size=(new_h, new_w), + mode='nearest-exact', + align_corners=False)[0, 0].round().long() + else: + mask = F.interpolate(mask.unsqueeze(0), + size=(new_h, new_w), + mode='bilinear', + align_corners=False)[0] + + self.curr_ti += 1 + + image, self.pad = pad_divide_by(image, 16) # DONE alreay for 3DCNN!! + image = image.unsqueeze(0) # add the batch dimension + if self.flip_aug: + image = torch.cat([image, torch.flip(image, dims=[-1])], dim=0) + + # whether to update the working memory + is_mem_frame = ((self.curr_ti - self.last_mem_ti >= self.mem_every) or + (mask is not None)) and (not end) + # segment when there is no input mask or when the input mask is incomplete + need_segment = (mask is None) or (self.object_manager.num_obj > 0 + and not self.object_manager.has_all(objects)) + update_sensory = ((self.curr_ti - self.last_mem_ti) in self.stagger_ti) and (not end) + + # reinit if it is the first frame for prediction + if first_frame_pred: + self.curr_ti = 0 + self.last_mem_ti = 0 + is_mem_frame = True + need_segment = True + update_sensory = True + + # encoding the image + ms_feat, pix_feat = self.image_feature_store.get_features(self.curr_ti, image) + key, shrinkage, selection = self.image_feature_store.get_key(self.curr_ti, image) + + # segmentation from memory if needed + if need_segment: + pred_prob_with_bg = self._segment(key, + selection, + pix_feat, + ms_feat, + update_sensory=update_sensory) + + # use the input mask if provided + if mask is not None: + # inform the manager of the new objects, and get a list of temporary id + # temporary ids -- indicates the position of objects in the tensor + # (starts with 1 due to the background channel) + corresponding_tmp_ids, _ = self.object_manager.add_new_objects(objects) + + mask, _ = pad_divide_by(mask, 16) + if need_segment: + # merge predicted mask with the incomplete input mask + pred_prob_no_bg = pred_prob_with_bg[1:] + # use the mutual exclusivity of segmentation + if idx_mask: + pred_prob_no_bg[:, mask > 0] = 0 + else: + pred_prob_no_bg[:, mask.max(0) > 0.5] = 0 + + new_masks = [] + for mask_id, tmp_id in enumerate(corresponding_tmp_ids): + if idx_mask: + this_mask = (mask == objects[mask_id]).type_as(pred_prob_no_bg) + else: + this_mask = mask[tmp_id] + if tmp_id > pred_prob_no_bg.shape[0]: + new_masks.append(this_mask.unsqueeze(0)) + else: + # +1 for padding the background channel + pred_prob_no_bg[tmp_id - 1] = this_mask + # new_masks are always in the order of tmp_id + mask = torch.cat([pred_prob_no_bg, *new_masks], dim=0) + elif idx_mask: + # simply convert cls to one-hot representation + if len(objects) == 0: + if delete_buffer: + self.image_feature_store.delete(self.curr_ti) + log.warn('Trying to insert an empty mask as memory!') + return torch.zeros((1, key.shape[-2] * 16, key.shape[-1] * 16), + device=key.device, + dtype=key.dtype) + mask = torch.stack( + [mask == objects[mask_id] for mask_id, _ in enumerate(corresponding_tmp_ids)], + dim=0) + if matting: + mask = mask.unsqueeze(0).float() / 255. + pred_prob_with_bg = torch.cat([1-mask, mask], 0) + else: + pred_prob_with_bg = aggregate(mask, dim=0) + pred_prob_with_bg = torch.softmax(pred_prob_with_bg, dim=0) + + self.last_mask = pred_prob_with_bg[1:].unsqueeze(0) + if self.flip_aug: + self.last_mask = torch.cat( + [self.last_mask, torch.flip(self.last_mask, dims=[-1])], dim=0) + self.last_pix_feat = pix_feat + + # save as memory if needed + if is_mem_frame or force_permanent: + # clear the memory for given mask and add the first predicted mask + if first_frame_pred: + self.clear_temp_mem() + self._add_memory(image, + pix_feat, + self.last_mask, + key, + shrinkage, + selection, + force_permanent=force_permanent, + is_deep_update=True) + else: # compute self.last_msk_value for non-memory frame + msk_value, _, _, _ = self.network.encode_mask( + image, + pix_feat, + self.memory.get_sensory(self.object_manager.all_obj_ids), + self.last_mask, + deep_update=False, + chunk_size=self.chunk_size, + need_weights=self.save_aux) + self.last_msk_value = msk_value + + if delete_buffer: + self.image_feature_store.delete(self.curr_ti) + + output_prob = unpad(pred_prob_with_bg, self.pad) + if resize_needed: + # restore output to the original size + output_prob = F.interpolate(output_prob.unsqueeze(0), + size=(h, w), + mode='bilinear', + align_corners=False)[0] + + return output_prob + + def delete_objects(self, objects: List[int]) -> None: + """ + Delete the given objects from the memory. + """ + self.object_manager.delete_objects(objects) + self.memory.purge_except(self.object_manager.all_obj_ids) + + def output_prob_to_mask(self, output_prob: torch.Tensor, matting: bool = True) -> torch.Tensor: + if matting: + new_mask = output_prob[1:].squeeze(0) + else: + mask = torch.argmax(output_prob, dim=0) + + # index in tensor != object id -- remap the ids here + new_mask = torch.zeros_like(mask) + for tmp_id, obj in self.object_manager.tmp_id_to_obj.items(): + new_mask[mask == tmp_id] = obj.id + + return new_mask diff --git a/preprocessing/matanyone/matanyone/inference/kv_memory_store.py b/preprocessing/matanyone/matanyone/inference/kv_memory_store.py new file mode 100644 index 0000000..e50b794 --- /dev/null +++ b/preprocessing/matanyone/matanyone/inference/kv_memory_store.py @@ -0,0 +1,348 @@ +from typing import Dict, List, Optional, Literal +from collections import defaultdict +import torch + + +def _add_last_dim(dictionary, key, new_value, prepend=False): + # append/prepend a new value to the last dimension of a tensor in a dictionary + # if the key does not exist, put the new value in + # append by default + if key in dictionary: + dictionary[key] = torch.cat([dictionary[key], new_value], -1) + else: + dictionary[key] = new_value + + +class KeyValueMemoryStore: + """ + Works for key/value pairs type storage + e.g., working and long-term memory + """ + def __init__(self, save_selection: bool = False, save_usage: bool = False): + """ + We store keys and values of objects that first appear in the same frame in a bucket. + Each bucket contains a set of object ids. + Each bucket is associated with a single key tensor + and a dictionary of value tensors indexed by object id. + + The keys and values are stored as the concatenation of a permanent part and a temporary part. + """ + self.save_selection = save_selection + self.save_usage = save_usage + + self.global_bucket_id = 0 # does not reduce even if buckets are removed + self.buckets: Dict[int, List[int]] = {} # indexed by bucket id + self.k: Dict[int, torch.Tensor] = {} # indexed by bucket id + self.v: Dict[int, torch.Tensor] = {} # indexed by object id + + # indexed by bucket id; the end point of permanent memory + self.perm_end_pt: Dict[int, int] = defaultdict(int) + + # shrinkage and selection are just like the keys + self.s = {} + if self.save_selection: + self.e = {} # does not contain the permanent memory part + + # usage + if self.save_usage: + self.use_cnt = {} # indexed by bucket id, does not contain the permanent memory part + self.life_cnt = {} # indexed by bucket id, does not contain the permanent memory part + + def add(self, + key: torch.Tensor, + values: Dict[int, torch.Tensor], + shrinkage: torch.Tensor, + selection: torch.Tensor, + supposed_bucket_id: int = -1, + as_permanent: Literal['no', 'first', 'all'] = 'no') -> None: + """ + key: (1/2)*C*N + values: dict of values ((1/2)*C*N), object ids are used as keys + shrinkage: (1/2)*1*N + selection: (1/2)*C*N + + supposed_bucket_id: used to sync the bucket id between working and long-term memory + if provided, the input should all be in a single bucket indexed by this id + as_permanent: whether to store the input as permanent memory + 'no': don't + 'first': only store it as permanent memory if the bucket is empty + 'all': always store it as permanent memory + """ + bs = key.shape[0] + ne = key.shape[-1] + assert len(key.shape) == 3 + assert len(shrinkage.shape) == 3 + assert not self.save_selection or len(selection.shape) == 3 + assert as_permanent in ['no', 'first', 'all'] + + # add the value and create new buckets if necessary + if supposed_bucket_id >= 0: + enabled_buckets = [supposed_bucket_id] + bucket_exist = supposed_bucket_id in self.buckets + for obj, value in values.items(): + if bucket_exist: + assert obj in self.v + assert obj in self.buckets[supposed_bucket_id] + _add_last_dim(self.v, obj, value, prepend=(as_permanent == 'all')) + else: + assert obj not in self.v + self.v[obj] = value + self.buckets[supposed_bucket_id] = list(values.keys()) + else: + new_bucket_id = None + enabled_buckets = set() + for obj, value in values.items(): + assert len(value.shape) == 3 + if obj in self.v: + _add_last_dim(self.v, obj, value, prepend=(as_permanent == 'all')) + bucket_used = [ + bucket_id for bucket_id, object_ids in self.buckets.items() + if obj in object_ids + ] + assert len(bucket_used) == 1 # each object should only be in one bucket + enabled_buckets.add(bucket_used[0]) + else: + self.v[obj] = value + if new_bucket_id is None: + # create new bucket + new_bucket_id = self.global_bucket_id + self.global_bucket_id += 1 + self.buckets[new_bucket_id] = [] + # put the new object into the corresponding bucket + self.buckets[new_bucket_id].append(obj) + enabled_buckets.add(new_bucket_id) + + # increment the permanent size if necessary + add_as_permanent = {} # indexed by bucket id + for bucket_id in enabled_buckets: + add_as_permanent[bucket_id] = False + if as_permanent == 'all': + self.perm_end_pt[bucket_id] += ne + add_as_permanent[bucket_id] = True + elif as_permanent == 'first': + if self.perm_end_pt[bucket_id] == 0: + self.perm_end_pt[bucket_id] = ne + add_as_permanent[bucket_id] = True + + # create new counters for usage if necessary + if self.save_usage and as_permanent != 'all': + new_count = torch.zeros((bs, ne), device=key.device, dtype=torch.float32) + new_life = torch.zeros((bs, ne), device=key.device, dtype=torch.float32) + 1e-7 + + # add the key to every bucket + for bucket_id in self.buckets: + if bucket_id not in enabled_buckets: + # if we are not adding new values to a bucket, we should skip it + continue + + _add_last_dim(self.k, bucket_id, key, prepend=add_as_permanent[bucket_id]) + _add_last_dim(self.s, bucket_id, shrinkage, prepend=add_as_permanent[bucket_id]) + if not add_as_permanent[bucket_id]: + if self.save_selection: + _add_last_dim(self.e, bucket_id, selection) + if self.save_usage: + _add_last_dim(self.use_cnt, bucket_id, new_count) + _add_last_dim(self.life_cnt, bucket_id, new_life) + + def update_bucket_usage(self, bucket_id: int, usage: torch.Tensor) -> None: + # increase all life count by 1 + # increase use of indexed elements + if not self.save_usage: + return + + usage = usage[:, self.perm_end_pt[bucket_id]:] + if usage.shape[-1] == 0: + # if there is no temporary memory, we don't need to update + return + self.use_cnt[bucket_id] += usage.view_as(self.use_cnt[bucket_id]) + self.life_cnt[bucket_id] += 1 + + def sieve_by_range(self, bucket_id: int, start: int, end: int, min_size: int) -> None: + # keep only the temporary elements *outside* of this range (with some boundary conditions) + # the permanent elements are ignored in this computation + # i.e., concat (a[:start], a[end:]) + # bucket with size <= min_size are not modified + + assert start >= 0 + assert end <= 0 + + object_ids = self.buckets[bucket_id] + bucket_num_elements = self.k[bucket_id].shape[-1] - self.perm_end_pt[bucket_id] + if bucket_num_elements <= min_size: + return + + if end == 0: + # negative 0 would not work as the end index! + # effectively make the second part an empty slice + end = self.k[bucket_id].shape[-1] + 1 + + p_size = self.perm_end_pt[bucket_id] + start = start + p_size + + k = self.k[bucket_id] + s = self.s[bucket_id] + if self.save_selection: + e = self.e[bucket_id] + if self.save_usage: + use_cnt = self.use_cnt[bucket_id] + life_cnt = self.life_cnt[bucket_id] + + self.k[bucket_id] = torch.cat([k[:, :, :start], k[:, :, end:]], -1) + self.s[bucket_id] = torch.cat([s[:, :, :start], s[:, :, end:]], -1) + if self.save_selection: + self.e[bucket_id] = torch.cat([e[:, :, :start - p_size], e[:, :, end:]], -1) + if self.save_usage: + self.use_cnt[bucket_id] = torch.cat([use_cnt[:, :start - p_size], use_cnt[:, end:]], -1) + self.life_cnt[bucket_id] = torch.cat([life_cnt[:, :start - p_size], life_cnt[:, end:]], + -1) + for obj_id in object_ids: + v = self.v[obj_id] + self.v[obj_id] = torch.cat([v[:, :, :start], v[:, :, end:]], -1) + + def remove_old_memory(self, bucket_id: int, max_len: int) -> None: + self.sieve_by_range(bucket_id, 0, -max_len, max_len) + + def remove_obsolete_features(self, bucket_id: int, max_size: int) -> None: + # for long-term memory only + object_ids = self.buckets[bucket_id] + + assert self.perm_end_pt[bucket_id] == 0 # permanent memory should be empty in LT memory + + # normalize with life duration + usage = self.get_usage(bucket_id) + bs = usage.shape[0] + + survivals = [] + + for bi in range(bs): + _, survived = torch.topk(usage[bi], k=max_size) + survivals.append(survived.flatten()) + assert survived.shape[-1] == survivals[0].shape[-1] + + self.k[bucket_id] = torch.stack( + [self.k[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0) + self.s[bucket_id] = torch.stack( + [self.s[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0) + + if self.save_selection: + # Long-term memory does not store selection so this should not be needed + self.e[bucket_id] = torch.stack( + [self.e[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0) + for obj_id in object_ids: + self.v[obj_id] = torch.stack( + [self.v[obj_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0) + + self.use_cnt[bucket_id] = torch.stack( + [self.use_cnt[bucket_id][bi, survived] for bi, survived in enumerate(survivals)], 0) + self.life_cnt[bucket_id] = torch.stack( + [self.life_cnt[bucket_id][bi, survived] for bi, survived in enumerate(survivals)], 0) + + def get_usage(self, bucket_id: int) -> torch.Tensor: + # return normalized usage + if not self.save_usage: + raise RuntimeError('I did not count usage!') + else: + usage = self.use_cnt[bucket_id] / self.life_cnt[bucket_id] + return usage + + def get_all_sliced( + self, bucket_id: int, start: int, end: int + ) -> (torch.Tensor, torch.Tensor, torch.Tensor, Dict[int, torch.Tensor], torch.Tensor): + # return k, sk, ek, value, normalized usage in order, sliced by start and end + # this only queries the temporary memory + + assert start >= 0 + assert end <= 0 + + p_size = self.perm_end_pt[bucket_id] + start = start + p_size + + if end == 0: + # negative 0 would not work as the end index! + k = self.k[bucket_id][:, :, start:] + sk = self.s[bucket_id][:, :, start:] + ek = self.e[bucket_id][:, :, start - p_size:] if self.save_selection else None + value = {obj_id: self.v[obj_id][:, :, start:] for obj_id in self.buckets[bucket_id]} + usage = self.get_usage(bucket_id)[:, start - p_size:] if self.save_usage else None + else: + k = self.k[bucket_id][:, :, start:end] + sk = self.s[bucket_id][:, :, start:end] + ek = self.e[bucket_id][:, :, start - p_size:end] if self.save_selection else None + value = {obj_id: self.v[obj_id][:, :, start:end] for obj_id in self.buckets[bucket_id]} + usage = self.get_usage(bucket_id)[:, start - p_size:end] if self.save_usage else None + + return k, sk, ek, value, usage + + def purge_except(self, obj_keep_idx: List[int]): + # purge certain objects from the memory except the one listed + obj_keep_idx = set(obj_keep_idx) + + # remove objects that are not in the keep list from the buckets + buckets_to_remove = [] + for bucket_id, object_ids in self.buckets.items(): + self.buckets[bucket_id] = [obj_id for obj_id in object_ids if obj_id in obj_keep_idx] + if len(self.buckets[bucket_id]) == 0: + buckets_to_remove.append(bucket_id) + + # remove object values that are not in the keep list + self.v = {k: v for k, v in self.v.items() if k in obj_keep_idx} + + # remove buckets that are empty + for bucket_id in buckets_to_remove: + del self.buckets[bucket_id] + del self.k[bucket_id] + del self.s[bucket_id] + if self.save_selection: + del self.e[bucket_id] + if self.save_usage: + del self.use_cnt[bucket_id] + del self.life_cnt[bucket_id] + + def clear_non_permanent_memory(self): + # clear all non-permanent memory + for bucket_id in self.buckets: + self.sieve_by_range(bucket_id, 0, 0, 0) + + def get_v_size(self, obj_id: int) -> int: + return self.v[obj_id].shape[-1] + + def size(self, bucket_id: int) -> int: + if bucket_id not in self.k: + return 0 + else: + return self.k[bucket_id].shape[-1] + + def perm_size(self, bucket_id: int) -> int: + return self.perm_end_pt[bucket_id] + + def non_perm_size(self, bucket_id: int) -> int: + return self.size(bucket_id) - self.perm_size(bucket_id) + + def engaged(self, bucket_id: Optional[int] = None) -> bool: + if bucket_id is None: + return len(self.buckets) > 0 + else: + return bucket_id in self.buckets + + @property + def num_objects(self) -> int: + return len(self.v) + + @property + def key(self) -> Dict[int, torch.Tensor]: + return self.k + + @property + def value(self) -> Dict[int, torch.Tensor]: + return self.v + + @property + def shrinkage(self) -> Dict[int, torch.Tensor]: + return self.s + + @property + def selection(self) -> Dict[int, torch.Tensor]: + return self.e + + def __contains__(self, key): + return key in self.v diff --git a/preprocessing/matanyone/matanyone/inference/memory_manager.py b/preprocessing/matanyone/matanyone/inference/memory_manager.py new file mode 100644 index 0000000..b70664c --- /dev/null +++ b/preprocessing/matanyone/matanyone/inference/memory_manager.py @@ -0,0 +1,453 @@ +import logging +from omegaconf import DictConfig +from typing import List, Dict +import torch + +from .object_manager import ObjectManager +from .kv_memory_store import KeyValueMemoryStore +from ..model.matanyone import MatAnyone +from ..model.utils.memory_utils import get_similarity, do_softmax + +log = logging.getLogger() + + +class MemoryManager: + """ + Manages all three memory stores and the transition between working/long-term memory + """ + def __init__(self, cfg: DictConfig, object_manager: ObjectManager): + self.object_manager = object_manager + self.sensory_dim = cfg.model.sensory_dim + self.top_k = cfg.top_k + self.chunk_size = cfg.chunk_size + + self.save_aux = cfg.save_aux + + self.use_long_term = cfg.use_long_term + self.count_long_term_usage = cfg.long_term.count_usage + # subtract 1 because the first-frame is now counted as "permanent memory" + # and is not counted towards max_mem_frames + # but we want to keep the hyperparameters consistent as before for the same behavior + if self.use_long_term: + self.max_mem_frames = cfg.long_term.max_mem_frames - 1 + self.min_mem_frames = cfg.long_term.min_mem_frames - 1 + self.num_prototypes = cfg.long_term.num_prototypes + self.max_long_tokens = cfg.long_term.max_num_tokens + self.buffer_tokens = cfg.long_term.buffer_tokens + else: + self.max_mem_frames = cfg.max_mem_frames - 1 + + # dimensions will be inferred from input later + self.CK = self.CV = None + self.H = self.W = None + + # The sensory memory is stored as a dictionary indexed by object ids + # each of shape bs * C^h * H * W + self.sensory = {} + + # a dictionary indexed by object ids, each of shape bs * T * Q * C + self.obj_v = {} + + self.work_mem = KeyValueMemoryStore(save_selection=self.use_long_term, + save_usage=self.use_long_term) + if self.use_long_term: + self.long_mem = KeyValueMemoryStore(save_usage=self.count_long_term_usage) + + self.config_stale = True + self.engaged = False + + def update_config(self, cfg: DictConfig) -> None: + self.config_stale = True + self.top_k = cfg['top_k'] + + assert self.use_long_term == cfg.use_long_term, 'cannot update this' + assert self.count_long_term_usage == cfg.long_term.count_usage, 'cannot update this' + + self.use_long_term = cfg.use_long_term + self.count_long_term_usage = cfg.long_term.count_usage + if self.use_long_term: + self.max_mem_frames = cfg.long_term.max_mem_frames - 1 + self.min_mem_frames = cfg.long_term.min_mem_frames - 1 + self.num_prototypes = cfg.long_term.num_prototypes + self.max_long_tokens = cfg.long_term.max_num_tokens + self.buffer_tokens = cfg.long_term.buffer_tokens + else: + self.max_mem_frames = cfg.max_mem_frames - 1 + + def _readout(self, affinity, v, uncert_mask=None) -> torch.Tensor: + # affinity: bs*N*HW + # v: bs*C*N or bs*num_objects*C*N + # returns bs*C*HW or bs*num_objects*C*HW + if len(v.shape) == 3: + # single object + if uncert_mask is not None: + return v @ affinity * uncert_mask + else: + return v @ affinity + else: + bs, num_objects, C, N = v.shape + v = v.view(bs, num_objects * C, N) + out = v @ affinity + if uncert_mask is not None: + uncert_mask = uncert_mask.flatten(start_dim=2).expand(-1, C, -1) + out = out * uncert_mask + return out.view(bs, num_objects, C, -1) + + def _get_mask_by_ids(self, mask: torch.Tensor, obj_ids: List[int]) -> torch.Tensor: + # -1 because the mask does not contain the background channel + return mask[:, [self.object_manager.find_tmp_by_id(obj) - 1 for obj in obj_ids]] + + def _get_sensory_by_ids(self, obj_ids: List[int]) -> torch.Tensor: + return torch.stack([self.sensory[obj] for obj in obj_ids], dim=1) + + def _get_object_mem_by_ids(self, obj_ids: List[int]) -> torch.Tensor: + return torch.stack([self.obj_v[obj] for obj in obj_ids], dim=1) + + def _get_visual_values_by_ids(self, obj_ids: List[int]) -> torch.Tensor: + # All the values that the object ids refer to should have the same shape + value = torch.stack([self.work_mem.value[obj] for obj in obj_ids], dim=1) + if self.use_long_term and obj_ids[0] in self.long_mem.value: + lt_value = torch.stack([self.long_mem.value[obj] for obj in obj_ids], dim=1) + value = torch.cat([lt_value, value], dim=-1) + + return value + + def read_first_frame(self, last_msk_value, pix_feat: torch.Tensor, + last_mask: torch.Tensor, network: MatAnyone, uncert_output=None) -> Dict[int, torch.Tensor]: + """ + Read from all memory stores and returns a single memory readout tensor for each object + + pix_feat: (1/2) x C x H x W + query_key: (1/2) x C^k x H x W + selection: (1/2) x C^k x H x W + last_mask: (1/2) x num_objects x H x W (at stride 16) + return a dict of memory readouts, indexed by object indices. Each readout is C*H*W + """ + h, w = pix_feat.shape[-2:] + bs = pix_feat.shape[0] + assert last_mask.shape[0] == bs + + """ + Compute affinity and perform readout + """ + all_readout_mem = {} + buckets = self.work_mem.buckets + for bucket_id, bucket in buckets.items(): + + if self.chunk_size < 1: + object_chunks = [bucket] + else: + object_chunks = [ + bucket[i:i + self.chunk_size] for i in range(0, len(bucket), self.chunk_size) + ] + + for objects in object_chunks: + this_sensory = self._get_sensory_by_ids(objects) + this_last_mask = self._get_mask_by_ids(last_mask, objects) + this_msk_value = self._get_visual_values_by_ids(objects) # (1/2)*num_objects*C*N + pixel_readout = network.pixel_fusion(pix_feat, last_msk_value, this_sensory, + this_last_mask) + this_obj_mem = self._get_object_mem_by_ids(objects).unsqueeze(2) + readout_memory, aux_features = network.readout_query(pixel_readout, this_obj_mem) + for i, obj in enumerate(objects): + all_readout_mem[obj] = readout_memory[:, i] + + if self.save_aux: + aux_output = { + # 'sensory': this_sensory, + # 'pixel_readout': pixel_readout, + 'q_logits': aux_features['logits'] if aux_features else None, + # 'q_weights': aux_features['q_weights'] if aux_features else None, + # 'p_weights': aux_features['p_weights'] if aux_features else None, + # 'attn_mask': aux_features['attn_mask'].float() if aux_features else None, + } + self.aux = aux_output + + return all_readout_mem + + def read(self, pix_feat: torch.Tensor, query_key: torch.Tensor, selection: torch.Tensor, + last_mask: torch.Tensor, network: MatAnyone, uncert_output=None, last_msk_value=None, ti=None, + last_pix_feat=None, last_pred_mask=None) -> Dict[int, torch.Tensor]: + """ + Read from all memory stores and returns a single memory readout tensor for each object + + pix_feat: (1/2) x C x H x W + query_key: (1/2) x C^k x H x W + selection: (1/2) x C^k x H x W + last_mask: (1/2) x num_objects x H x W (at stride 16) + return a dict of memory readouts, indexed by object indices. Each readout is C*H*W + """ + h, w = pix_feat.shape[-2:] + bs = pix_feat.shape[0] + assert query_key.shape[0] == bs + assert selection.shape[0] == bs + assert last_mask.shape[0] == bs + + uncert_mask = uncert_output["mask"] if uncert_output is not None else None + + query_key = query_key.flatten(start_dim=2) # bs*C^k*HW + selection = selection.flatten(start_dim=2) # bs*C^k*HW + """ + Compute affinity and perform readout + """ + all_readout_mem = {} + buckets = self.work_mem.buckets + for bucket_id, bucket in buckets.items(): + if self.use_long_term and self.long_mem.engaged(bucket_id): + # Use long-term memory + long_mem_size = self.long_mem.size(bucket_id) + memory_key = torch.cat([self.long_mem.key[bucket_id], self.work_mem.key[bucket_id]], + -1) + shrinkage = torch.cat( + [self.long_mem.shrinkage[bucket_id], self.work_mem.shrinkage[bucket_id]], -1) + + similarity = get_similarity(memory_key, shrinkage, query_key, selection) + affinity, usage = do_softmax(similarity, + top_k=self.top_k, + inplace=True, + return_usage=True) + """ + Record memory usage for working and long-term memory + """ + # ignore the index return for long-term memory + work_usage = usage[:, long_mem_size:] + self.work_mem.update_bucket_usage(bucket_id, work_usage) + + if self.count_long_term_usage: + # ignore the index return for working memory + long_usage = usage[:, :long_mem_size] + self.long_mem.update_bucket_usage(bucket_id, long_usage) + else: + # no long-term memory + memory_key = self.work_mem.key[bucket_id] + shrinkage = self.work_mem.shrinkage[bucket_id] + similarity = get_similarity(memory_key, shrinkage, query_key, selection, uncert_mask=uncert_mask) + + if self.use_long_term: + affinity, usage = do_softmax(similarity, + top_k=self.top_k, + inplace=True, + return_usage=True) + self.work_mem.update_bucket_usage(bucket_id, usage) + else: + affinity = do_softmax(similarity, top_k=self.top_k, inplace=True) + + if self.chunk_size < 1: + object_chunks = [bucket] + else: + object_chunks = [ + bucket[i:i + self.chunk_size] for i in range(0, len(bucket), self.chunk_size) + ] + + for objects in object_chunks: + this_sensory = self._get_sensory_by_ids(objects) + this_last_mask = self._get_mask_by_ids(last_mask, objects) + this_msk_value = self._get_visual_values_by_ids(objects) # (1/2)*num_objects*C*N + visual_readout = self._readout(affinity, + this_msk_value, uncert_mask).view(bs, len(objects), self.CV, h, w) + + uncert_output = network.pred_uncertainty(last_pix_feat, pix_feat, last_pred_mask, visual_readout[:,0]-last_msk_value[:,0]) + + if uncert_output is not None: + uncert_prob = uncert_output["prob"].unsqueeze(1) # b n 1 h w + visual_readout = visual_readout*uncert_prob + last_msk_value*(1-uncert_prob) + + pixel_readout = network.pixel_fusion(pix_feat, visual_readout, this_sensory, + this_last_mask) + this_obj_mem = self._get_object_mem_by_ids(objects).unsqueeze(2) + readout_memory, aux_features = network.readout_query(pixel_readout, this_obj_mem) + for i, obj in enumerate(objects): + all_readout_mem[obj] = readout_memory[:, i] + + if self.save_aux: + aux_output = { + # 'sensory': this_sensory, + # 'pixel_readout': pixel_readout, + 'q_logits': aux_features['logits'] if aux_features else None, + # 'q_weights': aux_features['q_weights'] if aux_features else None, + # 'p_weights': aux_features['p_weights'] if aux_features else None, + # 'attn_mask': aux_features['attn_mask'].float() if aux_features else None, + } + self.aux = aux_output + + return all_readout_mem + + def add_memory(self, + key: torch.Tensor, + shrinkage: torch.Tensor, + msk_value: torch.Tensor, + obj_value: torch.Tensor, + objects: List[int], + selection: torch.Tensor = None, + *, + as_permanent: bool = False) -> None: + # key: (1/2)*C*H*W + # msk_value: (1/2)*num_objects*C*H*W + # obj_value: (1/2)*num_objects*Q*C + # objects contains a list of object ids corresponding to the objects in msk_value/obj_value + bs = key.shape[0] + assert shrinkage.shape[0] == bs + assert msk_value.shape[0] == bs + assert obj_value.shape[0] == bs + + self.engaged = True + if self.H is None or self.config_stale: + self.config_stale = False + self.H, self.W = msk_value.shape[-2:] + self.HW = self.H * self.W + # convert from num. frames to num. tokens + self.max_work_tokens = self.max_mem_frames * self.HW + if self.use_long_term: + self.min_work_tokens = self.min_mem_frames * self.HW + + # key: bs*C*N + # value: bs*num_objects*C*N + key = key.flatten(start_dim=2) + shrinkage = shrinkage.flatten(start_dim=2) + self.CK = key.shape[1] + + msk_value = msk_value.flatten(start_dim=3) + self.CV = msk_value.shape[2] + + if selection is not None: + # not used in non-long-term mode + selection = selection.flatten(start_dim=2) + + # insert object values into object memory + for obj_id, obj in enumerate(objects): + if obj in self.obj_v: + """streaming average + each self.obj_v[obj] is (1/2)*num_summaries*(embed_dim+1) + first embed_dim keeps track of the sum of embeddings + the last dim keeps the total count + averaging in done inside the object transformer + + incoming obj_value is (1/2)*num_objects*num_summaries*(embed_dim+1) + self.obj_v[obj] = torch.cat([self.obj_v[obj], obj_value[:, obj_id]], dim=0) + """ + last_acc = self.obj_v[obj][:, :, -1] + new_acc = last_acc + obj_value[:, obj_id, :, -1] + + self.obj_v[obj][:, :, :-1] = (self.obj_v[obj][:, :, :-1] + + obj_value[:, obj_id, :, :-1]) + self.obj_v[obj][:, :, -1] = new_acc + else: + self.obj_v[obj] = obj_value[:, obj_id] + + # convert mask value tensor into a dict for insertion + msk_values = {obj: msk_value[:, obj_id] for obj_id, obj in enumerate(objects)} + self.work_mem.add(key, + msk_values, + shrinkage, + selection=selection, + as_permanent=as_permanent) + + for bucket_id in self.work_mem.buckets.keys(): + # long-term memory cleanup + if self.use_long_term: + # Do memory compressed if needed + if self.work_mem.non_perm_size(bucket_id) >= self.max_work_tokens: + # Remove obsolete features if needed + if self.long_mem.non_perm_size(bucket_id) >= (self.max_long_tokens - + self.num_prototypes): + self.long_mem.remove_obsolete_features( + bucket_id, + self.max_long_tokens - self.num_prototypes - self.buffer_tokens) + + self.compress_features(bucket_id) + else: + # FIFO + self.work_mem.remove_old_memory(bucket_id, self.max_work_tokens) + + def purge_except(self, obj_keep_idx: List[int]) -> None: + # purge certain objects from the memory except the one listed + self.work_mem.purge_except(obj_keep_idx) + if self.use_long_term and self.long_mem.engaged(): + self.long_mem.purge_except(obj_keep_idx) + self.sensory = {k: v for k, v in self.sensory.items() if k in obj_keep_idx} + + if not self.work_mem.engaged(): + # everything is removed! + self.engaged = False + + def compress_features(self, bucket_id: int) -> None: + + # perform memory consolidation + prototype_key, prototype_value, prototype_shrinkage = self.consolidation( + *self.work_mem.get_all_sliced(bucket_id, 0, -self.min_work_tokens)) + + # remove consolidated working memory + self.work_mem.sieve_by_range(bucket_id, + 0, + -self.min_work_tokens, + min_size=self.min_work_tokens) + + # add to long-term memory + self.long_mem.add(prototype_key, + prototype_value, + prototype_shrinkage, + selection=None, + supposed_bucket_id=bucket_id) + + def consolidation(self, candidate_key: torch.Tensor, candidate_shrinkage: torch.Tensor, + candidate_selection: torch.Tensor, candidate_value: Dict[int, torch.Tensor], + usage: torch.Tensor) -> (torch.Tensor, Dict[int, torch.Tensor], torch.Tensor): + # find the indices with max usage + bs = candidate_key.shape[0] + assert bs in [1, 2] + + prototype_key = [] + prototype_selection = [] + for bi in range(bs): + _, max_usage_indices = torch.topk(usage[bi], k=self.num_prototypes, dim=-1, sorted=True) + prototype_indices = max_usage_indices.flatten() + prototype_key.append(candidate_key[bi, :, prototype_indices]) + prototype_selection.append(candidate_selection[bi, :, prototype_indices]) + prototype_key = torch.stack(prototype_key, dim=0) + prototype_selection = torch.stack(prototype_selection, dim=0) + """ + Potentiation step + """ + similarity = get_similarity(candidate_key, candidate_shrinkage, prototype_key, + prototype_selection) + affinity = do_softmax(similarity) + + # readout the values + prototype_value = {k: self._readout(affinity, v) for k, v in candidate_value.items()} + + # readout the shrinkage term + prototype_shrinkage = self._readout(affinity, candidate_shrinkage) + + return prototype_key, prototype_value, prototype_shrinkage + + def initialize_sensory_if_needed(self, sample_key: torch.Tensor, ids: List[int]): + for obj in ids: + if obj not in self.sensory: + # also initializes the sensory memory + bs, _, h, w = sample_key.shape + self.sensory[obj] = torch.zeros((bs, self.sensory_dim, h, w), + device=sample_key.device) + + def update_sensory(self, sensory: torch.Tensor, ids: List[int]): + # sensory: 1*num_objects*C*H*W + for obj_id, obj in enumerate(ids): + self.sensory[obj] = sensory[:, obj_id] + + def get_sensory(self, ids: List[int]): + # returns (1/2)*num_objects*C*H*W + return self._get_sensory_by_ids(ids) + + def clear_non_permanent_memory(self): + self.work_mem.clear_non_permanent_memory() + if self.use_long_term: + self.long_mem.clear_non_permanent_memory() + + def clear_sensory_memory(self): + self.sensory = {} + + def clear_work_mem(self): + self.work_mem = KeyValueMemoryStore(save_selection=self.use_long_term, + save_usage=self.use_long_term) + + def clear_obj_mem(self): + self.obj_v = {} diff --git a/preprocessing/matanyone/matanyone/inference/object_info.py b/preprocessing/matanyone/matanyone/inference/object_info.py new file mode 100644 index 0000000..b0e0bd4 --- /dev/null +++ b/preprocessing/matanyone/matanyone/inference/object_info.py @@ -0,0 +1,24 @@ +class ObjectInfo: + """ + Store meta information for an object + """ + def __init__(self, id: int): + self.id = id + self.poke_count = 0 # count number of detections missed + + def poke(self) -> None: + self.poke_count += 1 + + def unpoke(self) -> None: + self.poke_count = 0 + + def __hash__(self): + return hash(self.id) + + def __eq__(self, other): + if type(other) == int: + return self.id == other + return self.id == other.id + + def __repr__(self): + return f'(ID: {self.id})' diff --git a/preprocessing/matanyone/matanyone/inference/object_manager.py b/preprocessing/matanyone/matanyone/inference/object_manager.py new file mode 100644 index 0000000..34a93a2 --- /dev/null +++ b/preprocessing/matanyone/matanyone/inference/object_manager.py @@ -0,0 +1,149 @@ +from typing import Union, List, Dict + +import torch +from .object_info import ObjectInfo + + +class ObjectManager: + """ + Object IDs are immutable. The same ID always represent the same object. + Temporary IDs are the positions of each object in the tensor. It changes as objects get removed. + Temporary IDs start from 1. + """ + + def __init__(self): + self.obj_to_tmp_id: Dict[ObjectInfo, int] = {} + self.tmp_id_to_obj: Dict[int, ObjectInfo] = {} + self.obj_id_to_obj: Dict[int, ObjectInfo] = {} + + self.all_historical_object_ids: List[int] = [] + + def _recompute_obj_id_to_obj_mapping(self) -> None: + self.obj_id_to_obj = {obj.id: obj for obj in self.obj_to_tmp_id} + + def add_new_objects( + self, objects: Union[List[ObjectInfo], ObjectInfo, + List[int]]) -> (List[int], List[int]): + if not isinstance(objects, list): + objects = [objects] + + corresponding_tmp_ids = [] + corresponding_obj_ids = [] + for obj in objects: + if isinstance(obj, int): + obj = ObjectInfo(id=obj) + + if obj in self.obj_to_tmp_id: + # old object + corresponding_tmp_ids.append(self.obj_to_tmp_id[obj]) + corresponding_obj_ids.append(obj.id) + else: + # new object + new_obj = ObjectInfo(id=obj.id) + + # new object + new_tmp_id = len(self.obj_to_tmp_id) + 1 + self.obj_to_tmp_id[new_obj] = new_tmp_id + self.tmp_id_to_obj[new_tmp_id] = new_obj + self.all_historical_object_ids.append(new_obj.id) + corresponding_tmp_ids.append(new_tmp_id) + corresponding_obj_ids.append(new_obj.id) + + self._recompute_obj_id_to_obj_mapping() + assert corresponding_tmp_ids == sorted(corresponding_tmp_ids) + return corresponding_tmp_ids, corresponding_obj_ids + + def delete_objects(self, obj_ids_to_remove: Union[int, List[int]]) -> None: + # delete an object or a list of objects + # re-sort the tmp ids + if isinstance(obj_ids_to_remove, int): + obj_ids_to_remove = [obj_ids_to_remove] + + new_tmp_id = 1 + total_num_id = len(self.obj_to_tmp_id) + + local_obj_to_tmp_id = {} + local_tmp_to_obj_id = {} + + for tmp_iter in range(1, total_num_id + 1): + obj = self.tmp_id_to_obj[tmp_iter] + if obj.id not in obj_ids_to_remove: + local_obj_to_tmp_id[obj] = new_tmp_id + local_tmp_to_obj_id[new_tmp_id] = obj + new_tmp_id += 1 + + self.obj_to_tmp_id = local_obj_to_tmp_id + self.tmp_id_to_obj = local_tmp_to_obj_id + self._recompute_obj_id_to_obj_mapping() + + def purge_inactive_objects(self, + max_missed_detection_count: int) -> (bool, List[int], List[int]): + # remove tmp ids of objects that are removed + obj_id_to_be_deleted = [] + tmp_id_to_be_deleted = [] + tmp_id_to_keep = [] + obj_id_to_keep = [] + + for obj in self.obj_to_tmp_id: + if obj.poke_count > max_missed_detection_count: + obj_id_to_be_deleted.append(obj.id) + tmp_id_to_be_deleted.append(self.obj_to_tmp_id[obj]) + else: + tmp_id_to_keep.append(self.obj_to_tmp_id[obj]) + obj_id_to_keep.append(obj.id) + + purge_activated = len(obj_id_to_be_deleted) > 0 + if purge_activated: + self.delete_objects(obj_id_to_be_deleted) + return purge_activated, tmp_id_to_keep, obj_id_to_keep + + def tmp_to_obj_cls(self, mask) -> torch.Tensor: + # remap tmp id cls representation to the true object id representation + new_mask = torch.zeros_like(mask) + for tmp_id, obj in self.tmp_id_to_obj.items(): + new_mask[mask == tmp_id] = obj.id + return new_mask + + def get_tmp_to_obj_mapping(self) -> Dict[int, ObjectInfo]: + # returns the mapping in a dict format for saving it with pickle + return {obj.id: tmp_id for obj, tmp_id in self.tmp_id_to_obj.items()} + + def realize_dict(self, obj_dict, dim=1) -> torch.Tensor: + # turns a dict indexed by obj id into a tensor, ordered by tmp IDs + output = [] + for _, obj in self.tmp_id_to_obj.items(): + if obj.id not in obj_dict: + raise NotImplementedError + output.append(obj_dict[obj.id]) + output = torch.stack(output, dim=dim) + return output + + def make_one_hot(self, cls_mask) -> torch.Tensor: + output = [] + for _, obj in self.tmp_id_to_obj.items(): + output.append(cls_mask == obj.id) + if len(output) == 0: + output = torch.zeros((0, *cls_mask.shape), dtype=torch.bool, device=cls_mask.device) + else: + output = torch.stack(output, dim=0) + return output + + @property + def all_obj_ids(self) -> List[int]: + return [k.id for k in self.obj_to_tmp_id] + + @property + def num_obj(self) -> int: + return len(self.obj_to_tmp_id) + + def has_all(self, objects: List[int]) -> bool: + for obj in objects: + if obj not in self.obj_to_tmp_id: + return False + return True + + def find_object_by_id(self, obj_id) -> ObjectInfo: + return self.obj_id_to_obj[obj_id] + + def find_tmp_by_id(self, obj_id) -> int: + return self.obj_to_tmp_id[self.obj_id_to_obj[obj_id]] diff --git a/preprocessing/matanyone/matanyone/inference/utils/__init__.py b/preprocessing/matanyone/matanyone/inference/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/preprocessing/matanyone/matanyone/inference/utils/args_utils.py b/preprocessing/matanyone/matanyone/inference/utils/args_utils.py new file mode 100644 index 0000000..a771cca --- /dev/null +++ b/preprocessing/matanyone/matanyone/inference/utils/args_utils.py @@ -0,0 +1,30 @@ +import logging +from omegaconf import DictConfig + +log = logging.getLogger() + + +def get_dataset_cfg(cfg: DictConfig): + dataset_name = cfg.dataset + data_cfg = cfg.datasets[dataset_name] + + potential_overrides = [ + 'image_directory', + 'mask_directory', + 'json_directory', + 'size', + 'save_all', + 'use_all_masks', + 'use_long_term', + 'mem_every', + ] + + for override in potential_overrides: + if cfg[override] is not None: + log.info(f'Overriding config {override} from {data_cfg[override]} to {cfg[override]}') + data_cfg[override] = cfg[override] + # escalte all potential overrides to the top-level config + if override in data_cfg: + cfg[override] = data_cfg[override] + + return data_cfg diff --git a/preprocessing/matanyone/matanyone/model/__init__.py b/preprocessing/matanyone/matanyone/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/preprocessing/matanyone/matanyone/model/aux_modules.py b/preprocessing/matanyone/matanyone/model/aux_modules.py new file mode 100644 index 0000000..efeb515 --- /dev/null +++ b/preprocessing/matanyone/matanyone/model/aux_modules.py @@ -0,0 +1,93 @@ +""" +For computing auxiliary outputs for auxiliary losses +""" +from typing import Dict +from omegaconf import DictConfig +import torch +import torch.nn as nn + +from .group_modules import GConv2d +from ...utils.tensor_utils import aggregate + + +class LinearPredictor(nn.Module): + def __init__(self, x_dim: int, pix_dim: int): + super().__init__() + self.projection = GConv2d(x_dim, pix_dim + 1, kernel_size=1) + + def forward(self, pix_feat: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + # pixel_feat: B*pix_dim*H*W + # x: B*num_objects*x_dim*H*W + num_objects = x.shape[1] + x = self.projection(x) + + pix_feat = pix_feat.unsqueeze(1).expand(-1, num_objects, -1, -1, -1) + logits = (pix_feat * x[:, :, :-1]).sum(dim=2) + x[:, :, -1] + return logits + + +class DirectPredictor(nn.Module): + def __init__(self, x_dim: int): + super().__init__() + self.projection = GConv2d(x_dim, 1, kernel_size=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: B*num_objects*x_dim*H*W + logits = self.projection(x).squeeze(2) + return logits + + +class AuxComputer(nn.Module): + def __init__(self, cfg: DictConfig): + super().__init__() + + use_sensory_aux = cfg.model.aux_loss.sensory.enabled + self.use_query_aux = cfg.model.aux_loss.query.enabled + self.use_sensory_aux = use_sensory_aux + + sensory_dim = cfg.model.sensory_dim + embed_dim = cfg.model.embed_dim + + if use_sensory_aux: + self.sensory_aux = LinearPredictor(sensory_dim, embed_dim) + + def _aggregate_with_selector(self, logits: torch.Tensor, selector: torch.Tensor) -> torch.Tensor: + prob = torch.sigmoid(logits) + if selector is not None: + prob = prob * selector + logits = aggregate(prob, dim=1) + return logits + + def forward(self, pix_feat: torch.Tensor, aux_input: Dict[str, torch.Tensor], + selector: torch.Tensor, seg_pass=False) -> Dict[str, torch.Tensor]: + sensory = aux_input['sensory'] + q_logits = aux_input['q_logits'] + + aux_output = {} + aux_output['attn_mask'] = aux_input['attn_mask'] + + if self.use_sensory_aux: + # B*num_objects*H*W + logits = self.sensory_aux(pix_feat, sensory) + aux_output['sensory_logits'] = self._aggregate_with_selector(logits, selector) + if self.use_query_aux: + # B*num_objects*num_levels*H*W + aux_output['q_logits'] = self._aggregate_with_selector( + torch.stack(q_logits, dim=2), + selector.unsqueeze(2) if selector is not None else None) + + return aux_output + + def compute_mask(self, aux_input: Dict[str, torch.Tensor], + selector: torch.Tensor) -> Dict[str, torch.Tensor]: + # sensory = aux_input['sensory'] + q_logits = aux_input['q_logits'] + + aux_output = {} + + # B*num_objects*num_levels*H*W + aux_output['q_logits'] = self._aggregate_with_selector( + torch.stack(q_logits, dim=2), + selector.unsqueeze(2) if selector is not None else None) + + return aux_output \ No newline at end of file diff --git a/preprocessing/matanyone/matanyone/model/big_modules.py b/preprocessing/matanyone/matanyone/model/big_modules.py new file mode 100644 index 0000000..4d09f53 --- /dev/null +++ b/preprocessing/matanyone/matanyone/model/big_modules.py @@ -0,0 +1,365 @@ +""" +big_modules.py - This file stores higher-level network blocks. + +x - usually denotes features that are shared between objects. +g - usually denotes features that are not shared between objects + with an extra "num_objects" dimension (batch_size * num_objects * num_channels * H * W). + +The trailing number of a variable usually denotes the stride +""" + +from typing import Iterable +from omegaconf import DictConfig +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .group_modules import MainToGroupDistributor, GroupFeatureFusionBlock, GConv2d +from .utils import resnet +from .modules import SensoryDeepUpdater, SensoryUpdater_fullscale, DecoderFeatureProcessor, MaskUpsampleBlock + +class UncertPred(nn.Module): + def __init__(self, model_cfg: DictConfig): + super().__init__() + self.conv1x1_v2 = nn.Conv2d(model_cfg.pixel_dim*2 + 1 + model_cfg.value_dim, 64, kernel_size=1, stride=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.conv3x3 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1, groups=1, bias=False, dilation=1) + self.bn2 = nn.BatchNorm2d(32) + self.conv3x3_out = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1, groups=1, bias=False, dilation=1) + + def forward(self, last_frame_feat: torch.Tensor, cur_frame_feat: torch.Tensor, last_mask: torch.Tensor, mem_val_diff:torch.Tensor): + last_mask = F.interpolate(last_mask, size=last_frame_feat.shape[-2:], mode='area') + x = torch.cat([last_frame_feat, cur_frame_feat, last_mask, mem_val_diff], dim=1) + x = self.conv1x1_v2(x) + x = self.bn1(x) + x = self.relu(x) + x = self.conv3x3(x) + x = self.bn2(x) + x = self.relu(x) + x = self.conv3x3_out(x) + return x + + # override the default train() to freeze BN statistics + def train(self, mode=True): + self.training = False + for module in self.children(): + module.train(False) + return self + +class PixelEncoder(nn.Module): + def __init__(self, model_cfg: DictConfig): + super().__init__() + + self.is_resnet = 'resnet' in model_cfg.pixel_encoder.type + # if model_cfg.pretrained_resnet is set in the model_cfg we get the value + # else default to True + is_pretrained_resnet = getattr(model_cfg,"pretrained_resnet",True) + if self.is_resnet: + if model_cfg.pixel_encoder.type == 'resnet18': + network = resnet.resnet18(pretrained=is_pretrained_resnet) + elif model_cfg.pixel_encoder.type == 'resnet50': + network = resnet.resnet50(pretrained=is_pretrained_resnet) + else: + raise NotImplementedError + self.conv1 = network.conv1 + self.bn1 = network.bn1 + self.relu = network.relu + self.maxpool = network.maxpool + + self.res2 = network.layer1 + self.layer2 = network.layer2 + self.layer3 = network.layer3 + else: + raise NotImplementedError + + def forward(self, x: torch.Tensor, seq_length=None) -> (torch.Tensor, torch.Tensor, torch.Tensor): + f1 = x + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + f2 = x + x = self.maxpool(x) + f4 = self.res2(x) + f8 = self.layer2(f4) + f16 = self.layer3(f8) + + return f16, f8, f4, f2, f1 + + # override the default train() to freeze BN statistics + def train(self, mode=True): + self.training = False + for module in self.children(): + module.train(False) + return self + + +class KeyProjection(nn.Module): + def __init__(self, model_cfg: DictConfig): + super().__init__() + in_dim = model_cfg.pixel_encoder.ms_dims[0] + mid_dim = model_cfg.pixel_dim + key_dim = model_cfg.key_dim + + self.pix_feat_proj = nn.Conv2d(in_dim, mid_dim, kernel_size=1) + self.key_proj = nn.Conv2d(mid_dim, key_dim, kernel_size=3, padding=1) + # shrinkage + self.d_proj = nn.Conv2d(mid_dim, 1, kernel_size=3, padding=1) + # selection + self.e_proj = nn.Conv2d(mid_dim, key_dim, kernel_size=3, padding=1) + + nn.init.orthogonal_(self.key_proj.weight.data) + nn.init.zeros_(self.key_proj.bias.data) + + def forward(self, x: torch.Tensor, *, need_s: bool, + need_e: bool) -> (torch.Tensor, torch.Tensor, torch.Tensor): + x = self.pix_feat_proj(x) + shrinkage = self.d_proj(x)**2 + 1 if (need_s) else None + selection = torch.sigmoid(self.e_proj(x)) if (need_e) else None + + return self.key_proj(x), shrinkage, selection + + +class MaskEncoder(nn.Module): + def __init__(self, model_cfg: DictConfig, single_object=False): + super().__init__() + pixel_dim = model_cfg.pixel_dim + value_dim = model_cfg.value_dim + sensory_dim = model_cfg.sensory_dim + final_dim = model_cfg.mask_encoder.final_dim + + self.single_object = single_object + extra_dim = 1 if single_object else 2 + + # if model_cfg.pretrained_resnet is set in the model_cfg we get the value + # else default to True + is_pretrained_resnet = getattr(model_cfg,"pretrained_resnet",True) + if model_cfg.mask_encoder.type == 'resnet18': + network = resnet.resnet18(pretrained=is_pretrained_resnet, extra_dim=extra_dim) + elif model_cfg.mask_encoder.type == 'resnet50': + network = resnet.resnet50(pretrained=is_pretrained_resnet, extra_dim=extra_dim) + else: + raise NotImplementedError + self.conv1 = network.conv1 + self.bn1 = network.bn1 + self.relu = network.relu + self.maxpool = network.maxpool + + self.layer1 = network.layer1 + self.layer2 = network.layer2 + self.layer3 = network.layer3 + + self.distributor = MainToGroupDistributor() + self.fuser = GroupFeatureFusionBlock(pixel_dim, final_dim, value_dim) + + self.sensory_update = SensoryDeepUpdater(value_dim, sensory_dim) + + def forward(self, + image: torch.Tensor, + pix_feat: torch.Tensor, + sensory: torch.Tensor, + masks: torch.Tensor, + others: torch.Tensor, + *, + deep_update: bool = True, + chunk_size: int = -1) -> (torch.Tensor, torch.Tensor): + # ms_features are from the key encoder + # we only use the first one (lowest resolution), following XMem + if self.single_object: + g = masks.unsqueeze(2) + else: + g = torch.stack([masks, others], dim=2) + + g = self.distributor(image, g) + + batch_size, num_objects = g.shape[:2] + if chunk_size < 1 or chunk_size >= num_objects: + chunk_size = num_objects + fast_path = True + new_sensory = sensory + else: + if deep_update: + new_sensory = torch.empty_like(sensory) + else: + new_sensory = sensory + fast_path = False + + # chunk-by-chunk inference + all_g = [] + for i in range(0, num_objects, chunk_size): + if fast_path: + g_chunk = g + else: + g_chunk = g[:, i:i + chunk_size] + actual_chunk_size = g_chunk.shape[1] + g_chunk = g_chunk.flatten(start_dim=0, end_dim=1) + + g_chunk = self.conv1(g_chunk) + g_chunk = self.bn1(g_chunk) # 1/2, 64 + g_chunk = self.maxpool(g_chunk) # 1/4, 64 + g_chunk = self.relu(g_chunk) + + g_chunk = self.layer1(g_chunk) # 1/4 + g_chunk = self.layer2(g_chunk) # 1/8 + g_chunk = self.layer3(g_chunk) # 1/16 + + g_chunk = g_chunk.view(batch_size, actual_chunk_size, *g_chunk.shape[1:]) + g_chunk = self.fuser(pix_feat, g_chunk) + all_g.append(g_chunk) + if deep_update: + if fast_path: + new_sensory = self.sensory_update(g_chunk, sensory) + else: + new_sensory[:, i:i + chunk_size] = self.sensory_update( + g_chunk, sensory[:, i:i + chunk_size]) + g = torch.cat(all_g, dim=1) + + return g, new_sensory + + # override the default train() to freeze BN statistics + def train(self, mode=True): + self.training = False + for module in self.children(): + module.train(False) + return self + + +class PixelFeatureFuser(nn.Module): + def __init__(self, model_cfg: DictConfig, single_object=False): + super().__init__() + value_dim = model_cfg.value_dim + sensory_dim = model_cfg.sensory_dim + pixel_dim = model_cfg.pixel_dim + embed_dim = model_cfg.embed_dim + self.single_object = single_object + + self.fuser = GroupFeatureFusionBlock(pixel_dim, value_dim, embed_dim) + if self.single_object: + self.sensory_compress = GConv2d(sensory_dim + 1, value_dim, kernel_size=1) + else: + self.sensory_compress = GConv2d(sensory_dim + 2, value_dim, kernel_size=1) + + def forward(self, + pix_feat: torch.Tensor, + pixel_memory: torch.Tensor, + sensory_memory: torch.Tensor, + last_mask: torch.Tensor, + last_others: torch.Tensor, + *, + chunk_size: int = -1) -> torch.Tensor: + batch_size, num_objects = pixel_memory.shape[:2] + + if self.single_object: + last_mask = last_mask.unsqueeze(2) + else: + last_mask = torch.stack([last_mask, last_others], dim=2) + + if chunk_size < 1: + chunk_size = num_objects + + # chunk-by-chunk inference + all_p16 = [] + for i in range(0, num_objects, chunk_size): + sensory_readout = self.sensory_compress( + torch.cat([sensory_memory[:, i:i + chunk_size], last_mask[:, i:i + chunk_size]], 2)) + p16 = pixel_memory[:, i:i + chunk_size] + sensory_readout + p16 = self.fuser(pix_feat, p16) + all_p16.append(p16) + p16 = torch.cat(all_p16, dim=1) + + return p16 + + +class MaskDecoder(nn.Module): + def __init__(self, model_cfg: DictConfig): + super().__init__() + embed_dim = model_cfg.embed_dim + sensory_dim = model_cfg.sensory_dim + ms_image_dims = model_cfg.pixel_encoder.ms_dims + up_dims = model_cfg.mask_decoder.up_dims + + assert embed_dim == up_dims[0] + + self.sensory_update = SensoryUpdater_fullscale([up_dims[0], up_dims[1], up_dims[2], up_dims[3], up_dims[4] + 1], sensory_dim, + sensory_dim) + + self.decoder_feat_proc = DecoderFeatureProcessor(ms_image_dims[1:], up_dims[:-1]) + self.up_16_8 = MaskUpsampleBlock(up_dims[0], up_dims[1]) + self.up_8_4 = MaskUpsampleBlock(up_dims[1], up_dims[2]) + # newly add for alpha matte + self.up_4_2 = MaskUpsampleBlock(up_dims[2], up_dims[3]) + self.up_2_1 = MaskUpsampleBlock(up_dims[3], up_dims[4]) + + self.pred_seg = nn.Conv2d(up_dims[-1], 1, kernel_size=3, padding=1) + self.pred_mat = nn.Conv2d(up_dims[-1], 1, kernel_size=3, padding=1) + + def forward(self, + ms_image_feat: Iterable[torch.Tensor], + memory_readout: torch.Tensor, + sensory: torch.Tensor, + *, + chunk_size: int = -1, + update_sensory: bool = True, + seg_pass: bool = False, + last_mask=None, + sigmoid_residual=False) -> (torch.Tensor, torch.Tensor): + + batch_size, num_objects = memory_readout.shape[:2] + f8, f4, f2, f1 = self.decoder_feat_proc(ms_image_feat[1:]) + if chunk_size < 1 or chunk_size >= num_objects: + chunk_size = num_objects + fast_path = True + new_sensory = sensory + else: + if update_sensory: + new_sensory = torch.empty_like(sensory) + else: + new_sensory = sensory + fast_path = False + + # chunk-by-chunk inference + all_logits = [] + for i in range(0, num_objects, chunk_size): + if fast_path: + p16 = memory_readout + else: + p16 = memory_readout[:, i:i + chunk_size] + actual_chunk_size = p16.shape[1] + + p8 = self.up_16_8(p16, f8) + p4 = self.up_8_4(p8, f4) + p2 = self.up_4_2(p4, f2) + p1 = self.up_2_1(p2, f1) + with torch.amp.autocast("cuda"): + if seg_pass: + if last_mask is not None: + res = self.pred_seg(F.relu(p1.flatten(start_dim=0, end_dim=1).float())) + if sigmoid_residual: + res = (torch.sigmoid(res) - 0.5) * 2 # regularization: (-1, 1) change on last mask + logits = last_mask + res + else: + logits = self.pred_seg(F.relu(p1.flatten(start_dim=0, end_dim=1).float())) + else: + if last_mask is not None: + res = self.pred_mat(F.relu(p1.flatten(start_dim=0, end_dim=1).float())) + if sigmoid_residual: + res = (torch.sigmoid(res) - 0.5) * 2 # regularization: (-1, 1) change on last mask + logits = last_mask + res + else: + logits = self.pred_mat(F.relu(p1.flatten(start_dim=0, end_dim=1).float())) + ## SensoryUpdater_fullscale + if update_sensory: + p1 = torch.cat( + [p1, logits.view(batch_size, actual_chunk_size, 1, *logits.shape[-2:])], 2) + if fast_path: + new_sensory = self.sensory_update([p16, p8, p4, p2, p1], sensory) + else: + new_sensory[:, + i:i + chunk_size] = self.sensory_update([p16, p8, p4, p2, p1], + sensory[:, + i:i + chunk_size]) + all_logits.append(logits) + logits = torch.cat(all_logits, dim=0) + logits = logits.view(batch_size, num_objects, *logits.shape[-2:]) + + return new_sensory, logits diff --git a/preprocessing/matanyone/matanyone/model/channel_attn.py b/preprocessing/matanyone/matanyone/model/channel_attn.py new file mode 100644 index 0000000..a2096c1 --- /dev/null +++ b/preprocessing/matanyone/matanyone/model/channel_attn.py @@ -0,0 +1,39 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class CAResBlock(nn.Module): + def __init__(self, in_dim: int, out_dim: int, residual: bool = True): + super().__init__() + self.residual = residual + self.conv1 = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1) + + t = int((abs(math.log2(out_dim)) + 1) // 2) + k = t if t % 2 else t + 1 + self.pool = nn.AdaptiveAvgPool2d(1) + self.conv = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False) + + if self.residual: + if in_dim == out_dim: + self.downsample = nn.Identity() + else: + self.downsample = nn.Conv2d(in_dim, out_dim, kernel_size=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r = x + x = self.conv1(F.relu(x)) + x = self.conv2(F.relu(x)) + + b, c = x.shape[:2] + w = self.pool(x).view(b, 1, c) + w = self.conv(w).transpose(-1, -2).unsqueeze(-1).sigmoid() # B*C*1*1 + + if self.residual: + x = x * w + self.downsample(r) + else: + x = x * w + + return x diff --git a/preprocessing/matanyone/matanyone/model/group_modules.py b/preprocessing/matanyone/matanyone/model/group_modules.py new file mode 100644 index 0000000..f143f46 --- /dev/null +++ b/preprocessing/matanyone/matanyone/model/group_modules.py @@ -0,0 +1,126 @@ +from typing import Optional +import torch +import torch.nn as nn +import torch.nn.functional as F +from .channel_attn import CAResBlock + +def interpolate_groups(g: torch.Tensor, ratio: float, mode: str, + align_corners: bool) -> torch.Tensor: + batch_size, num_objects = g.shape[:2] + g = F.interpolate(g.flatten(start_dim=0, end_dim=1), + scale_factor=ratio, + mode=mode, + align_corners=align_corners) + g = g.view(batch_size, num_objects, *g.shape[1:]) + return g + + +def upsample_groups(g: torch.Tensor, + ratio: float = 2, + mode: str = 'bilinear', + align_corners: bool = False) -> torch.Tensor: + return interpolate_groups(g, ratio, mode, align_corners) + + +def downsample_groups(g: torch.Tensor, + ratio: float = 1 / 2, + mode: str = 'area', + align_corners: bool = None) -> torch.Tensor: + return interpolate_groups(g, ratio, mode, align_corners) + + +class GConv2d(nn.Conv2d): + def forward(self, g: torch.Tensor) -> torch.Tensor: + batch_size, num_objects = g.shape[:2] + g = super().forward(g.flatten(start_dim=0, end_dim=1)) + return g.view(batch_size, num_objects, *g.shape[1:]) + + +class GroupResBlock(nn.Module): + def __init__(self, in_dim: int, out_dim: int): + super().__init__() + + if in_dim == out_dim: + self.downsample = nn.Identity() + else: + self.downsample = GConv2d(in_dim, out_dim, kernel_size=1) + + self.conv1 = GConv2d(in_dim, out_dim, kernel_size=3, padding=1) + self.conv2 = GConv2d(out_dim, out_dim, kernel_size=3, padding=1) + + def forward(self, g: torch.Tensor) -> torch.Tensor: + out_g = self.conv1(F.relu(g)) + out_g = self.conv2(F.relu(out_g)) + + g = self.downsample(g) + + return out_g + g + + +class MainToGroupDistributor(nn.Module): + def __init__(self, + x_transform: Optional[nn.Module] = None, + g_transform: Optional[nn.Module] = None, + method: str = 'cat', + reverse_order: bool = False): + super().__init__() + + self.x_transform = x_transform + self.g_transform = g_transform + self.method = method + self.reverse_order = reverse_order + + def forward(self, x: torch.Tensor, g: torch.Tensor, skip_expand: bool = False) -> torch.Tensor: + num_objects = g.shape[1] + + if self.x_transform is not None: + x = self.x_transform(x) + + if self.g_transform is not None: + g = self.g_transform(g) + + if not skip_expand: + x = x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1) + if self.method == 'cat': + if self.reverse_order: + g = torch.cat([g, x], 2) + else: + g = torch.cat([x, g], 2) + elif self.method == 'add': + g = x + g + elif self.method == 'mulcat': + g = torch.cat([x * g, g], dim=2) + elif self.method == 'muladd': + g = x * g + g + else: + raise NotImplementedError + + return g + + +class GroupFeatureFusionBlock(nn.Module): + def __init__(self, x_in_dim: int, g_in_dim: int, out_dim: int): + super().__init__() + + x_transform = nn.Conv2d(x_in_dim, out_dim, kernel_size=1) + g_transform = GConv2d(g_in_dim, out_dim, kernel_size=1) + + self.distributor = MainToGroupDistributor(x_transform=x_transform, + g_transform=g_transform, + method='add') + self.block1 = CAResBlock(out_dim, out_dim) + self.block2 = CAResBlock(out_dim, out_dim) + + def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor: + batch_size, num_objects = g.shape[:2] + + g = self.distributor(x, g) + + g = g.flatten(start_dim=0, end_dim=1) + + g = self.block1(g) + g = self.block2(g) + + g = g.view(batch_size, num_objects, *g.shape[1:]) + + return g \ No newline at end of file diff --git a/preprocessing/matanyone/matanyone/model/matanyone.py b/preprocessing/matanyone/matanyone/model/matanyone.py new file mode 100644 index 0000000..ec32c83 --- /dev/null +++ b/preprocessing/matanyone/matanyone/model/matanyone.py @@ -0,0 +1,333 @@ +from typing import List, Dict, Iterable +import logging +from omegaconf import DictConfig +import torch +import torch.nn as nn +import torch.nn.functional as F +from omegaconf import OmegaConf +from huggingface_hub import PyTorchModelHubMixin + +from .big_modules import PixelEncoder, UncertPred, KeyProjection, MaskEncoder, PixelFeatureFuser, MaskDecoder +from .aux_modules import AuxComputer +from .utils.memory_utils import get_affinity, readout +from .transformer.object_transformer import QueryTransformer +from .transformer.object_summarizer import ObjectSummarizer +from ...utils.tensor_utils import aggregate + +log = logging.getLogger() +class MatAnyone(nn.Module, + PyTorchModelHubMixin, + library_name="matanyone", + repo_url="https://github.com/pq-yang/MatAnyone", + coders={ + DictConfig: ( + lambda x: OmegaConf.to_container(x), + lambda data: OmegaConf.create(data), + ) + }, + ): + + def __init__(self, cfg: DictConfig, *, single_object=False): + super().__init__() + self.cfg = cfg + model_cfg = cfg.model + self.ms_dims = model_cfg.pixel_encoder.ms_dims + self.key_dim = model_cfg.key_dim + self.value_dim = model_cfg.value_dim + self.sensory_dim = model_cfg.sensory_dim + self.pixel_dim = model_cfg.pixel_dim + self.embed_dim = model_cfg.embed_dim + self.single_object = single_object + + log.info(f'Single object: {self.single_object}') + + self.pixel_encoder = PixelEncoder(model_cfg) + self.pix_feat_proj = nn.Conv2d(self.ms_dims[0], self.pixel_dim, kernel_size=1) + self.key_proj = KeyProjection(model_cfg) + self.mask_encoder = MaskEncoder(model_cfg, single_object=single_object) + self.mask_decoder = MaskDecoder(model_cfg) + self.pixel_fuser = PixelFeatureFuser(model_cfg, single_object=single_object) + self.object_transformer = QueryTransformer(model_cfg) + self.object_summarizer = ObjectSummarizer(model_cfg) + self.aux_computer = AuxComputer(cfg) + self.temp_sparity = UncertPred(model_cfg) + + self.register_buffer("pixel_mean", torch.Tensor(model_cfg.pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(model_cfg.pixel_std).view(-1, 1, 1), False) + + def _get_others(self, masks: torch.Tensor) -> torch.Tensor: + # for each object, return the sum of masks of all other objects + if self.single_object: + return None + + num_objects = masks.shape[1] + if num_objects >= 1: + others = (masks.sum(dim=1, keepdim=True) - masks).clamp(0, 1) + else: + others = torch.zeros_like(masks) + return others + + def pred_uncertainty(self, last_pix_feat: torch.Tensor, cur_pix_feat: torch.Tensor, last_mask: torch.Tensor, mem_val_diff:torch.Tensor): + logits = self.temp_sparity(last_frame_feat=last_pix_feat, + cur_frame_feat=cur_pix_feat, + last_mask=last_mask, + mem_val_diff=mem_val_diff) + + prob = torch.sigmoid(logits) + mask = (prob > 0) + 0 + + uncert_output = {"logits": logits, + "prob": prob, + "mask": mask} + + return uncert_output + + def encode_image(self, image: torch.Tensor, seq_length=None, last_feats=None) -> (Iterable[torch.Tensor], torch.Tensor): # type: ignore + image = (image - self.pixel_mean) / self.pixel_std + ms_image_feat = self.pixel_encoder(image, seq_length) # f16, f8, f4, f2, f1 + return ms_image_feat, self.pix_feat_proj(ms_image_feat[0]) + + def encode_mask( + self, + image: torch.Tensor, + ms_features: List[torch.Tensor], + sensory: torch.Tensor, + masks: torch.Tensor, + *, + deep_update: bool = True, + chunk_size: int = -1, + need_weights: bool = False) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor): + image = (image - self.pixel_mean) / self.pixel_std + others = self._get_others(masks) + mask_value, new_sensory = self.mask_encoder(image, + ms_features, + sensory, + masks, + others, + deep_update=deep_update, + chunk_size=chunk_size) + object_summaries, object_logits = self.object_summarizer(masks, mask_value, need_weights) + return mask_value, new_sensory, object_summaries, object_logits + + def transform_key(self, + final_pix_feat: torch.Tensor, + *, + need_sk: bool = True, + need_ek: bool = True) -> (torch.Tensor, torch.Tensor, torch.Tensor): + key, shrinkage, selection = self.key_proj(final_pix_feat, need_s=need_sk, need_e=need_ek) + return key, shrinkage, selection + + # Used in training only. + # This step is replaced by MemoryManager in test time + def read_memory(self, query_key: torch.Tensor, query_selection: torch.Tensor, + memory_key: torch.Tensor, memory_shrinkage: torch.Tensor, + msk_value: torch.Tensor, obj_memory: torch.Tensor, pix_feat: torch.Tensor, + sensory: torch.Tensor, last_mask: torch.Tensor, + selector: torch.Tensor, uncert_output=None, seg_pass=False, + last_pix_feat=None, last_pred_mask=None) -> (torch.Tensor, Dict[str, torch.Tensor]): + """ + query_key : B * CK * H * W + query_selection : B * CK * H * W + memory_key : B * CK * T * H * W + memory_shrinkage: B * 1 * T * H * W + msk_value : B * num_objects * CV * T * H * W + obj_memory : B * num_objects * T * num_summaries * C + pixel_feature : B * C * H * W + """ + batch_size, num_objects = msk_value.shape[:2] + + uncert_mask = uncert_output["mask"] if uncert_output is not None else None + + # read using visual attention + with torch.cuda.amp.autocast(enabled=False): + affinity = get_affinity(memory_key.float(), memory_shrinkage.float(), query_key.float(), + query_selection.float(), uncert_mask=uncert_mask) + + msk_value = msk_value.flatten(start_dim=1, end_dim=2).float() + + # B * (num_objects*CV) * H * W + pixel_readout = readout(affinity, msk_value, uncert_mask) + pixel_readout = pixel_readout.view(batch_size, num_objects, self.value_dim, + *pixel_readout.shape[-2:]) + + uncert_output = self.pred_uncertainty(last_pix_feat, pix_feat, last_pred_mask, pixel_readout[:,0]-msk_value[:,:,-1]) + uncert_prob = uncert_output["prob"].unsqueeze(1) # b n 1 h w + pixel_readout = pixel_readout*uncert_prob + msk_value[:,:,-1].unsqueeze(1)*(1-uncert_prob) + + pixel_readout = self.pixel_fusion(pix_feat, pixel_readout, sensory, last_mask) + + + # read from query transformer + mem_readout, aux_features = self.readout_query(pixel_readout, obj_memory, selector=selector, seg_pass=seg_pass) + + aux_output = { + 'sensory': sensory, + 'q_logits': aux_features['logits'] if aux_features else None, + 'attn_mask': aux_features['attn_mask'] if aux_features else None, + } + + return mem_readout, aux_output, uncert_output + + def read_first_frame_memory(self, pixel_readout, + obj_memory: torch.Tensor, pix_feat: torch.Tensor, + sensory: torch.Tensor, last_mask: torch.Tensor, + selector: torch.Tensor, seg_pass=False) -> (torch.Tensor, Dict[str, torch.Tensor]): + """ + query_key : B * CK * H * W + query_selection : B * CK * H * W + memory_key : B * CK * T * H * W + memory_shrinkage: B * 1 * T * H * W + msk_value : B * num_objects * CV * T * H * W + obj_memory : B * num_objects * T * num_summaries * C + pixel_feature : B * C * H * W + """ + + pixel_readout = self.pixel_fusion(pix_feat, pixel_readout, sensory, last_mask) + + # read from query transformer + mem_readout, aux_features = self.readout_query(pixel_readout, obj_memory, selector=selector, seg_pass=seg_pass) + + aux_output = { + 'sensory': sensory, + 'q_logits': aux_features['logits'] if aux_features else None, + 'attn_mask': aux_features['attn_mask'] if aux_features else None, + } + + return mem_readout, aux_output + + def pixel_fusion(self, + pix_feat: torch.Tensor, + pixel: torch.Tensor, + sensory: torch.Tensor, + last_mask: torch.Tensor, + *, + chunk_size: int = -1) -> torch.Tensor: + last_mask = F.interpolate(last_mask, size=sensory.shape[-2:], mode='area') + last_others = self._get_others(last_mask) + fused = self.pixel_fuser(pix_feat, + pixel, + sensory, + last_mask, + last_others, + chunk_size=chunk_size) + return fused + + def readout_query(self, + pixel_readout, + obj_memory, + *, + selector=None, + need_weights=False, + seg_pass=False) -> (torch.Tensor, Dict[str, torch.Tensor]): + return self.object_transformer(pixel_readout, + obj_memory, + selector=selector, + need_weights=need_weights, + seg_pass=seg_pass) + + def segment(self, + ms_image_feat: List[torch.Tensor], + memory_readout: torch.Tensor, + sensory: torch.Tensor, + *, + selector: bool = None, + chunk_size: int = -1, + update_sensory: bool = True, + seg_pass: bool = False, + clamp_mat: bool = True, + last_mask=None, + sigmoid_residual=False, + seg_mat=False) -> (torch.Tensor, torch.Tensor, torch.Tensor): + """ + multi_scale_features is from the key encoder for skip-connection + memory_readout is from working/long-term memory + sensory is the sensory memory + last_mask is the mask from the last frame, supplementing sensory memory + selector is 1 if an object exists, and 0 otherwise. We use it to filter padded objects + during training. + """ + #### use mat head for seg data + if seg_mat: + assert seg_pass + seg_pass = False + #### + sensory, logits = self.mask_decoder(ms_image_feat, + memory_readout, + sensory, + chunk_size=chunk_size, + update_sensory=update_sensory, + seg_pass = seg_pass, + last_mask=last_mask, + sigmoid_residual=sigmoid_residual) + if seg_pass: + prob = torch.sigmoid(logits) + if selector is not None: + prob = prob * selector + + # Softmax over all objects[] + logits = aggregate(prob, dim=1) + prob = F.softmax(logits, dim=1) + else: + if clamp_mat: + logits = logits.clamp(0.0, 1.0) + logits = torch.cat([torch.prod(1 - logits, dim=1, keepdim=True), logits], 1) + prob = logits + + return sensory, logits, prob + + def compute_aux(self, pix_feat: torch.Tensor, aux_inputs: Dict[str, torch.Tensor], + selector: torch.Tensor, seg_pass=False) -> Dict[str, torch.Tensor]: + return self.aux_computer(pix_feat, aux_inputs, selector, seg_pass=seg_pass) + + def forward(self, *args, **kwargs): + raise NotImplementedError + + def load_weights(self, src_dict, init_as_zero_if_needed=False) -> None: + if not self.single_object: + # Map single-object weight to multi-object weight (4->5 out channels in conv1) + for k in list(src_dict.keys()): + if k == 'mask_encoder.conv1.weight': + if src_dict[k].shape[1] == 4: + log.info(f'Converting {k} from single object to multiple objects.') + pads = torch.zeros((64, 1, 7, 7), device=src_dict[k].device) + if not init_as_zero_if_needed: + nn.init.orthogonal_(pads) + log.info(f'Randomly initialized padding for {k}.') + else: + log.info(f'Zero-initialized padding for {k}.') + src_dict[k] = torch.cat([src_dict[k], pads], 1) + elif k == 'pixel_fuser.sensory_compress.weight': + if src_dict[k].shape[1] == self.sensory_dim + 1: + log.info(f'Converting {k} from single object to multiple objects.') + pads = torch.zeros((self.value_dim, 1, 1, 1), device=src_dict[k].device) + if not init_as_zero_if_needed: + nn.init.orthogonal_(pads) + log.info(f'Randomly initialized padding for {k}.') + else: + log.info(f'Zero-initialized padding for {k}.') + src_dict[k] = torch.cat([src_dict[k], pads], 1) + elif self.single_object: + """ + If the model is multiple-object and we are training in single-object, + we strip the last channel of conv1. + This is not supposed to happen in standard training except when users are trying to + finetune a trained model with single object datasets. + """ + if src_dict['mask_encoder.conv1.weight'].shape[1] == 5: + log.warning('Converting mask_encoder.conv1.weight from multiple objects to single object.' + 'This is not supposed to happen in standard training.') + src_dict['mask_encoder.conv1.weight'] = src_dict['mask_encoder.conv1.weight'][:, :-1] + src_dict['pixel_fuser.sensory_compress.weight'] = src_dict['pixel_fuser.sensory_compress.weight'][:, :-1] + + for k in src_dict: + if k not in self.state_dict(): + log.info(f'Key {k} found in src_dict but not in self.state_dict()!!!') + for k in self.state_dict(): + if k not in src_dict: + log.info(f'Key {k} found in self.state_dict() but not in src_dict!!!') + + self.load_state_dict(src_dict, strict=False) + + @property + def device(self) -> torch.device: + return self.pixel_mean.device diff --git a/preprocessing/matanyone/matanyone/model/modules.py b/preprocessing/matanyone/matanyone/model/modules.py new file mode 100644 index 0000000..7350425 --- /dev/null +++ b/preprocessing/matanyone/matanyone/model/modules.py @@ -0,0 +1,149 @@ +from typing import List, Iterable +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .group_modules import MainToGroupDistributor, GroupResBlock, upsample_groups, GConv2d, downsample_groups + + +class UpsampleBlock(nn.Module): + def __init__(self, in_dim: int, out_dim: int, scale_factor: int = 2): + super().__init__() + self.out_conv = ResBlock(in_dim, out_dim) + self.scale_factor = scale_factor + + def forward(self, in_g: torch.Tensor, skip_f: torch.Tensor) -> torch.Tensor: + g = F.interpolate(in_g, + scale_factor=self.scale_factor, + mode='bilinear') + g = self.out_conv(g) + g = g + skip_f + return g + +class MaskUpsampleBlock(nn.Module): + def __init__(self, in_dim: int, out_dim: int, scale_factor: int = 2): + super().__init__() + self.distributor = MainToGroupDistributor(method='add') + self.out_conv = GroupResBlock(in_dim, out_dim) + self.scale_factor = scale_factor + + def forward(self, in_g: torch.Tensor, skip_f: torch.Tensor) -> torch.Tensor: + g = upsample_groups(in_g, ratio=self.scale_factor) + g = self.distributor(skip_f, g) + g = self.out_conv(g) + return g + + +class DecoderFeatureProcessor(nn.Module): + def __init__(self, decoder_dims: List[int], out_dims: List[int]): + super().__init__() + self.transforms = nn.ModuleList([ + nn.Conv2d(d_dim, p_dim, kernel_size=1) for d_dim, p_dim in zip(decoder_dims, out_dims) + ]) + + def forward(self, multi_scale_features: Iterable[torch.Tensor]) -> List[torch.Tensor]: + outputs = [func(x) for x, func in zip(multi_scale_features, self.transforms)] + return outputs + + +# @torch.jit.script +def _recurrent_update(h: torch.Tensor, values: torch.Tensor) -> torch.Tensor: + # h: batch_size * num_objects * hidden_dim * h * w + # values: batch_size * num_objects * (hidden_dim*3) * h * w + dim = values.shape[2] // 3 + forget_gate = torch.sigmoid(values[:, :, :dim]) + update_gate = torch.sigmoid(values[:, :, dim:dim * 2]) + new_value = torch.tanh(values[:, :, dim * 2:]) + new_h = forget_gate * h * (1 - update_gate) + update_gate * new_value + return new_h + + +class SensoryUpdater_fullscale(nn.Module): + # Used in the decoder, multi-scale feature + GRU + def __init__(self, g_dims: List[int], mid_dim: int, sensory_dim: int): + super().__init__() + self.g16_conv = GConv2d(g_dims[0], mid_dim, kernel_size=1) + self.g8_conv = GConv2d(g_dims[1], mid_dim, kernel_size=1) + self.g4_conv = GConv2d(g_dims[2], mid_dim, kernel_size=1) + self.g2_conv = GConv2d(g_dims[3], mid_dim, kernel_size=1) + self.g1_conv = GConv2d(g_dims[4], mid_dim, kernel_size=1) + + self.transform = GConv2d(mid_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1) + + nn.init.xavier_normal_(self.transform.weight) + + def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor: + g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \ + self.g4_conv(downsample_groups(g[2], ratio=1/4)) + \ + self.g2_conv(downsample_groups(g[3], ratio=1/8)) + \ + self.g1_conv(downsample_groups(g[4], ratio=1/16)) + + with torch.amp.autocast("cuda"): + g = g.float() + h = h.float() + values = self.transform(torch.cat([g, h], dim=2)) + new_h = _recurrent_update(h, values) + + return new_h + +class SensoryUpdater(nn.Module): + # Used in the decoder, multi-scale feature + GRU + def __init__(self, g_dims: List[int], mid_dim: int, sensory_dim: int): + super().__init__() + self.g16_conv = GConv2d(g_dims[0], mid_dim, kernel_size=1) + self.g8_conv = GConv2d(g_dims[1], mid_dim, kernel_size=1) + self.g4_conv = GConv2d(g_dims[2], mid_dim, kernel_size=1) + + self.transform = GConv2d(mid_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1) + + nn.init.xavier_normal_(self.transform.weight) + + def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor: + g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \ + self.g4_conv(downsample_groups(g[2], ratio=1/4)) + + with torch.amp.autocast("cuda"): + g = g.float() + h = h.float() + values = self.transform(torch.cat([g, h], dim=2)) + new_h = _recurrent_update(h, values) + + return new_h + + +class SensoryDeepUpdater(nn.Module): + def __init__(self, f_dim: int, sensory_dim: int): + super().__init__() + self.transform = GConv2d(f_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1) + + nn.init.xavier_normal_(self.transform.weight) + + def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor: + with torch.amp.autocast("cuda"): + g = g.float() + h = h.float() + values = self.transform(torch.cat([g, h], dim=2)) + new_h = _recurrent_update(h, values) + + return new_h + + +class ResBlock(nn.Module): + def __init__(self, in_dim: int, out_dim: int): + super().__init__() + + if in_dim == out_dim: + self.downsample = nn.Identity() + else: + self.downsample = nn.Conv2d(in_dim, out_dim, kernel_size=1) + + self.conv1 = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1) + + def forward(self, g: torch.Tensor) -> torch.Tensor: + out_g = self.conv1(F.relu(g)) + out_g = self.conv2(F.relu(out_g)) + + g = self.downsample(g) + + return out_g + g \ No newline at end of file diff --git a/preprocessing/matanyone/matanyone/model/transformer/__init__.py b/preprocessing/matanyone/matanyone/model/transformer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/preprocessing/matanyone/matanyone/model/transformer/object_summarizer.py b/preprocessing/matanyone/matanyone/model/transformer/object_summarizer.py new file mode 100644 index 0000000..a2cf75a --- /dev/null +++ b/preprocessing/matanyone/matanyone/model/transformer/object_summarizer.py @@ -0,0 +1,89 @@ +from typing import Optional +from omegaconf import DictConfig + +import torch +import torch.nn as nn +import torch.nn.functional as F +from .positional_encoding import PositionalEncoding + + +# @torch.jit.script +def _weighted_pooling(masks: torch.Tensor, value: torch.Tensor, + logits: torch.Tensor) -> (torch.Tensor, torch.Tensor): + # value: B*num_objects*H*W*value_dim + # logits: B*num_objects*H*W*num_summaries + # masks: B*num_objects*H*W*num_summaries: 1 if allowed + weights = logits.sigmoid() * masks + # B*num_objects*num_summaries*value_dim + sums = torch.einsum('bkhwq,bkhwc->bkqc', weights, value) + # B*num_objects*H*W*num_summaries -> B*num_objects*num_summaries*1 + area = weights.flatten(start_dim=2, end_dim=3).sum(2).unsqueeze(-1) + + # B*num_objects*num_summaries*value_dim + return sums, area + + +class ObjectSummarizer(nn.Module): + def __init__(self, model_cfg: DictConfig): + super().__init__() + + this_cfg = model_cfg.object_summarizer + self.value_dim = model_cfg.value_dim + self.embed_dim = this_cfg.embed_dim + self.num_summaries = this_cfg.num_summaries + self.add_pe = this_cfg.add_pe + self.pixel_pe_scale = model_cfg.pixel_pe_scale + self.pixel_pe_temperature = model_cfg.pixel_pe_temperature + + if self.add_pe: + self.pos_enc = PositionalEncoding(self.embed_dim, + scale=self.pixel_pe_scale, + temperature=self.pixel_pe_temperature) + + self.input_proj = nn.Linear(self.value_dim, self.embed_dim) + self.feature_pred = nn.Sequential( + nn.Linear(self.embed_dim, self.embed_dim), + nn.ReLU(inplace=True), + nn.Linear(self.embed_dim, self.embed_dim), + ) + self.weights_pred = nn.Sequential( + nn.Linear(self.embed_dim, self.embed_dim), + nn.ReLU(inplace=True), + nn.Linear(self.embed_dim, self.num_summaries), + ) + + def forward(self, + masks: torch.Tensor, + value: torch.Tensor, + need_weights: bool = False) -> (torch.Tensor, Optional[torch.Tensor]): + # masks: B*num_objects*(H0)*(W0) + # value: B*num_objects*value_dim*H*W + # -> B*num_objects*H*W*value_dim + h, w = value.shape[-2:] + masks = F.interpolate(masks, size=(h, w), mode='area') + masks = masks.unsqueeze(-1) + inv_masks = 1 - masks + repeated_masks = torch.cat([ + masks.expand(-1, -1, -1, -1, self.num_summaries // 2), + inv_masks.expand(-1, -1, -1, -1, self.num_summaries // 2), + ], + dim=-1) + + value = value.permute(0, 1, 3, 4, 2) + value = self.input_proj(value) + if self.add_pe: + pe = self.pos_enc(value) + value = value + pe + + with torch.amp.autocast("cuda"): + value = value.float() + feature = self.feature_pred(value) + logits = self.weights_pred(value) + sums, area = _weighted_pooling(repeated_masks, feature, logits) + + summaries = torch.cat([sums, area], dim=-1) + + if need_weights: + return summaries, logits + else: + return summaries, None \ No newline at end of file diff --git a/preprocessing/matanyone/matanyone/model/transformer/object_transformer.py b/preprocessing/matanyone/matanyone/model/transformer/object_transformer.py new file mode 100644 index 0000000..1aa6664 --- /dev/null +++ b/preprocessing/matanyone/matanyone/model/transformer/object_transformer.py @@ -0,0 +1,206 @@ +from typing import Dict, Optional +from omegaconf import DictConfig + +import torch +import torch.nn as nn +from ..group_modules import GConv2d +from ....utils.tensor_utils import aggregate +from .positional_encoding import PositionalEncoding +from .transformer_layers import CrossAttention, SelfAttention, FFN, PixelFFN + + +class QueryTransformerBlock(nn.Module): + def __init__(self, model_cfg: DictConfig): + super().__init__() + + this_cfg = model_cfg.object_transformer + self.embed_dim = this_cfg.embed_dim + self.num_heads = this_cfg.num_heads + self.num_queries = this_cfg.num_queries + self.ff_dim = this_cfg.ff_dim + + self.read_from_pixel = CrossAttention(self.embed_dim, + self.num_heads, + add_pe_to_qkv=this_cfg.read_from_pixel.add_pe_to_qkv) + self.self_attn = SelfAttention(self.embed_dim, + self.num_heads, + add_pe_to_qkv=this_cfg.query_self_attention.add_pe_to_qkv) + self.ffn = FFN(self.embed_dim, self.ff_dim) + self.read_from_query = CrossAttention(self.embed_dim, + self.num_heads, + add_pe_to_qkv=this_cfg.read_from_query.add_pe_to_qkv, + norm=this_cfg.read_from_query.output_norm) + self.pixel_ffn = PixelFFN(self.embed_dim) + + def forward( + self, + x: torch.Tensor, + pixel: torch.Tensor, + query_pe: torch.Tensor, + pixel_pe: torch.Tensor, + attn_mask: torch.Tensor, + need_weights: bool = False) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor): + # x: (bs*num_objects)*num_queries*embed_dim + # pixel: bs*num_objects*C*H*W + # query_pe: (bs*num_objects)*num_queries*embed_dim + # pixel_pe: (bs*num_objects)*(H*W)*C + # attn_mask: (bs*num_objects*num_heads)*num_queries*(H*W) + + # bs*num_objects*C*H*W -> (bs*num_objects)*(H*W)*C + pixel_flat = pixel.flatten(3, 4).flatten(0, 1).transpose(1, 2).contiguous() + x, q_weights = self.read_from_pixel(x, + pixel_flat, + query_pe, + pixel_pe, + attn_mask=attn_mask, + need_weights=need_weights) + x = self.self_attn(x, query_pe) + x = self.ffn(x) + + pixel_flat, p_weights = self.read_from_query(pixel_flat, + x, + pixel_pe, + query_pe, + need_weights=need_weights) + pixel = self.pixel_ffn(pixel, pixel_flat) + + if need_weights: + bs, num_objects, _, h, w = pixel.shape + q_weights = q_weights.view(bs, num_objects, self.num_heads, self.num_queries, h, w) + p_weights = p_weights.transpose(2, 3).view(bs, num_objects, self.num_heads, + self.num_queries, h, w) + + return x, pixel, q_weights, p_weights + + +class QueryTransformer(nn.Module): + def __init__(self, model_cfg: DictConfig): + super().__init__() + + this_cfg = model_cfg.object_transformer + self.value_dim = model_cfg.value_dim + self.embed_dim = this_cfg.embed_dim + self.num_heads = this_cfg.num_heads + self.num_queries = this_cfg.num_queries + + # query initialization and embedding + self.query_init = nn.Embedding(self.num_queries, self.embed_dim) + self.query_emb = nn.Embedding(self.num_queries, self.embed_dim) + + # projection from object summaries to query initialization and embedding + self.summary_to_query_init = nn.Linear(self.embed_dim, self.embed_dim) + self.summary_to_query_emb = nn.Linear(self.embed_dim, self.embed_dim) + + self.pixel_pe_scale = model_cfg.pixel_pe_scale + self.pixel_pe_temperature = model_cfg.pixel_pe_temperature + self.pixel_init_proj = GConv2d(self.embed_dim, self.embed_dim, kernel_size=1) + self.pixel_emb_proj = GConv2d(self.embed_dim, self.embed_dim, kernel_size=1) + self.spatial_pe = PositionalEncoding(self.embed_dim, + scale=self.pixel_pe_scale, + temperature=self.pixel_pe_temperature, + channel_last=False, + transpose_output=True) + + # transformer blocks + self.num_blocks = this_cfg.num_blocks + self.blocks = nn.ModuleList( + QueryTransformerBlock(model_cfg) for _ in range(self.num_blocks)) + self.mask_pred = nn.ModuleList( + nn.Sequential(nn.ReLU(), GConv2d(self.embed_dim, 1, kernel_size=1)) + for _ in range(self.num_blocks + 1)) + + self.act = nn.ReLU(inplace=True) + + def forward(self, + pixel: torch.Tensor, + obj_summaries: torch.Tensor, + selector: Optional[torch.Tensor] = None, + need_weights: bool = False, + seg_pass=False) -> (torch.Tensor, Dict[str, torch.Tensor]): + # pixel: B*num_objects*embed_dim*H*W + # obj_summaries: B*num_objects*T*num_queries*embed_dim + T = obj_summaries.shape[2] + bs, num_objects, _, H, W = pixel.shape + + # normalize object values + # the last channel is the cumulative area of the object + obj_summaries = obj_summaries.view(bs * num_objects, T, self.num_queries, + self.embed_dim + 1) + # sum over time + # during inference, T=1 as we already did streaming average in memory_manager + obj_sums = obj_summaries[:, :, :, :-1].sum(dim=1) + obj_area = obj_summaries[:, :, :, -1:].sum(dim=1) + obj_values = obj_sums / (obj_area + 1e-4) + obj_init = self.summary_to_query_init(obj_values) + obj_emb = self.summary_to_query_emb(obj_values) + + # positional embeddings for object queries + query = self.query_init.weight.unsqueeze(0).expand(bs * num_objects, -1, -1) + obj_init + query_emb = self.query_emb.weight.unsqueeze(0).expand(bs * num_objects, -1, -1) + obj_emb + + # positional embeddings for pixel features + pixel_init = self.pixel_init_proj(pixel) + pixel_emb = self.pixel_emb_proj(pixel) + pixel_pe = self.spatial_pe(pixel.flatten(0, 1)) + pixel_emb = pixel_emb.flatten(3, 4).flatten(0, 1).transpose(1, 2).contiguous() + pixel_pe = pixel_pe.flatten(1, 2) + pixel_emb + + pixel = pixel_init + + # run the transformer + aux_features = {'logits': []} + + # first aux output + aux_logits = self.mask_pred[0](pixel).squeeze(2) + attn_mask = self._get_aux_mask(aux_logits, selector, seg_pass=seg_pass) + aux_features['logits'].append(aux_logits) + for i in range(self.num_blocks): + query, pixel, q_weights, p_weights = self.blocks[i](query, + pixel, + query_emb, + pixel_pe, + attn_mask, + need_weights=need_weights) + + if self.training or i <= self.num_blocks - 1 or need_weights: + aux_logits = self.mask_pred[i + 1](pixel).squeeze(2) + attn_mask = self._get_aux_mask(aux_logits, selector, seg_pass=seg_pass) + aux_features['logits'].append(aux_logits) + + aux_features['q_weights'] = q_weights # last layer only + aux_features['p_weights'] = p_weights # last layer only + + if self.training: + # no need to save all heads + aux_features['attn_mask'] = attn_mask.view(bs, num_objects, self.num_heads, + self.num_queries, H, W)[:, :, 0] + + return pixel, aux_features + + def _get_aux_mask(self, logits: torch.Tensor, selector: torch.Tensor, seg_pass=False) -> torch.Tensor: + # logits: batch_size*num_objects*H*W + # selector: batch_size*num_objects*1*1 + # returns a mask of shape (batch_size*num_objects*num_heads)*num_queries*(H*W) + # where True means the attention is blocked + + if selector is None: + prob = logits.sigmoid() + else: + prob = logits.sigmoid() * selector + logits = aggregate(prob, dim=1) + + is_foreground = (logits[:, 1:] >= logits.max(dim=1, keepdim=True)[0]) + foreground_mask = is_foreground.bool().flatten(start_dim=2) + inv_foreground_mask = ~foreground_mask + inv_background_mask = foreground_mask + + aux_foreground_mask = inv_foreground_mask.unsqueeze(2).unsqueeze(2).repeat( + 1, 1, self.num_heads, self.num_queries // 2, 1).flatten(start_dim=0, end_dim=2) + aux_background_mask = inv_background_mask.unsqueeze(2).unsqueeze(2).repeat( + 1, 1, self.num_heads, self.num_queries // 2, 1).flatten(start_dim=0, end_dim=2) + + aux_mask = torch.cat([aux_foreground_mask, aux_background_mask], dim=1) + + aux_mask[torch.where(aux_mask.sum(-1) == aux_mask.shape[-1])] = False + + return aux_mask \ No newline at end of file diff --git a/preprocessing/matanyone/matanyone/model/transformer/positional_encoding.py b/preprocessing/matanyone/matanyone/model/transformer/positional_encoding.py new file mode 100644 index 0000000..6c15bb7 --- /dev/null +++ b/preprocessing/matanyone/matanyone/model/transformer/positional_encoding.py @@ -0,0 +1,108 @@ +# Reference: +# https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/transformer_decoder/position_encoding.py +# https://github.com/tatp22/multidim-positional-encoding/blob/master/positional_encodings/torch_encodings.py + +import math + +import numpy as np +import torch +from torch import nn + + +def get_emb(sin_inp: torch.Tensor) -> torch.Tensor: + """ + Gets a base embedding for one dimension with sin and cos intertwined + """ + emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1) + return torch.flatten(emb, -2, -1) + + +class PositionalEncoding(nn.Module): + def __init__(self, + dim: int, + scale: float = math.pi * 2, + temperature: float = 10000, + normalize: bool = True, + channel_last: bool = True, + transpose_output: bool = False): + super().__init__() + dim = int(np.ceil(dim / 4) * 2) + self.dim = dim + inv_freq = 1.0 / (temperature**(torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + self.normalize = normalize + self.scale = scale + self.eps = 1e-6 + self.channel_last = channel_last + self.transpose_output = transpose_output + + self.cached_penc = None # the cache is irrespective of the number of objects + + def forward(self, tensor: torch.Tensor) -> torch.Tensor: + """ + :param tensor: A 4/5d tensor of size + channel_last=True: (batch_size, h, w, c) or (batch_size, k, h, w, c) + channel_last=False: (batch_size, c, h, w) or (batch_size, k, c, h, w) + :return: positional encoding tensor that has the same shape as the input if the input is 4d + if the input is 5d, the output is broadcastable along the k-dimension + """ + if len(tensor.shape) != 4 and len(tensor.shape) != 5: + raise RuntimeError(f'The input tensor has to be 4/5d, got {tensor.shape}!') + + if len(tensor.shape) == 5: + # take a sample from the k dimension + num_objects = tensor.shape[1] + tensor = tensor[:, 0] + else: + num_objects = None + + if self.channel_last: + batch_size, h, w, c = tensor.shape + else: + batch_size, c, h, w = tensor.shape + + if self.cached_penc is not None and self.cached_penc.shape == tensor.shape: + if num_objects is None: + return self.cached_penc + else: + return self.cached_penc.unsqueeze(1) + + self.cached_penc = None + + pos_y = torch.arange(h, device=tensor.device, dtype=self.inv_freq.dtype) + pos_x = torch.arange(w, device=tensor.device, dtype=self.inv_freq.dtype) + if self.normalize: + pos_y = pos_y / (pos_y[-1] + self.eps) * self.scale + pos_x = pos_x / (pos_x[-1] + self.eps) * self.scale + + sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq) + sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) + emb_y = get_emb(sin_inp_y).unsqueeze(1) + emb_x = get_emb(sin_inp_x) + + emb = torch.zeros((h, w, self.dim * 2), device=tensor.device, dtype=tensor.dtype) + emb[:, :, :self.dim] = emb_x + emb[:, :, self.dim:] = emb_y + + if not self.channel_last and self.transpose_output: + # cancelled out + pass + elif (not self.channel_last) or (self.transpose_output): + emb = emb.permute(2, 0, 1) + + self.cached_penc = emb.unsqueeze(0).repeat(batch_size, 1, 1, 1) + if num_objects is None: + return self.cached_penc + else: + return self.cached_penc.unsqueeze(1) + + +if __name__ == '__main__': + pe = PositionalEncoding(8).cuda() + input = torch.ones((1, 8, 8, 8)).cuda() + output = pe(input) + # print(output) + print(output[0, :, 0, 0]) + print(output[0, :, 0, 5]) + print(output[0, 0, :, 0]) + print(output[0, 0, 0, :]) diff --git a/preprocessing/matanyone/matanyone/model/transformer/transformer_layers.py b/preprocessing/matanyone/matanyone/model/transformer/transformer_layers.py new file mode 100644 index 0000000..0b57bf2 --- /dev/null +++ b/preprocessing/matanyone/matanyone/model/transformer/transformer_layers.py @@ -0,0 +1,161 @@ +# Modified from PyTorch nn.Transformer + +from typing import List, Callable + +import torch +from torch import Tensor +import torch.nn as nn +import torch.nn.functional as F +from ...model.channel_attn import CAResBlock + + +class SelfAttention(nn.Module): + def __init__(self, + dim: int, + nhead: int, + dropout: float = 0.0, + batch_first: bool = True, + add_pe_to_qkv: List[bool] = [True, True, False]): + super().__init__() + self.self_attn = nn.MultiheadAttention(dim, nhead, dropout=dropout, batch_first=batch_first) + self.norm = nn.LayerNorm(dim) + self.dropout = nn.Dropout(dropout) + self.add_pe_to_qkv = add_pe_to_qkv + + def forward(self, + x: torch.Tensor, + pe: torch.Tensor, + attn_mask: bool = None, + key_padding_mask: bool = None) -> torch.Tensor: + x = self.norm(x) + if any(self.add_pe_to_qkv): + x_with_pe = x + pe + q = x_with_pe if self.add_pe_to_qkv[0] else x + k = x_with_pe if self.add_pe_to_qkv[1] else x + v = x_with_pe if self.add_pe_to_qkv[2] else x + else: + q = k = v = x + + r = x + x = self.self_attn(q, k, v, attn_mask=attn_mask, key_padding_mask=key_padding_mask)[0] + return r + self.dropout(x) + + +# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention +class CrossAttention(nn.Module): + def __init__(self, + dim: int, + nhead: int, + dropout: float = 0.0, + batch_first: bool = True, + add_pe_to_qkv: List[bool] = [True, True, False], + residual: bool = True, + norm: bool = True): + super().__init__() + self.cross_attn = nn.MultiheadAttention(dim, + nhead, + dropout=dropout, + batch_first=batch_first) + if norm: + self.norm = nn.LayerNorm(dim) + else: + self.norm = nn.Identity() + self.dropout = nn.Dropout(dropout) + self.add_pe_to_qkv = add_pe_to_qkv + self.residual = residual + + def forward(self, + x: torch.Tensor, + mem: torch.Tensor, + x_pe: torch.Tensor, + mem_pe: torch.Tensor, + attn_mask: bool = None, + *, + need_weights: bool = False) -> (torch.Tensor, torch.Tensor): + x = self.norm(x) + if self.add_pe_to_qkv[0]: + q = x + x_pe + else: + q = x + + if any(self.add_pe_to_qkv[1:]): + mem_with_pe = mem + mem_pe + k = mem_with_pe if self.add_pe_to_qkv[1] else mem + v = mem_with_pe if self.add_pe_to_qkv[2] else mem + else: + k = v = mem + r = x + x, weights = self.cross_attn(q, + k, + v, + attn_mask=attn_mask, + need_weights=need_weights, + average_attn_weights=False) + + if self.residual: + return r + self.dropout(x), weights + else: + return self.dropout(x), weights + + +class FFN(nn.Module): + def __init__(self, dim_in: int, dim_ff: int, activation=F.relu): + super().__init__() + self.linear1 = nn.Linear(dim_in, dim_ff) + self.linear2 = nn.Linear(dim_ff, dim_in) + self.norm = nn.LayerNorm(dim_in) + + if isinstance(activation, str): + self.activation = _get_activation_fn(activation) + else: + self.activation = activation + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r = x + x = self.norm(x) + x = self.linear2(self.activation(self.linear1(x))) + x = r + x + return x + + +class PixelFFN(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + self.conv = CAResBlock(dim, dim) + + def forward(self, pixel: torch.Tensor, pixel_flat: torch.Tensor) -> torch.Tensor: + # pixel: batch_size * num_objects * dim * H * W + # pixel_flat: (batch_size*num_objects) * (H*W) * dim + bs, num_objects, _, h, w = pixel.shape + pixel_flat = pixel_flat.view(bs * num_objects, h, w, self.dim) + pixel_flat = pixel_flat.permute(0, 3, 1, 2).contiguous() + + x = self.conv(pixel_flat) + x = x.view(bs, num_objects, self.dim, h, w) + return x + + +class OutputFFN(nn.Module): + def __init__(self, dim_in: int, dim_out: int, activation=F.relu): + super().__init__() + self.linear1 = nn.Linear(dim_in, dim_out) + self.linear2 = nn.Linear(dim_out, dim_out) + + if isinstance(activation, str): + self.activation = _get_activation_fn(activation) + else: + self.activation = activation + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear2(self.activation(self.linear1(x))) + return x + + +def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) diff --git a/preprocessing/matanyone/matanyone/model/utils/__init__.py b/preprocessing/matanyone/matanyone/model/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/preprocessing/matanyone/matanyone/model/utils/memory_utils.py b/preprocessing/matanyone/matanyone/model/utils/memory_utils.py new file mode 100644 index 0000000..e7dd5e7 --- /dev/null +++ b/preprocessing/matanyone/matanyone/model/utils/memory_utils.py @@ -0,0 +1,107 @@ +import math +import torch +from typing import Optional, Union, Tuple + + +# @torch.jit.script +def get_similarity(mk: torch.Tensor, + ms: torch.Tensor, + qk: torch.Tensor, + qe: torch.Tensor, + add_batch_dim: bool = False, + uncert_mask = None) -> torch.Tensor: + # used for training/inference and memory reading/memory potentiation + # mk: B x CK x [N] - Memory keys + # ms: B x 1 x [N] - Memory shrinkage + # qk: B x CK x [HW/P] - Query keys + # qe: B x CK x [HW/P] - Query selection + # Dimensions in [] are flattened + # Return: B*N*HW + if add_batch_dim: + mk, ms = mk.unsqueeze(0), ms.unsqueeze(0) + qk, qe = qk.unsqueeze(0), qe.unsqueeze(0) + + CK = mk.shape[1] + + mk = mk.flatten(start_dim=2) + ms = ms.flatten(start_dim=1).unsqueeze(2) if ms is not None else None + qk = qk.flatten(start_dim=2) + qe = qe.flatten(start_dim=2) if qe is not None else None + + # query token selection based on temporal sparsity + if uncert_mask is not None: + uncert_mask = uncert_mask.flatten(start_dim=2) + uncert_mask = uncert_mask.expand(-1, 64, -1) + qk = qk * uncert_mask + qe = qe * uncert_mask + + if qe is not None: + # See XMem's appendix for derivation + mk = mk.transpose(1, 2) + a_sq = (mk.pow(2) @ qe) + two_ab = 2 * (mk @ (qk * qe)) + b_sq = (qe * qk.pow(2)).sum(1, keepdim=True) + similarity = (-a_sq + two_ab - b_sq) + else: + # similar to STCN if we don't have the selection term + a_sq = mk.pow(2).sum(1).unsqueeze(2) + two_ab = 2 * (mk.transpose(1, 2) @ qk) + similarity = (-a_sq + two_ab) + + if ms is not None: + similarity = similarity * ms / math.sqrt(CK) # B*N*HW + else: + similarity = similarity / math.sqrt(CK) # B*N*HW + + return similarity + + +def do_softmax( + similarity: torch.Tensor, + top_k: Optional[int] = None, + inplace: bool = False, + return_usage: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + # normalize similarity with top-k softmax + # similarity: B x N x [HW/P] + # use inplace with care + if top_k is not None: + values, indices = torch.topk(similarity, k=top_k, dim=1) + + x_exp = values.exp_() + x_exp /= torch.sum(x_exp, dim=1, keepdim=True) + if inplace: + similarity.zero_().scatter_(1, indices, x_exp) # B*N*HW + affinity = similarity + else: + affinity = torch.zeros_like(similarity).scatter_(1, indices, x_exp) # B*N*HW + else: + maxes = torch.max(similarity, dim=1, keepdim=True)[0] + x_exp = torch.exp(similarity - maxes) + x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True) + affinity = x_exp / x_exp_sum + indices = None + + if return_usage: + return affinity, affinity.sum(dim=2) + + return affinity + + +def get_affinity(mk: torch.Tensor, ms: torch.Tensor, qk: torch.Tensor, + qe: torch.Tensor, uncert_mask = None) -> torch.Tensor: + # shorthand used in training with no top-k + similarity = get_similarity(mk, ms, qk, qe, uncert_mask=uncert_mask) + affinity = do_softmax(similarity) + return affinity + +def readout(affinity: torch.Tensor, mv: torch.Tensor, uncert_mask: torch.Tensor=None) -> torch.Tensor: + B, CV, T, H, W = mv.shape + + mo = mv.view(B, CV, T * H * W) + mem = torch.bmm(mo, affinity) + if uncert_mask is not None: + uncert_mask = uncert_mask.flatten(start_dim=2).expand(-1, CV, -1) + mem = mem * uncert_mask + mem = mem.view(B, CV, H, W) + + return mem diff --git a/preprocessing/matanyone/matanyone/model/utils/parameter_groups.py b/preprocessing/matanyone/matanyone/model/utils/parameter_groups.py new file mode 100644 index 0000000..177866a --- /dev/null +++ b/preprocessing/matanyone/matanyone/model/utils/parameter_groups.py @@ -0,0 +1,72 @@ +import logging + +log = logging.getLogger() + + +def get_parameter_groups(model, stage_cfg, print_log=False): + """ + Assign different weight decays and learning rates to different parameters. + Returns a parameter group which can be passed to the optimizer. + """ + weight_decay = stage_cfg.weight_decay + embed_weight_decay = stage_cfg.embed_weight_decay + backbone_lr_ratio = stage_cfg.backbone_lr_ratio + base_lr = stage_cfg.learning_rate + + backbone_params = [] + embed_params = [] + other_params = [] + + embedding_names = ['summary_pos', 'query_init', 'query_emb', 'obj_pe'] + embedding_names = [e + '.weight' for e in embedding_names] + + # inspired by detectron2 + memo = set() + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + # Avoid duplicating parameters + if param in memo: + continue + memo.add(param) + + if name.startswith('module'): + name = name[7:] + + inserted = False + if name.startswith('pixel_encoder.'): + backbone_params.append(param) + inserted = True + if print_log: + log.info(f'{name} counted as a backbone parameter.') + else: + for e in embedding_names: + if name.endswith(e): + embed_params.append(param) + inserted = True + if print_log: + log.info(f'{name} counted as an embedding parameter.') + break + + if not inserted: + other_params.append(param) + + parameter_groups = [ + { + 'params': backbone_params, + 'lr': base_lr * backbone_lr_ratio, + 'weight_decay': weight_decay + }, + { + 'params': embed_params, + 'lr': base_lr, + 'weight_decay': embed_weight_decay + }, + { + 'params': other_params, + 'lr': base_lr, + 'weight_decay': weight_decay + }, + ] + + return parameter_groups \ No newline at end of file diff --git a/preprocessing/matanyone/matanyone/model/utils/resnet.py b/preprocessing/matanyone/matanyone/model/utils/resnet.py new file mode 100644 index 0000000..44886ee --- /dev/null +++ b/preprocessing/matanyone/matanyone/model/utils/resnet.py @@ -0,0 +1,179 @@ +""" +resnet.py - A modified ResNet structure +We append extra channels to the first conv by some network surgery +""" + +from collections import OrderedDict +import math + +import torch +import torch.nn as nn +from torch.utils import model_zoo + + +def load_weights_add_extra_dim(target, source_state, extra_dim=1): + new_dict = OrderedDict() + + for k1, v1 in target.state_dict().items(): + if 'num_batches_tracked' not in k1: + if k1 in source_state: + tar_v = source_state[k1] + + if v1.shape != tar_v.shape: + # Init the new segmentation channel with zeros + # print(v1.shape, tar_v.shape) + c, _, w, h = v1.shape + pads = torch.zeros((c, extra_dim, w, h), device=tar_v.device) + nn.init.orthogonal_(pads) + tar_v = torch.cat([tar_v, pads], 1) + + new_dict[k1] = tar_v + + target.load_state_dict(new_dict) + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1, dilation=1): + return nn.Conv2d(in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, + planes, + kernel_size=3, + stride=stride, + dilation=dilation, + padding=dilation, + bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + def __init__(self, block, layers=(3, 4, 23, 3), extra_dim=0): + self.inplanes = 64 + super(ResNet, self).__init__() + self.conv1 = nn.Conv2d(3 + extra_dim, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1, dilation=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [block(self.inplanes, planes, stride, downsample)] + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, dilation=dilation)) + + return nn.Sequential(*layers) + + +def resnet18(pretrained=True, extra_dim=0): + model = ResNet(BasicBlock, [2, 2, 2, 2], extra_dim) + if pretrained: + load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet18']), extra_dim) + return model + + +def resnet50(pretrained=True, extra_dim=0): + model = ResNet(Bottleneck, [3, 4, 6, 3], extra_dim) + if pretrained: + load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet50']), extra_dim) + return model diff --git a/preprocessing/matanyone/matanyone_wrapper.py b/preprocessing/matanyone/matanyone_wrapper.py new file mode 100644 index 0000000..82fb773 --- /dev/null +++ b/preprocessing/matanyone/matanyone_wrapper.py @@ -0,0 +1,73 @@ +import tqdm +import torch +from torchvision.transforms.functional import to_tensor +import numpy as np +import random +import cv2 + +def gen_dilate(alpha, min_kernel_size, max_kernel_size): + kernel_size = random.randint(min_kernel_size, max_kernel_size) + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size,kernel_size)) + fg_and_unknown = np.array(np.not_equal(alpha, 0).astype(np.float32)) + dilate = cv2.dilate(fg_and_unknown, kernel, iterations=1)*255 + return dilate.astype(np.float32) + +def gen_erosion(alpha, min_kernel_size, max_kernel_size): + kernel_size = random.randint(min_kernel_size, max_kernel_size) + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size,kernel_size)) + fg = np.array(np.equal(alpha, 255).astype(np.float32)) + erode = cv2.erode(fg, kernel, iterations=1)*255 + return erode.astype(np.float32) + +@torch.inference_mode() +@torch.amp.autocast('cuda') +def matanyone(processor, frames_np, mask, r_erode=0, r_dilate=0, n_warmup=10): + """ + Args: + frames_np: [(H,W,C)]*n, uint8 + mask: (H,W), uint8 + Outputs: + com: [(H,W,C)]*n, uint8 + pha: [(H,W,C)]*n, uint8 + """ + + # print(f'===== [r_erode] {r_erode}; [r_dilate] {r_dilate} =====') + bgr = (np.array([120, 255, 155], dtype=np.float32)/255).reshape((1, 1, 3)) + objects = [1] + + # [optional] erode & dilate on given seg mask + if r_dilate > 0: + mask = gen_dilate(mask, r_dilate, r_dilate) + if r_erode > 0: + mask = gen_erosion(mask, r_erode, r_erode) + + mask = torch.from_numpy(mask).cuda() + + frames_np = [frames_np[0]]* n_warmup + frames_np + + frames = [] + phas = [] + for ti, frame_single in tqdm.tqdm(enumerate(frames_np)): + image = to_tensor(frame_single).cuda().float() + + if ti == 0: + output_prob = processor.step(image, mask, objects=objects) # encode given mask + output_prob = processor.step(image, first_frame_pred=True) # clear past memory for warmup frames + else: + if ti <= n_warmup: + output_prob = processor.step(image, first_frame_pred=True) # clear past memory for warmup frames + else: + output_prob = processor.step(image) + + # convert output probabilities to an object mask + mask = processor.output_prob_to_mask(output_prob) + + pha = mask.unsqueeze(2).cpu().numpy() + com_np = frame_single / 255. * pha + bgr * (1 - pha) + + # DONOT save the warmup frames + if ti > (n_warmup-1): + frames.append((com_np*255).astype(np.uint8)) + phas.append((pha*255).astype(np.uint8)) + + return frames, phas \ No newline at end of file diff --git a/preprocessing/matanyone/tools/__init__.py b/preprocessing/matanyone/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/preprocessing/matanyone/tools/base_segmenter.py b/preprocessing/matanyone/tools/base_segmenter.py new file mode 100644 index 0000000..096038e --- /dev/null +++ b/preprocessing/matanyone/tools/base_segmenter.py @@ -0,0 +1,141 @@ +import time +import torch +import cv2 +from PIL import Image, ImageDraw, ImageOps +import numpy as np +from typing import Union +from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator +import matplotlib.pyplot as plt +import PIL +from .mask_painter import mask_painter + + +class BaseSegmenter: + def __init__(self, SAM_checkpoint, model_type, device='cuda:0'): + """ + device: model device + SAM_checkpoint: path of SAM checkpoint + model_type: vit_b, vit_l, vit_h + """ + print(f"Initializing BaseSegmenter to {device}") + assert model_type in ['vit_b', 'vit_l', 'vit_h'], 'model_type must be vit_b, vit_l, or vit_h' + + self.device = device + # SAM_checkpoint = None + self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 + from accelerate import init_empty_weights + + # self.model = sam_model_registry[model_type](checkpoint=SAM_checkpoint) + with init_empty_weights(): + self.model = sam_model_registry[model_type](checkpoint=SAM_checkpoint) + from mmgp import offload + # self.model.to(torch.float16) + # offload.save_model(self.model, "ckpts/mask/sam_vit_h_4b8939_fp16.safetensors") + + offload.load_model_data(self.model, "ckpts/mask/sam_vit_h_4b8939_fp16.safetensors") + self.model.to(torch.float32) # need to be optimized, if not f32 crappy precision + self.model.to(device=self.device) + self.predictor = SamPredictor(self.model) + self.embedded = False + + @torch.no_grad() + def set_image(self, image: np.ndarray): + # PIL.open(image_path) 3channel: RGB + # image embedding: avoid encode the same image multiple times + self.orignal_image = image + if self.embedded: + print('repeat embedding, please reset_image.') + return + self.predictor.set_image(image) + self.embedded = True + return + + @torch.no_grad() + def reset_image(self): + # reset image embeding + self.predictor.reset_image() + self.embedded = False + + def predict(self, prompts, mode, multimask=True): + """ + image: numpy array, h, w, 3 + prompts: dictionary, 3 keys: 'point_coords', 'point_labels', 'mask_input' + prompts['point_coords']: numpy array [N,2] + prompts['point_labels']: numpy array [1,N] + prompts['mask_input']: numpy array [1,256,256] + mode: 'point' (points only), 'mask' (mask only), 'both' (consider both) + mask_outputs: True (return 3 masks), False (return 1 mask only) + whem mask_outputs=True, mask_input=logits[np.argmax(scores), :, :][None, :, :] + """ + assert self.embedded, 'prediction is called before set_image (feature embedding).' + assert mode in ['point', 'mask', 'both'], 'mode must be point, mask, or both' + + with torch.autocast(device_type='cuda', dtype=torch.float16): + if mode == 'point': + masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'], + point_labels=prompts['point_labels'], + multimask_output=multimask) + elif mode == 'mask': + masks, scores, logits = self.predictor.predict(mask_input=prompts['mask_input'], + multimask_output=multimask) + elif mode == 'both': # both + masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'], + point_labels=prompts['point_labels'], + mask_input=prompts['mask_input'], + multimask_output=multimask) + else: + raise("Not implement now!") + # masks (n, h, w), scores (n,), logits (n, 256, 256) + return masks, scores, logits + + +if __name__ == "__main__": + # load and show an image + image = cv2.imread('/hhd3/gaoshang/truck.jpg') + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # numpy array (h, w, 3) + + # initialise BaseSegmenter + SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth' + model_type = 'vit_h' + device = "cuda:4" + base_segmenter = BaseSegmenter(SAM_checkpoint=SAM_checkpoint, model_type=model_type, device=device) + + # image embedding (once embedded, multiple prompts can be applied) + base_segmenter.set_image(image) + + # examples + # point only ------------------------ + mode = 'point' + prompts = { + 'point_coords': np.array([[500, 375], [1125, 625]]), + 'point_labels': np.array([1, 1]), + } + masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=False) # masks (n, h, w), scores (n,), logits (n, 256, 256) + painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8) + painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3) + cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image) + + # both ------------------------ + mode = 'both' + mask_input = logits[np.argmax(scores), :, :] + prompts = {'mask_input': mask_input [None, :, :]} + prompts = { + 'point_coords': np.array([[500, 375], [1125, 625]]), + 'point_labels': np.array([1, 0]), + 'mask_input': mask_input[None, :, :] + } + masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256) + painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8) + painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3) + cv2.imwrite('/hhd3/gaoshang/truck_both.jpg', painted_image) + + # mask only ------------------------ + mode = 'mask' + mask_input = logits[np.argmax(scores), :, :] + + prompts = {'mask_input': mask_input[None, :, :]} + + masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256) + painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8) + painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3) + cv2.imwrite('/hhd3/gaoshang/truck_mask.jpg', painted_image) diff --git a/preprocessing/matanyone/tools/download_util.py b/preprocessing/matanyone/tools/download_util.py new file mode 100644 index 0000000..5e8fb1b --- /dev/null +++ b/preprocessing/matanyone/tools/download_util.py @@ -0,0 +1,109 @@ +import math +import os +import requests +from torch.hub import download_url_to_file, get_dir +from tqdm import tqdm +from urllib.parse import urlparse + +def sizeof_fmt(size, suffix='B'): + """Get human readable file size. + + Args: + size (int): File size. + suffix (str): Suffix. Default: 'B'. + + Return: + str: Formated file siz. + """ + for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: + if abs(size) < 1024.0: + return f'{size:3.1f} {unit}{suffix}' + size /= 1024.0 + return f'{size:3.1f} Y{suffix}' + + +def download_file_from_google_drive(file_id, save_path): + """Download files from google drive. + Ref: + https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 + Args: + file_id (str): File id. + save_path (str): Save path. + """ + + session = requests.Session() + URL = 'https://docs.google.com/uc?export=download' + params = {'id': file_id} + + response = session.get(URL, params=params, stream=True) + token = get_confirm_token(response) + if token: + params['confirm'] = token + response = session.get(URL, params=params, stream=True) + + # get file size + response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) + print(response_file_size) + if 'Content-Range' in response_file_size.headers: + file_size = int(response_file_size.headers['Content-Range'].split('/')[1]) + else: + file_size = None + + save_response_content(response, save_path, file_size) + + +def get_confirm_token(response): + for key, value in response.cookies.items(): + if key.startswith('download_warning'): + return value + return None + + +def save_response_content(response, destination, file_size=None, chunk_size=32768): + if file_size is not None: + pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') + + readable_file_size = sizeof_fmt(file_size) + else: + pbar = None + + with open(destination, 'wb') as f: + downloaded_size = 0 + for chunk in response.iter_content(chunk_size): + downloaded_size += chunk_size + if pbar is not None: + pbar.update(1) + pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}') + if chunk: # filter out keep-alive new chunks + f.write(chunk) + if pbar is not None: + pbar.close() + + +def load_file_from_url(url, model_dir=None, progress=True, file_name=None): + """Load file form http url, will download models if necessary. + Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py + Args: + url (str): URL to be downloaded. + model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. + Default: None. + progress (bool): Whether to show the download progress. Default: True. + file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. + Returns: + str: The path to the downloaded file. + """ + if model_dir is None: # use the pytorch hub_dir + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, 'checkpoints') + + os.makedirs(model_dir, exist_ok=True) + + parts = urlparse(url) + filename = os.path.basename(parts.path) + if file_name is not None: + filename = file_name + cached_file = os.path.abspath(os.path.join(model_dir, filename)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) + return cached_file \ No newline at end of file diff --git a/preprocessing/matanyone/tools/interact_tools.py b/preprocessing/matanyone/tools/interact_tools.py new file mode 100644 index 0000000..c70b8c4 --- /dev/null +++ b/preprocessing/matanyone/tools/interact_tools.py @@ -0,0 +1,99 @@ +import time +import torch +import cv2 +from PIL import Image, ImageDraw, ImageOps +import numpy as np +from typing import Union +from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator +import matplotlib.pyplot as plt +import PIL +from .mask_painter import mask_painter as mask_painter2 +from .base_segmenter import BaseSegmenter +from .painter import mask_painter, point_painter +import os +import requests +import sys + + +mask_color = 3 +mask_alpha = 0.7 +contour_color = 1 +contour_width = 5 +point_color_ne = 8 +point_color_ps = 50 +point_alpha = 0.9 +point_radius = 15 +contour_color = 2 +contour_width = 5 + + +class SamControler(): + def __init__(self, SAM_checkpoint, model_type, device): + ''' + initialize sam controler + ''' + self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device) + + + # def seg_again(self, image: np.ndarray): + # ''' + # it is used when interact in video + # ''' + # self.sam_controler.reset_image() + # self.sam_controler.set_image(image) + # return + + + def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True,mask_color=3): + ''' + it is used in first frame in video + return: mask, logit, painted image(mask+point) + ''' + # self.sam_controler.set_image(image) + origal_image = self.sam_controler.orignal_image + neg_flag = labels[-1] + if neg_flag==1: + #find neg + prompts = { + 'point_coords': points, + 'point_labels': labels, + } + masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask) + mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] + prompts = { + 'point_coords': points, + 'point_labels': labels, + 'mask_input': logit[None, :, :] + } + masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask) + mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] + else: + #find positive + prompts = { + 'point_coords': points, + 'point_labels': labels, + } + masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask) + mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] + + + assert len(points)==len(labels) + + painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width) + painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width) + painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width) + painted_image = Image.fromarray(painted_image) + + return mask, logit, painted_image + + + + + + + + + + + + \ No newline at end of file diff --git a/preprocessing/matanyone/tools/mask_painter.py b/preprocessing/matanyone/tools/mask_painter.py new file mode 100644 index 0000000..f471ea0 --- /dev/null +++ b/preprocessing/matanyone/tools/mask_painter.py @@ -0,0 +1,288 @@ +import cv2 +import torch +import numpy as np +from PIL import Image +import copy +import time + + +def colormap(rgb=True): + color_list = np.array( + [ + 0.000, 0.000, 0.000, + 1.000, 1.000, 1.000, + 1.000, 0.498, 0.313, + 0.392, 0.581, 0.929, + 0.000, 0.447, 0.741, + 0.850, 0.325, 0.098, + 0.929, 0.694, 0.125, + 0.494, 0.184, 0.556, + 0.466, 0.674, 0.188, + 0.301, 0.745, 0.933, + 0.635, 0.078, 0.184, + 0.300, 0.300, 0.300, + 0.600, 0.600, 0.600, + 1.000, 0.000, 0.000, + 1.000, 0.500, 0.000, + 0.749, 0.749, 0.000, + 0.000, 1.000, 0.000, + 0.000, 0.000, 1.000, + 0.667, 0.000, 1.000, + 0.333, 0.333, 0.000, + 0.333, 0.667, 0.000, + 0.333, 1.000, 0.000, + 0.667, 0.333, 0.000, + 0.667, 0.667, 0.000, + 0.667, 1.000, 0.000, + 1.000, 0.333, 0.000, + 1.000, 0.667, 0.000, + 1.000, 1.000, 0.000, + 0.000, 0.333, 0.500, + 0.000, 0.667, 0.500, + 0.000, 1.000, 0.500, + 0.333, 0.000, 0.500, + 0.333, 0.333, 0.500, + 0.333, 0.667, 0.500, + 0.333, 1.000, 0.500, + 0.667, 0.000, 0.500, + 0.667, 0.333, 0.500, + 0.667, 0.667, 0.500, + 0.667, 1.000, 0.500, + 1.000, 0.000, 0.500, + 1.000, 0.333, 0.500, + 1.000, 0.667, 0.500, + 1.000, 1.000, 0.500, + 0.000, 0.333, 1.000, + 0.000, 0.667, 1.000, + 0.000, 1.000, 1.000, + 0.333, 0.000, 1.000, + 0.333, 0.333, 1.000, + 0.333, 0.667, 1.000, + 0.333, 1.000, 1.000, + 0.667, 0.000, 1.000, + 0.667, 0.333, 1.000, + 0.667, 0.667, 1.000, + 0.667, 1.000, 1.000, + 1.000, 0.000, 1.000, + 1.000, 0.333, 1.000, + 1.000, 0.667, 1.000, + 0.167, 0.000, 0.000, + 0.333, 0.000, 0.000, + 0.500, 0.000, 0.000, + 0.667, 0.000, 0.000, + 0.833, 0.000, 0.000, + 1.000, 0.000, 0.000, + 0.000, 0.167, 0.000, + 0.000, 0.333, 0.000, + 0.000, 0.500, 0.000, + 0.000, 0.667, 0.000, + 0.000, 0.833, 0.000, + 0.000, 1.000, 0.000, + 0.000, 0.000, 0.167, + 0.000, 0.000, 0.333, + 0.000, 0.000, 0.500, + 0.000, 0.000, 0.667, + 0.000, 0.000, 0.833, + 0.000, 0.000, 1.000, + 0.143, 0.143, 0.143, + 0.286, 0.286, 0.286, + 0.429, 0.429, 0.429, + 0.571, 0.571, 0.571, + 0.714, 0.714, 0.714, + 0.857, 0.857, 0.857 + ] + ).astype(np.float32) + color_list = color_list.reshape((-1, 3)) * 255 + if not rgb: + color_list = color_list[:, ::-1] + return color_list + + +color_list = colormap() +color_list = color_list.astype('uint8').tolist() + + +def vis_add_mask(image, background_mask, contour_mask, background_color, contour_color, background_alpha, contour_alpha): + background_color = np.array(background_color) + contour_color = np.array(contour_color) + + # background_mask = 1 - background_mask + # contour_mask = 1 - contour_mask + + for i in range(3): + image[:, :, i] = image[:, :, i] * (1-background_alpha+background_mask*background_alpha) \ + + background_color[i] * (background_alpha-background_mask*background_alpha) + + image[:, :, i] = image[:, :, i] * (1-contour_alpha+contour_mask*contour_alpha) \ + + contour_color[i] * (contour_alpha-contour_mask*contour_alpha) + + return image.astype('uint8') + + +def mask_generator_00(mask, background_radius, contour_radius): + # no background width when '00' + # distance map + dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3) + dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3) + dist_map = dist_transform_fore - dist_transform_back + # ...:::!!!:::... + contour_radius += 2 + contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) + contour_mask = contour_mask / np.max(contour_mask) + contour_mask[contour_mask>0.5] = 1. + + return mask, contour_mask + + +def mask_generator_01(mask, background_radius, contour_radius): + # no background width when '00' + # distance map + dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3) + dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3) + dist_map = dist_transform_fore - dist_transform_back + # ...:::!!!:::... + contour_radius += 2 + contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) + contour_mask = contour_mask / np.max(contour_mask) + return mask, contour_mask + + +def mask_generator_10(mask, background_radius, contour_radius): + # distance map + dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3) + dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3) + dist_map = dist_transform_fore - dist_transform_back + # .....:::::!!!!! + background_mask = np.clip(dist_map, -background_radius, background_radius) + background_mask = (background_mask - np.min(background_mask)) + background_mask = background_mask / np.max(background_mask) + # ...:::!!!:::... + contour_radius += 2 + contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) + contour_mask = contour_mask / np.max(contour_mask) + contour_mask[contour_mask>0.5] = 1. + return background_mask, contour_mask + + +def mask_generator_11(mask, background_radius, contour_radius): + # distance map + dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3) + dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3) + dist_map = dist_transform_fore - dist_transform_back + # .....:::::!!!!! + background_mask = np.clip(dist_map, -background_radius, background_radius) + background_mask = (background_mask - np.min(background_mask)) + background_mask = background_mask / np.max(background_mask) + # ...:::!!!:::... + contour_radius += 2 + contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) + contour_mask = contour_mask / np.max(contour_mask) + return background_mask, contour_mask + + +def mask_painter(input_image, input_mask, background_alpha=0.5, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1, mode='11'): + """ + Input: + input_image: numpy array + input_mask: numpy array + background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing + background_blur_radius: radius of background blur, must be odd number + contour_width: width of mask contour, must be odd number + contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others + contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted + mode: painting mode, '00', no blur, '01' only blur contour, '10' only blur background, '11' blur both + + Output: + painted_image: numpy array + """ + assert input_image.shape[:2] == input_mask.shape, 'different shape' + assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD' + assert mode in ['00', '01', '10', '11'], 'mode should be 00, 01, 10, or 11' + + # downsample input image and mask + width, height = input_image.shape[0], input_image.shape[1] + res = 1024 + ratio = min(1.0 * res / max(width, height), 1.0) + input_image = cv2.resize(input_image, (int(height*ratio), int(width*ratio))) + input_mask = cv2.resize(input_mask, (int(height*ratio), int(width*ratio))) + + # 0: background, 1: foreground + msk = np.clip(input_mask, 0, 1) + + # generate masks for background and contour pixels + background_radius = (background_blur_radius - 1) // 2 + contour_radius = (contour_width - 1) // 2 + generator_dict = {'00':mask_generator_00, '01':mask_generator_01, '10':mask_generator_10, '11':mask_generator_11} + background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius) + + # paint + painted_image = vis_add_mask\ + (input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha, contour_alpha) # black for background + + return painted_image + + +if __name__ == '__main__': + + background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing + background_blur_radius = 31 # radius of background blur, must be odd number + contour_width = 11 # contour width, must be odd number + contour_color = 3 # id in color map, 0: black, 1: white, >1: others + contour_alpha = 1 # transparency of background, 0: no contour highlighted + + # load input image and mask + input_image = np.array(Image.open('./test_img/painter_input_image.jpg').convert('RGB')) + input_mask = np.array(Image.open('./test_img/painter_input_mask.jpg').convert('P')) + + # paint + overall_time_1 = 0 + overall_time_2 = 0 + overall_time_3 = 0 + overall_time_4 = 0 + overall_time_5 = 0 + + for i in range(50): + t2 = time.time() + painted_image_00 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='00') + e2 = time.time() + + t3 = time.time() + painted_image_10 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='10') + e3 = time.time() + + t1 = time.time() + painted_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha) + e1 = time.time() + + t4 = time.time() + painted_image_01 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='01') + e4 = time.time() + + t5 = time.time() + painted_image_11 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='11') + e5 = time.time() + + overall_time_1 += (e1 - t1) + overall_time_2 += (e2 - t2) + overall_time_3 += (e3 - t3) + overall_time_4 += (e4 - t4) + overall_time_5 += (e5 - t5) + + print(f'average time w gaussian: {overall_time_1/50}') + print(f'average time w/o gaussian00: {overall_time_2/50}') + print(f'average time w/o gaussian10: {overall_time_3/50}') + print(f'average time w/o gaussian01: {overall_time_4/50}') + print(f'average time w/o gaussian11: {overall_time_5/50}') + + # save + painted_image_00 = Image.fromarray(painted_image_00) + painted_image_00.save('./test_img/painter_output_image_00.png') + + painted_image_10 = Image.fromarray(painted_image_10) + painted_image_10.save('./test_img/painter_output_image_10.png') + + painted_image_01 = Image.fromarray(painted_image_01) + painted_image_01.save('./test_img/painter_output_image_01.png') + + painted_image_11 = Image.fromarray(painted_image_11) + painted_image_11.save('./test_img/painter_output_image_11.png') diff --git a/preprocessing/matanyone/tools/misc.py b/preprocessing/matanyone/tools/misc.py new file mode 100644 index 0000000..43b8499 --- /dev/null +++ b/preprocessing/matanyone/tools/misc.py @@ -0,0 +1,131 @@ +import os +import re +import random +import time +import torch +import torch.nn as nn +import logging +import numpy as np +from os import path as osp + +def constant_init(module, val, bias=0): + if hasattr(module, 'weight') and module.weight is not None: + nn.init.constant_(module.weight, val) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + +initialized_logger = {} +def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): + """Get the root logger. + The logger will be initialized if it has not been initialized. By default a + StreamHandler will be added. If `log_file` is specified, a FileHandler will + also be added. + Args: + logger_name (str): root logger name. Default: 'basicsr'. + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the root logger. + log_level (int): The root logger level. Note that only the process of + rank 0 is affected, while other processes will set the level to + "Error" and be silent most of the time. + Returns: + logging.Logger: The root logger. + """ + logger = logging.getLogger(logger_name) + # if the logger has been initialized, just return it + if logger_name in initialized_logger: + return logger + + format_str = '%(asctime)s %(levelname)s: %(message)s' + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(logging.Formatter(format_str)) + logger.addHandler(stream_handler) + logger.propagate = False + + if log_file is not None: + logger.setLevel(log_level) + # add file handler + # file_handler = logging.FileHandler(log_file, 'w') + file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log + file_handler.setFormatter(logging.Formatter(format_str)) + file_handler.setLevel(log_level) + logger.addHandler(file_handler) + initialized_logger[logger_name] = True + return logger + + +IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\ + torch.__version__)[0][:3])] >= [1, 12, 0] + +def gpu_is_available(): + if IS_HIGH_VERSION: + if torch.backends.mps.is_available(): + return True + return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False + +def get_device(gpu_id=None): + if gpu_id is None: + gpu_str = '' + elif isinstance(gpu_id, int): + gpu_str = f':{gpu_id}' + else: + raise TypeError('Input should be int value.') + + if IS_HIGH_VERSION: + if torch.backends.mps.is_available(): + return torch.device('mps'+gpu_str) + return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu') + + +def set_random_seed(seed): + """Set random seeds.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_time_str(): + return time.strftime('%Y%m%d_%H%M%S', time.localtime()) + + +def scandir(dir_path, suffix=None, recursive=False, full_path=False): + """Scan a directory to find the interested files. + + Args: + dir_path (str): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + full_path (bool, optional): If set to True, include the dir_path. + Default: False. + + Returns: + A generator for all the interested files with relative pathes. + """ + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + root = dir_path + + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + if full_path: + return_path = entry.path + else: + return_path = osp.relpath(entry.path, root) + + if suffix is None: + yield return_path + elif return_path.endswith(suffix): + yield return_path + else: + if recursive: + yield from _scandir(entry.path, suffix=suffix, recursive=recursive) + else: + continue + + return _scandir(dir_path, suffix=suffix, recursive=recursive) \ No newline at end of file diff --git a/preprocessing/matanyone/tools/painter.py b/preprocessing/matanyone/tools/painter.py new file mode 100644 index 0000000..0e711d3 --- /dev/null +++ b/preprocessing/matanyone/tools/painter.py @@ -0,0 +1,215 @@ +# paint masks, contours, or points on images, with specified colors +import cv2 +import torch +import numpy as np +from PIL import Image +import copy +import time + + +def colormap(rgb=True): + color_list = np.array( + [ + 0.000, 0.000, 0.000, + 1.000, 1.000, 1.000, + 1.000, 0.498, 0.313, + 0.392, 0.581, 0.929, + 0.000, 0.447, 0.741, + 0.850, 0.325, 0.098, + 0.929, 0.694, 0.125, + 0.494, 0.184, 0.556, + 0.466, 0.674, 0.188, + 0.301, 0.745, 0.933, + 0.635, 0.078, 0.184, + 0.300, 0.300, 0.300, + 0.600, 0.600, 0.600, + 1.000, 0.000, 0.000, + 1.000, 0.500, 0.000, + 0.749, 0.749, 0.000, + 0.000, 1.000, 0.000, + 0.000, 0.000, 1.000, + 0.667, 0.000, 1.000, + 0.333, 0.333, 0.000, + 0.333, 0.667, 0.000, + 0.333, 1.000, 0.000, + 0.667, 0.333, 0.000, + 0.667, 0.667, 0.000, + 0.667, 1.000, 0.000, + 1.000, 0.333, 0.000, + 1.000, 0.667, 0.000, + 1.000, 1.000, 0.000, + 0.000, 0.333, 0.500, + 0.000, 0.667, 0.500, + 0.000, 1.000, 0.500, + 0.333, 0.000, 0.500, + 0.333, 0.333, 0.500, + 0.333, 0.667, 0.500, + 0.333, 1.000, 0.500, + 0.667, 0.000, 0.500, + 0.667, 0.333, 0.500, + 0.667, 0.667, 0.500, + 0.667, 1.000, 0.500, + 1.000, 0.000, 0.500, + 1.000, 0.333, 0.500, + 1.000, 0.667, 0.500, + 1.000, 1.000, 0.500, + 0.000, 0.333, 1.000, + 0.000, 0.667, 1.000, + 0.000, 1.000, 1.000, + 0.333, 0.000, 1.000, + 0.333, 0.333, 1.000, + 0.333, 0.667, 1.000, + 0.333, 1.000, 1.000, + 0.667, 0.000, 1.000, + 0.667, 0.333, 1.000, + 0.667, 0.667, 1.000, + 0.667, 1.000, 1.000, + 1.000, 0.000, 1.000, + 1.000, 0.333, 1.000, + 1.000, 0.667, 1.000, + 0.167, 0.000, 0.000, + 0.333, 0.000, 0.000, + 0.500, 0.000, 0.000, + 0.667, 0.000, 0.000, + 0.833, 0.000, 0.000, + 1.000, 0.000, 0.000, + 0.000, 0.167, 0.000, + 0.000, 0.333, 0.000, + 0.000, 0.500, 0.000, + 0.000, 0.667, 0.000, + 0.000, 0.833, 0.000, + 0.000, 1.000, 0.000, + 0.000, 0.000, 0.167, + 0.000, 0.000, 0.333, + 0.000, 0.000, 0.500, + 0.000, 0.000, 0.667, + 0.000, 0.000, 0.833, + 0.000, 0.000, 1.000, + 0.143, 0.143, 0.143, + 0.286, 0.286, 0.286, + 0.429, 0.429, 0.429, + 0.571, 0.571, 0.571, + 0.714, 0.714, 0.714, + 0.857, 0.857, 0.857 + ] + ).astype(np.float32) + color_list = color_list.reshape((-1, 3)) * 255 + if not rgb: + color_list = color_list[:, ::-1] + return color_list + + +color_list = colormap() +color_list = color_list.astype('uint8').tolist() + + +def vis_add_mask(image, mask, color, alpha): + color = np.array(color_list[color]) + mask = mask > 0.5 + image[mask] = image[mask] * (1-alpha) + color * alpha + return image.astype('uint8') + +def point_painter(input_image, input_points, point_color=5, point_alpha=0.9, point_radius=15, contour_color=2, contour_width=5): + h, w = input_image.shape[:2] + point_mask = np.zeros((h, w)).astype('uint8') + for point in input_points: + point_mask[point[1], point[0]] = 1 + + kernel = cv2.getStructuringElement(2, (point_radius, point_radius)) + point_mask = cv2.dilate(point_mask, kernel) + + contour_radius = (contour_width - 1) // 2 + dist_transform_fore = cv2.distanceTransform(point_mask, cv2.DIST_L2, 3) + dist_transform_back = cv2.distanceTransform(1-point_mask, cv2.DIST_L2, 3) + dist_map = dist_transform_fore - dist_transform_back + # ...:::!!!:::... + contour_radius += 2 + contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) + contour_mask = contour_mask / np.max(contour_mask) + contour_mask[contour_mask>0.5] = 1. + + # paint mask + painted_image = vis_add_mask(input_image.copy(), point_mask, point_color, point_alpha) + # paint contour + painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1) + return painted_image + +def mask_painter(input_image, input_mask, mask_color=5, mask_alpha=0.7, contour_color=1, contour_width=3): + assert input_image.shape[:2] == input_mask.shape, 'different shape between image and mask' + # 0: background, 1: foreground + mask = np.clip(input_mask, 0, 1) + contour_radius = (contour_width - 1) // 2 + + dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3) + dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3) + dist_map = dist_transform_fore - dist_transform_back + # ...:::!!!:::... + contour_radius += 2 + contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) + contour_mask = contour_mask / np.max(contour_mask) + contour_mask[contour_mask>0.5] = 1. + + # paint mask + painted_image = vis_add_mask(input_image.copy(), mask.copy(), mask_color, mask_alpha) + # paint contour + painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1) + + return painted_image + +def background_remover(input_image, input_mask): + """ + input_image: H, W, 3, np.array + input_mask: H, W, np.array + + image_wo_background: PIL.Image + """ + assert input_image.shape[:2] == input_mask.shape, 'different shape between image and mask' + # 0: background, 1: foreground + mask = np.expand_dims(np.clip(input_mask, 0, 1), axis=2)*255 + image_wo_background = np.concatenate([input_image, mask], axis=2) # H, W, 4 + image_wo_background = Image.fromarray(image_wo_background).convert('RGBA') + + return image_wo_background + +if __name__ == '__main__': + input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB')) + input_mask = np.array(Image.open('images/painter_input_mask.jpg').convert('P')) + + # example of mask painter + mask_color = 3 + mask_alpha = 0.7 + contour_color = 1 + contour_width = 5 + + # save + painted_image = Image.fromarray(input_image) + painted_image.save('images/original.png') + + painted_image = mask_painter(input_image, input_mask, mask_color, mask_alpha, contour_color, contour_width) + # save + painted_image = Image.fromarray(input_image) + painted_image.save('images/original1.png') + + # example of point painter + input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB')) + input_points = np.array([[500, 375], [70, 600]]) # x, y + point_color = 5 + point_alpha = 0.9 + point_radius = 15 + contour_color = 2 + contour_width = 5 + painted_image_1 = point_painter(input_image, input_points, point_color, point_alpha, point_radius, contour_color, contour_width) + # save + painted_image = Image.fromarray(painted_image_1) + painted_image.save('images/point_painter_1.png') + + input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB')) + painted_image_2 = point_painter(input_image, input_points, point_color=9, point_radius=20, contour_color=29) + # save + painted_image = Image.fromarray(painted_image_2) + painted_image.save('images/point_painter_2.png') + + # example of background remover + input_image = np.array(Image.open('images/original.png').convert('RGB')) + image_wo_background = background_remover(input_image, input_mask) # return PIL.Image + image_wo_background.save('images/image_wo_background.png') diff --git a/preprocessing/matanyone/utils/__init__.py b/preprocessing/matanyone/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/preprocessing/matanyone/utils/get_default_model.py b/preprocessing/matanyone/utils/get_default_model.py new file mode 100644 index 0000000..c51eae6 --- /dev/null +++ b/preprocessing/matanyone/utils/get_default_model.py @@ -0,0 +1,27 @@ +""" +A helper function to get a default model for quick testing +""" +from omegaconf import open_dict +from hydra import compose, initialize + +import torch +from ..matanyone.model.matanyone import MatAnyone + +def get_matanyone_model(ckpt_path, device=None) -> MatAnyone: + initialize(version_base='1.3.2', config_path="../config", job_name="eval_our_config") + cfg = compose(config_name="eval_matanyone_config") + + with open_dict(cfg): + cfg['weights'] = ckpt_path + + # Load the network weights + if device is not None: + matanyone = MatAnyone(cfg, single_object=True).to(device).eval() + model_weights = torch.load(cfg.weights, map_location=device) + else: # if device is not specified, `.cuda()` by default + matanyone = MatAnyone(cfg, single_object=True).cuda().eval() + model_weights = torch.load(cfg.weights) + + matanyone.load_weights(model_weights) + + return matanyone diff --git a/preprocessing/matanyone/utils/tensor_utils.py b/preprocessing/matanyone/utils/tensor_utils.py new file mode 100644 index 0000000..bb25a45 --- /dev/null +++ b/preprocessing/matanyone/utils/tensor_utils.py @@ -0,0 +1,62 @@ +from typing import List, Iterable +import torch +import torch.nn.functional as F + + +# STM +def pad_divide_by(in_img: torch.Tensor, d: int) -> (torch.Tensor, Iterable[int]): + h, w = in_img.shape[-2:] + + if h % d > 0: + new_h = h + d - h % d + else: + new_h = h + if w % d > 0: + new_w = w + d - w % d + else: + new_w = w + lh, uh = int((new_h - h) / 2), int(new_h - h) - int((new_h - h) / 2) + lw, uw = int((new_w - w) / 2), int(new_w - w) - int((new_w - w) / 2) + pad_array = (int(lw), int(uw), int(lh), int(uh)) + out = F.pad(in_img, pad_array) + return out, pad_array + + +def unpad(img: torch.Tensor, pad: Iterable[int]) -> torch.Tensor: + if len(img.shape) == 4: + if pad[2] + pad[3] > 0: + img = img[:, :, pad[2]:-pad[3], :] + if pad[0] + pad[1] > 0: + img = img[:, :, :, pad[0]:-pad[1]] + elif len(img.shape) == 3: + if pad[2] + pad[3] > 0: + img = img[:, pad[2]:-pad[3], :] + if pad[0] + pad[1] > 0: + img = img[:, :, pad[0]:-pad[1]] + elif len(img.shape) == 5: + if pad[2] + pad[3] > 0: + img = img[:, :, :, pad[2]:-pad[3], :] + if pad[0] + pad[1] > 0: + img = img[:, :, :, :, pad[0]:-pad[1]] + else: + raise NotImplementedError + return img + + +# @torch.jit.script +def aggregate(prob: torch.Tensor, dim: int) -> torch.Tensor: + with torch.amp.autocast("cuda"): + prob = prob.float() + new_prob = torch.cat([torch.prod(1 - prob, dim=dim, keepdim=True), prob], + dim).clamp(1e-7, 1 - 1e-7) + logits = torch.log((new_prob / (1 - new_prob))) # (0, 1) --> (-inf, inf) + + return logits + + +# @torch.jit.script +def cls_to_one_hot(cls_gt: torch.Tensor, num_objects: int) -> torch.Tensor: + # cls_gt: B*1*H*W + B, _, H, W = cls_gt.shape + one_hot = torch.zeros(B, num_objects + 1, H, W, device=cls_gt.device).scatter_(1, cls_gt, 1) + return one_hot \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 70126cd..0b90776 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,4 +24,8 @@ onnxruntime-gpu rembg[gpu]==2.0.65 matplotlib timm +segment-anything +ffmpeg-python +omegaconf +hydra-core # rembg==2.0.65 \ No newline at end of file diff --git a/wan/modules/model.py b/wan/modules/model.py index e7a76a9..5af4ae8 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -482,7 +482,6 @@ class WanAttentionBlock(nn.Module): y *= 1 + e[4] y += e[3] - ffn = self.ffn[0] gelu = self.ffn[1] ffn2= self.ffn[2] @@ -500,8 +499,6 @@ class WanAttentionBlock(nn.Module): x.addcmul_(y, e[5]) - - if hint is not None: if context_scale == 1: x.add_(hint) @@ -539,24 +536,13 @@ class VaceWanAttentionBlock(WanAttentionBlock): c = hints[0] hints[0] = None if self.block_id == 0: - c = self.before_proj(c) + x + c = self.before_proj(c) + c += x c = super().forward(c, **kwargs) c_skip = self.after_proj(c) hints[0] = c return c_skip - # def forward(self, c, x, **kwargs): - # # behold dbm magic ! - # if self.block_id == 0: - # c = self.before_proj(c) + x - # all_c = [] - # else: - # all_c = c - # c = all_c.pop(-1) - # c = super().forward(c, **kwargs) - # c_skip = self.after_proj(c) - # all_c += [c_skip, c] - # return all_c class Head(nn.Module): @@ -793,37 +779,6 @@ class WanModel(ModelMixin, ConfigMixin): print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}") return best_threshold - - - # def forward_vace( - # self, - # x, - # vace_context, - # seq_len, - # context, - # e, - # kwargs - # ): - # # embeddings - # c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] - # c = [u.flatten(2).transpose(1, 2) for u in c] - # if (len(c) == 1 and seq_len == c[0].size(1)): - # c = c[0] - # else: - # c = torch.cat([ - # torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], - # dim=1) for u in c - # ]) - - # # arguments - # new_kwargs = dict(x=x) - # new_kwargs.update(kwargs) - - # for block in self.vace_blocks: - # c = block(c, context= context, e= e, **new_kwargs) - # hints = c[:-1] - - # return hints def forward( self, diff --git a/wan/text2video.py b/wan/text2video.py index ba24e6d..035700a 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -209,34 +209,52 @@ class WanT2V: def vace_latent(self, z, m): return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] - def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device, original_video = False, trim_video= 0): + def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, original_video = False, keep_frames= [], start_frame = 0, pre_src_video = None): image_sizes = [] - for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)): + trim_video = len(keep_frames) + + for i, (sub_src_video, sub_src_mask, sub_pre_src_video) in enumerate(zip(src_video, src_mask,pre_src_video)): + prepend_count = 0 if sub_pre_src_video == None else sub_pre_src_video.shape[1] + num_frames = total_frames - prepend_count if sub_src_mask is not None and sub_src_video is not None: - src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video) + src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame) # src_video is [-1, 1], 0 = inpainting area (in fact 127 in [0, 255]) # src_mask is [-1, 1], 0 = preserve original video (in fact 127 in [0, 255]) and 1 = Inpainting (in fact 255 in [0, 255]) src_video[i] = src_video[i].to(device) src_mask[i] = src_mask[i].to(device) + if prepend_count > 0: + src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1) + src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1) src_video_shape = src_video[i].shape - if src_video_shape[1] != num_frames: - src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) - src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) + if src_video_shape[1] != total_frames: + src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) + src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1) image_sizes.append(src_video[i].shape[2:]) elif sub_src_video is None: - src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device) - src_mask[i] = torch.ones_like(src_video[i], device=device) + if prepend_count > 0: + src_video[i] = torch.cat( [sub_pre_src_video, torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)], dim=1) + src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), torch.ones((3, num_frames, image_size[0], image_size[1]), device=device)] ,1) + else: + src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device) + src_mask[i] = torch.ones_like(src_video[i], device=device) image_sizes.append(image_size) else: - src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video) + src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame) src_video[i] = src_video[i].to(device) src_mask[i] = torch.zeros_like(src_video[i], device=device) if original_video else torch.ones_like(src_video[i], device=device) + if prepend_count > 0: + src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1) + src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1) src_video_shape = src_video[i].shape - if src_video_shape[1] != num_frames: - src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) - src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) + if src_video_shape[1] != total_frames: + src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) + src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) image_sizes.append(src_video[i].shape[2:]) + for k, keep in enumerate(keep_frames): + if not keep: + src_video[i][:, k:k+1] = 0 + src_mask[i][:, k:k+1] = 1 for i, ref_images in enumerate(src_ref_images): if ref_images is not None: diff --git a/wan/utils/utils.py b/wan/utils/utils.py index ce4ecd5..5149464 100644 --- a/wan/utils/utils.py +++ b/wan/utils/utils.py @@ -22,27 +22,27 @@ __all__ = ['cache_video', 'cache_image', 'str2bool'] from PIL import Image -def resample(video_fps, video_frames_count, max_frames, target_fps): +def resample(video_fps, video_frames_count, max_target_frames_count, target_fps, start_target_frame ): import math video_frame_duration = 1 /video_fps target_frame_duration = 1 / target_fps - cur_time = 0 - target_time = 0 - frame_no = 0 + target_time = start_target_frame * target_frame_duration + frame_no = math.ceil(target_time / video_frame_duration) + cur_time = frame_no * video_frame_duration frame_ids =[] while True: - if max_frames != 0 and len(frame_ids) >= max_frames: + if max_target_frames_count != 0 and len(frame_ids) >= max_target_frames_count : break add_frames_count = math.ceil( (target_time -cur_time) / video_frame_duration ) frame_no += add_frames_count + if frame_no >= video_frames_count: + break frame_ids.append(frame_no) cur_time += add_frames_count * video_frame_duration target_time += target_frame_duration - if frame_no >= video_frames_count -1: - break - frame_ids = frame_ids[:video_frames_count] + frame_ids = frame_ids[:max_target_frames_count] return frame_ids def get_video_frame(file_name, frame_no): diff --git a/wan/utils/vace_preprocessor.py b/wan/utils/vace_preprocessor.py index 3bfe885..c591cca 100644 --- a/wan/utils/vace_preprocessor.py +++ b/wan/utils/vace_preprocessor.py @@ -182,14 +182,14 @@ class VaceVideoProcessor(object): - def _get_frameid_bbox_adjust_last(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0): + def _get_frameid_bbox_adjust_last(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0, start_frame =0): from wan.utils.utils import resample target_fps = self.max_fps # video_frames_count = len(frame_timestamps) - frame_ids= resample(fps, video_frames_count, max_frames, target_fps) + frame_ids= resample(fps, video_frames_count, max_frames, target_fps, start_frame ) x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box h, w = y2 - y1, x2 - x1 @@ -206,7 +206,7 @@ class VaceVideoProcessor(object): np.log2(np.sqrt(max_area_z)) ))) - seq_len = max_area_z * ((max_frames- 1) // df +1) + seq_len = max_area_z * ((max_frames- start_frame - 1) // df +1) # of = min( # (len(frame_ids) - 1) // df + 1, @@ -226,9 +226,9 @@ class VaceVideoProcessor(object): return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps - def _get_frameid_bbox(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0): + def _get_frameid_bbox(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0, start_frame= 0): if self.keep_last: - return self._get_frameid_bbox_adjust_last(fps, video_frames_count, h, w, crop_box, rng, max_frames= max_frames) + return self._get_frameid_bbox_adjust_last(fps, video_frames_count, h, w, crop_box, rng, max_frames= max_frames, start_frame= start_frame) else: return self._get_frameid_bbox_default(fps, video_frames_count, h, w, crop_box, rng, max_frames= max_frames) @@ -238,7 +238,7 @@ class VaceVideoProcessor(object): def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs): return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs) - def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, max_frames= 0, trim_video =0, **kwargs): + def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, max_frames= 0, trim_video =0, start_frame = 0, **kwargs): rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000) # read video import decord @@ -254,7 +254,7 @@ class VaceVideoProcessor(object): if src_video != None: fps = 16 - length = src_video.shape[1] + length = src_video.shape[0] + start_frame if len(readers) > 0: min_readers = min([len(r) for r in readers]) length = min(length, min_readers ) @@ -269,7 +269,7 @@ class VaceVideoProcessor(object): h, w = src_video.shape[1:3] else: h, w = readers[0].next().shape[:2] - frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, length, h, w, crop_box, rng, max_frames=max_frames) + frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, length, h, w, crop_box, rng, max_frames=max_frames, start_frame = start_frame ) # preprocess video videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers] diff --git a/wgp.py b/wgp.py index 79dc592..c30c5ec 100644 --- a/wgp.py +++ b/wgp.py @@ -144,16 +144,26 @@ def process_prompt_and_add_tasks(state, model_choice): gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P") return - + sliding_window_repeat = inputs["sliding_window_repeat"] + sliding_window = sliding_window_repeat > 0 if "Vace" in model_filename: video_prompt_type = inputs["video_prompt_type"] image_refs = inputs["image_refs"] video_guide = inputs["video_guide"] video_mask = inputs["video_mask"] - if "Vace" in model_filename and "1.3B" in model_filename : + + if sliding_window: + if inputs["repeat_generation"]!=1: + gr.Info("Only one Video generated per Prompt is supported when Sliding windows is used") + return + if inputs["sliding_window_overlap"]>=inputs["video_length"] : + gr.Info("The number of frames of the Sliding Window Overlap must be less than the Number of Frames to Generate") + return + + if "1.3B" in model_filename : resolution_reformated = str(height) + "*" + str(width) if not resolution_reformated in VACE_SIZE_CONFIGS: - res = VACE_SIZE_CONFIGS.keys().join(" and ") + res = (" and ").join(VACE_SIZE_CONFIGS.keys()) gr.Info(f"Video Resolution for Vace model is not supported. Only {res} resolutions are allowed.") return if "I" in video_prompt_type: @@ -175,12 +185,19 @@ def process_prompt_and_add_tasks(state, model_choice): else: video_mask = None if "O" in video_prompt_type : - max_frames= inputs["max_frames"] + keep_frames= inputs["keep_frames"] video_length = inputs["video_length"] - if max_frames ==0: + if len(keep_frames) ==0: gr.Info(f"Warning : you have asked to reuse all the frames of the control Video in the Alternate Video Ending it. Please make sure the number of frames of the control Video is lower than the total number of frames to generate otherwise it won't make a difference.") - elif max_frames >= video_length: - gr.Info(f"The number of frames in the control Video to reuse ({max_frames}) in Alternate Video Ending can not be bigger than the total number of frames ({video_length}) to generate.") + # elif keep_frames >= video_length: + # gr.Info(f"The number of frames in the control Video to reuse ({keep_frames}) in Alternate Video Ending can not be bigger than the total number of frames ({video_length}) to generate.") + # return + elif "V" in video_prompt_type: + keep_frames= inputs["keep_frames"] + video_length = inputs["video_length"] + _, error = parse_keep_frames(keep_frames, video_length) + if len(error) > 0: + gr.Info(f"Invalid Keep Frames property: {error}") return if isinstance(image_refs, list): @@ -190,6 +207,9 @@ def process_prompt_and_add_tasks(state, model_choice): image_refs = resize_and_remove_background(image_refs, width, height, inputs["remove_background_image_ref"] ==1) + if sliding_window and len(prompts) > 0: + prompts = ["\n".join(prompts)] + for single_prompt in prompts: extra_inputs = { "prompt" : single_prompt, @@ -1545,8 +1565,8 @@ def download_models(transformer_filename, text_encoder_filename): from huggingface_hub import hf_hub_download, snapshot_download repoId = "DeepBeepMeep/Wan2.1" - sourceFolderList = ["xlm-roberta-large", "pose", "depth", "", ] - fileList = [ [], [],[], ["Wan2.1_VAE_bf16.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "flownet.pkl" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ] + sourceFolderList = ["xlm-roberta-large", "pose", "depth", "mask", "", ] + fileList = [ [], [],[], ["sam_vit_h_4b8939_fp16.safetensors"], ["Wan2.1_VAE_bf16.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "flownet.pkl" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ] targetRoot = "ckpts/" for sourceFolder, files in zip(sourceFolderList,fileList ): if len(files)==0: @@ -1787,25 +1807,6 @@ def get_model_name(model_filename): return model_name -# def generate_header(model_filename, compile, attention_mode): - -# header = "