gr.info based quit countdown timer

This commit is contained in:
Chris Malone 2025-04-14 04:07:59 +10:00
commit 93e876f06a
50 changed files with 5856 additions and 235 deletions

View File

@ -14,10 +14,14 @@
## 🔥 Latest News!!
* April 9 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you !
* April 13 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you !
- A new queuing system that lets you stack in a queue as many text2video and imag2video tasks as you want. Each task can rely on complete different generation parameters (different number of frames, steps, loras, ...).
- Temporal upsampling (Rife) and spatial upsampling (Lanczos) for a smoother video (32 fps or 64 fps) and to enlarge you video by x2 or x4. Check these new advanced options.
- Wan Vace Control Net support : with Vace you can inject in the scene people or objects, animate a person, perform inpainting or outpainting, continue a video, ... I have provided an introduction guide below.
- Integrated *Matanyone* tool directly inside WanGP so that you can create easily inpainting masks
- Sliding Window generation for Vace, create windows that can last dozen of seconds
- A new UI, tabs were replaced by a Dropdown box to easily switch models
* Mar 27 2025: 👋 Added support for the new Wan Fun InP models (image2video). The 14B Fun InP has probably better end image support but unfortunately existing loras do not work so well with it. The great novelty is the Fun InP image2 1.3B model : Image 2 Video is now accessible to even lower hardware configuration. It is not as good as the 14B models but very impressive for its size. You can choose any of those models in the Configuration tab. Many thanks to the VideoX-Fun team (https://github.com/aigc-apps/VideoX-Fun)
* Mar 26 2025: 👋 Good news ! Official support for RTX 50xx please check the installation instructions below.
* Mar 24 2025: 👋 Wan2.1GP v3.2:

View File

View File

@ -0,0 +1,683 @@
import sys
import os
import json
import time
import psutil
import ffmpeg
import imageio
from PIL import Image
import cv2
import torch
import numpy as np
import gradio as gr
from .tools.painter import mask_painter
from .tools.interact_tools import SamControler
from .tools.misc import get_device
from .tools.download_util import load_file_from_url
from .utils.get_default_model import get_matanyone_model
from .matanyone.inference.inference_core import InferenceCore
from .matanyone_wrapper import matanyone
arg_device = "cuda"
arg_sam_model_type="vit_h"
arg_mask_save = False
model = None
matanyone_model = None
# SAM generator
class MaskGenerator():
def __init__(self, sam_checkpoint, device):
global args_device
args_device = device
self.samcontroler = SamControler(sam_checkpoint, arg_sam_model_type, arg_device)
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
return mask, logit, painted_image
# convert points input to prompt state
def get_prompt(click_state, click_input):
inputs = json.loads(click_input)
points = click_state[0]
labels = click_state[1]
for input in inputs:
points.append(input[:2])
labels.append(input[2])
click_state[0] = points
click_state[1] = labels
prompt = {
"prompt_type":["click"],
"input_point":click_state[0],
"input_label":click_state[1],
"multimask_output":"True",
}
return prompt
def get_frames_from_image(image_input, image_state):
"""
Args:
video_path:str
timestamp:float64
Return
[[0:nearest_frame], [nearest_frame:], nearest_frame]
"""
user_name = time.time()
frames = [image_input] * 2 # hardcode: mimic a video with 2 frames
image_size = (frames[0].shape[0],frames[0].shape[1])
# initialize video_state
image_state = {
"user_name": user_name,
"image_name": "output.png",
"origin_images": frames,
"painted_images": frames.copy(),
"masks": [np.zeros((frames[0].shape[0],frames[0].shape[1]), np.uint8)]*len(frames),
"logits": [None]*len(frames),
"select_frame_number": 0,
"last_frame_numer": 0,
"fps": None
}
image_info = "Image Name: N/A,\nFPS: N/A,\nTotal Frames: {},\nImage Size:{}".format(len(frames), image_size)
model.samcontroler.sam_controler.reset_image()
model.samcontroler.sam_controler.set_image(image_state["origin_images"][0])
return image_state, image_info, image_state["origin_images"][0], \
gr.update(visible=True, maximum=10, value=10), gr.update(visible=True, maximum=len(frames), value=len(frames)), gr.update(visible=False, maximum=len(frames), value=len(frames)), \
gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True), gr.update(visible=True),\
gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True), gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=True), \
gr.update(visible=True)
# extract frames from upload video
def get_frames_from_video(video_input, video_state):
"""
Args:
video_path:str
timestamp:float64
Return
[[0:nearest_frame], [nearest_frame:], nearest_frame]
"""
while model == None:
time.sleep(1)
video_path = video_input
frames = []
user_name = time.time()
# extract Audio
# try:
# audio_path = video_input.replace(".mp4", "_audio.wav")
# ffmpeg.input(video_path).output(audio_path, format='wav', acodec='pcm_s16le', ac=2, ar='44100').run(overwrite_output=True, quiet=True)
# except Exception as e:
# print(f"Audio extraction error: {str(e)}")
# audio_path = "" # Set to "" if extraction fails
# print(f'audio_path: {audio_path}')
audio_path = ""
# extract frames
try:
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
while cap.isOpened():
ret, frame = cap.read()
if ret == True:
current_memory_usage = psutil.virtual_memory().percent
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
if current_memory_usage > 90:
break
else:
break
except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e:
print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
image_size = (frames[0].shape[0],frames[0].shape[1])
# resize if resolution too big
if image_size[0]>=1280 and image_size[0]>=1280:
scale = 1080 / min(image_size)
new_w = int(image_size[1] * scale)
new_h = int(image_size[0] * scale)
# update frames
frames = [cv2.resize(f, (new_w, new_h), interpolation=cv2.INTER_AREA) for f in frames]
# update image_size
image_size = (frames[0].shape[0],frames[0].shape[1])
# initialize video_state
video_state = {
"user_name": user_name,
"video_name": os.path.split(video_path)[-1],
"origin_images": frames,
"painted_images": frames.copy(),
"masks": [np.zeros((frames[0].shape[0],frames[0].shape[1]), np.uint8)]*len(frames),
"logits": [None]*len(frames),
"select_frame_number": 0,
"last_frame_number": 0,
"fps": fps,
"audio": audio_path
}
video_info = "Video Name: {},\nFPS: {},\nTotal Frames: {},\nImage Size:{}".format(video_state["video_name"], round(video_state["fps"], 0), len(frames), image_size)
model.samcontroler.sam_controler.reset_image()
model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
return video_state, video_info, video_state["origin_images"][0], \
gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), gr.update(visible=False, maximum=len(frames), value=len(frames)), \
gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True), gr.update(visible=True),\
gr.update(visible=True), gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=True), \
gr.update(visible=True)
# get the select frame from gradio slider
def select_video_template(image_selection_slider, video_state, interactive_state):
image_selection_slider -= 1
video_state["select_frame_number"] = image_selection_slider
# once select a new template frame, set the image in sam
model.samcontroler.sam_controler.reset_image()
model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider])
return video_state["painted_images"][image_selection_slider], video_state, interactive_state
def select_image_template(image_selection_slider, video_state, interactive_state):
image_selection_slider = 0 # fixed for image
video_state["select_frame_number"] = image_selection_slider
# once select a new template frame, set the image in sam
model.samcontroler.sam_controler.reset_image()
model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider])
return video_state["painted_images"][image_selection_slider], video_state, interactive_state
# set the tracking end frame
def get_end_number(track_pause_number_slider, video_state, interactive_state):
interactive_state["track_end_number"] = track_pause_number_slider
return video_state["painted_images"][track_pause_number_slider],interactive_state
# use sam to get the mask
def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData ): #
"""
Args:
template_frame: PIL.Image
point_prompt: flag for positive or negative button click
click_state: [[points], [labels]]
"""
if point_prompt == "Positive":
coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1])
interactive_state["positive_click_times"] += 1
else:
coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1])
interactive_state["negative_click_times"] += 1
# prompt for sam model
model.samcontroler.sam_controler.reset_image()
model.samcontroler.sam_controler.set_image(video_state["origin_images"][video_state["select_frame_number"]])
prompt = get_prompt(click_state=click_state, click_input=coordinate)
mask, logit, painted_image = model.first_frame_click(
image=video_state["origin_images"][video_state["select_frame_number"]],
points=np.array(prompt["input_point"]),
labels=np.array(prompt["input_label"]),
multimask=prompt["multimask_output"],
)
video_state["masks"][video_state["select_frame_number"]] = mask
video_state["logits"][video_state["select_frame_number"]] = logit
video_state["painted_images"][video_state["select_frame_number"]] = painted_image
return painted_image, video_state, interactive_state
def add_multi_mask(video_state, interactive_state, mask_dropdown):
mask = video_state["masks"][video_state["select_frame_number"]]
interactive_state["multi_mask"]["masks"].append(mask)
interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
select_frame = show_mask(video_state, interactive_state, mask_dropdown)
return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]]
def clear_click(video_state, click_state):
click_state = [[],[]]
template_frame = video_state["origin_images"][video_state["select_frame_number"]]
return template_frame, click_state
def remove_multi_mask(interactive_state, mask_dropdown):
interactive_state["multi_mask"]["mask_names"]= []
interactive_state["multi_mask"]["masks"] = []
return interactive_state, gr.update(choices=[],value=[])
def show_mask(video_state, interactive_state, mask_dropdown):
mask_dropdown.sort()
if video_state["origin_images"]:
select_frame = video_state["origin_images"][video_state["select_frame_number"]]
for i in range(len(mask_dropdown)):
mask_number = int(mask_dropdown[i].split("_")[1]) - 1
mask = interactive_state["multi_mask"]["masks"][mask_number]
select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2)
return select_frame
def save_video(frames, output_path, fps):
writer = imageio.get_writer( output_path, fps=fps, codec='libx264', quality=8)
for frame in frames:
writer.append_data(frame)
writer.close()
return output_path
# video matting
def video_matting(video_state, end_slider, matting_type, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size):
matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg)
# if interactive_state["track_end_number"]:
# following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
# else:
end_slider = max(video_state["select_frame_number"] +1, end_slider)
following_frames = video_state["origin_images"][video_state["select_frame_number"]: end_slider]
if interactive_state["multi_mask"]["masks"]:
if len(mask_dropdown) == 0:
mask_dropdown = ["mask_001"]
mask_dropdown.sort()
template_mask = interactive_state["multi_mask"]["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1]))
for i in range(1,len(mask_dropdown)):
mask_number = int(mask_dropdown[i].split("_")[1]) - 1
template_mask = np.clip(template_mask+interactive_state["multi_mask"]["masks"][mask_number]*(mask_number+1), 0, mask_number+1)
video_state["masks"][video_state["select_frame_number"]]= template_mask
else:
template_mask = video_state["masks"][video_state["select_frame_number"]]
fps = video_state["fps"]
audio_path = video_state["audio"]
# operation error
if len(np.unique(template_mask))==1:
template_mask[0][0]=1
foreground, alpha = matanyone(matanyone_processor, following_frames, template_mask*255, r_erode=erode_kernel_size, r_dilate=dilate_kernel_size)
output_frames = []
foreground_mat = matting_type == "Foreground"
for frame_origin, frame_alpha in zip(following_frames, alpha):
if foreground_mat:
frame_alpha[frame_alpha > 127] = 255
frame_alpha[frame_alpha <= 127] = 0
else:
frame_temp = frame_alpha.copy()
frame_alpha[frame_temp > 127] = 0
frame_alpha[frame_temp <= 127] = 255
output_frame = np.bitwise_and(frame_origin, 255-frame_alpha)
frame_grey = frame_alpha.copy()
frame_grey[frame_alpha == 255] = 127
output_frame += frame_grey
output_frames.append(output_frame)
foreground = output_frames
if not os.path.exists("mask_outputs"):
os.makedirs("mask_outputs")
file_name= video_state["video_name"]
file_name = ".".join(file_name.split(".")[:-1])
foreground_output = save_video(foreground, output_path="./mask_outputs/{}_fg.mp4".format(file_name), fps=fps)
# foreground_output = generate_video_from_frames(foreground, output_path="./results/{}_fg.mp4".format(video_state["video_name"]), fps=fps, audio_path=audio_path) # import video_input to name the output video
alpha_output = save_video(alpha, output_path="./mask_outputs/{}_alpha.mp4".format(file_name), fps=fps)
# alpha_output = generate_video_from_frames(alpha, output_path="./results/{}_alpha.mp4".format(video_state["video_name"]), fps=fps, gray2rgb=True, audio_path=audio_path) # import video_input to name the output video
return foreground_output, alpha_output, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
def show_outputs():
return gr.update(visible=True), gr.update(visible=True)
def add_audio_to_video(video_path, audio_path, output_path):
try:
video_input = ffmpeg.input(video_path)
audio_input = ffmpeg.input(audio_path)
_ = (
ffmpeg
.output(video_input, audio_input, output_path, vcodec="copy", acodec="aac")
.run(overwrite_output=True, capture_stdout=True, capture_stderr=True)
)
return output_path
except ffmpeg.Error as e:
print(f"FFmpeg error:\n{e.stderr.decode()}")
return None
def generate_video_from_frames(frames, output_path, fps=30, gray2rgb=False, audio_path=""):
"""
Generates a video from a list of frames.
Args:
frames (list of numpy arrays): The frames to include in the video.
output_path (str): The path to save the generated video.
fps (int, optional): The frame rate of the output video. Defaults to 30.
"""
frames = torch.from_numpy(np.asarray(frames))
_, h, w, _ = frames.shape
if gray2rgb:
frames = np.repeat(frames, 3, axis=3)
if not os.path.exists(os.path.dirname(output_path)):
os.makedirs(os.path.dirname(output_path))
video_temp_path = output_path.replace(".mp4", "_temp.mp4")
# resize back to ensure input resolution
imageio.mimwrite(video_temp_path, frames, fps=fps, quality=7,
codec='libx264', ffmpeg_params=["-vf", f"scale={w}:{h}"])
# add audio to video if audio path exists
if audio_path != "" and os.path.exists(audio_path):
output_path = add_audio_to_video(video_temp_path, audio_path, output_path)
os.remove(video_temp_path)
return output_path
else:
return video_temp_path
# reset all states for a new input
def restart():
return {
"user_name": "",
"video_name": "",
"origin_images": None,
"painted_images": None,
"masks": None,
"inpaint_masks": None,
"logits": None,
"select_frame_number": 0,
"fps": 30
}, {
"inference_times": 0,
"negative_click_times" : 0,
"positive_click_times": 0,
"mask_save": arg_mask_save,
"multi_mask": {
"mask_names": [],
"masks": []
},
"track_end_number": None,
}, [[],[]], None, None, \
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),\
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=False, choices=[], value=[]), "", gr.update(visible=False)
def load_unload_models(selected):
global model
global matanyone_model
if selected:
# args, defined in track_anything.py
sam_checkpoint_url_dict = {
'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
}
# os.path.join('.')
from mmgp import offload
# sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[arg_sam_model_type], ".")
sam_checkpoint = None
transfer_stream = torch.cuda.Stream()
with torch.cuda.stream(transfer_stream):
# initialize sams
model = MaskGenerator(sam_checkpoint, "cuda")
from .matanyone.model.matanyone import MatAnyone
matanyone_model = MatAnyone.from_pretrained("PeiqingYang/MatAnyone")
# pipe ={"mat" : matanyone_model, "sam" :model.samcontroler.sam_controler.model }
# offload.profile(pipe)
matanyone_model = matanyone_model.to(arg_device).eval()
matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg)
else:
import gc
model = None
matanyone_model = None
gc.collect()
torch.cuda.empty_cache()
def get_vmc_event_handler():
return load_unload_models
def export_to_vace_video_input(foreground_video_output):
gr.Info("Masked Video Input transferred to Vace For Inpainting")
return "V#" + str(time.time()), foreground_video_output
def export_to_vace_video_mask(foreground_video_output, alpha_video_output):
gr.Info("Masked Video Input and Full Mask transferred to Vace For Stronger Inpainting")
return "MV#" + str(time.time()), foreground_video_output, alpha_video_output
def display(vace_video_input, vace_video_mask, video_prompt_video_guide_trigger):
# my_tab.select(fn=load_unload_models, inputs=[], outputs=[])
media_url = "https://github.com/pq-yang/MatAnyone/releases/download/media/"
# download assets
gr.Markdown("Mast Edition is provided by MatAnyone")
with gr.Column( visible=True):
with gr.Row():
with gr.Accordion("Video Tutorial (click to expand)", open=False, elem_classes="custom-bg"):
with gr.Row():
with gr.Column():
gr.Markdown("### Case 1: Single Target")
gr.Video(value="preprocessing/matanyone/tutorial_single_target.mp4", elem_classes="video")
with gr.Column():
gr.Markdown("### Case 2: Multiple Targets")
gr.Video(value="preprocessing/matanyone/tutorial_multi_targets.mp4", elem_classes="video")
click_state = gr.State([[],[]])
interactive_state = gr.State({
"inference_times": 0,
"negative_click_times" : 0,
"positive_click_times": 0,
"mask_save": arg_mask_save,
"multi_mask": {
"mask_names": [],
"masks": []
},
"track_end_number": None,
}
)
video_state = gr.State(
{
"user_name": "",
"video_name": "",
"origin_images": None,
"painted_images": None,
"masks": None,
"inpaint_masks": None,
"logits": None,
"select_frame_number": 0,
"fps": 16,
"audio": "",
}
)
with gr.Column( visible=True):
with gr.Row():
with gr.Accordion('MatAnyone Settings (click to expand)', open=False):
with gr.Row():
erode_kernel_size = gr.Slider(label='Erode Kernel Size',
minimum=0,
maximum=30,
step=1,
value=10,
info="Erosion on the added mask",
interactive=True)
dilate_kernel_size = gr.Slider(label='Dilate Kernel Size',
minimum=0,
maximum=30,
step=1,
value=10,
info="Dilation on the added mask",
interactive=True)
with gr.Row():
image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Start Frame", info="Choose the start frame for target assignment and video matting", visible=False)
end_selection_slider = gr.Slider(minimum=1, maximum=300, step=1, value=81, label="Last Frame to Process", info="Last Frame to Process", visible=False)
track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="End frame", visible=False)
with gr.Row():
point_prompt = gr.Radio(
choices=["Positive", "Negative"],
value="Positive",
label="Point Prompt",
info="Click to add positive or negative point for target mask",
interactive=True,
visible=False,
min_width=100,
scale=1)
matting_type = gr.Radio(
choices=["Foreground", "Background"],
value="Foreground",
label="Matting Type",
info="Type of Video Matting to Generate",
interactive=True,
visible=False,
min_width=100,
scale=1)
mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask Selection", info="Choose 1~all mask(s) added in Step 2", visible=False, scale=2)
gr.Markdown("---")
with gr.Column():
# input video
with gr.Row(equal_height=True):
with gr.Column(scale=2):
gr.Markdown("## Step1: Upload video")
with gr.Column(scale=2):
step2_title = gr.Markdown("## Step2: Add masks <small>(Several clicks then **`Add Mask`** <u>one by one</u>)</small>", visible=False)
with gr.Row(equal_height=True):
with gr.Column(scale=2):
video_input = gr.Video(label="Input Video", elem_classes="video")
extract_frames_button = gr.Button(value="Load Video", interactive=True, elem_classes="new_button")
with gr.Column(scale=2):
video_info = gr.Textbox(label="Video Info", visible=False)
template_frame = gr.Image(label="Start Frame", type="pil",interactive=True, elem_id="template_frame", visible=False, elem_classes="image")
with gr.Row():
clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False, min_width=100)
add_mask_button = gr.Button(value="Set Mask", interactive=True, visible=False, min_width=100)
remove_mask_button = gr.Button(value="Remove Mask", interactive=True, visible=False, min_width=100) # no use
matting_button = gr.Button(value="Generate Video Matting", interactive=True, visible=False, min_width=100)
with gr.Row():
gr.Markdown("")
# output video
with gr.Row(equal_height=True) as output_row:
with gr.Column(scale=2):
foreground_video_output = gr.Video(label="Masked Video Output", visible=False, elem_classes="video")
foreground_output_button = gr.Button(value="Black & White Video Output", visible=False, elem_classes="new_button")
export_to_vace_video_input_btn = gr.Button("Export to Vace Video Input Video For Inpainting", visible= False)
with gr.Column(scale=2):
alpha_video_output = gr.Video(label="B & W Mask Video Output", visible=False, elem_classes="video")
alpha_output_button = gr.Button(value="Alpha Mask Output", visible=False, elem_classes="new_button")
export_to_vace_video_mask_btn = gr.Button("Export to Vace Video Input and Video Mask for stronger Inpainting", visible= False)
export_to_vace_video_input_btn.click(fn=export_to_vace_video_input, inputs= [foreground_video_output], outputs= [video_prompt_video_guide_trigger, vace_video_input])
export_to_vace_video_mask_btn.click(fn=export_to_vace_video_mask, inputs= [foreground_video_output, alpha_video_output], outputs= [video_prompt_video_guide_trigger, vace_video_input, vace_video_mask])
# first step: get the video information
extract_frames_button.click(
fn=get_frames_from_video,
inputs=[
video_input, video_state
],
outputs=[video_state, video_info, template_frame,
image_selection_slider, end_selection_slider, track_pause_number_slider, point_prompt, matting_type, clear_button_click, add_mask_button, matting_button, template_frame,
foreground_video_output, alpha_video_output, foreground_output_button, alpha_output_button, mask_dropdown, step2_title]
)
# second step: select images from slider
image_selection_slider.release(fn=select_video_template,
inputs=[image_selection_slider, video_state, interactive_state],
outputs=[template_frame, video_state, interactive_state], api_name="select_image")
track_pause_number_slider.release(fn=get_end_number,
inputs=[track_pause_number_slider, video_state, interactive_state],
outputs=[template_frame, interactive_state], api_name="end_image")
# click select image to get mask using sam
template_frame.select(
fn=sam_refine,
inputs=[video_state, point_prompt, click_state, interactive_state],
outputs=[template_frame, video_state, interactive_state]
)
# add different mask
add_mask_button.click(
fn=add_multi_mask,
inputs=[video_state, interactive_state, mask_dropdown],
outputs=[interactive_state, mask_dropdown, template_frame, click_state]
)
remove_mask_button.click(
fn=remove_multi_mask,
inputs=[interactive_state, mask_dropdown],
outputs=[interactive_state, mask_dropdown]
)
# video matting
matting_button.click(
fn=show_outputs,
inputs=[],
outputs=[foreground_video_output, alpha_video_output]).then(
fn=video_matting,
inputs=[video_state, end_selection_slider, matting_type, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size],
outputs=[foreground_video_output, alpha_video_output,foreground_video_output, alpha_video_output, export_to_vace_video_input_btn, export_to_vace_video_mask_btn]
)
# click to get mask
mask_dropdown.change(
fn=show_mask,
inputs=[video_state, interactive_state, mask_dropdown],
outputs=[template_frame]
)
# clear input
video_input.change(
fn=restart,
inputs=[],
outputs=[
video_state,
interactive_state,
click_state,
foreground_video_output, alpha_video_output,
template_frame,
image_selection_slider, end_selection_slider, track_pause_number_slider,point_prompt, export_to_vace_video_input_btn, export_to_vace_video_mask_btn, matting_type, clear_button_click,
add_mask_button, matting_button, template_frame, foreground_video_output, alpha_video_output, remove_mask_button, foreground_output_button, alpha_output_button, mask_dropdown, video_info, step2_title
],
queue=False,
show_progress=False)
video_input.clear(
fn=restart,
inputs=[],
outputs=[
video_state,
interactive_state,
click_state,
foreground_video_output, alpha_video_output,
template_frame,
image_selection_slider , end_selection_slider, track_pause_number_slider,point_prompt, export_to_vace_video_input_btn, export_to_vace_video_mask_btn, matting_type, clear_button_click,
add_mask_button, matting_button, template_frame, foreground_video_output, alpha_video_output, remove_mask_button, foreground_output_button, alpha_output_button, mask_dropdown, video_info, step2_title
],
queue=False,
show_progress=False)
# points clear
clear_button_click.click(
fn = clear_click,
inputs = [video_state, click_state,],
outputs = [template_frame,click_state],
)

View File

@ -0,0 +1,47 @@
defaults:
- _self_
- model: base
- override hydra/job_logging: custom-no-rank.yaml
hydra:
run:
dir: ../output/${exp_id}/${dataset}
output_subdir: ${now:%Y-%m-%d_%H-%M-%S}-hydra
amp: False
weights: pretrained_models/matanyone.pth # default (can be modified from outside)
output_dir: null # defaults to run_dir; specify this to override
flip_aug: False
# maximum shortest side of the input; -1 means no resizing
# With eval_vos.py, we usually just use the dataset's size (resizing done in dataloader)
# this parameter is added for the sole purpose for the GUI in the current codebase
# InferenceCore will downsize the input and restore the output to the original size if needed
# if you are using this code for some other project, you can also utilize this parameter
max_internal_size: -1
# these parameters, when set, override the dataset's default; useful for debugging
save_all: True
use_all_masks: False
use_long_term: False
mem_every: 5
# only relevant when long_term is not enabled
max_mem_frames: 5
# only relevant when long_term is enabled
long_term:
count_usage: True
max_mem_frames: 10
min_mem_frames: 5
num_prototypes: 128
max_num_tokens: 10000
buffer_tokens: 2000
top_k: 30
stagger_updates: 5
chunk_size: -1 # number of objects to process in parallel; -1 means unlimited
save_scores: False
save_aux: False
visualize: False

View File

@ -0,0 +1,22 @@
# python logging configuration for tasks
version: 1
formatters:
simple:
format: '[%(asctime)s][%(levelname)s] - %(message)s'
datefmt: '%Y-%m-%d %H:%M:%S'
handlers:
console:
class: logging.StreamHandler
formatter: simple
stream: ext://sys.stdout
file:
class: logging.FileHandler
formatter: simple
# absolute file path
filename: ${hydra.runtime.output_dir}/${now:%Y-%m-%d_%H-%M-%S}-eval.log
mode: w
root:
level: INFO
handlers: [console, file]
disable_existing_loggers: false

View File

@ -0,0 +1,22 @@
# python logging configuration for tasks
version: 1
formatters:
simple:
format: '[%(asctime)s][%(levelname)s][r${oc.env:LOCAL_RANK}] - %(message)s'
datefmt: '%Y-%m-%d %H:%M:%S'
handlers:
console:
class: logging.StreamHandler
formatter: simple
stream: ext://sys.stdout
file:
class: logging.FileHandler
formatter: simple
# absolute file path
filename: ${hydra.runtime.output_dir}/${now:%Y-%m-%d_%H-%M-%S}-rank${oc.env:LOCAL_RANK}.log
mode: w
root:
level: INFO
handlers: [console, file]
disable_existing_loggers: false

View File

@ -0,0 +1,58 @@
pixel_mean: [0.485, 0.456, 0.406]
pixel_std: [0.229, 0.224, 0.225]
pixel_dim: 256
key_dim: 64
value_dim: 256
sensory_dim: 256
embed_dim: 256
pixel_encoder:
type: resnet50
ms_dims: [1024, 512, 256, 64, 3] # f16, f8, f4, f2, f1
mask_encoder:
type: resnet18
final_dim: 256
pixel_pe_scale: 32
pixel_pe_temperature: 128
object_transformer:
embed_dim: ${model.embed_dim}
ff_dim: 2048
num_heads: 8
num_blocks: 3
num_queries: 16
read_from_pixel:
input_norm: False
input_add_pe: False
add_pe_to_qkv: [True, True, False]
read_from_past:
add_pe_to_qkv: [True, True, False]
read_from_memory:
add_pe_to_qkv: [True, True, False]
read_from_query:
add_pe_to_qkv: [True, True, False]
output_norm: False
query_self_attention:
add_pe_to_qkv: [True, True, False]
pixel_self_attention:
add_pe_to_qkv: [True, True, False]
object_summarizer:
embed_dim: ${model.object_transformer.embed_dim}
num_summaries: ${model.object_transformer.num_queries}
add_pe: True
aux_loss:
sensory:
enabled: True
weight: 0.01
query:
enabled: True
weight: 0.01
mask_decoder:
# first value must equal embed_dim
up_dims: [256, 128, 128, 64, 16]

View File

@ -0,0 +1,56 @@
import warnings
from typing import Iterable
import torch
from ..model.matanyone import MatAnyone
class ImageFeatureStore:
"""
A cache for image features.
These features might be reused at different parts of the inference pipeline.
This class provide an interface for reusing these features.
It is the user's responsibility to delete redundant features.
Feature of a frame should be associated with a unique index -- typically the frame id.
"""
def __init__(self, network: MatAnyone, no_warning: bool = False):
self.network = network
self._store = {}
self.no_warning = no_warning
def _encode_feature(self, index: int, image: torch.Tensor, last_feats=None) -> None:
ms_features, pix_feat = self.network.encode_image(image, last_feats=last_feats)
key, shrinkage, selection = self.network.transform_key(ms_features[0])
self._store[index] = (ms_features, pix_feat, key, shrinkage, selection)
def get_all_features(self, images: torch.Tensor) -> (Iterable[torch.Tensor], torch.Tensor):
seq_length = images.shape[0]
ms_features, pix_feat = self.network.encode_image(images, seq_length)
key, shrinkage, selection = self.network.transform_key(ms_features[0])
for index in range(seq_length):
self._store[index] = ([f[index].unsqueeze(0) for f in ms_features], pix_feat[index].unsqueeze(0), key[index].unsqueeze(0), shrinkage[index].unsqueeze(0), selection[index].unsqueeze(0))
def get_features(self, index: int,
image: torch.Tensor, last_feats=None) -> (Iterable[torch.Tensor], torch.Tensor):
if index not in self._store:
self._encode_feature(index, image, last_feats)
return self._store[index][:2]
def get_key(self, index: int,
image: torch.Tensor, last_feats=None) -> (torch.Tensor, torch.Tensor, torch.Tensor):
if index not in self._store:
self._encode_feature(index, image, last_feats)
return self._store[index][2:]
def delete(self, index: int) -> None:
if index in self._store:
del self._store[index]
def __len__(self):
return len(self._store)
def __del__(self):
if len(self._store) > 0 and not self.no_warning:
warnings.warn(f'Leaking {self._store.keys()} in the image feature store')

View File

@ -0,0 +1,406 @@
from typing import List, Optional, Iterable
import logging
from omegaconf import DictConfig
import numpy as np
import torch
import torch.nn.functional as F
from .memory_manager import MemoryManager
from .object_manager import ObjectManager
from .image_feature_store import ImageFeatureStore
from ..model.matanyone import MatAnyone
from ...utils.tensor_utils import pad_divide_by, unpad, aggregate
log = logging.getLogger()
class InferenceCore:
def __init__(self,
network: MatAnyone,
cfg: DictConfig,
*,
image_feature_store: ImageFeatureStore = None):
self.network = network
self.cfg = cfg
self.mem_every = cfg.mem_every
stagger_updates = cfg.stagger_updates
self.chunk_size = cfg.chunk_size
self.save_aux = cfg.save_aux
self.max_internal_size = cfg.max_internal_size
self.flip_aug = cfg.flip_aug
self.curr_ti = -1
self.last_mem_ti = 0
# at which time indices should we update the sensory memory
if stagger_updates >= self.mem_every:
self.stagger_ti = set(range(1, self.mem_every + 1))
else:
self.stagger_ti = set(
np.round(np.linspace(1, self.mem_every, stagger_updates)).astype(int))
self.object_manager = ObjectManager()
self.memory = MemoryManager(cfg=cfg, object_manager=self.object_manager)
if image_feature_store is None:
self.image_feature_store = ImageFeatureStore(self.network)
else:
self.image_feature_store = image_feature_store
self.last_mask = None
self.last_pix_feat = None
self.last_msk_value = None
def clear_memory(self):
self.curr_ti = -1
self.last_mem_ti = 0
self.memory = MemoryManager(cfg=self.cfg, object_manager=self.object_manager)
def clear_non_permanent_memory(self):
self.curr_ti = -1
self.last_mem_ti = 0
self.memory.clear_non_permanent_memory()
def clear_sensory_memory(self):
self.curr_ti = -1
self.last_mem_ti = 0
self.memory.clear_sensory_memory()
def update_config(self, cfg):
self.mem_every = cfg['mem_every']
self.memory.update_config(cfg)
def clear_temp_mem(self):
self.memory.clear_work_mem()
# self.object_manager = ObjectManager()
self.memory.clear_obj_mem()
# self.memory.clear_sensory_memory()
def _add_memory(self,
image: torch.Tensor,
pix_feat: torch.Tensor,
prob: torch.Tensor,
key: torch.Tensor,
shrinkage: torch.Tensor,
selection: torch.Tensor,
*,
is_deep_update: bool = True,
force_permanent: bool = False) -> None:
"""
Memorize the given segmentation in all memory stores.
The batch dimension is 1 if flip augmentation is not used.
image: RGB image, (1/2)*3*H*W
pix_feat: from the key encoder, (1/2)*_*H*W
prob: (1/2)*num_objects*H*W, in [0, 1]
key/shrinkage/selection: for anisotropic l2, (1/2)*_*H*W
selection can be None if not using long-term memory
is_deep_update: whether to use deep update (e.g. with the mask encoder)
force_permanent: whether to force the memory to be permanent
"""
if prob.shape[1] == 0:
# nothing to add
log.warn('Trying to add an empty object mask to memory!')
return
if force_permanent:
as_permanent = 'all'
else:
as_permanent = 'first'
self.memory.initialize_sensory_if_needed(key, self.object_manager.all_obj_ids)
msk_value, sensory, obj_value, _ = self.network.encode_mask(
image,
pix_feat,
self.memory.get_sensory(self.object_manager.all_obj_ids),
prob,
deep_update=is_deep_update,
chunk_size=self.chunk_size,
need_weights=self.save_aux)
self.memory.add_memory(key,
shrinkage,
msk_value,
obj_value,
self.object_manager.all_obj_ids,
selection=selection,
as_permanent=as_permanent)
self.last_mem_ti = self.curr_ti
if is_deep_update:
self.memory.update_sensory(sensory, self.object_manager.all_obj_ids)
self.last_msk_value = msk_value
def _segment(self,
key: torch.Tensor,
selection: torch.Tensor,
pix_feat: torch.Tensor,
ms_features: Iterable[torch.Tensor],
update_sensory: bool = True) -> torch.Tensor:
"""
Produce a segmentation using the given features and the memory
The batch dimension is 1 if flip augmentation is not used.
key/selection: for anisotropic l2: (1/2) * _ * H * W
pix_feat: from the key encoder, (1/2) * _ * H * W
ms_features: an iterable of multiscale features from the encoder, each is (1/2)*_*H*W
with strides 16, 8, and 4 respectively
update_sensory: whether to update the sensory memory
Returns: (num_objects+1)*H*W normalized probability; the first channel is the background
"""
bs = key.shape[0]
if self.flip_aug:
assert bs == 2
else:
assert bs == 1
if not self.memory.engaged:
log.warn('Trying to segment without any memory!')
return torch.zeros((1, key.shape[-2] * 16, key.shape[-1] * 16),
device=key.device,
dtype=key.dtype)
uncert_output = None
if self.curr_ti == 0: # ONLY for the first frame for prediction
memory_readout = self.memory.read_first_frame(self.last_msk_value, pix_feat, self.last_mask, self.network, uncert_output=uncert_output)
else:
memory_readout = self.memory.read(pix_feat, key, selection, self.last_mask, self.network, uncert_output=uncert_output, last_msk_value=self.last_msk_value, ti=self.curr_ti,
last_pix_feat=self.last_pix_feat, last_pred_mask=self.last_mask)
memory_readout = self.object_manager.realize_dict(memory_readout)
sensory, _, pred_prob_with_bg = self.network.segment(ms_features,
memory_readout,
self.memory.get_sensory(
self.object_manager.all_obj_ids),
chunk_size=self.chunk_size,
update_sensory=update_sensory)
# remove batch dim
if self.flip_aug:
# average predictions of the non-flipped and flipped version
pred_prob_with_bg = (pred_prob_with_bg[0] +
torch.flip(pred_prob_with_bg[1], dims=[-1])) / 2
else:
pred_prob_with_bg = pred_prob_with_bg[0]
if update_sensory:
self.memory.update_sensory(sensory, self.object_manager.all_obj_ids)
return pred_prob_with_bg
def pred_all_flow(self, images):
self.total_len = images.shape[0]
images, self.pad = pad_divide_by(images, 16)
images = images.unsqueeze(0) # add the batch dimension: (1,t,c,h,w)
self.flows_forward, self.flows_backward = self.network.pred_forward_backward_flow(images)
def encode_all_images(self, images):
images, self.pad = pad_divide_by(images, 16)
self.image_feature_store.get_all_features(images) # t c h w
return images
def step(self,
image: torch.Tensor,
mask: Optional[torch.Tensor] = None,
objects: Optional[List[int]] = None,
*,
idx_mask: bool = False,
end: bool = False,
delete_buffer: bool = True,
force_permanent: bool = False,
matting: bool = True,
first_frame_pred: bool = False) -> torch.Tensor:
"""
Take a step with a new incoming image.
If there is an incoming mask with new objects, we will memorize them.
If there is no incoming mask, we will segment the image using the memory.
In both cases, we will update the memory and return a segmentation.
image: 3*H*W
mask: H*W (if idx mask) or len(objects)*H*W or None
objects: list of object ids that are valid in the mask Tensor.
The ids themselves do not need to be consecutive/in order, but they need to be
in the same position in the list as the corresponding mask
in the tensor in non-idx-mask mode.
objects is ignored if the mask is None.
If idx_mask is False and objects is None, we sequentially infer the object ids.
idx_mask: if True, mask is expected to contain an object id at every pixel.
If False, mask should have multiple channels with each channel representing one object.
end: if we are at the end of the sequence, we do not need to update memory
if unsure just set it to False
delete_buffer: whether to delete the image feature buffer after this step
force_permanent: the memory recorded this frame will be added to the permanent memory
"""
if objects is None and mask is not None:
assert not idx_mask
objects = list(range(1, mask.shape[0] + 1))
# resize input if needed -- currently only used for the GUI
resize_needed = False
if self.max_internal_size > 0:
h, w = image.shape[-2:]
min_side = min(h, w)
if min_side > self.max_internal_size:
resize_needed = True
new_h = int(h / min_side * self.max_internal_size)
new_w = int(w / min_side * self.max_internal_size)
image = F.interpolate(image.unsqueeze(0),
size=(new_h, new_w),
mode='bilinear',
align_corners=False)[0]
if mask is not None:
if idx_mask:
mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).float(),
size=(new_h, new_w),
mode='nearest-exact',
align_corners=False)[0, 0].round().long()
else:
mask = F.interpolate(mask.unsqueeze(0),
size=(new_h, new_w),
mode='bilinear',
align_corners=False)[0]
self.curr_ti += 1
image, self.pad = pad_divide_by(image, 16) # DONE alreay for 3DCNN!!
image = image.unsqueeze(0) # add the batch dimension
if self.flip_aug:
image = torch.cat([image, torch.flip(image, dims=[-1])], dim=0)
# whether to update the working memory
is_mem_frame = ((self.curr_ti - self.last_mem_ti >= self.mem_every) or
(mask is not None)) and (not end)
# segment when there is no input mask or when the input mask is incomplete
need_segment = (mask is None) or (self.object_manager.num_obj > 0
and not self.object_manager.has_all(objects))
update_sensory = ((self.curr_ti - self.last_mem_ti) in self.stagger_ti) and (not end)
# reinit if it is the first frame for prediction
if first_frame_pred:
self.curr_ti = 0
self.last_mem_ti = 0
is_mem_frame = True
need_segment = True
update_sensory = True
# encoding the image
ms_feat, pix_feat = self.image_feature_store.get_features(self.curr_ti, image)
key, shrinkage, selection = self.image_feature_store.get_key(self.curr_ti, image)
# segmentation from memory if needed
if need_segment:
pred_prob_with_bg = self._segment(key,
selection,
pix_feat,
ms_feat,
update_sensory=update_sensory)
# use the input mask if provided
if mask is not None:
# inform the manager of the new objects, and get a list of temporary id
# temporary ids -- indicates the position of objects in the tensor
# (starts with 1 due to the background channel)
corresponding_tmp_ids, _ = self.object_manager.add_new_objects(objects)
mask, _ = pad_divide_by(mask, 16)
if need_segment:
# merge predicted mask with the incomplete input mask
pred_prob_no_bg = pred_prob_with_bg[1:]
# use the mutual exclusivity of segmentation
if idx_mask:
pred_prob_no_bg[:, mask > 0] = 0
else:
pred_prob_no_bg[:, mask.max(0) > 0.5] = 0
new_masks = []
for mask_id, tmp_id in enumerate(corresponding_tmp_ids):
if idx_mask:
this_mask = (mask == objects[mask_id]).type_as(pred_prob_no_bg)
else:
this_mask = mask[tmp_id]
if tmp_id > pred_prob_no_bg.shape[0]:
new_masks.append(this_mask.unsqueeze(0))
else:
# +1 for padding the background channel
pred_prob_no_bg[tmp_id - 1] = this_mask
# new_masks are always in the order of tmp_id
mask = torch.cat([pred_prob_no_bg, *new_masks], dim=0)
elif idx_mask:
# simply convert cls to one-hot representation
if len(objects) == 0:
if delete_buffer:
self.image_feature_store.delete(self.curr_ti)
log.warn('Trying to insert an empty mask as memory!')
return torch.zeros((1, key.shape[-2] * 16, key.shape[-1] * 16),
device=key.device,
dtype=key.dtype)
mask = torch.stack(
[mask == objects[mask_id] for mask_id, _ in enumerate(corresponding_tmp_ids)],
dim=0)
if matting:
mask = mask.unsqueeze(0).float() / 255.
pred_prob_with_bg = torch.cat([1-mask, mask], 0)
else:
pred_prob_with_bg = aggregate(mask, dim=0)
pred_prob_with_bg = torch.softmax(pred_prob_with_bg, dim=0)
self.last_mask = pred_prob_with_bg[1:].unsqueeze(0)
if self.flip_aug:
self.last_mask = torch.cat(
[self.last_mask, torch.flip(self.last_mask, dims=[-1])], dim=0)
self.last_pix_feat = pix_feat
# save as memory if needed
if is_mem_frame or force_permanent:
# clear the memory for given mask and add the first predicted mask
if first_frame_pred:
self.clear_temp_mem()
self._add_memory(image,
pix_feat,
self.last_mask,
key,
shrinkage,
selection,
force_permanent=force_permanent,
is_deep_update=True)
else: # compute self.last_msk_value for non-memory frame
msk_value, _, _, _ = self.network.encode_mask(
image,
pix_feat,
self.memory.get_sensory(self.object_manager.all_obj_ids),
self.last_mask,
deep_update=False,
chunk_size=self.chunk_size,
need_weights=self.save_aux)
self.last_msk_value = msk_value
if delete_buffer:
self.image_feature_store.delete(self.curr_ti)
output_prob = unpad(pred_prob_with_bg, self.pad)
if resize_needed:
# restore output to the original size
output_prob = F.interpolate(output_prob.unsqueeze(0),
size=(h, w),
mode='bilinear',
align_corners=False)[0]
return output_prob
def delete_objects(self, objects: List[int]) -> None:
"""
Delete the given objects from the memory.
"""
self.object_manager.delete_objects(objects)
self.memory.purge_except(self.object_manager.all_obj_ids)
def output_prob_to_mask(self, output_prob: torch.Tensor, matting: bool = True) -> torch.Tensor:
if matting:
new_mask = output_prob[1:].squeeze(0)
else:
mask = torch.argmax(output_prob, dim=0)
# index in tensor != object id -- remap the ids here
new_mask = torch.zeros_like(mask)
for tmp_id, obj in self.object_manager.tmp_id_to_obj.items():
new_mask[mask == tmp_id] = obj.id
return new_mask

View File

@ -0,0 +1,348 @@
from typing import Dict, List, Optional, Literal
from collections import defaultdict
import torch
def _add_last_dim(dictionary, key, new_value, prepend=False):
# append/prepend a new value to the last dimension of a tensor in a dictionary
# if the key does not exist, put the new value in
# append by default
if key in dictionary:
dictionary[key] = torch.cat([dictionary[key], new_value], -1)
else:
dictionary[key] = new_value
class KeyValueMemoryStore:
"""
Works for key/value pairs type storage
e.g., working and long-term memory
"""
def __init__(self, save_selection: bool = False, save_usage: bool = False):
"""
We store keys and values of objects that first appear in the same frame in a bucket.
Each bucket contains a set of object ids.
Each bucket is associated with a single key tensor
and a dictionary of value tensors indexed by object id.
The keys and values are stored as the concatenation of a permanent part and a temporary part.
"""
self.save_selection = save_selection
self.save_usage = save_usage
self.global_bucket_id = 0 # does not reduce even if buckets are removed
self.buckets: Dict[int, List[int]] = {} # indexed by bucket id
self.k: Dict[int, torch.Tensor] = {} # indexed by bucket id
self.v: Dict[int, torch.Tensor] = {} # indexed by object id
# indexed by bucket id; the end point of permanent memory
self.perm_end_pt: Dict[int, int] = defaultdict(int)
# shrinkage and selection are just like the keys
self.s = {}
if self.save_selection:
self.e = {} # does not contain the permanent memory part
# usage
if self.save_usage:
self.use_cnt = {} # indexed by bucket id, does not contain the permanent memory part
self.life_cnt = {} # indexed by bucket id, does not contain the permanent memory part
def add(self,
key: torch.Tensor,
values: Dict[int, torch.Tensor],
shrinkage: torch.Tensor,
selection: torch.Tensor,
supposed_bucket_id: int = -1,
as_permanent: Literal['no', 'first', 'all'] = 'no') -> None:
"""
key: (1/2)*C*N
values: dict of values ((1/2)*C*N), object ids are used as keys
shrinkage: (1/2)*1*N
selection: (1/2)*C*N
supposed_bucket_id: used to sync the bucket id between working and long-term memory
if provided, the input should all be in a single bucket indexed by this id
as_permanent: whether to store the input as permanent memory
'no': don't
'first': only store it as permanent memory if the bucket is empty
'all': always store it as permanent memory
"""
bs = key.shape[0]
ne = key.shape[-1]
assert len(key.shape) == 3
assert len(shrinkage.shape) == 3
assert not self.save_selection or len(selection.shape) == 3
assert as_permanent in ['no', 'first', 'all']
# add the value and create new buckets if necessary
if supposed_bucket_id >= 0:
enabled_buckets = [supposed_bucket_id]
bucket_exist = supposed_bucket_id in self.buckets
for obj, value in values.items():
if bucket_exist:
assert obj in self.v
assert obj in self.buckets[supposed_bucket_id]
_add_last_dim(self.v, obj, value, prepend=(as_permanent == 'all'))
else:
assert obj not in self.v
self.v[obj] = value
self.buckets[supposed_bucket_id] = list(values.keys())
else:
new_bucket_id = None
enabled_buckets = set()
for obj, value in values.items():
assert len(value.shape) == 3
if obj in self.v:
_add_last_dim(self.v, obj, value, prepend=(as_permanent == 'all'))
bucket_used = [
bucket_id for bucket_id, object_ids in self.buckets.items()
if obj in object_ids
]
assert len(bucket_used) == 1 # each object should only be in one bucket
enabled_buckets.add(bucket_used[0])
else:
self.v[obj] = value
if new_bucket_id is None:
# create new bucket
new_bucket_id = self.global_bucket_id
self.global_bucket_id += 1
self.buckets[new_bucket_id] = []
# put the new object into the corresponding bucket
self.buckets[new_bucket_id].append(obj)
enabled_buckets.add(new_bucket_id)
# increment the permanent size if necessary
add_as_permanent = {} # indexed by bucket id
for bucket_id in enabled_buckets:
add_as_permanent[bucket_id] = False
if as_permanent == 'all':
self.perm_end_pt[bucket_id] += ne
add_as_permanent[bucket_id] = True
elif as_permanent == 'first':
if self.perm_end_pt[bucket_id] == 0:
self.perm_end_pt[bucket_id] = ne
add_as_permanent[bucket_id] = True
# create new counters for usage if necessary
if self.save_usage and as_permanent != 'all':
new_count = torch.zeros((bs, ne), device=key.device, dtype=torch.float32)
new_life = torch.zeros((bs, ne), device=key.device, dtype=torch.float32) + 1e-7
# add the key to every bucket
for bucket_id in self.buckets:
if bucket_id not in enabled_buckets:
# if we are not adding new values to a bucket, we should skip it
continue
_add_last_dim(self.k, bucket_id, key, prepend=add_as_permanent[bucket_id])
_add_last_dim(self.s, bucket_id, shrinkage, prepend=add_as_permanent[bucket_id])
if not add_as_permanent[bucket_id]:
if self.save_selection:
_add_last_dim(self.e, bucket_id, selection)
if self.save_usage:
_add_last_dim(self.use_cnt, bucket_id, new_count)
_add_last_dim(self.life_cnt, bucket_id, new_life)
def update_bucket_usage(self, bucket_id: int, usage: torch.Tensor) -> None:
# increase all life count by 1
# increase use of indexed elements
if not self.save_usage:
return
usage = usage[:, self.perm_end_pt[bucket_id]:]
if usage.shape[-1] == 0:
# if there is no temporary memory, we don't need to update
return
self.use_cnt[bucket_id] += usage.view_as(self.use_cnt[bucket_id])
self.life_cnt[bucket_id] += 1
def sieve_by_range(self, bucket_id: int, start: int, end: int, min_size: int) -> None:
# keep only the temporary elements *outside* of this range (with some boundary conditions)
# the permanent elements are ignored in this computation
# i.e., concat (a[:start], a[end:])
# bucket with size <= min_size are not modified
assert start >= 0
assert end <= 0
object_ids = self.buckets[bucket_id]
bucket_num_elements = self.k[bucket_id].shape[-1] - self.perm_end_pt[bucket_id]
if bucket_num_elements <= min_size:
return
if end == 0:
# negative 0 would not work as the end index!
# effectively make the second part an empty slice
end = self.k[bucket_id].shape[-1] + 1
p_size = self.perm_end_pt[bucket_id]
start = start + p_size
k = self.k[bucket_id]
s = self.s[bucket_id]
if self.save_selection:
e = self.e[bucket_id]
if self.save_usage:
use_cnt = self.use_cnt[bucket_id]
life_cnt = self.life_cnt[bucket_id]
self.k[bucket_id] = torch.cat([k[:, :, :start], k[:, :, end:]], -1)
self.s[bucket_id] = torch.cat([s[:, :, :start], s[:, :, end:]], -1)
if self.save_selection:
self.e[bucket_id] = torch.cat([e[:, :, :start - p_size], e[:, :, end:]], -1)
if self.save_usage:
self.use_cnt[bucket_id] = torch.cat([use_cnt[:, :start - p_size], use_cnt[:, end:]], -1)
self.life_cnt[bucket_id] = torch.cat([life_cnt[:, :start - p_size], life_cnt[:, end:]],
-1)
for obj_id in object_ids:
v = self.v[obj_id]
self.v[obj_id] = torch.cat([v[:, :, :start], v[:, :, end:]], -1)
def remove_old_memory(self, bucket_id: int, max_len: int) -> None:
self.sieve_by_range(bucket_id, 0, -max_len, max_len)
def remove_obsolete_features(self, bucket_id: int, max_size: int) -> None:
# for long-term memory only
object_ids = self.buckets[bucket_id]
assert self.perm_end_pt[bucket_id] == 0 # permanent memory should be empty in LT memory
# normalize with life duration
usage = self.get_usage(bucket_id)
bs = usage.shape[0]
survivals = []
for bi in range(bs):
_, survived = torch.topk(usage[bi], k=max_size)
survivals.append(survived.flatten())
assert survived.shape[-1] == survivals[0].shape[-1]
self.k[bucket_id] = torch.stack(
[self.k[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
self.s[bucket_id] = torch.stack(
[self.s[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
if self.save_selection:
# Long-term memory does not store selection so this should not be needed
self.e[bucket_id] = torch.stack(
[self.e[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
for obj_id in object_ids:
self.v[obj_id] = torch.stack(
[self.v[obj_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
self.use_cnt[bucket_id] = torch.stack(
[self.use_cnt[bucket_id][bi, survived] for bi, survived in enumerate(survivals)], 0)
self.life_cnt[bucket_id] = torch.stack(
[self.life_cnt[bucket_id][bi, survived] for bi, survived in enumerate(survivals)], 0)
def get_usage(self, bucket_id: int) -> torch.Tensor:
# return normalized usage
if not self.save_usage:
raise RuntimeError('I did not count usage!')
else:
usage = self.use_cnt[bucket_id] / self.life_cnt[bucket_id]
return usage
def get_all_sliced(
self, bucket_id: int, start: int, end: int
) -> (torch.Tensor, torch.Tensor, torch.Tensor, Dict[int, torch.Tensor], torch.Tensor):
# return k, sk, ek, value, normalized usage in order, sliced by start and end
# this only queries the temporary memory
assert start >= 0
assert end <= 0
p_size = self.perm_end_pt[bucket_id]
start = start + p_size
if end == 0:
# negative 0 would not work as the end index!
k = self.k[bucket_id][:, :, start:]
sk = self.s[bucket_id][:, :, start:]
ek = self.e[bucket_id][:, :, start - p_size:] if self.save_selection else None
value = {obj_id: self.v[obj_id][:, :, start:] for obj_id in self.buckets[bucket_id]}
usage = self.get_usage(bucket_id)[:, start - p_size:] if self.save_usage else None
else:
k = self.k[bucket_id][:, :, start:end]
sk = self.s[bucket_id][:, :, start:end]
ek = self.e[bucket_id][:, :, start - p_size:end] if self.save_selection else None
value = {obj_id: self.v[obj_id][:, :, start:end] for obj_id in self.buckets[bucket_id]}
usage = self.get_usage(bucket_id)[:, start - p_size:end] if self.save_usage else None
return k, sk, ek, value, usage
def purge_except(self, obj_keep_idx: List[int]):
# purge certain objects from the memory except the one listed
obj_keep_idx = set(obj_keep_idx)
# remove objects that are not in the keep list from the buckets
buckets_to_remove = []
for bucket_id, object_ids in self.buckets.items():
self.buckets[bucket_id] = [obj_id for obj_id in object_ids if obj_id in obj_keep_idx]
if len(self.buckets[bucket_id]) == 0:
buckets_to_remove.append(bucket_id)
# remove object values that are not in the keep list
self.v = {k: v for k, v in self.v.items() if k in obj_keep_idx}
# remove buckets that are empty
for bucket_id in buckets_to_remove:
del self.buckets[bucket_id]
del self.k[bucket_id]
del self.s[bucket_id]
if self.save_selection:
del self.e[bucket_id]
if self.save_usage:
del self.use_cnt[bucket_id]
del self.life_cnt[bucket_id]
def clear_non_permanent_memory(self):
# clear all non-permanent memory
for bucket_id in self.buckets:
self.sieve_by_range(bucket_id, 0, 0, 0)
def get_v_size(self, obj_id: int) -> int:
return self.v[obj_id].shape[-1]
def size(self, bucket_id: int) -> int:
if bucket_id not in self.k:
return 0
else:
return self.k[bucket_id].shape[-1]
def perm_size(self, bucket_id: int) -> int:
return self.perm_end_pt[bucket_id]
def non_perm_size(self, bucket_id: int) -> int:
return self.size(bucket_id) - self.perm_size(bucket_id)
def engaged(self, bucket_id: Optional[int] = None) -> bool:
if bucket_id is None:
return len(self.buckets) > 0
else:
return bucket_id in self.buckets
@property
def num_objects(self) -> int:
return len(self.v)
@property
def key(self) -> Dict[int, torch.Tensor]:
return self.k
@property
def value(self) -> Dict[int, torch.Tensor]:
return self.v
@property
def shrinkage(self) -> Dict[int, torch.Tensor]:
return self.s
@property
def selection(self) -> Dict[int, torch.Tensor]:
return self.e
def __contains__(self, key):
return key in self.v

View File

@ -0,0 +1,453 @@
import logging
from omegaconf import DictConfig
from typing import List, Dict
import torch
from .object_manager import ObjectManager
from .kv_memory_store import KeyValueMemoryStore
from ..model.matanyone import MatAnyone
from ..model.utils.memory_utils import get_similarity, do_softmax
log = logging.getLogger()
class MemoryManager:
"""
Manages all three memory stores and the transition between working/long-term memory
"""
def __init__(self, cfg: DictConfig, object_manager: ObjectManager):
self.object_manager = object_manager
self.sensory_dim = cfg.model.sensory_dim
self.top_k = cfg.top_k
self.chunk_size = cfg.chunk_size
self.save_aux = cfg.save_aux
self.use_long_term = cfg.use_long_term
self.count_long_term_usage = cfg.long_term.count_usage
# subtract 1 because the first-frame is now counted as "permanent memory"
# and is not counted towards max_mem_frames
# but we want to keep the hyperparameters consistent as before for the same behavior
if self.use_long_term:
self.max_mem_frames = cfg.long_term.max_mem_frames - 1
self.min_mem_frames = cfg.long_term.min_mem_frames - 1
self.num_prototypes = cfg.long_term.num_prototypes
self.max_long_tokens = cfg.long_term.max_num_tokens
self.buffer_tokens = cfg.long_term.buffer_tokens
else:
self.max_mem_frames = cfg.max_mem_frames - 1
# dimensions will be inferred from input later
self.CK = self.CV = None
self.H = self.W = None
# The sensory memory is stored as a dictionary indexed by object ids
# each of shape bs * C^h * H * W
self.sensory = {}
# a dictionary indexed by object ids, each of shape bs * T * Q * C
self.obj_v = {}
self.work_mem = KeyValueMemoryStore(save_selection=self.use_long_term,
save_usage=self.use_long_term)
if self.use_long_term:
self.long_mem = KeyValueMemoryStore(save_usage=self.count_long_term_usage)
self.config_stale = True
self.engaged = False
def update_config(self, cfg: DictConfig) -> None:
self.config_stale = True
self.top_k = cfg['top_k']
assert self.use_long_term == cfg.use_long_term, 'cannot update this'
assert self.count_long_term_usage == cfg.long_term.count_usage, 'cannot update this'
self.use_long_term = cfg.use_long_term
self.count_long_term_usage = cfg.long_term.count_usage
if self.use_long_term:
self.max_mem_frames = cfg.long_term.max_mem_frames - 1
self.min_mem_frames = cfg.long_term.min_mem_frames - 1
self.num_prototypes = cfg.long_term.num_prototypes
self.max_long_tokens = cfg.long_term.max_num_tokens
self.buffer_tokens = cfg.long_term.buffer_tokens
else:
self.max_mem_frames = cfg.max_mem_frames - 1
def _readout(self, affinity, v, uncert_mask=None) -> torch.Tensor:
# affinity: bs*N*HW
# v: bs*C*N or bs*num_objects*C*N
# returns bs*C*HW or bs*num_objects*C*HW
if len(v.shape) == 3:
# single object
if uncert_mask is not None:
return v @ affinity * uncert_mask
else:
return v @ affinity
else:
bs, num_objects, C, N = v.shape
v = v.view(bs, num_objects * C, N)
out = v @ affinity
if uncert_mask is not None:
uncert_mask = uncert_mask.flatten(start_dim=2).expand(-1, C, -1)
out = out * uncert_mask
return out.view(bs, num_objects, C, -1)
def _get_mask_by_ids(self, mask: torch.Tensor, obj_ids: List[int]) -> torch.Tensor:
# -1 because the mask does not contain the background channel
return mask[:, [self.object_manager.find_tmp_by_id(obj) - 1 for obj in obj_ids]]
def _get_sensory_by_ids(self, obj_ids: List[int]) -> torch.Tensor:
return torch.stack([self.sensory[obj] for obj in obj_ids], dim=1)
def _get_object_mem_by_ids(self, obj_ids: List[int]) -> torch.Tensor:
return torch.stack([self.obj_v[obj] for obj in obj_ids], dim=1)
def _get_visual_values_by_ids(self, obj_ids: List[int]) -> torch.Tensor:
# All the values that the object ids refer to should have the same shape
value = torch.stack([self.work_mem.value[obj] for obj in obj_ids], dim=1)
if self.use_long_term and obj_ids[0] in self.long_mem.value:
lt_value = torch.stack([self.long_mem.value[obj] for obj in obj_ids], dim=1)
value = torch.cat([lt_value, value], dim=-1)
return value
def read_first_frame(self, last_msk_value, pix_feat: torch.Tensor,
last_mask: torch.Tensor, network: MatAnyone, uncert_output=None) -> Dict[int, torch.Tensor]:
"""
Read from all memory stores and returns a single memory readout tensor for each object
pix_feat: (1/2) x C x H x W
query_key: (1/2) x C^k x H x W
selection: (1/2) x C^k x H x W
last_mask: (1/2) x num_objects x H x W (at stride 16)
return a dict of memory readouts, indexed by object indices. Each readout is C*H*W
"""
h, w = pix_feat.shape[-2:]
bs = pix_feat.shape[0]
assert last_mask.shape[0] == bs
"""
Compute affinity and perform readout
"""
all_readout_mem = {}
buckets = self.work_mem.buckets
for bucket_id, bucket in buckets.items():
if self.chunk_size < 1:
object_chunks = [bucket]
else:
object_chunks = [
bucket[i:i + self.chunk_size] for i in range(0, len(bucket), self.chunk_size)
]
for objects in object_chunks:
this_sensory = self._get_sensory_by_ids(objects)
this_last_mask = self._get_mask_by_ids(last_mask, objects)
this_msk_value = self._get_visual_values_by_ids(objects) # (1/2)*num_objects*C*N
pixel_readout = network.pixel_fusion(pix_feat, last_msk_value, this_sensory,
this_last_mask)
this_obj_mem = self._get_object_mem_by_ids(objects).unsqueeze(2)
readout_memory, aux_features = network.readout_query(pixel_readout, this_obj_mem)
for i, obj in enumerate(objects):
all_readout_mem[obj] = readout_memory[:, i]
if self.save_aux:
aux_output = {
# 'sensory': this_sensory,
# 'pixel_readout': pixel_readout,
'q_logits': aux_features['logits'] if aux_features else None,
# 'q_weights': aux_features['q_weights'] if aux_features else None,
# 'p_weights': aux_features['p_weights'] if aux_features else None,
# 'attn_mask': aux_features['attn_mask'].float() if aux_features else None,
}
self.aux = aux_output
return all_readout_mem
def read(self, pix_feat: torch.Tensor, query_key: torch.Tensor, selection: torch.Tensor,
last_mask: torch.Tensor, network: MatAnyone, uncert_output=None, last_msk_value=None, ti=None,
last_pix_feat=None, last_pred_mask=None) -> Dict[int, torch.Tensor]:
"""
Read from all memory stores and returns a single memory readout tensor for each object
pix_feat: (1/2) x C x H x W
query_key: (1/2) x C^k x H x W
selection: (1/2) x C^k x H x W
last_mask: (1/2) x num_objects x H x W (at stride 16)
return a dict of memory readouts, indexed by object indices. Each readout is C*H*W
"""
h, w = pix_feat.shape[-2:]
bs = pix_feat.shape[0]
assert query_key.shape[0] == bs
assert selection.shape[0] == bs
assert last_mask.shape[0] == bs
uncert_mask = uncert_output["mask"] if uncert_output is not None else None
query_key = query_key.flatten(start_dim=2) # bs*C^k*HW
selection = selection.flatten(start_dim=2) # bs*C^k*HW
"""
Compute affinity and perform readout
"""
all_readout_mem = {}
buckets = self.work_mem.buckets
for bucket_id, bucket in buckets.items():
if self.use_long_term and self.long_mem.engaged(bucket_id):
# Use long-term memory
long_mem_size = self.long_mem.size(bucket_id)
memory_key = torch.cat([self.long_mem.key[bucket_id], self.work_mem.key[bucket_id]],
-1)
shrinkage = torch.cat(
[self.long_mem.shrinkage[bucket_id], self.work_mem.shrinkage[bucket_id]], -1)
similarity = get_similarity(memory_key, shrinkage, query_key, selection)
affinity, usage = do_softmax(similarity,
top_k=self.top_k,
inplace=True,
return_usage=True)
"""
Record memory usage for working and long-term memory
"""
# ignore the index return for long-term memory
work_usage = usage[:, long_mem_size:]
self.work_mem.update_bucket_usage(bucket_id, work_usage)
if self.count_long_term_usage:
# ignore the index return for working memory
long_usage = usage[:, :long_mem_size]
self.long_mem.update_bucket_usage(bucket_id, long_usage)
else:
# no long-term memory
memory_key = self.work_mem.key[bucket_id]
shrinkage = self.work_mem.shrinkage[bucket_id]
similarity = get_similarity(memory_key, shrinkage, query_key, selection, uncert_mask=uncert_mask)
if self.use_long_term:
affinity, usage = do_softmax(similarity,
top_k=self.top_k,
inplace=True,
return_usage=True)
self.work_mem.update_bucket_usage(bucket_id, usage)
else:
affinity = do_softmax(similarity, top_k=self.top_k, inplace=True)
if self.chunk_size < 1:
object_chunks = [bucket]
else:
object_chunks = [
bucket[i:i + self.chunk_size] for i in range(0, len(bucket), self.chunk_size)
]
for objects in object_chunks:
this_sensory = self._get_sensory_by_ids(objects)
this_last_mask = self._get_mask_by_ids(last_mask, objects)
this_msk_value = self._get_visual_values_by_ids(objects) # (1/2)*num_objects*C*N
visual_readout = self._readout(affinity,
this_msk_value, uncert_mask).view(bs, len(objects), self.CV, h, w)
uncert_output = network.pred_uncertainty(last_pix_feat, pix_feat, last_pred_mask, visual_readout[:,0]-last_msk_value[:,0])
if uncert_output is not None:
uncert_prob = uncert_output["prob"].unsqueeze(1) # b n 1 h w
visual_readout = visual_readout*uncert_prob + last_msk_value*(1-uncert_prob)
pixel_readout = network.pixel_fusion(pix_feat, visual_readout, this_sensory,
this_last_mask)
this_obj_mem = self._get_object_mem_by_ids(objects).unsqueeze(2)
readout_memory, aux_features = network.readout_query(pixel_readout, this_obj_mem)
for i, obj in enumerate(objects):
all_readout_mem[obj] = readout_memory[:, i]
if self.save_aux:
aux_output = {
# 'sensory': this_sensory,
# 'pixel_readout': pixel_readout,
'q_logits': aux_features['logits'] if aux_features else None,
# 'q_weights': aux_features['q_weights'] if aux_features else None,
# 'p_weights': aux_features['p_weights'] if aux_features else None,
# 'attn_mask': aux_features['attn_mask'].float() if aux_features else None,
}
self.aux = aux_output
return all_readout_mem
def add_memory(self,
key: torch.Tensor,
shrinkage: torch.Tensor,
msk_value: torch.Tensor,
obj_value: torch.Tensor,
objects: List[int],
selection: torch.Tensor = None,
*,
as_permanent: bool = False) -> None:
# key: (1/2)*C*H*W
# msk_value: (1/2)*num_objects*C*H*W
# obj_value: (1/2)*num_objects*Q*C
# objects contains a list of object ids corresponding to the objects in msk_value/obj_value
bs = key.shape[0]
assert shrinkage.shape[0] == bs
assert msk_value.shape[0] == bs
assert obj_value.shape[0] == bs
self.engaged = True
if self.H is None or self.config_stale:
self.config_stale = False
self.H, self.W = msk_value.shape[-2:]
self.HW = self.H * self.W
# convert from num. frames to num. tokens
self.max_work_tokens = self.max_mem_frames * self.HW
if self.use_long_term:
self.min_work_tokens = self.min_mem_frames * self.HW
# key: bs*C*N
# value: bs*num_objects*C*N
key = key.flatten(start_dim=2)
shrinkage = shrinkage.flatten(start_dim=2)
self.CK = key.shape[1]
msk_value = msk_value.flatten(start_dim=3)
self.CV = msk_value.shape[2]
if selection is not None:
# not used in non-long-term mode
selection = selection.flatten(start_dim=2)
# insert object values into object memory
for obj_id, obj in enumerate(objects):
if obj in self.obj_v:
"""streaming average
each self.obj_v[obj] is (1/2)*num_summaries*(embed_dim+1)
first embed_dim keeps track of the sum of embeddings
the last dim keeps the total count
averaging in done inside the object transformer
incoming obj_value is (1/2)*num_objects*num_summaries*(embed_dim+1)
self.obj_v[obj] = torch.cat([self.obj_v[obj], obj_value[:, obj_id]], dim=0)
"""
last_acc = self.obj_v[obj][:, :, -1]
new_acc = last_acc + obj_value[:, obj_id, :, -1]
self.obj_v[obj][:, :, :-1] = (self.obj_v[obj][:, :, :-1] +
obj_value[:, obj_id, :, :-1])
self.obj_v[obj][:, :, -1] = new_acc
else:
self.obj_v[obj] = obj_value[:, obj_id]
# convert mask value tensor into a dict for insertion
msk_values = {obj: msk_value[:, obj_id] for obj_id, obj in enumerate(objects)}
self.work_mem.add(key,
msk_values,
shrinkage,
selection=selection,
as_permanent=as_permanent)
for bucket_id in self.work_mem.buckets.keys():
# long-term memory cleanup
if self.use_long_term:
# Do memory compressed if needed
if self.work_mem.non_perm_size(bucket_id) >= self.max_work_tokens:
# Remove obsolete features if needed
if self.long_mem.non_perm_size(bucket_id) >= (self.max_long_tokens -
self.num_prototypes):
self.long_mem.remove_obsolete_features(
bucket_id,
self.max_long_tokens - self.num_prototypes - self.buffer_tokens)
self.compress_features(bucket_id)
else:
# FIFO
self.work_mem.remove_old_memory(bucket_id, self.max_work_tokens)
def purge_except(self, obj_keep_idx: List[int]) -> None:
# purge certain objects from the memory except the one listed
self.work_mem.purge_except(obj_keep_idx)
if self.use_long_term and self.long_mem.engaged():
self.long_mem.purge_except(obj_keep_idx)
self.sensory = {k: v for k, v in self.sensory.items() if k in obj_keep_idx}
if not self.work_mem.engaged():
# everything is removed!
self.engaged = False
def compress_features(self, bucket_id: int) -> None:
# perform memory consolidation
prototype_key, prototype_value, prototype_shrinkage = self.consolidation(
*self.work_mem.get_all_sliced(bucket_id, 0, -self.min_work_tokens))
# remove consolidated working memory
self.work_mem.sieve_by_range(bucket_id,
0,
-self.min_work_tokens,
min_size=self.min_work_tokens)
# add to long-term memory
self.long_mem.add(prototype_key,
prototype_value,
prototype_shrinkage,
selection=None,
supposed_bucket_id=bucket_id)
def consolidation(self, candidate_key: torch.Tensor, candidate_shrinkage: torch.Tensor,
candidate_selection: torch.Tensor, candidate_value: Dict[int, torch.Tensor],
usage: torch.Tensor) -> (torch.Tensor, Dict[int, torch.Tensor], torch.Tensor):
# find the indices with max usage
bs = candidate_key.shape[0]
assert bs in [1, 2]
prototype_key = []
prototype_selection = []
for bi in range(bs):
_, max_usage_indices = torch.topk(usage[bi], k=self.num_prototypes, dim=-1, sorted=True)
prototype_indices = max_usage_indices.flatten()
prototype_key.append(candidate_key[bi, :, prototype_indices])
prototype_selection.append(candidate_selection[bi, :, prototype_indices])
prototype_key = torch.stack(prototype_key, dim=0)
prototype_selection = torch.stack(prototype_selection, dim=0)
"""
Potentiation step
"""
similarity = get_similarity(candidate_key, candidate_shrinkage, prototype_key,
prototype_selection)
affinity = do_softmax(similarity)
# readout the values
prototype_value = {k: self._readout(affinity, v) for k, v in candidate_value.items()}
# readout the shrinkage term
prototype_shrinkage = self._readout(affinity, candidate_shrinkage)
return prototype_key, prototype_value, prototype_shrinkage
def initialize_sensory_if_needed(self, sample_key: torch.Tensor, ids: List[int]):
for obj in ids:
if obj not in self.sensory:
# also initializes the sensory memory
bs, _, h, w = sample_key.shape
self.sensory[obj] = torch.zeros((bs, self.sensory_dim, h, w),
device=sample_key.device)
def update_sensory(self, sensory: torch.Tensor, ids: List[int]):
# sensory: 1*num_objects*C*H*W
for obj_id, obj in enumerate(ids):
self.sensory[obj] = sensory[:, obj_id]
def get_sensory(self, ids: List[int]):
# returns (1/2)*num_objects*C*H*W
return self._get_sensory_by_ids(ids)
def clear_non_permanent_memory(self):
self.work_mem.clear_non_permanent_memory()
if self.use_long_term:
self.long_mem.clear_non_permanent_memory()
def clear_sensory_memory(self):
self.sensory = {}
def clear_work_mem(self):
self.work_mem = KeyValueMemoryStore(save_selection=self.use_long_term,
save_usage=self.use_long_term)
def clear_obj_mem(self):
self.obj_v = {}

View File

@ -0,0 +1,24 @@
class ObjectInfo:
"""
Store meta information for an object
"""
def __init__(self, id: int):
self.id = id
self.poke_count = 0 # count number of detections missed
def poke(self) -> None:
self.poke_count += 1
def unpoke(self) -> None:
self.poke_count = 0
def __hash__(self):
return hash(self.id)
def __eq__(self, other):
if type(other) == int:
return self.id == other
return self.id == other.id
def __repr__(self):
return f'(ID: {self.id})'

View File

@ -0,0 +1,149 @@
from typing import Union, List, Dict
import torch
from .object_info import ObjectInfo
class ObjectManager:
"""
Object IDs are immutable. The same ID always represent the same object.
Temporary IDs are the positions of each object in the tensor. It changes as objects get removed.
Temporary IDs start from 1.
"""
def __init__(self):
self.obj_to_tmp_id: Dict[ObjectInfo, int] = {}
self.tmp_id_to_obj: Dict[int, ObjectInfo] = {}
self.obj_id_to_obj: Dict[int, ObjectInfo] = {}
self.all_historical_object_ids: List[int] = []
def _recompute_obj_id_to_obj_mapping(self) -> None:
self.obj_id_to_obj = {obj.id: obj for obj in self.obj_to_tmp_id}
def add_new_objects(
self, objects: Union[List[ObjectInfo], ObjectInfo,
List[int]]) -> (List[int], List[int]):
if not isinstance(objects, list):
objects = [objects]
corresponding_tmp_ids = []
corresponding_obj_ids = []
for obj in objects:
if isinstance(obj, int):
obj = ObjectInfo(id=obj)
if obj in self.obj_to_tmp_id:
# old object
corresponding_tmp_ids.append(self.obj_to_tmp_id[obj])
corresponding_obj_ids.append(obj.id)
else:
# new object
new_obj = ObjectInfo(id=obj.id)
# new object
new_tmp_id = len(self.obj_to_tmp_id) + 1
self.obj_to_tmp_id[new_obj] = new_tmp_id
self.tmp_id_to_obj[new_tmp_id] = new_obj
self.all_historical_object_ids.append(new_obj.id)
corresponding_tmp_ids.append(new_tmp_id)
corresponding_obj_ids.append(new_obj.id)
self._recompute_obj_id_to_obj_mapping()
assert corresponding_tmp_ids == sorted(corresponding_tmp_ids)
return corresponding_tmp_ids, corresponding_obj_ids
def delete_objects(self, obj_ids_to_remove: Union[int, List[int]]) -> None:
# delete an object or a list of objects
# re-sort the tmp ids
if isinstance(obj_ids_to_remove, int):
obj_ids_to_remove = [obj_ids_to_remove]
new_tmp_id = 1
total_num_id = len(self.obj_to_tmp_id)
local_obj_to_tmp_id = {}
local_tmp_to_obj_id = {}
for tmp_iter in range(1, total_num_id + 1):
obj = self.tmp_id_to_obj[tmp_iter]
if obj.id not in obj_ids_to_remove:
local_obj_to_tmp_id[obj] = new_tmp_id
local_tmp_to_obj_id[new_tmp_id] = obj
new_tmp_id += 1
self.obj_to_tmp_id = local_obj_to_tmp_id
self.tmp_id_to_obj = local_tmp_to_obj_id
self._recompute_obj_id_to_obj_mapping()
def purge_inactive_objects(self,
max_missed_detection_count: int) -> (bool, List[int], List[int]):
# remove tmp ids of objects that are removed
obj_id_to_be_deleted = []
tmp_id_to_be_deleted = []
tmp_id_to_keep = []
obj_id_to_keep = []
for obj in self.obj_to_tmp_id:
if obj.poke_count > max_missed_detection_count:
obj_id_to_be_deleted.append(obj.id)
tmp_id_to_be_deleted.append(self.obj_to_tmp_id[obj])
else:
tmp_id_to_keep.append(self.obj_to_tmp_id[obj])
obj_id_to_keep.append(obj.id)
purge_activated = len(obj_id_to_be_deleted) > 0
if purge_activated:
self.delete_objects(obj_id_to_be_deleted)
return purge_activated, tmp_id_to_keep, obj_id_to_keep
def tmp_to_obj_cls(self, mask) -> torch.Tensor:
# remap tmp id cls representation to the true object id representation
new_mask = torch.zeros_like(mask)
for tmp_id, obj in self.tmp_id_to_obj.items():
new_mask[mask == tmp_id] = obj.id
return new_mask
def get_tmp_to_obj_mapping(self) -> Dict[int, ObjectInfo]:
# returns the mapping in a dict format for saving it with pickle
return {obj.id: tmp_id for obj, tmp_id in self.tmp_id_to_obj.items()}
def realize_dict(self, obj_dict, dim=1) -> torch.Tensor:
# turns a dict indexed by obj id into a tensor, ordered by tmp IDs
output = []
for _, obj in self.tmp_id_to_obj.items():
if obj.id not in obj_dict:
raise NotImplementedError
output.append(obj_dict[obj.id])
output = torch.stack(output, dim=dim)
return output
def make_one_hot(self, cls_mask) -> torch.Tensor:
output = []
for _, obj in self.tmp_id_to_obj.items():
output.append(cls_mask == obj.id)
if len(output) == 0:
output = torch.zeros((0, *cls_mask.shape), dtype=torch.bool, device=cls_mask.device)
else:
output = torch.stack(output, dim=0)
return output
@property
def all_obj_ids(self) -> List[int]:
return [k.id for k in self.obj_to_tmp_id]
@property
def num_obj(self) -> int:
return len(self.obj_to_tmp_id)
def has_all(self, objects: List[int]) -> bool:
for obj in objects:
if obj not in self.obj_to_tmp_id:
return False
return True
def find_object_by_id(self, obj_id) -> ObjectInfo:
return self.obj_id_to_obj[obj_id]
def find_tmp_by_id(self, obj_id) -> int:
return self.obj_to_tmp_id[self.obj_id_to_obj[obj_id]]

View File

@ -0,0 +1,30 @@
import logging
from omegaconf import DictConfig
log = logging.getLogger()
def get_dataset_cfg(cfg: DictConfig):
dataset_name = cfg.dataset
data_cfg = cfg.datasets[dataset_name]
potential_overrides = [
'image_directory',
'mask_directory',
'json_directory',
'size',
'save_all',
'use_all_masks',
'use_long_term',
'mem_every',
]
for override in potential_overrides:
if cfg[override] is not None:
log.info(f'Overriding config {override} from {data_cfg[override]} to {cfg[override]}')
data_cfg[override] = cfg[override]
# escalte all potential overrides to the top-level config
if override in data_cfg:
cfg[override] = data_cfg[override]
return data_cfg

View File

@ -0,0 +1,93 @@
"""
For computing auxiliary outputs for auxiliary losses
"""
from typing import Dict
from omegaconf import DictConfig
import torch
import torch.nn as nn
from .group_modules import GConv2d
from ...utils.tensor_utils import aggregate
class LinearPredictor(nn.Module):
def __init__(self, x_dim: int, pix_dim: int):
super().__init__()
self.projection = GConv2d(x_dim, pix_dim + 1, kernel_size=1)
def forward(self, pix_feat: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
# pixel_feat: B*pix_dim*H*W
# x: B*num_objects*x_dim*H*W
num_objects = x.shape[1]
x = self.projection(x)
pix_feat = pix_feat.unsqueeze(1).expand(-1, num_objects, -1, -1, -1)
logits = (pix_feat * x[:, :, :-1]).sum(dim=2) + x[:, :, -1]
return logits
class DirectPredictor(nn.Module):
def __init__(self, x_dim: int):
super().__init__()
self.projection = GConv2d(x_dim, 1, kernel_size=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: B*num_objects*x_dim*H*W
logits = self.projection(x).squeeze(2)
return logits
class AuxComputer(nn.Module):
def __init__(self, cfg: DictConfig):
super().__init__()
use_sensory_aux = cfg.model.aux_loss.sensory.enabled
self.use_query_aux = cfg.model.aux_loss.query.enabled
self.use_sensory_aux = use_sensory_aux
sensory_dim = cfg.model.sensory_dim
embed_dim = cfg.model.embed_dim
if use_sensory_aux:
self.sensory_aux = LinearPredictor(sensory_dim, embed_dim)
def _aggregate_with_selector(self, logits: torch.Tensor, selector: torch.Tensor) -> torch.Tensor:
prob = torch.sigmoid(logits)
if selector is not None:
prob = prob * selector
logits = aggregate(prob, dim=1)
return logits
def forward(self, pix_feat: torch.Tensor, aux_input: Dict[str, torch.Tensor],
selector: torch.Tensor, seg_pass=False) -> Dict[str, torch.Tensor]:
sensory = aux_input['sensory']
q_logits = aux_input['q_logits']
aux_output = {}
aux_output['attn_mask'] = aux_input['attn_mask']
if self.use_sensory_aux:
# B*num_objects*H*W
logits = self.sensory_aux(pix_feat, sensory)
aux_output['sensory_logits'] = self._aggregate_with_selector(logits, selector)
if self.use_query_aux:
# B*num_objects*num_levels*H*W
aux_output['q_logits'] = self._aggregate_with_selector(
torch.stack(q_logits, dim=2),
selector.unsqueeze(2) if selector is not None else None)
return aux_output
def compute_mask(self, aux_input: Dict[str, torch.Tensor],
selector: torch.Tensor) -> Dict[str, torch.Tensor]:
# sensory = aux_input['sensory']
q_logits = aux_input['q_logits']
aux_output = {}
# B*num_objects*num_levels*H*W
aux_output['q_logits'] = self._aggregate_with_selector(
torch.stack(q_logits, dim=2),
selector.unsqueeze(2) if selector is not None else None)
return aux_output

View File

@ -0,0 +1,365 @@
"""
big_modules.py - This file stores higher-level network blocks.
x - usually denotes features that are shared between objects.
g - usually denotes features that are not shared between objects
with an extra "num_objects" dimension (batch_size * num_objects * num_channels * H * W).
The trailing number of a variable usually denotes the stride
"""
from typing import Iterable
from omegaconf import DictConfig
import torch
import torch.nn as nn
import torch.nn.functional as F
from .group_modules import MainToGroupDistributor, GroupFeatureFusionBlock, GConv2d
from .utils import resnet
from .modules import SensoryDeepUpdater, SensoryUpdater_fullscale, DecoderFeatureProcessor, MaskUpsampleBlock
class UncertPred(nn.Module):
def __init__(self, model_cfg: DictConfig):
super().__init__()
self.conv1x1_v2 = nn.Conv2d(model_cfg.pixel_dim*2 + 1 + model_cfg.value_dim, 64, kernel_size=1, stride=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.conv3x3 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1, groups=1, bias=False, dilation=1)
self.bn2 = nn.BatchNorm2d(32)
self.conv3x3_out = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1, groups=1, bias=False, dilation=1)
def forward(self, last_frame_feat: torch.Tensor, cur_frame_feat: torch.Tensor, last_mask: torch.Tensor, mem_val_diff:torch.Tensor):
last_mask = F.interpolate(last_mask, size=last_frame_feat.shape[-2:], mode='area')
x = torch.cat([last_frame_feat, cur_frame_feat, last_mask, mem_val_diff], dim=1)
x = self.conv1x1_v2(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv3x3(x)
x = self.bn2(x)
x = self.relu(x)
x = self.conv3x3_out(x)
return x
# override the default train() to freeze BN statistics
def train(self, mode=True):
self.training = False
for module in self.children():
module.train(False)
return self
class PixelEncoder(nn.Module):
def __init__(self, model_cfg: DictConfig):
super().__init__()
self.is_resnet = 'resnet' in model_cfg.pixel_encoder.type
# if model_cfg.pretrained_resnet is set in the model_cfg we get the value
# else default to True
is_pretrained_resnet = getattr(model_cfg,"pretrained_resnet",True)
if self.is_resnet:
if model_cfg.pixel_encoder.type == 'resnet18':
network = resnet.resnet18(pretrained=is_pretrained_resnet)
elif model_cfg.pixel_encoder.type == 'resnet50':
network = resnet.resnet50(pretrained=is_pretrained_resnet)
else:
raise NotImplementedError
self.conv1 = network.conv1
self.bn1 = network.bn1
self.relu = network.relu
self.maxpool = network.maxpool
self.res2 = network.layer1
self.layer2 = network.layer2
self.layer3 = network.layer3
else:
raise NotImplementedError
def forward(self, x: torch.Tensor, seq_length=None) -> (torch.Tensor, torch.Tensor, torch.Tensor):
f1 = x
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
f2 = x
x = self.maxpool(x)
f4 = self.res2(x)
f8 = self.layer2(f4)
f16 = self.layer3(f8)
return f16, f8, f4, f2, f1
# override the default train() to freeze BN statistics
def train(self, mode=True):
self.training = False
for module in self.children():
module.train(False)
return self
class KeyProjection(nn.Module):
def __init__(self, model_cfg: DictConfig):
super().__init__()
in_dim = model_cfg.pixel_encoder.ms_dims[0]
mid_dim = model_cfg.pixel_dim
key_dim = model_cfg.key_dim
self.pix_feat_proj = nn.Conv2d(in_dim, mid_dim, kernel_size=1)
self.key_proj = nn.Conv2d(mid_dim, key_dim, kernel_size=3, padding=1)
# shrinkage
self.d_proj = nn.Conv2d(mid_dim, 1, kernel_size=3, padding=1)
# selection
self.e_proj = nn.Conv2d(mid_dim, key_dim, kernel_size=3, padding=1)
nn.init.orthogonal_(self.key_proj.weight.data)
nn.init.zeros_(self.key_proj.bias.data)
def forward(self, x: torch.Tensor, *, need_s: bool,
need_e: bool) -> (torch.Tensor, torch.Tensor, torch.Tensor):
x = self.pix_feat_proj(x)
shrinkage = self.d_proj(x)**2 + 1 if (need_s) else None
selection = torch.sigmoid(self.e_proj(x)) if (need_e) else None
return self.key_proj(x), shrinkage, selection
class MaskEncoder(nn.Module):
def __init__(self, model_cfg: DictConfig, single_object=False):
super().__init__()
pixel_dim = model_cfg.pixel_dim
value_dim = model_cfg.value_dim
sensory_dim = model_cfg.sensory_dim
final_dim = model_cfg.mask_encoder.final_dim
self.single_object = single_object
extra_dim = 1 if single_object else 2
# if model_cfg.pretrained_resnet is set in the model_cfg we get the value
# else default to True
is_pretrained_resnet = getattr(model_cfg,"pretrained_resnet",True)
if model_cfg.mask_encoder.type == 'resnet18':
network = resnet.resnet18(pretrained=is_pretrained_resnet, extra_dim=extra_dim)
elif model_cfg.mask_encoder.type == 'resnet50':
network = resnet.resnet50(pretrained=is_pretrained_resnet, extra_dim=extra_dim)
else:
raise NotImplementedError
self.conv1 = network.conv1
self.bn1 = network.bn1
self.relu = network.relu
self.maxpool = network.maxpool
self.layer1 = network.layer1
self.layer2 = network.layer2
self.layer3 = network.layer3
self.distributor = MainToGroupDistributor()
self.fuser = GroupFeatureFusionBlock(pixel_dim, final_dim, value_dim)
self.sensory_update = SensoryDeepUpdater(value_dim, sensory_dim)
def forward(self,
image: torch.Tensor,
pix_feat: torch.Tensor,
sensory: torch.Tensor,
masks: torch.Tensor,
others: torch.Tensor,
*,
deep_update: bool = True,
chunk_size: int = -1) -> (torch.Tensor, torch.Tensor):
# ms_features are from the key encoder
# we only use the first one (lowest resolution), following XMem
if self.single_object:
g = masks.unsqueeze(2)
else:
g = torch.stack([masks, others], dim=2)
g = self.distributor(image, g)
batch_size, num_objects = g.shape[:2]
if chunk_size < 1 or chunk_size >= num_objects:
chunk_size = num_objects
fast_path = True
new_sensory = sensory
else:
if deep_update:
new_sensory = torch.empty_like(sensory)
else:
new_sensory = sensory
fast_path = False
# chunk-by-chunk inference
all_g = []
for i in range(0, num_objects, chunk_size):
if fast_path:
g_chunk = g
else:
g_chunk = g[:, i:i + chunk_size]
actual_chunk_size = g_chunk.shape[1]
g_chunk = g_chunk.flatten(start_dim=0, end_dim=1)
g_chunk = self.conv1(g_chunk)
g_chunk = self.bn1(g_chunk) # 1/2, 64
g_chunk = self.maxpool(g_chunk) # 1/4, 64
g_chunk = self.relu(g_chunk)
g_chunk = self.layer1(g_chunk) # 1/4
g_chunk = self.layer2(g_chunk) # 1/8
g_chunk = self.layer3(g_chunk) # 1/16
g_chunk = g_chunk.view(batch_size, actual_chunk_size, *g_chunk.shape[1:])
g_chunk = self.fuser(pix_feat, g_chunk)
all_g.append(g_chunk)
if deep_update:
if fast_path:
new_sensory = self.sensory_update(g_chunk, sensory)
else:
new_sensory[:, i:i + chunk_size] = self.sensory_update(
g_chunk, sensory[:, i:i + chunk_size])
g = torch.cat(all_g, dim=1)
return g, new_sensory
# override the default train() to freeze BN statistics
def train(self, mode=True):
self.training = False
for module in self.children():
module.train(False)
return self
class PixelFeatureFuser(nn.Module):
def __init__(self, model_cfg: DictConfig, single_object=False):
super().__init__()
value_dim = model_cfg.value_dim
sensory_dim = model_cfg.sensory_dim
pixel_dim = model_cfg.pixel_dim
embed_dim = model_cfg.embed_dim
self.single_object = single_object
self.fuser = GroupFeatureFusionBlock(pixel_dim, value_dim, embed_dim)
if self.single_object:
self.sensory_compress = GConv2d(sensory_dim + 1, value_dim, kernel_size=1)
else:
self.sensory_compress = GConv2d(sensory_dim + 2, value_dim, kernel_size=1)
def forward(self,
pix_feat: torch.Tensor,
pixel_memory: torch.Tensor,
sensory_memory: torch.Tensor,
last_mask: torch.Tensor,
last_others: torch.Tensor,
*,
chunk_size: int = -1) -> torch.Tensor:
batch_size, num_objects = pixel_memory.shape[:2]
if self.single_object:
last_mask = last_mask.unsqueeze(2)
else:
last_mask = torch.stack([last_mask, last_others], dim=2)
if chunk_size < 1:
chunk_size = num_objects
# chunk-by-chunk inference
all_p16 = []
for i in range(0, num_objects, chunk_size):
sensory_readout = self.sensory_compress(
torch.cat([sensory_memory[:, i:i + chunk_size], last_mask[:, i:i + chunk_size]], 2))
p16 = pixel_memory[:, i:i + chunk_size] + sensory_readout
p16 = self.fuser(pix_feat, p16)
all_p16.append(p16)
p16 = torch.cat(all_p16, dim=1)
return p16
class MaskDecoder(nn.Module):
def __init__(self, model_cfg: DictConfig):
super().__init__()
embed_dim = model_cfg.embed_dim
sensory_dim = model_cfg.sensory_dim
ms_image_dims = model_cfg.pixel_encoder.ms_dims
up_dims = model_cfg.mask_decoder.up_dims
assert embed_dim == up_dims[0]
self.sensory_update = SensoryUpdater_fullscale([up_dims[0], up_dims[1], up_dims[2], up_dims[3], up_dims[4] + 1], sensory_dim,
sensory_dim)
self.decoder_feat_proc = DecoderFeatureProcessor(ms_image_dims[1:], up_dims[:-1])
self.up_16_8 = MaskUpsampleBlock(up_dims[0], up_dims[1])
self.up_8_4 = MaskUpsampleBlock(up_dims[1], up_dims[2])
# newly add for alpha matte
self.up_4_2 = MaskUpsampleBlock(up_dims[2], up_dims[3])
self.up_2_1 = MaskUpsampleBlock(up_dims[3], up_dims[4])
self.pred_seg = nn.Conv2d(up_dims[-1], 1, kernel_size=3, padding=1)
self.pred_mat = nn.Conv2d(up_dims[-1], 1, kernel_size=3, padding=1)
def forward(self,
ms_image_feat: Iterable[torch.Tensor],
memory_readout: torch.Tensor,
sensory: torch.Tensor,
*,
chunk_size: int = -1,
update_sensory: bool = True,
seg_pass: bool = False,
last_mask=None,
sigmoid_residual=False) -> (torch.Tensor, torch.Tensor):
batch_size, num_objects = memory_readout.shape[:2]
f8, f4, f2, f1 = self.decoder_feat_proc(ms_image_feat[1:])
if chunk_size < 1 or chunk_size >= num_objects:
chunk_size = num_objects
fast_path = True
new_sensory = sensory
else:
if update_sensory:
new_sensory = torch.empty_like(sensory)
else:
new_sensory = sensory
fast_path = False
# chunk-by-chunk inference
all_logits = []
for i in range(0, num_objects, chunk_size):
if fast_path:
p16 = memory_readout
else:
p16 = memory_readout[:, i:i + chunk_size]
actual_chunk_size = p16.shape[1]
p8 = self.up_16_8(p16, f8)
p4 = self.up_8_4(p8, f4)
p2 = self.up_4_2(p4, f2)
p1 = self.up_2_1(p2, f1)
with torch.amp.autocast("cuda"):
if seg_pass:
if last_mask is not None:
res = self.pred_seg(F.relu(p1.flatten(start_dim=0, end_dim=1).float()))
if sigmoid_residual:
res = (torch.sigmoid(res) - 0.5) * 2 # regularization: (-1, 1) change on last mask
logits = last_mask + res
else:
logits = self.pred_seg(F.relu(p1.flatten(start_dim=0, end_dim=1).float()))
else:
if last_mask is not None:
res = self.pred_mat(F.relu(p1.flatten(start_dim=0, end_dim=1).float()))
if sigmoid_residual:
res = (torch.sigmoid(res) - 0.5) * 2 # regularization: (-1, 1) change on last mask
logits = last_mask + res
else:
logits = self.pred_mat(F.relu(p1.flatten(start_dim=0, end_dim=1).float()))
## SensoryUpdater_fullscale
if update_sensory:
p1 = torch.cat(
[p1, logits.view(batch_size, actual_chunk_size, 1, *logits.shape[-2:])], 2)
if fast_path:
new_sensory = self.sensory_update([p16, p8, p4, p2, p1], sensory)
else:
new_sensory[:,
i:i + chunk_size] = self.sensory_update([p16, p8, p4, p2, p1],
sensory[:,
i:i + chunk_size])
all_logits.append(logits)
logits = torch.cat(all_logits, dim=0)
logits = logits.view(batch_size, num_objects, *logits.shape[-2:])
return new_sensory, logits

View File

@ -0,0 +1,39 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class CAResBlock(nn.Module):
def __init__(self, in_dim: int, out_dim: int, residual: bool = True):
super().__init__()
self.residual = residual
self.conv1 = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1)
t = int((abs(math.log2(out_dim)) + 1) // 2)
k = t if t % 2 else t + 1
self.pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
if self.residual:
if in_dim == out_dim:
self.downsample = nn.Identity()
else:
self.downsample = nn.Conv2d(in_dim, out_dim, kernel_size=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
r = x
x = self.conv1(F.relu(x))
x = self.conv2(F.relu(x))
b, c = x.shape[:2]
w = self.pool(x).view(b, 1, c)
w = self.conv(w).transpose(-1, -2).unsqueeze(-1).sigmoid() # B*C*1*1
if self.residual:
x = x * w + self.downsample(r)
else:
x = x * w
return x

View File

@ -0,0 +1,126 @@
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from .channel_attn import CAResBlock
def interpolate_groups(g: torch.Tensor, ratio: float, mode: str,
align_corners: bool) -> torch.Tensor:
batch_size, num_objects = g.shape[:2]
g = F.interpolate(g.flatten(start_dim=0, end_dim=1),
scale_factor=ratio,
mode=mode,
align_corners=align_corners)
g = g.view(batch_size, num_objects, *g.shape[1:])
return g
def upsample_groups(g: torch.Tensor,
ratio: float = 2,
mode: str = 'bilinear',
align_corners: bool = False) -> torch.Tensor:
return interpolate_groups(g, ratio, mode, align_corners)
def downsample_groups(g: torch.Tensor,
ratio: float = 1 / 2,
mode: str = 'area',
align_corners: bool = None) -> torch.Tensor:
return interpolate_groups(g, ratio, mode, align_corners)
class GConv2d(nn.Conv2d):
def forward(self, g: torch.Tensor) -> torch.Tensor:
batch_size, num_objects = g.shape[:2]
g = super().forward(g.flatten(start_dim=0, end_dim=1))
return g.view(batch_size, num_objects, *g.shape[1:])
class GroupResBlock(nn.Module):
def __init__(self, in_dim: int, out_dim: int):
super().__init__()
if in_dim == out_dim:
self.downsample = nn.Identity()
else:
self.downsample = GConv2d(in_dim, out_dim, kernel_size=1)
self.conv1 = GConv2d(in_dim, out_dim, kernel_size=3, padding=1)
self.conv2 = GConv2d(out_dim, out_dim, kernel_size=3, padding=1)
def forward(self, g: torch.Tensor) -> torch.Tensor:
out_g = self.conv1(F.relu(g))
out_g = self.conv2(F.relu(out_g))
g = self.downsample(g)
return out_g + g
class MainToGroupDistributor(nn.Module):
def __init__(self,
x_transform: Optional[nn.Module] = None,
g_transform: Optional[nn.Module] = None,
method: str = 'cat',
reverse_order: bool = False):
super().__init__()
self.x_transform = x_transform
self.g_transform = g_transform
self.method = method
self.reverse_order = reverse_order
def forward(self, x: torch.Tensor, g: torch.Tensor, skip_expand: bool = False) -> torch.Tensor:
num_objects = g.shape[1]
if self.x_transform is not None:
x = self.x_transform(x)
if self.g_transform is not None:
g = self.g_transform(g)
if not skip_expand:
x = x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1)
if self.method == 'cat':
if self.reverse_order:
g = torch.cat([g, x], 2)
else:
g = torch.cat([x, g], 2)
elif self.method == 'add':
g = x + g
elif self.method == 'mulcat':
g = torch.cat([x * g, g], dim=2)
elif self.method == 'muladd':
g = x * g + g
else:
raise NotImplementedError
return g
class GroupFeatureFusionBlock(nn.Module):
def __init__(self, x_in_dim: int, g_in_dim: int, out_dim: int):
super().__init__()
x_transform = nn.Conv2d(x_in_dim, out_dim, kernel_size=1)
g_transform = GConv2d(g_in_dim, out_dim, kernel_size=1)
self.distributor = MainToGroupDistributor(x_transform=x_transform,
g_transform=g_transform,
method='add')
self.block1 = CAResBlock(out_dim, out_dim)
self.block2 = CAResBlock(out_dim, out_dim)
def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
batch_size, num_objects = g.shape[:2]
g = self.distributor(x, g)
g = g.flatten(start_dim=0, end_dim=1)
g = self.block1(g)
g = self.block2(g)
g = g.view(batch_size, num_objects, *g.shape[1:])
return g

View File

@ -0,0 +1,333 @@
from typing import List, Dict, Iterable
import logging
from omegaconf import DictConfig
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf
from huggingface_hub import PyTorchModelHubMixin
from .big_modules import PixelEncoder, UncertPred, KeyProjection, MaskEncoder, PixelFeatureFuser, MaskDecoder
from .aux_modules import AuxComputer
from .utils.memory_utils import get_affinity, readout
from .transformer.object_transformer import QueryTransformer
from .transformer.object_summarizer import ObjectSummarizer
from ...utils.tensor_utils import aggregate
log = logging.getLogger()
class MatAnyone(nn.Module,
PyTorchModelHubMixin,
library_name="matanyone",
repo_url="https://github.com/pq-yang/MatAnyone",
coders={
DictConfig: (
lambda x: OmegaConf.to_container(x),
lambda data: OmegaConf.create(data),
)
},
):
def __init__(self, cfg: DictConfig, *, single_object=False):
super().__init__()
self.cfg = cfg
model_cfg = cfg.model
self.ms_dims = model_cfg.pixel_encoder.ms_dims
self.key_dim = model_cfg.key_dim
self.value_dim = model_cfg.value_dim
self.sensory_dim = model_cfg.sensory_dim
self.pixel_dim = model_cfg.pixel_dim
self.embed_dim = model_cfg.embed_dim
self.single_object = single_object
log.info(f'Single object: {self.single_object}')
self.pixel_encoder = PixelEncoder(model_cfg)
self.pix_feat_proj = nn.Conv2d(self.ms_dims[0], self.pixel_dim, kernel_size=1)
self.key_proj = KeyProjection(model_cfg)
self.mask_encoder = MaskEncoder(model_cfg, single_object=single_object)
self.mask_decoder = MaskDecoder(model_cfg)
self.pixel_fuser = PixelFeatureFuser(model_cfg, single_object=single_object)
self.object_transformer = QueryTransformer(model_cfg)
self.object_summarizer = ObjectSummarizer(model_cfg)
self.aux_computer = AuxComputer(cfg)
self.temp_sparity = UncertPred(model_cfg)
self.register_buffer("pixel_mean", torch.Tensor(model_cfg.pixel_mean).view(-1, 1, 1), False)
self.register_buffer("pixel_std", torch.Tensor(model_cfg.pixel_std).view(-1, 1, 1), False)
def _get_others(self, masks: torch.Tensor) -> torch.Tensor:
# for each object, return the sum of masks of all other objects
if self.single_object:
return None
num_objects = masks.shape[1]
if num_objects >= 1:
others = (masks.sum(dim=1, keepdim=True) - masks).clamp(0, 1)
else:
others = torch.zeros_like(masks)
return others
def pred_uncertainty(self, last_pix_feat: torch.Tensor, cur_pix_feat: torch.Tensor, last_mask: torch.Tensor, mem_val_diff:torch.Tensor):
logits = self.temp_sparity(last_frame_feat=last_pix_feat,
cur_frame_feat=cur_pix_feat,
last_mask=last_mask,
mem_val_diff=mem_val_diff)
prob = torch.sigmoid(logits)
mask = (prob > 0) + 0
uncert_output = {"logits": logits,
"prob": prob,
"mask": mask}
return uncert_output
def encode_image(self, image: torch.Tensor, seq_length=None, last_feats=None) -> (Iterable[torch.Tensor], torch.Tensor): # type: ignore
image = (image - self.pixel_mean) / self.pixel_std
ms_image_feat = self.pixel_encoder(image, seq_length) # f16, f8, f4, f2, f1
return ms_image_feat, self.pix_feat_proj(ms_image_feat[0])
def encode_mask(
self,
image: torch.Tensor,
ms_features: List[torch.Tensor],
sensory: torch.Tensor,
masks: torch.Tensor,
*,
deep_update: bool = True,
chunk_size: int = -1,
need_weights: bool = False) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
image = (image - self.pixel_mean) / self.pixel_std
others = self._get_others(masks)
mask_value, new_sensory = self.mask_encoder(image,
ms_features,
sensory,
masks,
others,
deep_update=deep_update,
chunk_size=chunk_size)
object_summaries, object_logits = self.object_summarizer(masks, mask_value, need_weights)
return mask_value, new_sensory, object_summaries, object_logits
def transform_key(self,
final_pix_feat: torch.Tensor,
*,
need_sk: bool = True,
need_ek: bool = True) -> (torch.Tensor, torch.Tensor, torch.Tensor):
key, shrinkage, selection = self.key_proj(final_pix_feat, need_s=need_sk, need_e=need_ek)
return key, shrinkage, selection
# Used in training only.
# This step is replaced by MemoryManager in test time
def read_memory(self, query_key: torch.Tensor, query_selection: torch.Tensor,
memory_key: torch.Tensor, memory_shrinkage: torch.Tensor,
msk_value: torch.Tensor, obj_memory: torch.Tensor, pix_feat: torch.Tensor,
sensory: torch.Tensor, last_mask: torch.Tensor,
selector: torch.Tensor, uncert_output=None, seg_pass=False,
last_pix_feat=None, last_pred_mask=None) -> (torch.Tensor, Dict[str, torch.Tensor]):
"""
query_key : B * CK * H * W
query_selection : B * CK * H * W
memory_key : B * CK * T * H * W
memory_shrinkage: B * 1 * T * H * W
msk_value : B * num_objects * CV * T * H * W
obj_memory : B * num_objects * T * num_summaries * C
pixel_feature : B * C * H * W
"""
batch_size, num_objects = msk_value.shape[:2]
uncert_mask = uncert_output["mask"] if uncert_output is not None else None
# read using visual attention
with torch.cuda.amp.autocast(enabled=False):
affinity = get_affinity(memory_key.float(), memory_shrinkage.float(), query_key.float(),
query_selection.float(), uncert_mask=uncert_mask)
msk_value = msk_value.flatten(start_dim=1, end_dim=2).float()
# B * (num_objects*CV) * H * W
pixel_readout = readout(affinity, msk_value, uncert_mask)
pixel_readout = pixel_readout.view(batch_size, num_objects, self.value_dim,
*pixel_readout.shape[-2:])
uncert_output = self.pred_uncertainty(last_pix_feat, pix_feat, last_pred_mask, pixel_readout[:,0]-msk_value[:,:,-1])
uncert_prob = uncert_output["prob"].unsqueeze(1) # b n 1 h w
pixel_readout = pixel_readout*uncert_prob + msk_value[:,:,-1].unsqueeze(1)*(1-uncert_prob)
pixel_readout = self.pixel_fusion(pix_feat, pixel_readout, sensory, last_mask)
# read from query transformer
mem_readout, aux_features = self.readout_query(pixel_readout, obj_memory, selector=selector, seg_pass=seg_pass)
aux_output = {
'sensory': sensory,
'q_logits': aux_features['logits'] if aux_features else None,
'attn_mask': aux_features['attn_mask'] if aux_features else None,
}
return mem_readout, aux_output, uncert_output
def read_first_frame_memory(self, pixel_readout,
obj_memory: torch.Tensor, pix_feat: torch.Tensor,
sensory: torch.Tensor, last_mask: torch.Tensor,
selector: torch.Tensor, seg_pass=False) -> (torch.Tensor, Dict[str, torch.Tensor]):
"""
query_key : B * CK * H * W
query_selection : B * CK * H * W
memory_key : B * CK * T * H * W
memory_shrinkage: B * 1 * T * H * W
msk_value : B * num_objects * CV * T * H * W
obj_memory : B * num_objects * T * num_summaries * C
pixel_feature : B * C * H * W
"""
pixel_readout = self.pixel_fusion(pix_feat, pixel_readout, sensory, last_mask)
# read from query transformer
mem_readout, aux_features = self.readout_query(pixel_readout, obj_memory, selector=selector, seg_pass=seg_pass)
aux_output = {
'sensory': sensory,
'q_logits': aux_features['logits'] if aux_features else None,
'attn_mask': aux_features['attn_mask'] if aux_features else None,
}
return mem_readout, aux_output
def pixel_fusion(self,
pix_feat: torch.Tensor,
pixel: torch.Tensor,
sensory: torch.Tensor,
last_mask: torch.Tensor,
*,
chunk_size: int = -1) -> torch.Tensor:
last_mask = F.interpolate(last_mask, size=sensory.shape[-2:], mode='area')
last_others = self._get_others(last_mask)
fused = self.pixel_fuser(pix_feat,
pixel,
sensory,
last_mask,
last_others,
chunk_size=chunk_size)
return fused
def readout_query(self,
pixel_readout,
obj_memory,
*,
selector=None,
need_weights=False,
seg_pass=False) -> (torch.Tensor, Dict[str, torch.Tensor]):
return self.object_transformer(pixel_readout,
obj_memory,
selector=selector,
need_weights=need_weights,
seg_pass=seg_pass)
def segment(self,
ms_image_feat: List[torch.Tensor],
memory_readout: torch.Tensor,
sensory: torch.Tensor,
*,
selector: bool = None,
chunk_size: int = -1,
update_sensory: bool = True,
seg_pass: bool = False,
clamp_mat: bool = True,
last_mask=None,
sigmoid_residual=False,
seg_mat=False) -> (torch.Tensor, torch.Tensor, torch.Tensor):
"""
multi_scale_features is from the key encoder for skip-connection
memory_readout is from working/long-term memory
sensory is the sensory memory
last_mask is the mask from the last frame, supplementing sensory memory
selector is 1 if an object exists, and 0 otherwise. We use it to filter padded objects
during training.
"""
#### use mat head for seg data
if seg_mat:
assert seg_pass
seg_pass = False
####
sensory, logits = self.mask_decoder(ms_image_feat,
memory_readout,
sensory,
chunk_size=chunk_size,
update_sensory=update_sensory,
seg_pass = seg_pass,
last_mask=last_mask,
sigmoid_residual=sigmoid_residual)
if seg_pass:
prob = torch.sigmoid(logits)
if selector is not None:
prob = prob * selector
# Softmax over all objects[]
logits = aggregate(prob, dim=1)
prob = F.softmax(logits, dim=1)
else:
if clamp_mat:
logits = logits.clamp(0.0, 1.0)
logits = torch.cat([torch.prod(1 - logits, dim=1, keepdim=True), logits], 1)
prob = logits
return sensory, logits, prob
def compute_aux(self, pix_feat: torch.Tensor, aux_inputs: Dict[str, torch.Tensor],
selector: torch.Tensor, seg_pass=False) -> Dict[str, torch.Tensor]:
return self.aux_computer(pix_feat, aux_inputs, selector, seg_pass=seg_pass)
def forward(self, *args, **kwargs):
raise NotImplementedError
def load_weights(self, src_dict, init_as_zero_if_needed=False) -> None:
if not self.single_object:
# Map single-object weight to multi-object weight (4->5 out channels in conv1)
for k in list(src_dict.keys()):
if k == 'mask_encoder.conv1.weight':
if src_dict[k].shape[1] == 4:
log.info(f'Converting {k} from single object to multiple objects.')
pads = torch.zeros((64, 1, 7, 7), device=src_dict[k].device)
if not init_as_zero_if_needed:
nn.init.orthogonal_(pads)
log.info(f'Randomly initialized padding for {k}.')
else:
log.info(f'Zero-initialized padding for {k}.')
src_dict[k] = torch.cat([src_dict[k], pads], 1)
elif k == 'pixel_fuser.sensory_compress.weight':
if src_dict[k].shape[1] == self.sensory_dim + 1:
log.info(f'Converting {k} from single object to multiple objects.')
pads = torch.zeros((self.value_dim, 1, 1, 1), device=src_dict[k].device)
if not init_as_zero_if_needed:
nn.init.orthogonal_(pads)
log.info(f'Randomly initialized padding for {k}.')
else:
log.info(f'Zero-initialized padding for {k}.')
src_dict[k] = torch.cat([src_dict[k], pads], 1)
elif self.single_object:
"""
If the model is multiple-object and we are training in single-object,
we strip the last channel of conv1.
This is not supposed to happen in standard training except when users are trying to
finetune a trained model with single object datasets.
"""
if src_dict['mask_encoder.conv1.weight'].shape[1] == 5:
log.warning('Converting mask_encoder.conv1.weight from multiple objects to single object.'
'This is not supposed to happen in standard training.')
src_dict['mask_encoder.conv1.weight'] = src_dict['mask_encoder.conv1.weight'][:, :-1]
src_dict['pixel_fuser.sensory_compress.weight'] = src_dict['pixel_fuser.sensory_compress.weight'][:, :-1]
for k in src_dict:
if k not in self.state_dict():
log.info(f'Key {k} found in src_dict but not in self.state_dict()!!!')
for k in self.state_dict():
if k not in src_dict:
log.info(f'Key {k} found in self.state_dict() but not in src_dict!!!')
self.load_state_dict(src_dict, strict=False)
@property
def device(self) -> torch.device:
return self.pixel_mean.device

View File

@ -0,0 +1,149 @@
from typing import List, Iterable
import torch
import torch.nn as nn
import torch.nn.functional as F
from .group_modules import MainToGroupDistributor, GroupResBlock, upsample_groups, GConv2d, downsample_groups
class UpsampleBlock(nn.Module):
def __init__(self, in_dim: int, out_dim: int, scale_factor: int = 2):
super().__init__()
self.out_conv = ResBlock(in_dim, out_dim)
self.scale_factor = scale_factor
def forward(self, in_g: torch.Tensor, skip_f: torch.Tensor) -> torch.Tensor:
g = F.interpolate(in_g,
scale_factor=self.scale_factor,
mode='bilinear')
g = self.out_conv(g)
g = g + skip_f
return g
class MaskUpsampleBlock(nn.Module):
def __init__(self, in_dim: int, out_dim: int, scale_factor: int = 2):
super().__init__()
self.distributor = MainToGroupDistributor(method='add')
self.out_conv = GroupResBlock(in_dim, out_dim)
self.scale_factor = scale_factor
def forward(self, in_g: torch.Tensor, skip_f: torch.Tensor) -> torch.Tensor:
g = upsample_groups(in_g, ratio=self.scale_factor)
g = self.distributor(skip_f, g)
g = self.out_conv(g)
return g
class DecoderFeatureProcessor(nn.Module):
def __init__(self, decoder_dims: List[int], out_dims: List[int]):
super().__init__()
self.transforms = nn.ModuleList([
nn.Conv2d(d_dim, p_dim, kernel_size=1) for d_dim, p_dim in zip(decoder_dims, out_dims)
])
def forward(self, multi_scale_features: Iterable[torch.Tensor]) -> List[torch.Tensor]:
outputs = [func(x) for x, func in zip(multi_scale_features, self.transforms)]
return outputs
# @torch.jit.script
def _recurrent_update(h: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
# h: batch_size * num_objects * hidden_dim * h * w
# values: batch_size * num_objects * (hidden_dim*3) * h * w
dim = values.shape[2] // 3
forget_gate = torch.sigmoid(values[:, :, :dim])
update_gate = torch.sigmoid(values[:, :, dim:dim * 2])
new_value = torch.tanh(values[:, :, dim * 2:])
new_h = forget_gate * h * (1 - update_gate) + update_gate * new_value
return new_h
class SensoryUpdater_fullscale(nn.Module):
# Used in the decoder, multi-scale feature + GRU
def __init__(self, g_dims: List[int], mid_dim: int, sensory_dim: int):
super().__init__()
self.g16_conv = GConv2d(g_dims[0], mid_dim, kernel_size=1)
self.g8_conv = GConv2d(g_dims[1], mid_dim, kernel_size=1)
self.g4_conv = GConv2d(g_dims[2], mid_dim, kernel_size=1)
self.g2_conv = GConv2d(g_dims[3], mid_dim, kernel_size=1)
self.g1_conv = GConv2d(g_dims[4], mid_dim, kernel_size=1)
self.transform = GConv2d(mid_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1)
nn.init.xavier_normal_(self.transform.weight)
def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \
self.g4_conv(downsample_groups(g[2], ratio=1/4)) + \
self.g2_conv(downsample_groups(g[3], ratio=1/8)) + \
self.g1_conv(downsample_groups(g[4], ratio=1/16))
with torch.amp.autocast("cuda"):
g = g.float()
h = h.float()
values = self.transform(torch.cat([g, h], dim=2))
new_h = _recurrent_update(h, values)
return new_h
class SensoryUpdater(nn.Module):
# Used in the decoder, multi-scale feature + GRU
def __init__(self, g_dims: List[int], mid_dim: int, sensory_dim: int):
super().__init__()
self.g16_conv = GConv2d(g_dims[0], mid_dim, kernel_size=1)
self.g8_conv = GConv2d(g_dims[1], mid_dim, kernel_size=1)
self.g4_conv = GConv2d(g_dims[2], mid_dim, kernel_size=1)
self.transform = GConv2d(mid_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1)
nn.init.xavier_normal_(self.transform.weight)
def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \
self.g4_conv(downsample_groups(g[2], ratio=1/4))
with torch.amp.autocast("cuda"):
g = g.float()
h = h.float()
values = self.transform(torch.cat([g, h], dim=2))
new_h = _recurrent_update(h, values)
return new_h
class SensoryDeepUpdater(nn.Module):
def __init__(self, f_dim: int, sensory_dim: int):
super().__init__()
self.transform = GConv2d(f_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1)
nn.init.xavier_normal_(self.transform.weight)
def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
with torch.amp.autocast("cuda"):
g = g.float()
h = h.float()
values = self.transform(torch.cat([g, h], dim=2))
new_h = _recurrent_update(h, values)
return new_h
class ResBlock(nn.Module):
def __init__(self, in_dim: int, out_dim: int):
super().__init__()
if in_dim == out_dim:
self.downsample = nn.Identity()
else:
self.downsample = nn.Conv2d(in_dim, out_dim, kernel_size=1)
self.conv1 = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1)
def forward(self, g: torch.Tensor) -> torch.Tensor:
out_g = self.conv1(F.relu(g))
out_g = self.conv2(F.relu(out_g))
g = self.downsample(g)
return out_g + g

View File

@ -0,0 +1,89 @@
from typing import Optional
from omegaconf import DictConfig
import torch
import torch.nn as nn
import torch.nn.functional as F
from .positional_encoding import PositionalEncoding
# @torch.jit.script
def _weighted_pooling(masks: torch.Tensor, value: torch.Tensor,
logits: torch.Tensor) -> (torch.Tensor, torch.Tensor):
# value: B*num_objects*H*W*value_dim
# logits: B*num_objects*H*W*num_summaries
# masks: B*num_objects*H*W*num_summaries: 1 if allowed
weights = logits.sigmoid() * masks
# B*num_objects*num_summaries*value_dim
sums = torch.einsum('bkhwq,bkhwc->bkqc', weights, value)
# B*num_objects*H*W*num_summaries -> B*num_objects*num_summaries*1
area = weights.flatten(start_dim=2, end_dim=3).sum(2).unsqueeze(-1)
# B*num_objects*num_summaries*value_dim
return sums, area
class ObjectSummarizer(nn.Module):
def __init__(self, model_cfg: DictConfig):
super().__init__()
this_cfg = model_cfg.object_summarizer
self.value_dim = model_cfg.value_dim
self.embed_dim = this_cfg.embed_dim
self.num_summaries = this_cfg.num_summaries
self.add_pe = this_cfg.add_pe
self.pixel_pe_scale = model_cfg.pixel_pe_scale
self.pixel_pe_temperature = model_cfg.pixel_pe_temperature
if self.add_pe:
self.pos_enc = PositionalEncoding(self.embed_dim,
scale=self.pixel_pe_scale,
temperature=self.pixel_pe_temperature)
self.input_proj = nn.Linear(self.value_dim, self.embed_dim)
self.feature_pred = nn.Sequential(
nn.Linear(self.embed_dim, self.embed_dim),
nn.ReLU(inplace=True),
nn.Linear(self.embed_dim, self.embed_dim),
)
self.weights_pred = nn.Sequential(
nn.Linear(self.embed_dim, self.embed_dim),
nn.ReLU(inplace=True),
nn.Linear(self.embed_dim, self.num_summaries),
)
def forward(self,
masks: torch.Tensor,
value: torch.Tensor,
need_weights: bool = False) -> (torch.Tensor, Optional[torch.Tensor]):
# masks: B*num_objects*(H0)*(W0)
# value: B*num_objects*value_dim*H*W
# -> B*num_objects*H*W*value_dim
h, w = value.shape[-2:]
masks = F.interpolate(masks, size=(h, w), mode='area')
masks = masks.unsqueeze(-1)
inv_masks = 1 - masks
repeated_masks = torch.cat([
masks.expand(-1, -1, -1, -1, self.num_summaries // 2),
inv_masks.expand(-1, -1, -1, -1, self.num_summaries // 2),
],
dim=-1)
value = value.permute(0, 1, 3, 4, 2)
value = self.input_proj(value)
if self.add_pe:
pe = self.pos_enc(value)
value = value + pe
with torch.amp.autocast("cuda"):
value = value.float()
feature = self.feature_pred(value)
logits = self.weights_pred(value)
sums, area = _weighted_pooling(repeated_masks, feature, logits)
summaries = torch.cat([sums, area], dim=-1)
if need_weights:
return summaries, logits
else:
return summaries, None

View File

@ -0,0 +1,206 @@
from typing import Dict, Optional
from omegaconf import DictConfig
import torch
import torch.nn as nn
from ..group_modules import GConv2d
from ....utils.tensor_utils import aggregate
from .positional_encoding import PositionalEncoding
from .transformer_layers import CrossAttention, SelfAttention, FFN, PixelFFN
class QueryTransformerBlock(nn.Module):
def __init__(self, model_cfg: DictConfig):
super().__init__()
this_cfg = model_cfg.object_transformer
self.embed_dim = this_cfg.embed_dim
self.num_heads = this_cfg.num_heads
self.num_queries = this_cfg.num_queries
self.ff_dim = this_cfg.ff_dim
self.read_from_pixel = CrossAttention(self.embed_dim,
self.num_heads,
add_pe_to_qkv=this_cfg.read_from_pixel.add_pe_to_qkv)
self.self_attn = SelfAttention(self.embed_dim,
self.num_heads,
add_pe_to_qkv=this_cfg.query_self_attention.add_pe_to_qkv)
self.ffn = FFN(self.embed_dim, self.ff_dim)
self.read_from_query = CrossAttention(self.embed_dim,
self.num_heads,
add_pe_to_qkv=this_cfg.read_from_query.add_pe_to_qkv,
norm=this_cfg.read_from_query.output_norm)
self.pixel_ffn = PixelFFN(self.embed_dim)
def forward(
self,
x: torch.Tensor,
pixel: torch.Tensor,
query_pe: torch.Tensor,
pixel_pe: torch.Tensor,
attn_mask: torch.Tensor,
need_weights: bool = False) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
# x: (bs*num_objects)*num_queries*embed_dim
# pixel: bs*num_objects*C*H*W
# query_pe: (bs*num_objects)*num_queries*embed_dim
# pixel_pe: (bs*num_objects)*(H*W)*C
# attn_mask: (bs*num_objects*num_heads)*num_queries*(H*W)
# bs*num_objects*C*H*W -> (bs*num_objects)*(H*W)*C
pixel_flat = pixel.flatten(3, 4).flatten(0, 1).transpose(1, 2).contiguous()
x, q_weights = self.read_from_pixel(x,
pixel_flat,
query_pe,
pixel_pe,
attn_mask=attn_mask,
need_weights=need_weights)
x = self.self_attn(x, query_pe)
x = self.ffn(x)
pixel_flat, p_weights = self.read_from_query(pixel_flat,
x,
pixel_pe,
query_pe,
need_weights=need_weights)
pixel = self.pixel_ffn(pixel, pixel_flat)
if need_weights:
bs, num_objects, _, h, w = pixel.shape
q_weights = q_weights.view(bs, num_objects, self.num_heads, self.num_queries, h, w)
p_weights = p_weights.transpose(2, 3).view(bs, num_objects, self.num_heads,
self.num_queries, h, w)
return x, pixel, q_weights, p_weights
class QueryTransformer(nn.Module):
def __init__(self, model_cfg: DictConfig):
super().__init__()
this_cfg = model_cfg.object_transformer
self.value_dim = model_cfg.value_dim
self.embed_dim = this_cfg.embed_dim
self.num_heads = this_cfg.num_heads
self.num_queries = this_cfg.num_queries
# query initialization and embedding
self.query_init = nn.Embedding(self.num_queries, self.embed_dim)
self.query_emb = nn.Embedding(self.num_queries, self.embed_dim)
# projection from object summaries to query initialization and embedding
self.summary_to_query_init = nn.Linear(self.embed_dim, self.embed_dim)
self.summary_to_query_emb = nn.Linear(self.embed_dim, self.embed_dim)
self.pixel_pe_scale = model_cfg.pixel_pe_scale
self.pixel_pe_temperature = model_cfg.pixel_pe_temperature
self.pixel_init_proj = GConv2d(self.embed_dim, self.embed_dim, kernel_size=1)
self.pixel_emb_proj = GConv2d(self.embed_dim, self.embed_dim, kernel_size=1)
self.spatial_pe = PositionalEncoding(self.embed_dim,
scale=self.pixel_pe_scale,
temperature=self.pixel_pe_temperature,
channel_last=False,
transpose_output=True)
# transformer blocks
self.num_blocks = this_cfg.num_blocks
self.blocks = nn.ModuleList(
QueryTransformerBlock(model_cfg) for _ in range(self.num_blocks))
self.mask_pred = nn.ModuleList(
nn.Sequential(nn.ReLU(), GConv2d(self.embed_dim, 1, kernel_size=1))
for _ in range(self.num_blocks + 1))
self.act = nn.ReLU(inplace=True)
def forward(self,
pixel: torch.Tensor,
obj_summaries: torch.Tensor,
selector: Optional[torch.Tensor] = None,
need_weights: bool = False,
seg_pass=False) -> (torch.Tensor, Dict[str, torch.Tensor]):
# pixel: B*num_objects*embed_dim*H*W
# obj_summaries: B*num_objects*T*num_queries*embed_dim
T = obj_summaries.shape[2]
bs, num_objects, _, H, W = pixel.shape
# normalize object values
# the last channel is the cumulative area of the object
obj_summaries = obj_summaries.view(bs * num_objects, T, self.num_queries,
self.embed_dim + 1)
# sum over time
# during inference, T=1 as we already did streaming average in memory_manager
obj_sums = obj_summaries[:, :, :, :-1].sum(dim=1)
obj_area = obj_summaries[:, :, :, -1:].sum(dim=1)
obj_values = obj_sums / (obj_area + 1e-4)
obj_init = self.summary_to_query_init(obj_values)
obj_emb = self.summary_to_query_emb(obj_values)
# positional embeddings for object queries
query = self.query_init.weight.unsqueeze(0).expand(bs * num_objects, -1, -1) + obj_init
query_emb = self.query_emb.weight.unsqueeze(0).expand(bs * num_objects, -1, -1) + obj_emb
# positional embeddings for pixel features
pixel_init = self.pixel_init_proj(pixel)
pixel_emb = self.pixel_emb_proj(pixel)
pixel_pe = self.spatial_pe(pixel.flatten(0, 1))
pixel_emb = pixel_emb.flatten(3, 4).flatten(0, 1).transpose(1, 2).contiguous()
pixel_pe = pixel_pe.flatten(1, 2) + pixel_emb
pixel = pixel_init
# run the transformer
aux_features = {'logits': []}
# first aux output
aux_logits = self.mask_pred[0](pixel).squeeze(2)
attn_mask = self._get_aux_mask(aux_logits, selector, seg_pass=seg_pass)
aux_features['logits'].append(aux_logits)
for i in range(self.num_blocks):
query, pixel, q_weights, p_weights = self.blocks[i](query,
pixel,
query_emb,
pixel_pe,
attn_mask,
need_weights=need_weights)
if self.training or i <= self.num_blocks - 1 or need_weights:
aux_logits = self.mask_pred[i + 1](pixel).squeeze(2)
attn_mask = self._get_aux_mask(aux_logits, selector, seg_pass=seg_pass)
aux_features['logits'].append(aux_logits)
aux_features['q_weights'] = q_weights # last layer only
aux_features['p_weights'] = p_weights # last layer only
if self.training:
# no need to save all heads
aux_features['attn_mask'] = attn_mask.view(bs, num_objects, self.num_heads,
self.num_queries, H, W)[:, :, 0]
return pixel, aux_features
def _get_aux_mask(self, logits: torch.Tensor, selector: torch.Tensor, seg_pass=False) -> torch.Tensor:
# logits: batch_size*num_objects*H*W
# selector: batch_size*num_objects*1*1
# returns a mask of shape (batch_size*num_objects*num_heads)*num_queries*(H*W)
# where True means the attention is blocked
if selector is None:
prob = logits.sigmoid()
else:
prob = logits.sigmoid() * selector
logits = aggregate(prob, dim=1)
is_foreground = (logits[:, 1:] >= logits.max(dim=1, keepdim=True)[0])
foreground_mask = is_foreground.bool().flatten(start_dim=2)
inv_foreground_mask = ~foreground_mask
inv_background_mask = foreground_mask
aux_foreground_mask = inv_foreground_mask.unsqueeze(2).unsqueeze(2).repeat(
1, 1, self.num_heads, self.num_queries // 2, 1).flatten(start_dim=0, end_dim=2)
aux_background_mask = inv_background_mask.unsqueeze(2).unsqueeze(2).repeat(
1, 1, self.num_heads, self.num_queries // 2, 1).flatten(start_dim=0, end_dim=2)
aux_mask = torch.cat([aux_foreground_mask, aux_background_mask], dim=1)
aux_mask[torch.where(aux_mask.sum(-1) == aux_mask.shape[-1])] = False
return aux_mask

View File

@ -0,0 +1,108 @@
# Reference:
# https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/transformer_decoder/position_encoding.py
# https://github.com/tatp22/multidim-positional-encoding/blob/master/positional_encodings/torch_encodings.py
import math
import numpy as np
import torch
from torch import nn
def get_emb(sin_inp: torch.Tensor) -> torch.Tensor:
"""
Gets a base embedding for one dimension with sin and cos intertwined
"""
emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
return torch.flatten(emb, -2, -1)
class PositionalEncoding(nn.Module):
def __init__(self,
dim: int,
scale: float = math.pi * 2,
temperature: float = 10000,
normalize: bool = True,
channel_last: bool = True,
transpose_output: bool = False):
super().__init__()
dim = int(np.ceil(dim / 4) * 2)
self.dim = dim
inv_freq = 1.0 / (temperature**(torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.normalize = normalize
self.scale = scale
self.eps = 1e-6
self.channel_last = channel_last
self.transpose_output = transpose_output
self.cached_penc = None # the cache is irrespective of the number of objects
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
"""
:param tensor: A 4/5d tensor of size
channel_last=True: (batch_size, h, w, c) or (batch_size, k, h, w, c)
channel_last=False: (batch_size, c, h, w) or (batch_size, k, c, h, w)
:return: positional encoding tensor that has the same shape as the input if the input is 4d
if the input is 5d, the output is broadcastable along the k-dimension
"""
if len(tensor.shape) != 4 and len(tensor.shape) != 5:
raise RuntimeError(f'The input tensor has to be 4/5d, got {tensor.shape}!')
if len(tensor.shape) == 5:
# take a sample from the k dimension
num_objects = tensor.shape[1]
tensor = tensor[:, 0]
else:
num_objects = None
if self.channel_last:
batch_size, h, w, c = tensor.shape
else:
batch_size, c, h, w = tensor.shape
if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
if num_objects is None:
return self.cached_penc
else:
return self.cached_penc.unsqueeze(1)
self.cached_penc = None
pos_y = torch.arange(h, device=tensor.device, dtype=self.inv_freq.dtype)
pos_x = torch.arange(w, device=tensor.device, dtype=self.inv_freq.dtype)
if self.normalize:
pos_y = pos_y / (pos_y[-1] + self.eps) * self.scale
pos_x = pos_x / (pos_x[-1] + self.eps) * self.scale
sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
emb_y = get_emb(sin_inp_y).unsqueeze(1)
emb_x = get_emb(sin_inp_x)
emb = torch.zeros((h, w, self.dim * 2), device=tensor.device, dtype=tensor.dtype)
emb[:, :, :self.dim] = emb_x
emb[:, :, self.dim:] = emb_y
if not self.channel_last and self.transpose_output:
# cancelled out
pass
elif (not self.channel_last) or (self.transpose_output):
emb = emb.permute(2, 0, 1)
self.cached_penc = emb.unsqueeze(0).repeat(batch_size, 1, 1, 1)
if num_objects is None:
return self.cached_penc
else:
return self.cached_penc.unsqueeze(1)
if __name__ == '__main__':
pe = PositionalEncoding(8).cuda()
input = torch.ones((1, 8, 8, 8)).cuda()
output = pe(input)
# print(output)
print(output[0, :, 0, 0])
print(output[0, :, 0, 5])
print(output[0, 0, :, 0])
print(output[0, 0, 0, :])

View File

@ -0,0 +1,161 @@
# Modified from PyTorch nn.Transformer
from typing import List, Callable
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from ...model.channel_attn import CAResBlock
class SelfAttention(nn.Module):
def __init__(self,
dim: int,
nhead: int,
dropout: float = 0.0,
batch_first: bool = True,
add_pe_to_qkv: List[bool] = [True, True, False]):
super().__init__()
self.self_attn = nn.MultiheadAttention(dim, nhead, dropout=dropout, batch_first=batch_first)
self.norm = nn.LayerNorm(dim)
self.dropout = nn.Dropout(dropout)
self.add_pe_to_qkv = add_pe_to_qkv
def forward(self,
x: torch.Tensor,
pe: torch.Tensor,
attn_mask: bool = None,
key_padding_mask: bool = None) -> torch.Tensor:
x = self.norm(x)
if any(self.add_pe_to_qkv):
x_with_pe = x + pe
q = x_with_pe if self.add_pe_to_qkv[0] else x
k = x_with_pe if self.add_pe_to_qkv[1] else x
v = x_with_pe if self.add_pe_to_qkv[2] else x
else:
q = k = v = x
r = x
x = self.self_attn(q, k, v, attn_mask=attn_mask, key_padding_mask=key_padding_mask)[0]
return r + self.dropout(x)
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention
class CrossAttention(nn.Module):
def __init__(self,
dim: int,
nhead: int,
dropout: float = 0.0,
batch_first: bool = True,
add_pe_to_qkv: List[bool] = [True, True, False],
residual: bool = True,
norm: bool = True):
super().__init__()
self.cross_attn = nn.MultiheadAttention(dim,
nhead,
dropout=dropout,
batch_first=batch_first)
if norm:
self.norm = nn.LayerNorm(dim)
else:
self.norm = nn.Identity()
self.dropout = nn.Dropout(dropout)
self.add_pe_to_qkv = add_pe_to_qkv
self.residual = residual
def forward(self,
x: torch.Tensor,
mem: torch.Tensor,
x_pe: torch.Tensor,
mem_pe: torch.Tensor,
attn_mask: bool = None,
*,
need_weights: bool = False) -> (torch.Tensor, torch.Tensor):
x = self.norm(x)
if self.add_pe_to_qkv[0]:
q = x + x_pe
else:
q = x
if any(self.add_pe_to_qkv[1:]):
mem_with_pe = mem + mem_pe
k = mem_with_pe if self.add_pe_to_qkv[1] else mem
v = mem_with_pe if self.add_pe_to_qkv[2] else mem
else:
k = v = mem
r = x
x, weights = self.cross_attn(q,
k,
v,
attn_mask=attn_mask,
need_weights=need_weights,
average_attn_weights=False)
if self.residual:
return r + self.dropout(x), weights
else:
return self.dropout(x), weights
class FFN(nn.Module):
def __init__(self, dim_in: int, dim_ff: int, activation=F.relu):
super().__init__()
self.linear1 = nn.Linear(dim_in, dim_ff)
self.linear2 = nn.Linear(dim_ff, dim_in)
self.norm = nn.LayerNorm(dim_in)
if isinstance(activation, str):
self.activation = _get_activation_fn(activation)
else:
self.activation = activation
def forward(self, x: torch.Tensor) -> torch.Tensor:
r = x
x = self.norm(x)
x = self.linear2(self.activation(self.linear1(x)))
x = r + x
return x
class PixelFFN(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.dim = dim
self.conv = CAResBlock(dim, dim)
def forward(self, pixel: torch.Tensor, pixel_flat: torch.Tensor) -> torch.Tensor:
# pixel: batch_size * num_objects * dim * H * W
# pixel_flat: (batch_size*num_objects) * (H*W) * dim
bs, num_objects, _, h, w = pixel.shape
pixel_flat = pixel_flat.view(bs * num_objects, h, w, self.dim)
pixel_flat = pixel_flat.permute(0, 3, 1, 2).contiguous()
x = self.conv(pixel_flat)
x = x.view(bs, num_objects, self.dim, h, w)
return x
class OutputFFN(nn.Module):
def __init__(self, dim_in: int, dim_out: int, activation=F.relu):
super().__init__()
self.linear1 = nn.Linear(dim_in, dim_out)
self.linear2 = nn.Linear(dim_out, dim_out)
if isinstance(activation, str):
self.activation = _get_activation_fn(activation)
else:
self.activation = activation
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear2(self.activation(self.linear1(x)))
return x
def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
if activation == "relu":
return F.relu
elif activation == "gelu":
return F.gelu
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))

View File

@ -0,0 +1,107 @@
import math
import torch
from typing import Optional, Union, Tuple
# @torch.jit.script
def get_similarity(mk: torch.Tensor,
ms: torch.Tensor,
qk: torch.Tensor,
qe: torch.Tensor,
add_batch_dim: bool = False,
uncert_mask = None) -> torch.Tensor:
# used for training/inference and memory reading/memory potentiation
# mk: B x CK x [N] - Memory keys
# ms: B x 1 x [N] - Memory shrinkage
# qk: B x CK x [HW/P] - Query keys
# qe: B x CK x [HW/P] - Query selection
# Dimensions in [] are flattened
# Return: B*N*HW
if add_batch_dim:
mk, ms = mk.unsqueeze(0), ms.unsqueeze(0)
qk, qe = qk.unsqueeze(0), qe.unsqueeze(0)
CK = mk.shape[1]
mk = mk.flatten(start_dim=2)
ms = ms.flatten(start_dim=1).unsqueeze(2) if ms is not None else None
qk = qk.flatten(start_dim=2)
qe = qe.flatten(start_dim=2) if qe is not None else None
# query token selection based on temporal sparsity
if uncert_mask is not None:
uncert_mask = uncert_mask.flatten(start_dim=2)
uncert_mask = uncert_mask.expand(-1, 64, -1)
qk = qk * uncert_mask
qe = qe * uncert_mask
if qe is not None:
# See XMem's appendix for derivation
mk = mk.transpose(1, 2)
a_sq = (mk.pow(2) @ qe)
two_ab = 2 * (mk @ (qk * qe))
b_sq = (qe * qk.pow(2)).sum(1, keepdim=True)
similarity = (-a_sq + two_ab - b_sq)
else:
# similar to STCN if we don't have the selection term
a_sq = mk.pow(2).sum(1).unsqueeze(2)
two_ab = 2 * (mk.transpose(1, 2) @ qk)
similarity = (-a_sq + two_ab)
if ms is not None:
similarity = similarity * ms / math.sqrt(CK) # B*N*HW
else:
similarity = similarity / math.sqrt(CK) # B*N*HW
return similarity
def do_softmax(
similarity: torch.Tensor,
top_k: Optional[int] = None,
inplace: bool = False,
return_usage: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# normalize similarity with top-k softmax
# similarity: B x N x [HW/P]
# use inplace with care
if top_k is not None:
values, indices = torch.topk(similarity, k=top_k, dim=1)
x_exp = values.exp_()
x_exp /= torch.sum(x_exp, dim=1, keepdim=True)
if inplace:
similarity.zero_().scatter_(1, indices, x_exp) # B*N*HW
affinity = similarity
else:
affinity = torch.zeros_like(similarity).scatter_(1, indices, x_exp) # B*N*HW
else:
maxes = torch.max(similarity, dim=1, keepdim=True)[0]
x_exp = torch.exp(similarity - maxes)
x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True)
affinity = x_exp / x_exp_sum
indices = None
if return_usage:
return affinity, affinity.sum(dim=2)
return affinity
def get_affinity(mk: torch.Tensor, ms: torch.Tensor, qk: torch.Tensor,
qe: torch.Tensor, uncert_mask = None) -> torch.Tensor:
# shorthand used in training with no top-k
similarity = get_similarity(mk, ms, qk, qe, uncert_mask=uncert_mask)
affinity = do_softmax(similarity)
return affinity
def readout(affinity: torch.Tensor, mv: torch.Tensor, uncert_mask: torch.Tensor=None) -> torch.Tensor:
B, CV, T, H, W = mv.shape
mo = mv.view(B, CV, T * H * W)
mem = torch.bmm(mo, affinity)
if uncert_mask is not None:
uncert_mask = uncert_mask.flatten(start_dim=2).expand(-1, CV, -1)
mem = mem * uncert_mask
mem = mem.view(B, CV, H, W)
return mem

View File

@ -0,0 +1,72 @@
import logging
log = logging.getLogger()
def get_parameter_groups(model, stage_cfg, print_log=False):
"""
Assign different weight decays and learning rates to different parameters.
Returns a parameter group which can be passed to the optimizer.
"""
weight_decay = stage_cfg.weight_decay
embed_weight_decay = stage_cfg.embed_weight_decay
backbone_lr_ratio = stage_cfg.backbone_lr_ratio
base_lr = stage_cfg.learning_rate
backbone_params = []
embed_params = []
other_params = []
embedding_names = ['summary_pos', 'query_init', 'query_emb', 'obj_pe']
embedding_names = [e + '.weight' for e in embedding_names]
# inspired by detectron2
memo = set()
for name, param in model.named_parameters():
if not param.requires_grad:
continue
# Avoid duplicating parameters
if param in memo:
continue
memo.add(param)
if name.startswith('module'):
name = name[7:]
inserted = False
if name.startswith('pixel_encoder.'):
backbone_params.append(param)
inserted = True
if print_log:
log.info(f'{name} counted as a backbone parameter.')
else:
for e in embedding_names:
if name.endswith(e):
embed_params.append(param)
inserted = True
if print_log:
log.info(f'{name} counted as an embedding parameter.')
break
if not inserted:
other_params.append(param)
parameter_groups = [
{
'params': backbone_params,
'lr': base_lr * backbone_lr_ratio,
'weight_decay': weight_decay
},
{
'params': embed_params,
'lr': base_lr,
'weight_decay': embed_weight_decay
},
{
'params': other_params,
'lr': base_lr,
'weight_decay': weight_decay
},
]
return parameter_groups

View File

@ -0,0 +1,179 @@
"""
resnet.py - A modified ResNet structure
We append extra channels to the first conv by some network surgery
"""
from collections import OrderedDict
import math
import torch
import torch.nn as nn
from torch.utils import model_zoo
def load_weights_add_extra_dim(target, source_state, extra_dim=1):
new_dict = OrderedDict()
for k1, v1 in target.state_dict().items():
if 'num_batches_tracked' not in k1:
if k1 in source_state:
tar_v = source_state[k1]
if v1.shape != tar_v.shape:
# Init the new segmentation channel with zeros
# print(v1.shape, tar_v.shape)
c, _, w, h = v1.shape
pads = torch.zeros((c, extra_dim, w, h), device=tar_v.device)
nn.init.orthogonal_(pads)
tar_v = torch.cat([tar_v, pads], 1)
new_dict[k1] = tar_v
target.load_state_dict(new_dict)
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
}
def conv3x3(in_planes, out_planes, stride=1, dilation=1):
return nn.Conv2d(in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
dilation=dilation,
bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes,
planes,
kernel_size=3,
stride=stride,
dilation=dilation,
padding=dilation,
bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers=(3, 4, 23, 3), extra_dim=0):
self.inplanes = 64
super(ResNet, self).__init__()
self.conv1 = nn.Conv2d(3 + extra_dim, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = [block(self.inplanes, planes, stride, downsample)]
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, dilation=dilation))
return nn.Sequential(*layers)
def resnet18(pretrained=True, extra_dim=0):
model = ResNet(BasicBlock, [2, 2, 2, 2], extra_dim)
if pretrained:
load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet18']), extra_dim)
return model
def resnet50(pretrained=True, extra_dim=0):
model = ResNet(Bottleneck, [3, 4, 6, 3], extra_dim)
if pretrained:
load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet50']), extra_dim)
return model

View File

@ -0,0 +1,73 @@
import tqdm
import torch
from torchvision.transforms.functional import to_tensor
import numpy as np
import random
import cv2
def gen_dilate(alpha, min_kernel_size, max_kernel_size):
kernel_size = random.randint(min_kernel_size, max_kernel_size)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size,kernel_size))
fg_and_unknown = np.array(np.not_equal(alpha, 0).astype(np.float32))
dilate = cv2.dilate(fg_and_unknown, kernel, iterations=1)*255
return dilate.astype(np.float32)
def gen_erosion(alpha, min_kernel_size, max_kernel_size):
kernel_size = random.randint(min_kernel_size, max_kernel_size)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size,kernel_size))
fg = np.array(np.equal(alpha, 255).astype(np.float32))
erode = cv2.erode(fg, kernel, iterations=1)*255
return erode.astype(np.float32)
@torch.inference_mode()
@torch.amp.autocast('cuda')
def matanyone(processor, frames_np, mask, r_erode=0, r_dilate=0, n_warmup=10):
"""
Args:
frames_np: [(H,W,C)]*n, uint8
mask: (H,W), uint8
Outputs:
com: [(H,W,C)]*n, uint8
pha: [(H,W,C)]*n, uint8
"""
# print(f'===== [r_erode] {r_erode}; [r_dilate] {r_dilate} =====')
bgr = (np.array([120, 255, 155], dtype=np.float32)/255).reshape((1, 1, 3))
objects = [1]
# [optional] erode & dilate on given seg mask
if r_dilate > 0:
mask = gen_dilate(mask, r_dilate, r_dilate)
if r_erode > 0:
mask = gen_erosion(mask, r_erode, r_erode)
mask = torch.from_numpy(mask).cuda()
frames_np = [frames_np[0]]* n_warmup + frames_np
frames = []
phas = []
for ti, frame_single in tqdm.tqdm(enumerate(frames_np)):
image = to_tensor(frame_single).cuda().float()
if ti == 0:
output_prob = processor.step(image, mask, objects=objects) # encode given mask
output_prob = processor.step(image, first_frame_pred=True) # clear past memory for warmup frames
else:
if ti <= n_warmup:
output_prob = processor.step(image, first_frame_pred=True) # clear past memory for warmup frames
else:
output_prob = processor.step(image)
# convert output probabilities to an object mask
mask = processor.output_prob_to_mask(output_prob)
pha = mask.unsqueeze(2).cpu().numpy()
com_np = frame_single / 255. * pha + bgr * (1 - pha)
# DONOT save the warmup frames
if ti > (n_warmup-1):
frames.append((com_np*255).astype(np.uint8))
phas.append((pha*255).astype(np.uint8))
return frames, phas

View File

@ -0,0 +1,141 @@
import time
import torch
import cv2
from PIL import Image, ImageDraw, ImageOps
import numpy as np
from typing import Union
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
import matplotlib.pyplot as plt
import PIL
from .mask_painter import mask_painter
class BaseSegmenter:
def __init__(self, SAM_checkpoint, model_type, device='cuda:0'):
"""
device: model device
SAM_checkpoint: path of SAM checkpoint
model_type: vit_b, vit_l, vit_h
"""
print(f"Initializing BaseSegmenter to {device}")
assert model_type in ['vit_b', 'vit_l', 'vit_h'], 'model_type must be vit_b, vit_l, or vit_h'
self.device = device
# SAM_checkpoint = None
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
from accelerate import init_empty_weights
# self.model = sam_model_registry[model_type](checkpoint=SAM_checkpoint)
with init_empty_weights():
self.model = sam_model_registry[model_type](checkpoint=SAM_checkpoint)
from mmgp import offload
# self.model.to(torch.float16)
# offload.save_model(self.model, "ckpts/mask/sam_vit_h_4b8939_fp16.safetensors")
offload.load_model_data(self.model, "ckpts/mask/sam_vit_h_4b8939_fp16.safetensors")
self.model.to(torch.float32) # need to be optimized, if not f32 crappy precision
self.model.to(device=self.device)
self.predictor = SamPredictor(self.model)
self.embedded = False
@torch.no_grad()
def set_image(self, image: np.ndarray):
# PIL.open(image_path) 3channel: RGB
# image embedding: avoid encode the same image multiple times
self.orignal_image = image
if self.embedded:
print('repeat embedding, please reset_image.')
return
self.predictor.set_image(image)
self.embedded = True
return
@torch.no_grad()
def reset_image(self):
# reset image embeding
self.predictor.reset_image()
self.embedded = False
def predict(self, prompts, mode, multimask=True):
"""
image: numpy array, h, w, 3
prompts: dictionary, 3 keys: 'point_coords', 'point_labels', 'mask_input'
prompts['point_coords']: numpy array [N,2]
prompts['point_labels']: numpy array [1,N]
prompts['mask_input']: numpy array [1,256,256]
mode: 'point' (points only), 'mask' (mask only), 'both' (consider both)
mask_outputs: True (return 3 masks), False (return 1 mask only)
whem mask_outputs=True, mask_input=logits[np.argmax(scores), :, :][None, :, :]
"""
assert self.embedded, 'prediction is called before set_image (feature embedding).'
assert mode in ['point', 'mask', 'both'], 'mode must be point, mask, or both'
with torch.autocast(device_type='cuda', dtype=torch.float16):
if mode == 'point':
masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'],
point_labels=prompts['point_labels'],
multimask_output=multimask)
elif mode == 'mask':
masks, scores, logits = self.predictor.predict(mask_input=prompts['mask_input'],
multimask_output=multimask)
elif mode == 'both': # both
masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'],
point_labels=prompts['point_labels'],
mask_input=prompts['mask_input'],
multimask_output=multimask)
else:
raise("Not implement now!")
# masks (n, h, w), scores (n,), logits (n, 256, 256)
return masks, scores, logits
if __name__ == "__main__":
# load and show an image
image = cv2.imread('/hhd3/gaoshang/truck.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # numpy array (h, w, 3)
# initialise BaseSegmenter
SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
model_type = 'vit_h'
device = "cuda:4"
base_segmenter = BaseSegmenter(SAM_checkpoint=SAM_checkpoint, model_type=model_type, device=device)
# image embedding (once embedded, multiple prompts can be applied)
base_segmenter.set_image(image)
# examples
# point only ------------------------
mode = 'point'
prompts = {
'point_coords': np.array([[500, 375], [1125, 625]]),
'point_labels': np.array([1, 1]),
}
masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=False) # masks (n, h, w), scores (n,), logits (n, 256, 256)
painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image)
# both ------------------------
mode = 'both'
mask_input = logits[np.argmax(scores), :, :]
prompts = {'mask_input': mask_input [None, :, :]}
prompts = {
'point_coords': np.array([[500, 375], [1125, 625]]),
'point_labels': np.array([1, 0]),
'mask_input': mask_input[None, :, :]
}
masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
cv2.imwrite('/hhd3/gaoshang/truck_both.jpg', painted_image)
# mask only ------------------------
mode = 'mask'
mask_input = logits[np.argmax(scores), :, :]
prompts = {'mask_input': mask_input[None, :, :]}
masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
cv2.imwrite('/hhd3/gaoshang/truck_mask.jpg', painted_image)

View File

@ -0,0 +1,109 @@
import math
import os
import requests
from torch.hub import download_url_to_file, get_dir
from tqdm import tqdm
from urllib.parse import urlparse
def sizeof_fmt(size, suffix='B'):
"""Get human readable file size.
Args:
size (int): File size.
suffix (str): Suffix. Default: 'B'.
Return:
str: Formated file siz.
"""
for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
if abs(size) < 1024.0:
return f'{size:3.1f} {unit}{suffix}'
size /= 1024.0
return f'{size:3.1f} Y{suffix}'
def download_file_from_google_drive(file_id, save_path):
"""Download files from google drive.
Ref:
https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
Args:
file_id (str): File id.
save_path (str): Save path.
"""
session = requests.Session()
URL = 'https://docs.google.com/uc?export=download'
params = {'id': file_id}
response = session.get(URL, params=params, stream=True)
token = get_confirm_token(response)
if token:
params['confirm'] = token
response = session.get(URL, params=params, stream=True)
# get file size
response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
print(response_file_size)
if 'Content-Range' in response_file_size.headers:
file_size = int(response_file_size.headers['Content-Range'].split('/')[1])
else:
file_size = None
save_response_content(response, save_path, file_size)
def get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None
def save_response_content(response, destination, file_size=None, chunk_size=32768):
if file_size is not None:
pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
readable_file_size = sizeof_fmt(file_size)
else:
pbar = None
with open(destination, 'wb') as f:
downloaded_size = 0
for chunk in response.iter_content(chunk_size):
downloaded_size += chunk_size
if pbar is not None:
pbar.update(1)
pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}')
if chunk: # filter out keep-alive new chunks
f.write(chunk)
if pbar is not None:
pbar.close()
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
"""Load file form http url, will download models if necessary.
Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
Args:
url (str): URL to be downloaded.
model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
Default: None.
progress (bool): Whether to show the download progress. Default: True.
file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
Returns:
str: The path to the downloaded file.
"""
if model_dir is None: # use the pytorch hub_dir
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(model_dir, exist_ok=True)
parts = urlparse(url)
filename = os.path.basename(parts.path)
if file_name is not None:
filename = file_name
cached_file = os.path.abspath(os.path.join(model_dir, filename))
if not os.path.exists(cached_file):
print(f'Downloading: "{url}" to {cached_file}\n')
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
return cached_file

View File

@ -0,0 +1,99 @@
import time
import torch
import cv2
from PIL import Image, ImageDraw, ImageOps
import numpy as np
from typing import Union
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
import matplotlib.pyplot as plt
import PIL
from .mask_painter import mask_painter as mask_painter2
from .base_segmenter import BaseSegmenter
from .painter import mask_painter, point_painter
import os
import requests
import sys
mask_color = 3
mask_alpha = 0.7
contour_color = 1
contour_width = 5
point_color_ne = 8
point_color_ps = 50
point_alpha = 0.9
point_radius = 15
contour_color = 2
contour_width = 5
class SamControler():
def __init__(self, SAM_checkpoint, model_type, device):
'''
initialize sam controler
'''
self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
# def seg_again(self, image: np.ndarray):
# '''
# it is used when interact in video
# '''
# self.sam_controler.reset_image()
# self.sam_controler.set_image(image)
# return
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True,mask_color=3):
'''
it is used in first frame in video
return: mask, logit, painted image(mask+point)
'''
# self.sam_controler.set_image(image)
origal_image = self.sam_controler.orignal_image
neg_flag = labels[-1]
if neg_flag==1:
#find neg
prompts = {
'point_coords': points,
'point_labels': labels,
}
masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
prompts = {
'point_coords': points,
'point_labels': labels,
'mask_input': logit[None, :, :]
}
masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask)
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
else:
#find positive
prompts = {
'point_coords': points,
'point_labels': labels,
}
masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
assert len(points)==len(labels)
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
painted_image = Image.fromarray(painted_image)
return mask, logit, painted_image

View File

@ -0,0 +1,288 @@
import cv2
import torch
import numpy as np
from PIL import Image
import copy
import time
def colormap(rgb=True):
color_list = np.array(
[
0.000, 0.000, 0.000,
1.000, 1.000, 1.000,
1.000, 0.498, 0.313,
0.392, 0.581, 0.929,
0.000, 0.447, 0.741,
0.850, 0.325, 0.098,
0.929, 0.694, 0.125,
0.494, 0.184, 0.556,
0.466, 0.674, 0.188,
0.301, 0.745, 0.933,
0.635, 0.078, 0.184,
0.300, 0.300, 0.300,
0.600, 0.600, 0.600,
1.000, 0.000, 0.000,
1.000, 0.500, 0.000,
0.749, 0.749, 0.000,
0.000, 1.000, 0.000,
0.000, 0.000, 1.000,
0.667, 0.000, 1.000,
0.333, 0.333, 0.000,
0.333, 0.667, 0.000,
0.333, 1.000, 0.000,
0.667, 0.333, 0.000,
0.667, 0.667, 0.000,
0.667, 1.000, 0.000,
1.000, 0.333, 0.000,
1.000, 0.667, 0.000,
1.000, 1.000, 0.000,
0.000, 0.333, 0.500,
0.000, 0.667, 0.500,
0.000, 1.000, 0.500,
0.333, 0.000, 0.500,
0.333, 0.333, 0.500,
0.333, 0.667, 0.500,
0.333, 1.000, 0.500,
0.667, 0.000, 0.500,
0.667, 0.333, 0.500,
0.667, 0.667, 0.500,
0.667, 1.000, 0.500,
1.000, 0.000, 0.500,
1.000, 0.333, 0.500,
1.000, 0.667, 0.500,
1.000, 1.000, 0.500,
0.000, 0.333, 1.000,
0.000, 0.667, 1.000,
0.000, 1.000, 1.000,
0.333, 0.000, 1.000,
0.333, 0.333, 1.000,
0.333, 0.667, 1.000,
0.333, 1.000, 1.000,
0.667, 0.000, 1.000,
0.667, 0.333, 1.000,
0.667, 0.667, 1.000,
0.667, 1.000, 1.000,
1.000, 0.000, 1.000,
1.000, 0.333, 1.000,
1.000, 0.667, 1.000,
0.167, 0.000, 0.000,
0.333, 0.000, 0.000,
0.500, 0.000, 0.000,
0.667, 0.000, 0.000,
0.833, 0.000, 0.000,
1.000, 0.000, 0.000,
0.000, 0.167, 0.000,
0.000, 0.333, 0.000,
0.000, 0.500, 0.000,
0.000, 0.667, 0.000,
0.000, 0.833, 0.000,
0.000, 1.000, 0.000,
0.000, 0.000, 0.167,
0.000, 0.000, 0.333,
0.000, 0.000, 0.500,
0.000, 0.000, 0.667,
0.000, 0.000, 0.833,
0.000, 0.000, 1.000,
0.143, 0.143, 0.143,
0.286, 0.286, 0.286,
0.429, 0.429, 0.429,
0.571, 0.571, 0.571,
0.714, 0.714, 0.714,
0.857, 0.857, 0.857
]
).astype(np.float32)
color_list = color_list.reshape((-1, 3)) * 255
if not rgb:
color_list = color_list[:, ::-1]
return color_list
color_list = colormap()
color_list = color_list.astype('uint8').tolist()
def vis_add_mask(image, background_mask, contour_mask, background_color, contour_color, background_alpha, contour_alpha):
background_color = np.array(background_color)
contour_color = np.array(contour_color)
# background_mask = 1 - background_mask
# contour_mask = 1 - contour_mask
for i in range(3):
image[:, :, i] = image[:, :, i] * (1-background_alpha+background_mask*background_alpha) \
+ background_color[i] * (background_alpha-background_mask*background_alpha)
image[:, :, i] = image[:, :, i] * (1-contour_alpha+contour_mask*contour_alpha) \
+ contour_color[i] * (contour_alpha-contour_mask*contour_alpha)
return image.astype('uint8')
def mask_generator_00(mask, background_radius, contour_radius):
# no background width when '00'
# distance map
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
dist_map = dist_transform_fore - dist_transform_back
# ...:::!!!:::...
contour_radius += 2
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
contour_mask = contour_mask / np.max(contour_mask)
contour_mask[contour_mask>0.5] = 1.
return mask, contour_mask
def mask_generator_01(mask, background_radius, contour_radius):
# no background width when '00'
# distance map
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
dist_map = dist_transform_fore - dist_transform_back
# ...:::!!!:::...
contour_radius += 2
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
contour_mask = contour_mask / np.max(contour_mask)
return mask, contour_mask
def mask_generator_10(mask, background_radius, contour_radius):
# distance map
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
dist_map = dist_transform_fore - dist_transform_back
# .....:::::!!!!!
background_mask = np.clip(dist_map, -background_radius, background_radius)
background_mask = (background_mask - np.min(background_mask))
background_mask = background_mask / np.max(background_mask)
# ...:::!!!:::...
contour_radius += 2
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
contour_mask = contour_mask / np.max(contour_mask)
contour_mask[contour_mask>0.5] = 1.
return background_mask, contour_mask
def mask_generator_11(mask, background_radius, contour_radius):
# distance map
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
dist_map = dist_transform_fore - dist_transform_back
# .....:::::!!!!!
background_mask = np.clip(dist_map, -background_radius, background_radius)
background_mask = (background_mask - np.min(background_mask))
background_mask = background_mask / np.max(background_mask)
# ...:::!!!:::...
contour_radius += 2
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
contour_mask = contour_mask / np.max(contour_mask)
return background_mask, contour_mask
def mask_painter(input_image, input_mask, background_alpha=0.5, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1, mode='11'):
"""
Input:
input_image: numpy array
input_mask: numpy array
background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
background_blur_radius: radius of background blur, must be odd number
contour_width: width of mask contour, must be odd number
contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
mode: painting mode, '00', no blur, '01' only blur contour, '10' only blur background, '11' blur both
Output:
painted_image: numpy array
"""
assert input_image.shape[:2] == input_mask.shape, 'different shape'
assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
assert mode in ['00', '01', '10', '11'], 'mode should be 00, 01, 10, or 11'
# downsample input image and mask
width, height = input_image.shape[0], input_image.shape[1]
res = 1024
ratio = min(1.0 * res / max(width, height), 1.0)
input_image = cv2.resize(input_image, (int(height*ratio), int(width*ratio)))
input_mask = cv2.resize(input_mask, (int(height*ratio), int(width*ratio)))
# 0: background, 1: foreground
msk = np.clip(input_mask, 0, 1)
# generate masks for background and contour pixels
background_radius = (background_blur_radius - 1) // 2
contour_radius = (contour_width - 1) // 2
generator_dict = {'00':mask_generator_00, '01':mask_generator_01, '10':mask_generator_10, '11':mask_generator_11}
background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius)
# paint
painted_image = vis_add_mask\
(input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha, contour_alpha) # black for background
return painted_image
if __name__ == '__main__':
background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing
background_blur_radius = 31 # radius of background blur, must be odd number
contour_width = 11 # contour width, must be odd number
contour_color = 3 # id in color map, 0: black, 1: white, >1: others
contour_alpha = 1 # transparency of background, 0: no contour highlighted
# load input image and mask
input_image = np.array(Image.open('./test_img/painter_input_image.jpg').convert('RGB'))
input_mask = np.array(Image.open('./test_img/painter_input_mask.jpg').convert('P'))
# paint
overall_time_1 = 0
overall_time_2 = 0
overall_time_3 = 0
overall_time_4 = 0
overall_time_5 = 0
for i in range(50):
t2 = time.time()
painted_image_00 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='00')
e2 = time.time()
t3 = time.time()
painted_image_10 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='10')
e3 = time.time()
t1 = time.time()
painted_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha)
e1 = time.time()
t4 = time.time()
painted_image_01 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='01')
e4 = time.time()
t5 = time.time()
painted_image_11 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='11')
e5 = time.time()
overall_time_1 += (e1 - t1)
overall_time_2 += (e2 - t2)
overall_time_3 += (e3 - t3)
overall_time_4 += (e4 - t4)
overall_time_5 += (e5 - t5)
print(f'average time w gaussian: {overall_time_1/50}')
print(f'average time w/o gaussian00: {overall_time_2/50}')
print(f'average time w/o gaussian10: {overall_time_3/50}')
print(f'average time w/o gaussian01: {overall_time_4/50}')
print(f'average time w/o gaussian11: {overall_time_5/50}')
# save
painted_image_00 = Image.fromarray(painted_image_00)
painted_image_00.save('./test_img/painter_output_image_00.png')
painted_image_10 = Image.fromarray(painted_image_10)
painted_image_10.save('./test_img/painter_output_image_10.png')
painted_image_01 = Image.fromarray(painted_image_01)
painted_image_01.save('./test_img/painter_output_image_01.png')
painted_image_11 = Image.fromarray(painted_image_11)
painted_image_11.save('./test_img/painter_output_image_11.png')

View File

@ -0,0 +1,131 @@
import os
import re
import random
import time
import torch
import torch.nn as nn
import logging
import numpy as np
from os import path as osp
def constant_init(module, val, bias=0):
if hasattr(module, 'weight') and module.weight is not None:
nn.init.constant_(module.weight, val)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
initialized_logger = {}
def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
"""Get the root logger.
The logger will be initialized if it has not been initialized. By default a
StreamHandler will be added. If `log_file` is specified, a FileHandler will
also be added.
Args:
logger_name (str): root logger name. Default: 'basicsr'.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the root logger.
log_level (int): The root logger level. Note that only the process of
rank 0 is affected, while other processes will set the level to
"Error" and be silent most of the time.
Returns:
logging.Logger: The root logger.
"""
logger = logging.getLogger(logger_name)
# if the logger has been initialized, just return it
if logger_name in initialized_logger:
return logger
format_str = '%(asctime)s %(levelname)s: %(message)s'
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter(format_str))
logger.addHandler(stream_handler)
logger.propagate = False
if log_file is not None:
logger.setLevel(log_level)
# add file handler
# file_handler = logging.FileHandler(log_file, 'w')
file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log
file_handler.setFormatter(logging.Formatter(format_str))
file_handler.setLevel(log_level)
logger.addHandler(file_handler)
initialized_logger[logger_name] = True
return logger
IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
torch.__version__)[0][:3])] >= [1, 12, 0]
def gpu_is_available():
if IS_HIGH_VERSION:
if torch.backends.mps.is_available():
return True
return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False
def get_device(gpu_id=None):
if gpu_id is None:
gpu_str = ''
elif isinstance(gpu_id, int):
gpu_str = f':{gpu_id}'
else:
raise TypeError('Input should be int value.')
if IS_HIGH_VERSION:
if torch.backends.mps.is_available():
return torch.device('mps'+gpu_str)
return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu')
def set_random_seed(seed):
"""Set random seeds."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def get_time_str():
return time.strftime('%Y%m%d_%H%M%S', time.localtime())
def scandir(dir_path, suffix=None, recursive=False, full_path=False):
"""Scan a directory to find the interested files.
Args:
dir_path (str): Path of the directory.
suffix (str | tuple(str), optional): File suffix that we are
interested in. Default: None.
recursive (bool, optional): If set to True, recursively scan the
directory. Default: False.
full_path (bool, optional): If set to True, include the dir_path.
Default: False.
Returns:
A generator for all the interested files with relative pathes.
"""
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
raise TypeError('"suffix" must be a string or tuple of strings')
root = dir_path
def _scandir(dir_path, suffix, recursive):
for entry in os.scandir(dir_path):
if not entry.name.startswith('.') and entry.is_file():
if full_path:
return_path = entry.path
else:
return_path = osp.relpath(entry.path, root)
if suffix is None:
yield return_path
elif return_path.endswith(suffix):
yield return_path
else:
if recursive:
yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
else:
continue
return _scandir(dir_path, suffix=suffix, recursive=recursive)

View File

@ -0,0 +1,215 @@
# paint masks, contours, or points on images, with specified colors
import cv2
import torch
import numpy as np
from PIL import Image
import copy
import time
def colormap(rgb=True):
color_list = np.array(
[
0.000, 0.000, 0.000,
1.000, 1.000, 1.000,
1.000, 0.498, 0.313,
0.392, 0.581, 0.929,
0.000, 0.447, 0.741,
0.850, 0.325, 0.098,
0.929, 0.694, 0.125,
0.494, 0.184, 0.556,
0.466, 0.674, 0.188,
0.301, 0.745, 0.933,
0.635, 0.078, 0.184,
0.300, 0.300, 0.300,
0.600, 0.600, 0.600,
1.000, 0.000, 0.000,
1.000, 0.500, 0.000,
0.749, 0.749, 0.000,
0.000, 1.000, 0.000,
0.000, 0.000, 1.000,
0.667, 0.000, 1.000,
0.333, 0.333, 0.000,
0.333, 0.667, 0.000,
0.333, 1.000, 0.000,
0.667, 0.333, 0.000,
0.667, 0.667, 0.000,
0.667, 1.000, 0.000,
1.000, 0.333, 0.000,
1.000, 0.667, 0.000,
1.000, 1.000, 0.000,
0.000, 0.333, 0.500,
0.000, 0.667, 0.500,
0.000, 1.000, 0.500,
0.333, 0.000, 0.500,
0.333, 0.333, 0.500,
0.333, 0.667, 0.500,
0.333, 1.000, 0.500,
0.667, 0.000, 0.500,
0.667, 0.333, 0.500,
0.667, 0.667, 0.500,
0.667, 1.000, 0.500,
1.000, 0.000, 0.500,
1.000, 0.333, 0.500,
1.000, 0.667, 0.500,
1.000, 1.000, 0.500,
0.000, 0.333, 1.000,
0.000, 0.667, 1.000,
0.000, 1.000, 1.000,
0.333, 0.000, 1.000,
0.333, 0.333, 1.000,
0.333, 0.667, 1.000,
0.333, 1.000, 1.000,
0.667, 0.000, 1.000,
0.667, 0.333, 1.000,
0.667, 0.667, 1.000,
0.667, 1.000, 1.000,
1.000, 0.000, 1.000,
1.000, 0.333, 1.000,
1.000, 0.667, 1.000,
0.167, 0.000, 0.000,
0.333, 0.000, 0.000,
0.500, 0.000, 0.000,
0.667, 0.000, 0.000,
0.833, 0.000, 0.000,
1.000, 0.000, 0.000,
0.000, 0.167, 0.000,
0.000, 0.333, 0.000,
0.000, 0.500, 0.000,
0.000, 0.667, 0.000,
0.000, 0.833, 0.000,
0.000, 1.000, 0.000,
0.000, 0.000, 0.167,
0.000, 0.000, 0.333,
0.000, 0.000, 0.500,
0.000, 0.000, 0.667,
0.000, 0.000, 0.833,
0.000, 0.000, 1.000,
0.143, 0.143, 0.143,
0.286, 0.286, 0.286,
0.429, 0.429, 0.429,
0.571, 0.571, 0.571,
0.714, 0.714, 0.714,
0.857, 0.857, 0.857
]
).astype(np.float32)
color_list = color_list.reshape((-1, 3)) * 255
if not rgb:
color_list = color_list[:, ::-1]
return color_list
color_list = colormap()
color_list = color_list.astype('uint8').tolist()
def vis_add_mask(image, mask, color, alpha):
color = np.array(color_list[color])
mask = mask > 0.5
image[mask] = image[mask] * (1-alpha) + color * alpha
return image.astype('uint8')
def point_painter(input_image, input_points, point_color=5, point_alpha=0.9, point_radius=15, contour_color=2, contour_width=5):
h, w = input_image.shape[:2]
point_mask = np.zeros((h, w)).astype('uint8')
for point in input_points:
point_mask[point[1], point[0]] = 1
kernel = cv2.getStructuringElement(2, (point_radius, point_radius))
point_mask = cv2.dilate(point_mask, kernel)
contour_radius = (contour_width - 1) // 2
dist_transform_fore = cv2.distanceTransform(point_mask, cv2.DIST_L2, 3)
dist_transform_back = cv2.distanceTransform(1-point_mask, cv2.DIST_L2, 3)
dist_map = dist_transform_fore - dist_transform_back
# ...:::!!!:::...
contour_radius += 2
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
contour_mask = contour_mask / np.max(contour_mask)
contour_mask[contour_mask>0.5] = 1.
# paint mask
painted_image = vis_add_mask(input_image.copy(), point_mask, point_color, point_alpha)
# paint contour
painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1)
return painted_image
def mask_painter(input_image, input_mask, mask_color=5, mask_alpha=0.7, contour_color=1, contour_width=3):
assert input_image.shape[:2] == input_mask.shape, 'different shape between image and mask'
# 0: background, 1: foreground
mask = np.clip(input_mask, 0, 1)
contour_radius = (contour_width - 1) // 2
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
dist_map = dist_transform_fore - dist_transform_back
# ...:::!!!:::...
contour_radius += 2
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
contour_mask = contour_mask / np.max(contour_mask)
contour_mask[contour_mask>0.5] = 1.
# paint mask
painted_image = vis_add_mask(input_image.copy(), mask.copy(), mask_color, mask_alpha)
# paint contour
painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1)
return painted_image
def background_remover(input_image, input_mask):
"""
input_image: H, W, 3, np.array
input_mask: H, W, np.array
image_wo_background: PIL.Image
"""
assert input_image.shape[:2] == input_mask.shape, 'different shape between image and mask'
# 0: background, 1: foreground
mask = np.expand_dims(np.clip(input_mask, 0, 1), axis=2)*255
image_wo_background = np.concatenate([input_image, mask], axis=2) # H, W, 4
image_wo_background = Image.fromarray(image_wo_background).convert('RGBA')
return image_wo_background
if __name__ == '__main__':
input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
input_mask = np.array(Image.open('images/painter_input_mask.jpg').convert('P'))
# example of mask painter
mask_color = 3
mask_alpha = 0.7
contour_color = 1
contour_width = 5
# save
painted_image = Image.fromarray(input_image)
painted_image.save('images/original.png')
painted_image = mask_painter(input_image, input_mask, mask_color, mask_alpha, contour_color, contour_width)
# save
painted_image = Image.fromarray(input_image)
painted_image.save('images/original1.png')
# example of point painter
input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
input_points = np.array([[500, 375], [70, 600]]) # x, y
point_color = 5
point_alpha = 0.9
point_radius = 15
contour_color = 2
contour_width = 5
painted_image_1 = point_painter(input_image, input_points, point_color, point_alpha, point_radius, contour_color, contour_width)
# save
painted_image = Image.fromarray(painted_image_1)
painted_image.save('images/point_painter_1.png')
input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
painted_image_2 = point_painter(input_image, input_points, point_color=9, point_radius=20, contour_color=29)
# save
painted_image = Image.fromarray(painted_image_2)
painted_image.save('images/point_painter_2.png')
# example of background remover
input_image = np.array(Image.open('images/original.png').convert('RGB'))
image_wo_background = background_remover(input_image, input_mask) # return PIL.Image
image_wo_background.save('images/image_wo_background.png')

View File

@ -0,0 +1,27 @@
"""
A helper function to get a default model for quick testing
"""
from omegaconf import open_dict
from hydra import compose, initialize
import torch
from ..matanyone.model.matanyone import MatAnyone
def get_matanyone_model(ckpt_path, device=None) -> MatAnyone:
initialize(version_base='1.3.2', config_path="../config", job_name="eval_our_config")
cfg = compose(config_name="eval_matanyone_config")
with open_dict(cfg):
cfg['weights'] = ckpt_path
# Load the network weights
if device is not None:
matanyone = MatAnyone(cfg, single_object=True).to(device).eval()
model_weights = torch.load(cfg.weights, map_location=device)
else: # if device is not specified, `.cuda()` by default
matanyone = MatAnyone(cfg, single_object=True).cuda().eval()
model_weights = torch.load(cfg.weights)
matanyone.load_weights(model_weights)
return matanyone

View File

@ -0,0 +1,62 @@
from typing import List, Iterable
import torch
import torch.nn.functional as F
# STM
def pad_divide_by(in_img: torch.Tensor, d: int) -> (torch.Tensor, Iterable[int]):
h, w = in_img.shape[-2:]
if h % d > 0:
new_h = h + d - h % d
else:
new_h = h
if w % d > 0:
new_w = w + d - w % d
else:
new_w = w
lh, uh = int((new_h - h) / 2), int(new_h - h) - int((new_h - h) / 2)
lw, uw = int((new_w - w) / 2), int(new_w - w) - int((new_w - w) / 2)
pad_array = (int(lw), int(uw), int(lh), int(uh))
out = F.pad(in_img, pad_array)
return out, pad_array
def unpad(img: torch.Tensor, pad: Iterable[int]) -> torch.Tensor:
if len(img.shape) == 4:
if pad[2] + pad[3] > 0:
img = img[:, :, pad[2]:-pad[3], :]
if pad[0] + pad[1] > 0:
img = img[:, :, :, pad[0]:-pad[1]]
elif len(img.shape) == 3:
if pad[2] + pad[3] > 0:
img = img[:, pad[2]:-pad[3], :]
if pad[0] + pad[1] > 0:
img = img[:, :, pad[0]:-pad[1]]
elif len(img.shape) == 5:
if pad[2] + pad[3] > 0:
img = img[:, :, :, pad[2]:-pad[3], :]
if pad[0] + pad[1] > 0:
img = img[:, :, :, :, pad[0]:-pad[1]]
else:
raise NotImplementedError
return img
# @torch.jit.script
def aggregate(prob: torch.Tensor, dim: int) -> torch.Tensor:
with torch.amp.autocast("cuda"):
prob = prob.float()
new_prob = torch.cat([torch.prod(1 - prob, dim=dim, keepdim=True), prob],
dim).clamp(1e-7, 1 - 1e-7)
logits = torch.log((new_prob / (1 - new_prob))) # (0, 1) --> (-inf, inf)
return logits
# @torch.jit.script
def cls_to_one_hot(cls_gt: torch.Tensor, num_objects: int) -> torch.Tensor:
# cls_gt: B*1*H*W
B, _, H, W = cls_gt.shape
one_hot = torch.zeros(B, num_objects + 1, H, W, device=cls_gt.device).scatter_(1, cls_gt, 1)
return one_hot

View File

@ -24,4 +24,8 @@ onnxruntime-gpu
rembg[gpu]==2.0.65
matplotlib
timm
segment-anything
ffmpeg-python
omegaconf
hydra-core
# rembg==2.0.65

View File

@ -482,7 +482,6 @@ class WanAttentionBlock(nn.Module):
y *= 1 + e[4]
y += e[3]
ffn = self.ffn[0]
gelu = self.ffn[1]
ffn2= self.ffn[2]
@ -500,8 +499,6 @@ class WanAttentionBlock(nn.Module):
x.addcmul_(y, e[5])
if hint is not None:
if context_scale == 1:
x.add_(hint)
@ -539,24 +536,13 @@ class VaceWanAttentionBlock(WanAttentionBlock):
c = hints[0]
hints[0] = None
if self.block_id == 0:
c = self.before_proj(c) + x
c = self.before_proj(c)
c += x
c = super().forward(c, **kwargs)
c_skip = self.after_proj(c)
hints[0] = c
return c_skip
# def forward(self, c, x, **kwargs):
# # behold dbm magic !
# if self.block_id == 0:
# c = self.before_proj(c) + x
# all_c = []
# else:
# all_c = c
# c = all_c.pop(-1)
# c = super().forward(c, **kwargs)
# c_skip = self.after_proj(c)
# all_c += [c_skip, c]
# return all_c
class Head(nn.Module):
@ -794,37 +780,6 @@ class WanModel(ModelMixin, ConfigMixin):
return best_threshold
# def forward_vace(
# self,
# x,
# vace_context,
# seq_len,
# context,
# e,
# kwargs
# ):
# # embeddings
# c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
# c = [u.flatten(2).transpose(1, 2) for u in c]
# if (len(c) == 1 and seq_len == c[0].size(1)):
# c = c[0]
# else:
# c = torch.cat([
# torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
# dim=1) for u in c
# ])
# # arguments
# new_kwargs = dict(x=x)
# new_kwargs.update(kwargs)
# for block in self.vace_blocks:
# c = block(c, context= context, e= e, **new_kwargs)
# hints = c[:-1]
# return hints
def forward(
self,
x,

View File

@ -209,34 +209,52 @@ class WanT2V:
def vace_latent(self, z, m):
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device, original_video = False, trim_video= 0):
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, original_video = False, keep_frames= [], start_frame = 0, pre_src_video = None):
image_sizes = []
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
trim_video = len(keep_frames)
for i, (sub_src_video, sub_src_mask, sub_pre_src_video) in enumerate(zip(src_video, src_mask,pre_src_video)):
prepend_count = 0 if sub_pre_src_video == None else sub_pre_src_video.shape[1]
num_frames = total_frames - prepend_count
if sub_src_mask is not None and sub_src_video is not None:
src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video)
src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame)
# src_video is [-1, 1], 0 = inpainting area (in fact 127 in [0, 255])
# src_mask is [-1, 1], 0 = preserve original video (in fact 127 in [0, 255]) and 1 = Inpainting (in fact 255 in [0, 255])
src_video[i] = src_video[i].to(device)
src_mask[i] = src_mask[i].to(device)
if prepend_count > 0:
src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1)
src_video_shape = src_video[i].shape
if src_video_shape[1] != num_frames:
src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
if src_video_shape[1] != total_frames:
src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
image_sizes.append(src_video[i].shape[2:])
elif sub_src_video is None:
src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
src_mask[i] = torch.ones_like(src_video[i], device=device)
if prepend_count > 0:
src_video[i] = torch.cat( [sub_pre_src_video, torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)], dim=1)
src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), torch.ones((3, num_frames, image_size[0], image_size[1]), device=device)] ,1)
else:
src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
src_mask[i] = torch.ones_like(src_video[i], device=device)
image_sizes.append(image_size)
else:
src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video)
src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame)
src_video[i] = src_video[i].to(device)
src_mask[i] = torch.zeros_like(src_video[i], device=device) if original_video else torch.ones_like(src_video[i], device=device)
if prepend_count > 0:
src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1)
src_video_shape = src_video[i].shape
if src_video_shape[1] != num_frames:
src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
if src_video_shape[1] != total_frames:
src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
image_sizes.append(src_video[i].shape[2:])
for k, keep in enumerate(keep_frames):
if not keep:
src_video[i][:, k:k+1] = 0
src_mask[i][:, k:k+1] = 1
for i, ref_images in enumerate(src_ref_images):
if ref_images is not None:

View File

@ -22,27 +22,27 @@ __all__ = ['cache_video', 'cache_image', 'str2bool']
from PIL import Image
def resample(video_fps, video_frames_count, max_frames, target_fps):
def resample(video_fps, video_frames_count, max_target_frames_count, target_fps, start_target_frame ):
import math
video_frame_duration = 1 /video_fps
target_frame_duration = 1 / target_fps
cur_time = 0
target_time = 0
frame_no = 0
target_time = start_target_frame * target_frame_duration
frame_no = math.ceil(target_time / video_frame_duration)
cur_time = frame_no * video_frame_duration
frame_ids =[]
while True:
if max_frames != 0 and len(frame_ids) >= max_frames:
if max_target_frames_count != 0 and len(frame_ids) >= max_target_frames_count :
break
add_frames_count = math.ceil( (target_time -cur_time) / video_frame_duration )
frame_no += add_frames_count
if frame_no >= video_frames_count:
break
frame_ids.append(frame_no)
cur_time += add_frames_count * video_frame_duration
target_time += target_frame_duration
if frame_no >= video_frames_count -1:
break
frame_ids = frame_ids[:video_frames_count]
frame_ids = frame_ids[:max_target_frames_count]
return frame_ids
def get_video_frame(file_name, frame_no):

View File

@ -182,14 +182,14 @@ class VaceVideoProcessor(object):
def _get_frameid_bbox_adjust_last(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0):
def _get_frameid_bbox_adjust_last(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0, start_frame =0):
from wan.utils.utils import resample
target_fps = self.max_fps
# video_frames_count = len(frame_timestamps)
frame_ids= resample(fps, video_frames_count, max_frames, target_fps)
frame_ids= resample(fps, video_frames_count, max_frames, target_fps, start_frame )
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
h, w = y2 - y1, x2 - x1
@ -206,7 +206,7 @@ class VaceVideoProcessor(object):
np.log2(np.sqrt(max_area_z))
)))
seq_len = max_area_z * ((max_frames- 1) // df +1)
seq_len = max_area_z * ((max_frames- start_frame - 1) // df +1)
# of = min(
# (len(frame_ids) - 1) // df + 1,
@ -226,9 +226,9 @@ class VaceVideoProcessor(object):
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
def _get_frameid_bbox(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0):
def _get_frameid_bbox(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0, start_frame= 0):
if self.keep_last:
return self._get_frameid_bbox_adjust_last(fps, video_frames_count, h, w, crop_box, rng, max_frames= max_frames)
return self._get_frameid_bbox_adjust_last(fps, video_frames_count, h, w, crop_box, rng, max_frames= max_frames, start_frame= start_frame)
else:
return self._get_frameid_bbox_default(fps, video_frames_count, h, w, crop_box, rng, max_frames= max_frames)
@ -238,7 +238,7 @@ class VaceVideoProcessor(object):
def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs):
return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs)
def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, max_frames= 0, trim_video =0, **kwargs):
def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, max_frames= 0, trim_video =0, start_frame = 0, **kwargs):
rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000)
# read video
import decord
@ -254,7 +254,7 @@ class VaceVideoProcessor(object):
if src_video != None:
fps = 16
length = src_video.shape[1]
length = src_video.shape[0] + start_frame
if len(readers) > 0:
min_readers = min([len(r) for r in readers])
length = min(length, min_readers )
@ -269,7 +269,7 @@ class VaceVideoProcessor(object):
h, w = src_video.shape[1:3]
else:
h, w = readers[0].next().shape[:2]
frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, length, h, w, crop_box, rng, max_frames=max_frames)
frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, length, h, w, crop_box, rng, max_frames=max_frames, start_frame = start_frame )
# preprocess video
videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers]

486
wgp.py
View File

@ -144,16 +144,26 @@ def process_prompt_and_add_tasks(state, model_choice):
gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P")
return
sliding_window_repeat = inputs["sliding_window_repeat"]
sliding_window = sliding_window_repeat > 0
if "Vace" in model_filename:
video_prompt_type = inputs["video_prompt_type"]
image_refs = inputs["image_refs"]
video_guide = inputs["video_guide"]
video_mask = inputs["video_mask"]
if "Vace" in model_filename and "1.3B" in model_filename :
if sliding_window:
if inputs["repeat_generation"]!=1:
gr.Info("Only one Video generated per Prompt is supported when Sliding windows is used")
return
if inputs["sliding_window_overlap"]>=inputs["video_length"] :
gr.Info("The number of frames of the Sliding Window Overlap must be less than the Number of Frames to Generate")
return
if "1.3B" in model_filename :
resolution_reformated = str(height) + "*" + str(width)
if not resolution_reformated in VACE_SIZE_CONFIGS:
res = VACE_SIZE_CONFIGS.keys().join(" and ")
res = (" and ").join(VACE_SIZE_CONFIGS.keys())
gr.Info(f"Video Resolution for Vace model is not supported. Only {res} resolutions are allowed.")
return
if "I" in video_prompt_type:
@ -175,12 +185,19 @@ def process_prompt_and_add_tasks(state, model_choice):
else:
video_mask = None
if "O" in video_prompt_type :
max_frames= inputs["max_frames"]
keep_frames= inputs["keep_frames"]
video_length = inputs["video_length"]
if max_frames ==0:
if len(keep_frames) ==0:
gr.Info(f"Warning : you have asked to reuse all the frames of the control Video in the Alternate Video Ending it. Please make sure the number of frames of the control Video is lower than the total number of frames to generate otherwise it won't make a difference.")
elif max_frames >= video_length:
gr.Info(f"The number of frames in the control Video to reuse ({max_frames}) in Alternate Video Ending can not be bigger than the total number of frames ({video_length}) to generate.")
# elif keep_frames >= video_length:
# gr.Info(f"The number of frames in the control Video to reuse ({keep_frames}) in Alternate Video Ending can not be bigger than the total number of frames ({video_length}) to generate.")
# return
elif "V" in video_prompt_type:
keep_frames= inputs["keep_frames"]
video_length = inputs["video_length"]
_, error = parse_keep_frames(keep_frames, video_length)
if len(error) > 0:
gr.Info(f"Invalid Keep Frames property: {error}")
return
if isinstance(image_refs, list):
@ -190,6 +207,9 @@ def process_prompt_and_add_tasks(state, model_choice):
image_refs = resize_and_remove_background(image_refs, width, height, inputs["remove_background_image_ref"] ==1)
if sliding_window and len(prompts) > 0:
prompts = ["\n".join(prompts)]
for single_prompt in prompts:
extra_inputs = {
"prompt" : single_prompt,
@ -1545,8 +1565,8 @@ def download_models(transformer_filename, text_encoder_filename):
from huggingface_hub import hf_hub_download, snapshot_download
repoId = "DeepBeepMeep/Wan2.1"
sourceFolderList = ["xlm-roberta-large", "pose", "depth", "", ]
fileList = [ [], [],[], ["Wan2.1_VAE_bf16.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "flownet.pkl" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ]
sourceFolderList = ["xlm-roberta-large", "pose", "depth", "mask", "", ]
fileList = [ [], [],[], ["sam_vit_h_4b8939_fp16.safetensors"], ["Wan2.1_VAE_bf16.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "flownet.pkl" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ]
targetRoot = "ckpts/"
for sourceFolder, files in zip(sourceFolderList,fileList ):
if len(files)==0:
@ -1787,25 +1807,6 @@ def get_model_name(model_filename):
return model_name
# def generate_header(model_filename, compile, attention_mode):
# header = "<div class='title-with-lines'><div class=line></div><h2>"
# model_name = get_model_name(model_filename)
# header += model_name
# header += " (attention mode: " + (attention_mode if attention_mode!="auto" else "auto/" + get_auto_attention() )
# if attention_mode not in attention_modes_installed:
# header += " -NOT INSTALLED-"
# elif attention_mode not in attention_modes_supported:
# header += " -NOT SUPPORTED-"
# if compile:
# header += ", pytorch compilation ON"
# header += ") </h2><div class=line></div> "
# return header
def generate_header(model_filename, compile, attention_mode):
@ -2070,7 +2071,7 @@ def convert_image(image):
return cast(Image, ImageOps.exif_transpose(image))
def preprocess_video(process_type, height, width, video_in, max_frames):
def preprocess_video(process_type, height, width, video_in, max_frames, start_frame=0):
from wan.utils.utils import resample
@ -2080,8 +2081,10 @@ def preprocess_video(process_type, height, width, video_in, max_frames):
fps = reader.get_avg_fps()
frame_nos = resample(fps, len(reader), max_frames= max_frames, target_fps=16)
frame_nos = resample(fps, len(reader), max_target_frames_count= max_frames, target_fps=16, start_target_frame= start_frame)
frames_list = reader.get_batch(frame_nos)
if len(frames_list) == 0:
return None
frame_height, frame_width, _ = frames_list[0].shape
scale = ((height * width ) / (frame_height * frame_width))**(1/2)
@ -2127,6 +2130,57 @@ def preprocess_video(process_type, height, width, video_in, max_frames):
return torch.stack(torch_frames)
def parse_keep_frames(keep_frames, video_length):
def is_integer(n):
try:
float(n)
except ValueError:
return False
else:
return float(n).is_integer()
def absolute(n):
if n==0:
return 0
elif n < 0:
return max(0, video_length + n)
else:
return min(n-1, video_length-1)
if len(keep_frames) == 0:
return [True] *video_length, ""
frames =[False] *video_length
error = ""
sections = keep_frames.split(" ")
for section in sections:
section = section.strip()
if ":" in section:
parts = section.split(":")
if not is_integer(parts[0]):
error =f"Invalid integer {parts[0]}"
break
start_range = absolute(int(parts[0]))
if not is_integer(parts[1]):
error =f"Invalid integer {parts[1]}"
break
end_range = absolute(int(parts[1]))
for i in range(start_range, end_range + 1):
frames[i] = True
else:
if not is_integer(section):
error =f"Invalid integer {section}"
break
index = absolute(int(section))
frames[index] = True
if len(error ) > 0:
return [], error
for i in range(len(frames)-1, 0, -1):
if frames[i]:
break
frames= frames[0: i+1]
return frames, error
def generate_video(
task_id,
progress,
@ -2152,7 +2206,10 @@ def generate_video(
image_refs,
video_guide,
video_mask,
max_frames,
keep_frames,
sliding_window_repeat,
sliding_window_overlap,
sliding_window_discard_last_frames,
remove_background_image_ref,
temporal_upsampling,
spatial_upsampling,
@ -2308,37 +2365,6 @@ def generate_video(
else:
raise gr.Error("Teacache not supported for this model")
if "Vace" in model_filename:
# video_prompt_type = video_prompt_type +"G"
if any(process in video_prompt_type for process in ("P", "D", "G")) :
prompts_max = gen["prompts_max"]
status = get_generation_status(prompt_no, prompts_max, 1, 1)
preprocess_type = None
if "P" in video_prompt_type :
progress_args = [0, status + " - Extracting Open Pose Information"]
preprocess_type = "pose"
elif "D" in video_prompt_type :
progress_args = [0, status + " - Extracting Depth Information"]
preprocess_type = "depth"
elif "G" in video_prompt_type :
progress_args = [0, status + " - Extracting Gray Level Information"]
preprocess_type = "gray"
if preprocess_type != None :
progress(*progress_args )
gen["progress_args"] = progress_args
video_guide = preprocess_video(preprocess_type, width=width, height=height,video_in=video_guide, max_frames= video_length)
image_refs = image_refs.copy() if image_refs != None else None # required since prepare_source do inplace modifications
src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide],
[video_mask],
[image_refs],
video_length, VACE_SIZE_CONFIGS[resolution_reformated], "cpu",
original_video= "O" in video_prompt_type,
trim_video=max_frames)
else:
src_video, src_mask, src_ref_images = None, None, None
import random
if seed == None or seed <0:
@ -2355,6 +2381,21 @@ def generate_video(
gen["prompt"] = prompt
repeat_no = 0
extra_generation = 0
start_frame = 0
sliding_window = sliding_window_repeat > 0
if sliding_window:
reuse_frames = sliding_window_overlap
discard_last_frames = sliding_window_discard_last_frames #4
repeat_generation = sliding_window_repeat
prompts = prompt.split("\n")
prompts = [part for part in prompts if len(prompt)>0]
gen["sliding_window"] = sliding_window
frames_already_processed = None
pre_video_guide = None
while True:
extra_generation += gen.get("extra_orders",0)
gen["extra_orders"] = 0
@ -2362,10 +2403,59 @@ def generate_video(
gen["total_generation"] = total_generation
if abort or repeat_no >= total_generation:
break
if "Vace" in model_filename and (repeat_no == 0 or sliding_window):
if sliding_window:
prompt = prompts[repeat_no] if repeat_no < len(prompts) else prompts[-1]
# video_prompt_type = video_prompt_type +"G"
image_refs_copy = image_refs.copy() if image_refs != None else None # required since prepare_source do inplace modifications
video_guide_copy = video_guide
video_mask_copy = video_mask
if any(process in video_prompt_type for process in ("P", "D", "G")) :
prompts_max = gen["prompts_max"]
status = get_generation_status(prompt_no, prompts_max, 1, 1, sliding_window)
preprocess_type = None
if "P" in video_prompt_type :
progress_args = [0, status + " - Extracting Open Pose Information"]
preprocess_type = "pose"
elif "D" in video_prompt_type :
progress_args = [0, status + " - Extracting Depth Information"]
preprocess_type = "depth"
elif "G" in video_prompt_type :
progress_args = [0, status + " - Extracting Gray Level Information"]
preprocess_type = "gray"
if preprocess_type != None :
progress(*progress_args )
gen["progress_args"] = progress_args
video_guide_copy = preprocess_video(preprocess_type, width=width, height=height,video_in=video_guide, max_frames= video_length if repeat_no ==0 else video_length - reuse_frames, start_frame = start_frame)
keep_frames_parsed, error = parse_keep_frames(keep_frames, video_length)
if len(error) > 0:
raise gr.Error(f"invalid keep frames {keep_frames}")
if repeat_no == 0:
image_size = VACE_SIZE_CONFIGS[resolution_reformated] # default frame dimensions until it is set by video_src (if there is any)
src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_copy],
[video_mask_copy ],
[image_refs_copy],
video_length, image_size = image_size, device ="cpu",
original_video= "O" in video_prompt_type,
keep_frames=keep_frames_parsed,
start_frame = start_frame,
pre_src_video = [pre_video_guide]
)
if repeat_no == 0 and src_video != None and len(src_video) > 0:
image_size = src_video[0].shape[-2:]
else:
src_video, src_mask, src_ref_images = None, None, None
repeat_no +=1
gen["repeat_no"] = repeat_no
prompts_max = gen["prompts_max"]
status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation)
status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation, sliding_window)
yield status
@ -2501,6 +2591,15 @@ def generate_video(
# yield f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s"
else:
sample = samples.cpu()
if sliding_window :
start_frame += video_length
if discard_last_frames > 0:
sample = sample[: , :-discard_last_frames]
start_frame -= discard_last_frames
pre_video_guide = sample[:, -reuse_frames:]
if repeat_no > 1:
sample = sample[: , reuse_frames:]
start_frame -= reuse_frames
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
if os.name == 'nt':
@ -2527,7 +2626,13 @@ def generate_video(
if exp > 0:
from rife.inference import temporal_interpolation
sample = temporal_interpolation( os.path.join("ckpts", "flownet.pkl"), sample, exp, device=processing_device)
if sliding_window and repeat_no > 1:
sample = torch.cat([frames_already_processed[:, -2:-1], sample], dim=1)
sample = temporal_interpolation( os.path.join("ckpts", "flownet.pkl"), sample, exp, device=processing_device)
sample = sample[:, 1:]
else:
sample = temporal_interpolation( os.path.join("ckpts", "flownet.pkl"), sample, exp, device=processing_device)
fps = fps * 2**exp
if len(spatial_upsampling) > 0:
@ -2552,6 +2657,12 @@ def generate_video(
new_frames = None
sample = sample * 2 - 1
if sliding_window :
if repeat_no == 1:
frames_already_processed = sample
else:
sample = torch.cat([frames_already_processed, sample], dim=1)
frames_already_processed = sample
cache_video(
tensor=sample[None],
@ -2578,7 +2689,8 @@ def generate_video(
print(f"New video saved to Path: "+video_path)
file_list.append(video_path)
state['update_gallery'] = True
seed += 1
if not sliding_window:
seed += 1
if temp_filename!= None and os.path.isfile(temp_filename):
os.remove(temp_filename)
@ -2640,6 +2752,8 @@ def process_tasks(state, progress=gr.Progress()):
finally:
if not ok:
queue.clear()
gen["prompts_max"] = 0
gen["prompt"] = ""
yield status
queue[:] = [item for item in queue if item['id'] != task['id']]
@ -2654,17 +2768,19 @@ def process_tasks(state, progress=gr.Progress()):
yield f"Total Generation Time: {end_time-start_time:.1f}s"
def get_generation_status(prompt_no, prompts_max, repeat_no, repeat_max):
def get_generation_status(prompt_no, prompts_max, repeat_no, repeat_max, sliding_window):
item = "Sliding Window" if sliding_window else "Sample"
if prompts_max == 1:
if repeat_max == 1:
return "Video"
else:
return f"Sample {repeat_no}/{repeat_max}"
return f"{item} {repeat_no}/{repeat_max}"
else:
if repeat_max == 1:
return f"Prompt {prompt_no}/{prompts_max}"
else:
return f"Prompt {prompt_no}/{prompts_max}, Sample {repeat_no}/{repeat_max}"
return f"Prompt {prompt_no}/{prompts_max}, {item} {repeat_no}/{repeat_max}"
refresh_id = 0
@ -2680,7 +2796,8 @@ def update_status(state):
prompts_max = gen.get("prompts_max",0)
total_generation = gen["total_generation"]
repeat_no = gen["repeat_no"]
status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation)
sliding_window = gen["sliding_window"]
status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation, sliding_window)
gen["progress_status"] = status
gen["refresh"] = get_new_refresh_id()
@ -2697,7 +2814,7 @@ def one_more_sample(state):
prompts_max = gen.get("prompts_max",0)
total_generation = gen["total_generation"] + extra_orders
repeat_no = gen["repeat_no"]
status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation)
status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation, gen.get("sliding_window",False))
gen["progress_status"] = status
@ -3019,7 +3136,7 @@ def prepare_inputs_dict(target, inputs ):
if not "Vace" in model_filename:
unsaved_params = ["video_prompt_type", "max_frames", "remove_background_image_ref"]
unsaved_params = ["video_prompt_type", "keep_frames", "remove_background_image_ref", "sliding_window_repeat", "sliding_window_overlap", "sliding_window_discard_last_frames"]
for k in unsaved_params:
inputs.pop(k)
@ -3061,7 +3178,10 @@ def save_inputs(
image_refs,
video_guide,
video_mask,
max_frames,
keep_frames,
sliding_window_repeat,
sliding_window_overlap,
sliding_window_discard_last_frames,
remove_background_image_ref,
temporal_upsampling,
spatial_upsampling,
@ -3251,6 +3371,13 @@ def refresh_video_prompt_type_video_guide(video_prompt_type, video_prompt_type_v
visible = "V" in video_prompt_type
return video_prompt_type, gr.update(visible = visible), gr.update(visible = visible), gr.update(visible= "M" in video_prompt_type )
def refresh_video_prompt_video_guide_trigger(video_prompt_type, video_prompt_type_video_guide):
video_prompt_type_video_guide = video_prompt_type_video_guide.split("#")[0]
video_prompt_type = del_in_sequence(video_prompt_type, "ODPCMV")
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide)
return video_prompt_type, video_prompt_type_video_guide, gr.update(visible= "V" in video_prompt_type ), gr.update(visible= "M" in video_prompt_type) , gr.update(visible= "V" in video_prompt_type )
def generate_video_tab(update_form = False, state_dict = None, ui_defaults = None, model_choice = None, header = None):
global inputs_names #, advanced
@ -3370,12 +3497,13 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
("Transfer Depth from the Control Video", "DV"),
("Recolorize the Control Video", "CV"),
# ("Alternate Video Ending", "OV"),
("(adv) Video contains Open Pose, Depth or Black & White ", "V"),
("(adv) Inpainting of Control Video using Mask Video ", "MV"),
("Video contains Open Pose, Depth, Black & White, Inpainting ", "V"),
("Control Video and Mask video for stronger Inpainting ", "MV"),
],
value=filter_letters(video_prompt_type_value, "ODPCMV"),
label="Video to Video", scale = 3
)
video_prompt_video_guide_trigger = gr.Text(visible=False, value="")
video_prompt_type_image_refs = gr.Dropdown(
choices=[
@ -3389,8 +3517,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
# video_prompt_type_image_refs = gr.Checkbox(value="I" in video_prompt_type_value , label= "Use References Images (Faces, Objects) to customize New Video", scale =1 )
video_guide = gr.Video(label= "Control Video", visible= "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None),)
max_frames = gr.Slider(0, 100, value=ui_defaults.get("max_frames",0), step=1, label="Nb of frames in Control Video to use (0 = max)", visible= "V" in video_prompt_type_value, scale = 2 )
keep_frames = gr.Text(value=ui_defaults.get("keep_frames","") , visible= "V" in video_prompt_type_value, scale = 2, label= "Frames to keep in Control Video (empty=All, 1=first, a:b for a range, space to separate values)" ) #, -1=last
image_refs = gr.Gallery( label ="Reference Images",
type ="pil", show_label= True,
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in video_prompt_type_value,
@ -3465,28 +3592,32 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
label="Resolution"
)
with gr.Row():
with gr.Column():
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)")
with gr.Column():
num_inference_steps = gr.Slider(1, 100, value=ui_defaults.get("num_inference_steps",30), step=1, label="Number of Inference Steps")
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)")
num_inference_steps = gr.Slider(1, 100, value=ui_defaults.get("num_inference_steps",30), step=1, label="Number of Inference Steps")
show_advanced = gr.Checkbox(label="Advanced Mode", value=advanced_ui)
with gr.Row(visible=advanced_ui) as advanced_row:
with gr.Column():
seed = gr.Slider(-1, 999999999, value=ui_defaults["seed"], step=1, label="Seed (-1 for random)")
with gr.Row():
repeat_generation = gr.Slider(1, 25.0, value=ui_defaults.get("repeat_generation",1), step=1, label="Default Number of Generated Videos per Prompt")
multi_images_gen_type = gr.Dropdown( value=ui_defaults.get("multi_images_gen_type",0),
choices=[
("Generate every combination of images and texts", 0),
("Match images and text prompts", 1),
], visible= args.multiple_images, label= "Multiple Images as Texts Prompts"
)
with gr.Row():
guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance_scale",5), step=0.5, label="Guidance Scale", visible=True)
embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=False)
flow_shift = gr.Slider(0.0, 25.0, value=ui_defaults.get("flow_shift",3), step=0.1, label="Shift Scale")
with gr.Row():
negative_prompt = gr.Textbox(label="Negative Prompt", value=ui_defaults.get("negative_prompt", "") )
with gr.Tabs(visible=advanced_ui) as advanced_row:
# with gr.Row(visible=advanced_ui) as advanced_row:
with gr.Tab("Generation"):
with gr.Column():
seed = gr.Slider(-1, 999999999, value=ui_defaults["seed"], step=1, label="Seed (-1 for random)")
with gr.Row():
repeat_generation = gr.Slider(1, 25.0, value=ui_defaults.get("repeat_generation",1), step=1, label="Default Number of Generated Videos per Prompt")
multi_images_gen_type = gr.Dropdown( value=ui_defaults.get("multi_images_gen_type",0),
choices=[
("Generate every combination of images and texts", 0),
("Match images and text prompts", 1),
], visible= args.multiple_images, label= "Multiple Images as Texts Prompts"
)
with gr.Row():
guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance_scale",5), step=0.5, label="Guidance Scale", visible=True)
embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=False)
flow_shift = gr.Slider(0.0, 25.0, value=ui_defaults.get("flow_shift",3), step=0.1, label="Shift Scale")
with gr.Row():
negative_prompt = gr.Textbox(label="Negative Prompt", value=ui_defaults.get("negative_prompt", "") )
with gr.Tab("Loras"):
with gr.Column(visible = True): #as loras_column:
gr.Markdown("<B>Loras can be used to create special effects on the video by mentioning a trigger word in the Prompt. You can save Loras combinations in presets.</B>")
loras_choices = gr.Dropdown(
@ -3500,7 +3631,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
loras_multipliers = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by space characters or carriage returns, line that starts with # are ignored", value=launch_multis_str)
with gr.Row():
gr.Markdown("<B>Tea Cache accelerates by skipping intelligently some steps, the more steps are skipped the lower the quality of the video (Tea Cache consumes also VRAM)</B>")
with gr.Row():
with gr.Tab("Speed"):
with gr.Column():
gr.Markdown("<B>Tea Cache accelerates the Video generation by skipping denoising steps. This may impact the quality</B>")
tea_cache_setting = gr.Dropdown(
choices=[
("Tea Cache Disabled", 0),
@ -3516,9 +3650,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
)
tea_cache_start_step_perc = gr.Slider(0, 100, value=ui_defaults.get("tea_cache_start_step_perc",0), step=1, label="Tea Cache starting moment in % of generation")
with gr.Row():
with gr.Tab("Upsampling"):
with gr.Column():
gr.Markdown("<B>Upsampling - postprocessing that may improve fluidity and the size of the video</B>")
with gr.Row():
temporal_upsampling = gr.Dropdown(
choices=[
("Disabled", ""),
@ -3542,6 +3677,59 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
label="Spatial Upsampling"
)
with gr.Tab("Quality"):
with gr.Row():
gr.Markdown("<B>Experimental: Skip Layer Guidance, should improve video quality</B>")
with gr.Row():
slg_switch = gr.Dropdown(
choices=[
("OFF", 0),
("ON", 1),
],
value=ui_defaults.get("slg_switch",0),
visible=True,
scale = 1,
label="Skip Layer guidance"
)
slg_layers = gr.Dropdown(
choices=[
(str(i), i ) for i in range(40)
],
value=ui_defaults.get("slg_layers", ["9"]),
multiselect= True,
label="Skip Layers",
scale= 3
)
with gr.Row():
slg_start_perc = gr.Slider(0, 100, value=ui_defaults.get("slg_start_perc",10), step=1, label="Denoising Steps % start")
slg_end_perc = gr.Slider(0, 100, value=ui_defaults.get("slg_end_perc",90), step=1, label="Denoising Steps % end")
with gr.Row():
gr.Markdown("<B>Experimental: Classifier-Free Guidance Zero Star, better adherence to Text Prompt")
with gr.Row():
cfg_star_switch = gr.Dropdown(
choices=[
("OFF", 0),
("ON", 1),
],
value=ui_defaults.get("cfg_star_switch",0),
visible=True,
scale = 1,
label="CFG Star"
)
with gr.Row():
cfg_zero_step = gr.Slider(-1, 39, value=ui_defaults.get("cfg_zero_step",-1), step=1, label="CFG Zero below this Layer (Extra Process)")
with gr.Tab("Sliding Window", visible= "Vace" in model_filename ) as sliding_window_tab:
with gr.Column():
gr.Markdown("<B>A Sliding Window allows you to generate video with a duration not limited by the Model</B>")
sliding_window_repeat = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_repeat", 0), step=1, label="Sliding Window Iterations (O=Disabled)")
sliding_window_overlap = gr.Slider(1, 32, value=ui_defaults.get("sliding_window_overlap",16), step=1, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
sliding_window_discard_last_frames = gr.Slider(1, 10, value=ui_defaults.get("sliding_window_discard_last_frames", 4), step=1, label="Discard Last Frames of a Window (that may have bad quality)")
with gr.Tab("Miscellaneous"):
gr.Markdown("<B>With Riflex you can generate videos longer than 5s which is the default duration of videos used to train the model</B>")
RIFLEx_setting = gr.Dropdown(
choices=[
@ -3552,50 +3740,9 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
value=ui_defaults.get("RIFLEx_setting",0),
label="RIFLEx positional embedding to generate long video"
)
with gr.Row():
gr.Markdown("<B>Experimental: Skip Layer Guidance, should improve video quality</B>")
with gr.Row():
slg_switch = gr.Dropdown(
choices=[
("OFF", 0),
("ON", 1),
],
value=ui_defaults.get("slg_switch",0),
visible=True,
scale = 1,
label="Skip Layer guidance"
)
slg_layers = gr.Dropdown(
choices=[
(str(i), i ) for i in range(40)
],
value=ui_defaults.get("slg_layers", ["9"]),
multiselect= True,
label="Skip Layers",
scale= 3
)
with gr.Row():
slg_start_perc = gr.Slider(0, 100, value=ui_defaults.get("slg_start_perc",10), step=1, label="Denoising Steps % start")
slg_end_perc = gr.Slider(0, 100, value=ui_defaults.get("slg_end_perc",90), step=1, label="Denoising Steps % end")
with gr.Row():
gr.Markdown("<B>Experimental: Classifier-Free Guidance Zero Star, better adherence to Text Prompt")
with gr.Row():
cfg_star_switch = gr.Dropdown(
choices=[
("OFF", 0),
("ON", 1),
],
value=ui_defaults.get("cfg_star_switch",0),
visible=True,
scale = 1,
label="CFG Star"
)
with gr.Row():
cfg_zero_step = gr.Slider(-1, 39, value=ui_defaults.get("cfg_zero_step",-1), step=1, label="CFG Zero below this Layer (Extra Process)")
with gr.Row():
save_settings_btn = gr.Button("Set Settings as Default", visible = not args.lock_config)
with gr.Row():
save_settings_btn = gr.Button("Set Settings as Default", visible = not args.lock_config)
if not update_form:
with gr.Column():
@ -3795,7 +3942,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
)
extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column,
prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, advanced_row] # show_advanced presets_column,
prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, advanced_row, sliding_window_tab] # show_advanced presets_column,
if update_form:
locals_dict = locals()
gen_inputs = [state_dict if k=="state" else locals_dict[k] for k in inputs_names] + [state_dict] + extra_inputs
@ -3805,9 +3952,9 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
target_settings = gr.Text(value = "settings", interactive= False, visible= False)
image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start, image_end])
# video_prompt_type.change(fn=refresh_video_prompt_type, inputs=[state, video_prompt_type], outputs=[image_refs, video_guide, video_mask, max_frames, remove_background_image_ref])
video_prompt_video_guide_trigger.change(fn=refresh_video_prompt_video_guide_trigger, inputs=[video_prompt_type, video_prompt_video_guide_trigger], outputs=[video_prompt_type, video_prompt_type_video_guide, video_guide, video_mask, keep_frames])
video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_image_ref ])
video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [video_prompt_type, video_prompt_type_video_guide], outputs = [video_prompt_type, video_guide, max_frames, video_mask])
video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [video_prompt_type, video_prompt_type_video_guide], outputs = [video_prompt_type, video_guide, keep_frames, video_mask])
show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars])
@ -3910,12 +4057,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
return (
loras_choices, lset_name, state, queue_df, current_gen_column,
gen_status, output, abort_btn, generate_btn, add_to_queue_btn,
gen_info,
prompt, wizard_prompt, wizard_prompt_activated_var, wizard_variables_var,
prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars,
advanced_row, image_prompt_column, video_prompt_column, queue_accordion,
*prompt_vars
)
gen_info, queue_accordion, video_guide, video_mask, video_prompt_video_guide_trigger
)
def generate_download_tab(lset_name,loras_choices, state):
@ -4099,13 +4242,14 @@ def generate_about_tab():
gr.Markdown("- <B>Alibaba Wan team for the best open source video generator")
gr.Markdown("- <B>Alibaba Vace and Fun Teams for their incredible control net models")
gr.Markdown("- <B>Cocktail Peanuts</B> : QA and simple installation via Pinokio.computer")
gr.Markdown("- <B>Tophness</B> : created multi tabs and queuing frameworks")
gr.Markdown("- <B>Tophness</B> : created (former) multi tabs and queuing frameworks")
gr.Markdown("- <B>AmericanPresidentJimmyCarter</B> : added original support for Skip Layer Guidance")
gr.Markdown("- <B>Remade_AI</B> : for their awesome Loras collection")
gr.Markdown("<BR>Huge acknowlegments to these great open source projects used in WanGP:")
gr.Markdown("- <B>Rife</B>: temporal upsampler (https://github.com/hzwer/ECCV2022-RIFE)")
gr.Markdown("- <B>DwPose</B>: Open Pose extractor (https://github.com/IDEA-Research/DWPose)")
gr.Markdown("- <B>Midas</B>: Depth extractor (https://github.com/isl-org/MiDaS")
gr.Markdown("- <B>Matanyone</B> and <B>SAM2</B>: Mask Generation (https://github.com/pq-yang/MatAnyone) and (https://github.com/facebookresearch/sam2)")
def generate_info_tab():
@ -4139,8 +4283,30 @@ def generate_dropdown_model_list():
)
def select_tab(tab_state, evt:gr.SelectData):
tab_video_mask_creator = 2
old_tab_no = tab_state.get("tab_no",0)
new_tab_no = evt.index
if old_tab_no == tab_video_mask_creator:
vmc_event_handler(False)
elif new_tab_no == tab_video_mask_creator:
if gen_in_progress:
gr.Info("Unable to access this Tab while a Generation is in Progress. Please come back later")
tab_state["tab_auto"]=old_tab_no
else:
vmc_event_handler(True)
tab_state["tab_no"] = new_tab_no
def select_tab_auto(tab_state):
old_tab_no = tab_state.pop("tab_auto", -1)
if old_tab_no>= 0:
tab_state["tab_auto"]=old_tab_no
return gr.Tabs(selected=old_tab_no) # !! doesnt work !!
return gr.Tab()
def create_demo():
global vmc_event_handler
css = """
#model_list{
background-color:black;
@ -4377,6 +4543,8 @@ def create_demo():
gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.0 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
global model_list
tab_state = gr.State({ "tab_no":0 })
with gr.Tabs(selected="video_gen", ) as main_tabs:
with gr.Tab("Video Generator", id="video_gen") as t2v_tab:
with gr.Row():
@ -4393,14 +4561,15 @@ def create_demo():
(
loras_choices, lset_name, state, queue_df, current_gen_column,
gen_status, output, abort_btn, generate_btn, add_to_queue_btn,
gen_info,
prompt, wizard_prompt, wizard_prompt_activated_var, wizard_variables_var,
prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars,
advanced_row, image_prompt_column, video_prompt_column, queue_accordion,
*prompt_vars_outputs
gen_info, queue_accordion, video_guide, video_mask, video_prompt_type_video_trigger
) = generate_video_tab(model_choice=model_choice, header=header)
with gr.Tab("Informations"):
generate_info_tab()
with gr.Tab("Video Mask Creator", id="video_mask_creator") as video_mask_creator:
from preprocessing.matanyone import app as matanyone_app
vmc_event_handler = matanyone_app.get_vmc_event_handler()
matanyone_app.display(video_guide, video_mask, video_prompt_type_video_trigger)
if not args.lock_config:
with gr.Tab("Downloads", id="downloads") as downloads_tab:
generate_download_tab(lset_name, loras_choices, state)
@ -4427,6 +4596,7 @@ def create_demo():
trigger_mode="always_last"
)
main_tabs.select(fn=select_tab, inputs= [tab_state], outputs= None).then(fn=select_tab_auto, inputs= [tab_state], outputs=[main_tabs])
return demo
if __name__ == "__main__":