Simplified Vace, added auto open pose and depth extrators

This commit is contained in:
DeepBeepMeep 2025-04-09 15:51:23 +02:00
parent fea835f21f
commit 9ac1674615
23 changed files with 3316 additions and 104 deletions

View File

@ -14,7 +14,7 @@
## 🔥 Latest News!!
* April 4 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you !
* April 9 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you !
- A new queuing system that lets you stack in a queue as many text2video and imag2video tasks as you want. Each task can rely on complete different generation parameters (different number of frames, steps, loras, ...).
- Temporal upsampling (Rife) and spatial upsampling (Lanczos) for a smoother video (32 fps or 64 fps) and to enlarge you video by x2 or x4. Check these new advanced options.
- Wan Vace Control Net support : with Vace you can inject in the scene people or objects, animate a person, perform inpainting or outpainting, continue a video, ... I have provided an introduction guide below.
@ -272,23 +272,24 @@ You can define multiple lines of macros. If there is only one macro line, the ap
### VACE ControlNet introduction
Vace is a ControlNet 1.3B text2video model that allows you on top of a text prompt to provide visual hints to guide the generation. It can do more things than image2video although it is not as good for just starting a video with an image because it only a 1.3B model (in fact 3B) versus 14B and (it is not specialized for start frames). However, with Vace you can inject in the scene people or objects, animate a person, perform inpainting or outpainting, continue a video, ...
Vace is a ControlNet 1.3B text2video model that allows you to do Video to Video and Reference to Video (inject your own images into the output video). So with Vace you can inject in the scene people or objects of your choice, animate a person, perform inpainting or outpainting, continue a video, ...
First you need to select the Vace 1.3B model in the Drop Down box at the top. Please note that Vace works well for the moment only with videos up to 5s (81 frames).
Beside the usual Text Prompt, three new types of visual hints can be provided (and combined !):
- reference Images: use this to inject people or objects in the video. You can select multiple reference Images. The integration of the image is more efficient if the background is replaced by the full white color. You can do that with your preferred background remover or use the built in background remover by checking the box *Remove background*
- a Control Video: Based on your choice, you can decide to transfer the motion, the depth in a new Video. You can tell WanGP to use only the first n frames of Control Video and to extrapolate the rest. You can also do inpainting ). If the video contains area of grey color 127, they will be considered as masks and will be filled based on the Text prompt of the reference Images.
- a Video: this can be a video that contains a body pose (an animated wireframe that indicates the positions of limbs of a person), a greyed depth map video, a normal video combined with a masked video (see below),... The Vace model will detect automatically what to do depending on the video content. You can tell WanGP to use only the n first frames of this Video. All the frames beyond and up the number of requested frames will be generated by following the Text prompt and the other visual hints (for instance reference images). If the video contains area of grey color 127, they will be considered as masks and will be filled based on the Text prompt of the reference Images. There
- reference Images: Use this to inject people or objects of your choice in the video. You can select multiple reference Images. The integration of the image is more efficient if the background is replaced by the full white color. You can do that with your preferred background remover or use the built in background remover by checking the box *Remove background*
- a Video Mask
This offers a stronger mechanism to tell Vace which parts should be kept (black) or replaced (white). You can do as well inpainting / outpainting, fill the missing part of a video more efficientlty with just the video hint.
This offers a stronger mechanism to tell Vace which parts should be kept (black) or replaced (white). You can do as well inpainting / outpainting, fill the missing part of a video more efficientlty with just the video hint. If a video mask is white, it will be generated so with black frames at the beginning and at the end and the rest white, you could generate the missing frames in between.
Examples:
- Inject people and / objects into a scene describe by a text promtp: Ref. Images + text Prompt
- Animate a character described in a text prompt: Body Pose Video + text Prompt
- Animate a character of your choice : Ref Images + Body Pose Video + text Prompt
- Inject people and / objects into a scene describe by a text prompt: Ref. Images + text Prompt
- Animate a character described in a text prompt: a Video of person moving + text Prompt
- Animate a character of your choice (pose transfer) : Ref Images + a Video of person moving + text Prompt
- Change the style of a scene (depth transfer): a Video that contains objects / person at differen depths + text Prompt
There are lots of possible combinations. Some of them require to prepare some materials (masks on top of video, full masks, etc...).

View File

@ -0,0 +1,2 @@
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.

View File

@ -0,0 +1,127 @@
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import cv2
import numpy as np
import onnxruntime
def nms(boxes, scores, nms_thr):
"""Single class NMS implemented in Numpy."""
x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= nms_thr)[0]
order = order[inds + 1]
return keep
def multiclass_nms(boxes, scores, nms_thr, score_thr):
"""Multiclass NMS implemented in Numpy. Class-aware version."""
final_dets = []
num_classes = scores.shape[1]
for cls_ind in range(num_classes):
cls_scores = scores[:, cls_ind]
valid_score_mask = cls_scores > score_thr
if valid_score_mask.sum() == 0:
continue
else:
valid_scores = cls_scores[valid_score_mask]
valid_boxes = boxes[valid_score_mask]
keep = nms(valid_boxes, valid_scores, nms_thr)
if len(keep) > 0:
cls_inds = np.ones((len(keep), 1)) * cls_ind
dets = np.concatenate(
[valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
)
final_dets.append(dets)
if len(final_dets) == 0:
return None
return np.concatenate(final_dets, 0)
def demo_postprocess(outputs, img_size, p6=False):
grids = []
expanded_strides = []
strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
hsizes = [img_size[0] // stride for stride in strides]
wsizes = [img_size[1] // stride for stride in strides]
for hsize, wsize, stride in zip(hsizes, wsizes, strides):
xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
grids.append(grid)
shape = grid.shape[:2]
expanded_strides.append(np.full((*shape, 1), stride))
grids = np.concatenate(grids, 1)
expanded_strides = np.concatenate(expanded_strides, 1)
outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
return outputs
def preprocess(img, input_size, swap=(2, 0, 1)):
if len(img.shape) == 3:
padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
else:
padded_img = np.ones(input_size, dtype=np.uint8) * 114
r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
resized_img = cv2.resize(
img,
(int(img.shape[1] * r), int(img.shape[0] * r)),
interpolation=cv2.INTER_LINEAR,
).astype(np.uint8)
padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
padded_img = padded_img.transpose(swap)
padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
return padded_img, r
def inference_detector(session, oriImg):
input_shape = (640,640)
img, ratio = preprocess(oriImg, input_shape)
ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
output = session.run(None, ort_inputs)
predictions = demo_postprocess(output[0], input_shape)[0]
boxes = predictions[:, :4]
scores = predictions[:, 4:5] * predictions[:, 5:]
boxes_xyxy = np.ones_like(boxes)
boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.
boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.
boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.
boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
boxes_xyxy /= ratio
dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
if dets is not None:
final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
isscore = final_scores>0.3
iscat = final_cls_inds == 0
isbbox = [ i and j for (i, j) in zip(isscore, iscat)]
final_boxes = final_boxes[isbbox]
else:
final_boxes = np.array([])
return final_boxes

View File

@ -0,0 +1,362 @@
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import List, Tuple
import cv2
import numpy as np
import onnxruntime as ort
def preprocess(
img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256)
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Do preprocessing for RTMPose model inference.
Args:
img (np.ndarray): Input image in shape.
input_size (tuple): Input image size in shape (w, h).
Returns:
tuple:
- resized_img (np.ndarray): Preprocessed image.
- center (np.ndarray): Center of image.
- scale (np.ndarray): Scale of image.
"""
# get shape of image
img_shape = img.shape[:2]
out_img, out_center, out_scale = [], [], []
if len(out_bbox) == 0:
out_bbox = [[0, 0, img_shape[1], img_shape[0]]]
for i in range(len(out_bbox)):
x0 = out_bbox[i][0]
y0 = out_bbox[i][1]
x1 = out_bbox[i][2]
y1 = out_bbox[i][3]
bbox = np.array([x0, y0, x1, y1])
# get center and scale
center, scale = bbox_xyxy2cs(bbox, padding=1.25)
# do affine transformation
resized_img, scale = top_down_affine(input_size, scale, center, img)
# normalize image
mean = np.array([123.675, 116.28, 103.53])
std = np.array([58.395, 57.12, 57.375])
resized_img = (resized_img - mean) / std
out_img.append(resized_img)
out_center.append(center)
out_scale.append(scale)
return out_img, out_center, out_scale
def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray:
"""Inference RTMPose model.
Args:
sess (ort.InferenceSession): ONNXRuntime session.
img (np.ndarray): Input image in shape.
Returns:
outputs (np.ndarray): Output of RTMPose model.
"""
all_out = []
# build input
for i in range(len(img)):
input = [img[i].transpose(2, 0, 1)]
# build output
sess_input = {sess.get_inputs()[0].name: input}
sess_output = []
for out in sess.get_outputs():
sess_output.append(out.name)
# run model
outputs = sess.run(sess_output, sess_input)
all_out.append(outputs)
return all_out
def postprocess(outputs: List[np.ndarray],
model_input_size: Tuple[int, int],
center: Tuple[int, int],
scale: Tuple[int, int],
simcc_split_ratio: float = 2.0
) -> Tuple[np.ndarray, np.ndarray]:
"""Postprocess for RTMPose model output.
Args:
outputs (np.ndarray): Output of RTMPose model.
model_input_size (tuple): RTMPose model Input image size.
center (tuple): Center of bbox in shape (x, y).
scale (tuple): Scale of bbox in shape (w, h).
simcc_split_ratio (float): Split ratio of simcc.
Returns:
tuple:
- keypoints (np.ndarray): Rescaled keypoints.
- scores (np.ndarray): Model predict scores.
"""
all_key = []
all_score = []
for i in range(len(outputs)):
# use simcc to decode
simcc_x, simcc_y = outputs[i]
keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio)
# rescale keypoints
keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2
all_key.append(keypoints[0])
all_score.append(scores[0])
return np.array(all_key), np.array(all_score)
def bbox_xyxy2cs(bbox: np.ndarray,
padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]:
"""Transform the bbox format from (x,y,w,h) into (center, scale)
Args:
bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
as (left, top, right, bottom)
padding (float): BBox padding factor that will be multilied to scale.
Default: 1.0
Returns:
tuple: A tuple containing center and scale.
- np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
(n, 2)
- np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
(n, 2)
"""
# convert single bbox from (4, ) to (1, 4)
dim = bbox.ndim
if dim == 1:
bbox = bbox[None, :]
# get bbox center and scale
x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
center = np.hstack([x1 + x2, y1 + y2]) * 0.5
scale = np.hstack([x2 - x1, y2 - y1]) * padding
if dim == 1:
center = center[0]
scale = scale[0]
return center, scale
def _fix_aspect_ratio(bbox_scale: np.ndarray,
aspect_ratio: float) -> np.ndarray:
"""Extend the scale to match the given aspect ratio.
Args:
scale (np.ndarray): The image scale (w, h) in shape (2, )
aspect_ratio (float): The ratio of ``w/h``
Returns:
np.ndarray: The reshaped image scale in (2, )
"""
w, h = np.hsplit(bbox_scale, [1])
bbox_scale = np.where(w > h * aspect_ratio,
np.hstack([w, w / aspect_ratio]),
np.hstack([h * aspect_ratio, h]))
return bbox_scale
def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
"""Rotate a point by an angle.
Args:
pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
angle_rad (float): rotation angle in radian
Returns:
np.ndarray: Rotated point in shape (2, )
"""
sn, cs = np.sin(angle_rad), np.cos(angle_rad)
rot_mat = np.array([[cs, -sn], [sn, cs]])
return rot_mat @ pt
def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray:
"""To calculate the affine matrix, three pairs of points are required. This
function is used to get the 3rd point, given 2D points a & b.
The 3rd point is defined by rotating vector `a - b` by 90 degrees
anticlockwise, using b as the rotation center.
Args:
a (np.ndarray): The 1st point (x,y) in shape (2, )
b (np.ndarray): The 2nd point (x,y) in shape (2, )
Returns:
np.ndarray: The 3rd point.
"""
direction = a - b
c = b + np.r_[-direction[1], direction[0]]
return c
def get_warp_matrix(center: np.ndarray,
scale: np.ndarray,
rot: float,
output_size: Tuple[int, int],
shift: Tuple[float, float] = (0., 0.),
inv: bool = False) -> np.ndarray:
"""Calculate the affine transformation matrix that can warp the bbox area
in the input image to the output size.
Args:
center (np.ndarray[2, ]): Center of the bounding box (x, y).
scale (np.ndarray[2, ]): Scale of the bounding box
wrt [width, height].
rot (float): Rotation angle (degree).
output_size (np.ndarray[2, ] | list(2,)): Size of the
destination heatmaps.
shift (0-100%): Shift translation ratio wrt the width/height.
Default (0., 0.).
inv (bool): Option to inverse the affine transform direction.
(inv=False: src->dst or inv=True: dst->src)
Returns:
np.ndarray: A 2x3 transformation matrix
"""
shift = np.array(shift)
src_w = scale[0]
dst_w = output_size[0]
dst_h = output_size[1]
# compute transformation matrix
rot_rad = np.deg2rad(rot)
src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad)
dst_dir = np.array([0., dst_w * -0.5])
# get four corners of the src rectangle in the original image
src = np.zeros((3, 2), dtype=np.float32)
src[0, :] = center + scale * shift
src[1, :] = center + src_dir + scale * shift
src[2, :] = _get_3rd_point(src[0, :], src[1, :])
# get four corners of the dst rectangle in the input image
dst = np.zeros((3, 2), dtype=np.float32)
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
if inv:
warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
else:
warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
return warp_mat
def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict,
img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Get the bbox image as the model input by affine transform.
Args:
input_size (dict): The input size of the model.
bbox_scale (dict): The bbox scale of the img.
bbox_center (dict): The bbox center of the img.
img (np.ndarray): The original image.
Returns:
tuple: A tuple containing center and scale.
- np.ndarray[float32]: img after affine transform.
- np.ndarray[float32]: bbox scale after affine transform.
"""
w, h = input_size
warp_size = (int(w), int(h))
# reshape bbox to fixed aspect ratio
bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)
# get the affine matrix
center = bbox_center
scale = bbox_scale
rot = 0
warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
# do affine transform
img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR)
return img, bbox_scale
def get_simcc_maximum(simcc_x: np.ndarray,
simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Get maximum response location and value from simcc representations.
Note:
instance number: N
num_keypoints: K
heatmap height: H
heatmap width: W
Args:
simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
Returns:
tuple:
- locs (np.ndarray): locations of maximum heatmap responses in shape
(K, 2) or (N, K, 2)
- vals (np.ndarray): values of maximum heatmap responses in shape
(K,) or (N, K)
"""
N, K, Wx = simcc_x.shape
simcc_x = simcc_x.reshape(N * K, -1)
simcc_y = simcc_y.reshape(N * K, -1)
# get maximum value locations
x_locs = np.argmax(simcc_x, axis=1)
y_locs = np.argmax(simcc_y, axis=1)
locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
max_val_x = np.amax(simcc_x, axis=1)
max_val_y = np.amax(simcc_y, axis=1)
# get maximum value across x and y axis
mask = max_val_x > max_val_y
max_val_x[mask] = max_val_y[mask]
vals = max_val_x
locs[vals <= 0.] = -1
# reshape
locs = locs.reshape(N, K, 2)
vals = vals.reshape(N, K)
return locs, vals
def decode(simcc_x: np.ndarray, simcc_y: np.ndarray,
simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]:
"""Modulate simcc distribution with Gaussian.
Args:
simcc_x (np.ndarray[K, Wx]): model predicted simcc in x.
simcc_y (np.ndarray[K, Wy]): model predicted simcc in y.
simcc_split_ratio (int): The split ratio of simcc.
Returns:
tuple: A tuple containing center and scale.
- np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2)
- np.ndarray[float32]: scores in shape (K,) or (n, K)
"""
keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
keypoints /= simcc_split_ratio
return keypoints, scores
def inference_pose(session, out_bbox, oriImg):
h, w = session.get_inputs()[0].shape[2:]
model_input_size = (w, h)
resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size)
outputs = inference(session, resized_img)
keypoints, scores = postprocess(outputs, model_input_size, center, scale)
return keypoints, scores

View File

@ -0,0 +1,183 @@
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import cv2
import torch
import numpy as np
from . import util
from .wholebody import Wholebody, HWC3, resize_image
from PIL import Image
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
def convert_to_numpy(image):
if isinstance(image, Image.Image):
image = np.array(image)
elif isinstance(image, torch.Tensor):
image = image.detach().cpu().numpy()
elif isinstance(image, np.ndarray):
image = image.copy()
else:
raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.'
return image
def draw_pose(pose, H, W, use_hand=False, use_body=False, use_face=False):
bodies = pose['bodies']
faces = pose['faces']
hands = pose['hands']
candidate = bodies['candidate']
subset = bodies['subset']
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
if use_body:
canvas = util.draw_bodypose(canvas, candidate, subset)
if use_hand:
canvas = util.draw_handpose(canvas, hands)
if use_face:
canvas = util.draw_facepose(canvas, faces)
return canvas
class PoseAnnotator:
def __init__(self, cfg, device=None):
onnx_det = cfg['DETECTION_MODEL']
onnx_pose = cfg['POSE_MODEL']
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
self.pose_estimation = Wholebody(onnx_det, onnx_pose, device=self.device)
self.resize_size = cfg.get("RESIZE_SIZE", 1024)
self.use_body = cfg.get('USE_BODY', True)
self.use_face = cfg.get('USE_FACE', True)
self.use_hand = cfg.get('USE_HAND', True)
@torch.no_grad()
@torch.inference_mode
def forward(self, image):
image = convert_to_numpy(image)
input_image = HWC3(image[..., ::-1])
return self.process(resize_image(input_image, self.resize_size), image.shape[:2])
def process(self, ori_img, ori_shape):
ori_h, ori_w = ori_shape
ori_img = ori_img.copy()
H, W, C = ori_img.shape
with torch.no_grad():
candidate, subset, det_result = self.pose_estimation(ori_img)
nums, keys, locs = candidate.shape
candidate[..., 0] /= float(W)
candidate[..., 1] /= float(H)
body = candidate[:, :18].copy()
body = body.reshape(nums * 18, locs)
score = subset[:, :18]
for i in range(len(score)):
for j in range(len(score[i])):
if score[i][j] > 0.3:
score[i][j] = int(18 * i + j)
else:
score[i][j] = -1
un_visible = subset < 0.3
candidate[un_visible] = -1
foot = candidate[:, 18:24]
faces = candidate[:, 24:92]
hands = candidate[:, 92:113]
hands = np.vstack([hands, candidate[:, 113:]])
bodies = dict(candidate=body, subset=score)
pose = dict(bodies=bodies, hands=hands, faces=faces)
ret_data = {}
if self.use_body:
detected_map_body = draw_pose(pose, H, W, use_body=True)
detected_map_body = cv2.resize(detected_map_body[..., ::-1], (ori_w, ori_h),
interpolation=cv2.INTER_LANCZOS4 if ori_h * ori_w > H * W else cv2.INTER_AREA)
ret_data["detected_map_body"] = detected_map_body
if self.use_face:
detected_map_face = draw_pose(pose, H, W, use_face=True)
detected_map_face = cv2.resize(detected_map_face[..., ::-1], (ori_w, ori_h),
interpolation=cv2.INTER_LANCZOS4 if ori_h * ori_w > H * W else cv2.INTER_AREA)
ret_data["detected_map_face"] = detected_map_face
if self.use_body and self.use_face:
detected_map_bodyface = draw_pose(pose, H, W, use_body=True, use_face=True)
detected_map_bodyface = cv2.resize(detected_map_bodyface[..., ::-1], (ori_w, ori_h),
interpolation=cv2.INTER_LANCZOS4 if ori_h * ori_w > H * W else cv2.INTER_AREA)
ret_data["detected_map_bodyface"] = detected_map_bodyface
if self.use_hand and self.use_body and self.use_face:
detected_map_handbodyface = draw_pose(pose, H, W, use_hand=True, use_body=True, use_face=True)
detected_map_handbodyface = cv2.resize(detected_map_handbodyface[..., ::-1], (ori_w, ori_h),
interpolation=cv2.INTER_LANCZOS4 if ori_h * ori_w > H * W else cv2.INTER_AREA)
ret_data["detected_map_handbodyface"] = detected_map_handbodyface
# convert_size
if det_result.shape[0] > 0:
w_ratio, h_ratio = ori_w / W, ori_h / H
det_result[..., ::2] *= h_ratio
det_result[..., 1::2] *= w_ratio
det_result = det_result.astype(np.int32)
return ret_data, det_result
class PoseBodyFaceAnnotator(PoseAnnotator):
def __init__(self, cfg):
super().__init__(cfg)
self.use_body, self.use_face, self.use_hand = True, True, False
@torch.no_grad()
@torch.inference_mode
def forward(self, image):
ret_data, det_result = super().forward(image)
return ret_data['detected_map_bodyface']
class PoseBodyFaceVideoAnnotator(PoseBodyFaceAnnotator):
def forward(self, frames):
ret_frames = []
for frame in frames:
anno_frame = super().forward(np.array(frame))
ret_frames.append(anno_frame)
return ret_frames
import imageio
def save_one_video(file_path, videos, fps=8, quality=8, macro_block_size=None):
try:
video_writer = imageio.get_writer(file_path, fps=fps, codec='libx264', quality=quality, macro_block_size=macro_block_size)
for frame in videos:
video_writer.append_data(frame)
video_writer.close()
return True
except Exception as e:
print(f"Video save error: {e}")
return False
def get_frames(video_path):
frames = []
# Opens the Video file with CV2
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
print("video fps: " + str(fps))
i = 0
while cap.isOpened():
ret, frame = cap.read()
if ret == False:
break
frames.append(frame)
i += 1
cap.release()
cv2.destroyAllWindows()
return frames, fps

View File

@ -0,0 +1,299 @@
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import numpy as np
import matplotlib
import cv2
eps = 0.01
def smart_resize(x, s):
Ht, Wt = s
if x.ndim == 2:
Ho, Wo = x.shape
Co = 1
else:
Ho, Wo, Co = x.shape
if Co == 3 or Co == 1:
k = float(Ht + Wt) / float(Ho + Wo)
return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
else:
return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2)
def smart_resize_k(x, fx, fy):
if x.ndim == 2:
Ho, Wo = x.shape
Co = 1
else:
Ho, Wo, Co = x.shape
Ht, Wt = Ho * fy, Wo * fx
if Co == 3 or Co == 1:
k = float(Ht + Wt) / float(Ho + Wo)
return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
else:
return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2)
def padRightDownCorner(img, stride, padValue):
h = img.shape[0]
w = img.shape[1]
pad = 4 * [None]
pad[0] = 0 # up
pad[1] = 0 # left
pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
img_padded = img
pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1))
img_padded = np.concatenate((pad_up, img_padded), axis=0)
pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1))
img_padded = np.concatenate((pad_left, img_padded), axis=1)
pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1))
img_padded = np.concatenate((img_padded, pad_down), axis=0)
pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1))
img_padded = np.concatenate((img_padded, pad_right), axis=1)
return img_padded, pad
def transfer(model, model_weights):
transfered_model_weights = {}
for weights_name in model.state_dict().keys():
transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])]
return transfered_model_weights
def draw_bodypose(canvas, candidate, subset):
H, W, C = canvas.shape
candidate = np.array(candidate)
subset = np.array(subset)
stickwidth = 4
limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
[1, 16], [16, 18], [3, 17], [6, 18]]
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
for i in range(17):
for n in range(len(subset)):
index = subset[n][np.array(limbSeq[i]) - 1]
if -1 in index:
continue
Y = candidate[index.astype(int), 0] * float(W)
X = candidate[index.astype(int), 1] * float(H)
mX = np.mean(X)
mY = np.mean(Y)
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
cv2.fillConvexPoly(canvas, polygon, colors[i])
canvas = (canvas * 0.6).astype(np.uint8)
for i in range(18):
for n in range(len(subset)):
index = int(subset[n][i])
if index == -1:
continue
x, y = candidate[index][0:2]
x = int(x * W)
y = int(y * H)
cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
return canvas
def draw_handpose(canvas, all_hand_peaks):
H, W, C = canvas.shape
edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
[10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
for peaks in all_hand_peaks:
peaks = np.array(peaks)
for ie, e in enumerate(edges):
x1, y1 = peaks[e[0]]
x2, y2 = peaks[e[1]]
x1 = int(x1 * W)
y1 = int(y1 * H)
x2 = int(x2 * W)
y2 = int(y2 * H)
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2)
for i, keyponit in enumerate(peaks):
x, y = keyponit
x = int(x * W)
y = int(y * H)
if x > eps and y > eps:
cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
return canvas
def draw_facepose(canvas, all_lmks):
H, W, C = canvas.shape
for lmks in all_lmks:
lmks = np.array(lmks)
for lmk in lmks:
x, y = lmk
x = int(x * W)
y = int(y * H)
if x > eps and y > eps:
cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1)
return canvas
# detect hand according to body pose keypoints
# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
def handDetect(candidate, subset, oriImg):
# right hand: wrist 4, elbow 3, shoulder 2
# left hand: wrist 7, elbow 6, shoulder 5
ratioWristElbow = 0.33
detect_result = []
image_height, image_width = oriImg.shape[0:2]
for person in subset.astype(int):
# if any of three not detected
has_left = np.sum(person[[5, 6, 7]] == -1) == 0
has_right = np.sum(person[[2, 3, 4]] == -1) == 0
if not (has_left or has_right):
continue
hands = []
#left hand
if has_left:
left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]]
x1, y1 = candidate[left_shoulder_index][:2]
x2, y2 = candidate[left_elbow_index][:2]
x3, y3 = candidate[left_wrist_index][:2]
hands.append([x1, y1, x2, y2, x3, y3, True])
# right hand
if has_right:
right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]]
x1, y1 = candidate[right_shoulder_index][:2]
x2, y2 = candidate[right_elbow_index][:2]
x3, y3 = candidate[right_wrist_index][:2]
hands.append([x1, y1, x2, y2, x3, y3, False])
for x1, y1, x2, y2, x3, y3, is_left in hands:
# pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
# handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
# handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
# const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
# const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
# handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
x = x3 + ratioWristElbow * (x3 - x2)
y = y3 + ratioWristElbow * (y3 - y2)
distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
# x-y refers to the center --> offset to topLeft point
# handRectangle.x -= handRectangle.width / 2.f;
# handRectangle.y -= handRectangle.height / 2.f;
x -= width / 2
y -= width / 2 # width = height
# overflow the image
if x < 0: x = 0
if y < 0: y = 0
width1 = width
width2 = width
if x + width > image_width: width1 = image_width - x
if y + width > image_height: width2 = image_height - y
width = min(width1, width2)
# the max hand box value is 20 pixels
if width >= 20:
detect_result.append([int(x), int(y), int(width), is_left])
'''
return value: [[x, y, w, True if left hand else False]].
width=height since the network require squared input.
x, y is the coordinate of top left
'''
return detect_result
# Written by Lvmin
def faceDetect(candidate, subset, oriImg):
# left right eye ear 14 15 16 17
detect_result = []
image_height, image_width = oriImg.shape[0:2]
for person in subset.astype(int):
has_head = person[0] > -1
if not has_head:
continue
has_left_eye = person[14] > -1
has_right_eye = person[15] > -1
has_left_ear = person[16] > -1
has_right_ear = person[17] > -1
if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear):
continue
head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]]
width = 0.0
x0, y0 = candidate[head][:2]
if has_left_eye:
x1, y1 = candidate[left_eye][:2]
d = max(abs(x0 - x1), abs(y0 - y1))
width = max(width, d * 3.0)
if has_right_eye:
x1, y1 = candidate[right_eye][:2]
d = max(abs(x0 - x1), abs(y0 - y1))
width = max(width, d * 3.0)
if has_left_ear:
x1, y1 = candidate[left_ear][:2]
d = max(abs(x0 - x1), abs(y0 - y1))
width = max(width, d * 1.5)
if has_right_ear:
x1, y1 = candidate[right_ear][:2]
d = max(abs(x0 - x1), abs(y0 - y1))
width = max(width, d * 1.5)
x, y = x0, y0
x -= width
y -= width
if x < 0:
x = 0
if y < 0:
y = 0
width1 = width * 2
width2 = width * 2
if x + width > image_width:
width1 = image_width - x
if y + width > image_height:
width2 = image_height - y
width = min(width1, width2)
if width >= 20:
detect_result.append([int(x), int(y), int(width)])
return detect_result
# get max index of 2d array
def npmax(array):
arrayindex = array.argmax(1)
arrayvalue = array.max(1)
i = arrayvalue.argmax()
j = arrayindex[i]
return i, j

View File

@ -0,0 +1,80 @@
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import cv2
import numpy as np
import onnxruntime as ort
from .onnxdet import inference_detector
from .onnxpose import inference_pose
def HWC3(x):
assert x.dtype == np.uint8
if x.ndim == 2:
x = x[:, :, None]
assert x.ndim == 3
H, W, C = x.shape
assert C == 1 or C == 3 or C == 4
if C == 3:
return x
if C == 1:
return np.concatenate([x, x, x], axis=2)
if C == 4:
color = x[:, :, 0:3].astype(np.float32)
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
y = color * alpha + 255.0 * (1.0 - alpha)
y = y.clip(0, 255).astype(np.uint8)
return y
def resize_image(input_image, resolution):
H, W, C = input_image.shape
H = float(H)
W = float(W)
k = float(resolution) / min(H, W)
H *= k
W *= k
H = int(np.round(H / 64.0)) * 64
W = int(np.round(W / 64.0)) * 64
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
return img
class Wholebody:
def __init__(self, onnx_det, onnx_pose, device = 'cuda:0'):
providers = ['CPUExecutionProvider'
] if device == 'cpu' else ['CUDAExecutionProvider']
# onnx_det = 'annotator/ckpts/yolox_l.onnx'
# onnx_pose = 'annotator/ckpts/dw-ll_ucoco_384.onnx'
self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
def __call__(self, ori_img):
det_result = inference_detector(self.session_det, ori_img)
keypoints, scores = inference_pose(self.session_pose, det_result, ori_img)
keypoints_info = np.concatenate(
(keypoints, scores[..., None]), axis=-1)
# compute neck joint
neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
# neck score when visualizing pred
neck[:, 2:4] = np.logical_and(
keypoints_info[:, 5, 2:4] > 0.3,
keypoints_info[:, 6, 2:4] > 0.3).astype(int)
new_keypoints_info = np.insert(
keypoints_info, 17, neck, axis=1)
mmpose_idx = [
17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3
]
openpose_idx = [
1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17
]
new_keypoints_info[:, openpose_idx] = \
new_keypoints_info[:, mmpose_idx]
keypoints_info = new_keypoints_info
keypoints, scores = keypoints_info[
..., :2], keypoints_info[..., 2]
return keypoints, scores, det_result

35
preprocessing/gray.py Normal file
View File

@ -0,0 +1,35 @@
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import cv2
import numpy as np
from PIL import Image
import torch
def convert_to_numpy(image):
if isinstance(image, Image.Image):
image = np.array(image)
elif isinstance(image, torch.Tensor):
image = image.detach().cpu().numpy()
elif isinstance(image, np.ndarray):
image = image.copy()
else:
raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.'
return image
class GrayAnnotator:
def __init__(self, cfg):
pass
def forward(self, image):
image = convert_to_numpy(image)
gray_map = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
return gray_map[..., None].repeat(3, axis=2)
class GrayVideoAnnotator(GrayAnnotator):
def forward(self, frames):
ret_frames = []
for frame in frames:
anno_frame = super().forward(np.array(frame))
ret_frames.append(anno_frame)
return ret_frames

View File

@ -0,0 +1,2 @@
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.

166
preprocessing/midas/api.py Normal file
View File

@ -0,0 +1,166 @@
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
# based on https://github.com/isl-org/MiDaS
import cv2
import torch
import torch.nn as nn
from torchvision.transforms import Compose
from .dpt_depth import DPTDepthModel
from .midas_net import MidasNet
from .midas_net_custom import MidasNet_small
from .transforms import NormalizeImage, PrepareForNet, Resize
# ISL_PATHS = {
# "dpt_large": "dpt_large-midas-2f21e586.pt",
# "dpt_hybrid": "dpt_hybrid-midas-501f0c75.pt",
# "midas_v21": "",
# "midas_v21_small": "",
# }
# remote_model_path =
# "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt"
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
def load_midas_transform(model_type):
# https://github.com/isl-org/MiDaS/blob/master/run.py
# load transform only
if model_type == 'dpt_large': # DPT-Large
net_w, net_h = 384, 384
resize_mode = 'minimal'
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
elif model_type == 'dpt_hybrid': # DPT-Hybrid
net_w, net_h = 384, 384
resize_mode = 'minimal'
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
elif model_type == 'midas_v21':
net_w, net_h = 384, 384
resize_mode = 'upper_bound'
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
elif model_type == 'midas_v21_small':
net_w, net_h = 256, 256
resize_mode = 'upper_bound'
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
else:
assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
transform = Compose([
Resize(
net_w,
net_h,
resize_target=None,
keep_aspect_ratio=True,
ensure_multiple_of=32,
resize_method=resize_mode,
image_interpolation_method=cv2.INTER_CUBIC,
),
normalization,
PrepareForNet(),
])
return transform
def load_model(model_type, model_path):
# https://github.com/isl-org/MiDaS/blob/master/run.py
# load network
# model_path = ISL_PATHS[model_type]
if model_type == 'dpt_large': # DPT-Large
model = DPTDepthModel(
path=model_path,
backbone='vitl16_384',
non_negative=True,
)
net_w, net_h = 384, 384
resize_mode = 'minimal'
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
elif model_type == 'dpt_hybrid': # DPT-Hybrid
model = DPTDepthModel(
path=model_path,
backbone='vitb_rn50_384',
non_negative=True,
)
net_w, net_h = 384, 384
resize_mode = 'minimal'
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
elif model_type == 'midas_v21':
model = MidasNet(model_path, non_negative=True)
net_w, net_h = 384, 384
resize_mode = 'upper_bound'
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
elif model_type == 'midas_v21_small':
model = MidasNet_small(model_path,
features=64,
backbone='efficientnet_lite3',
exportable=True,
non_negative=True,
blocks={'expand': True})
net_w, net_h = 256, 256
resize_mode = 'upper_bound'
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
else:
print(
f"model_type '{model_type}' not implemented, use: --model_type large"
)
assert False
transform = Compose([
Resize(
net_w,
net_h,
resize_target=None,
keep_aspect_ratio=True,
ensure_multiple_of=32,
resize_method=resize_mode,
image_interpolation_method=cv2.INTER_CUBIC,
),
normalization,
PrepareForNet(),
])
return model.eval(), transform
class MiDaSInference(nn.Module):
MODEL_TYPES_TORCH_HUB = ['DPT_Large', 'DPT_Hybrid', 'MiDaS_small']
MODEL_TYPES_ISL = [
'dpt_large',
'dpt_hybrid',
'midas_v21',
'midas_v21_small',
]
def __init__(self, model_type, model_path):
super().__init__()
assert (model_type in self.MODEL_TYPES_ISL)
model, _ = load_model(model_type, model_path)
self.model = model
self.model.train = disabled_train
def forward(self, x):
with torch.no_grad():
prediction = self.model(x)
return prediction

View File

@ -0,0 +1,18 @@
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
class BaseModel(torch.nn.Module):
def load(self, path):
"""Load model from file.
Args:
path (str): file path
"""
parameters = torch.load(path, map_location=torch.device('cpu'), weights_only=True)
if 'optimizer' in parameters:
parameters = parameters['model']
self.load_state_dict(parameters)

View File

@ -0,0 +1,391 @@
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torch.nn as nn
from .vit import (_make_pretrained_vitb16_384, _make_pretrained_vitb_rn50_384,
_make_pretrained_vitl16_384)
def _make_encoder(
backbone,
features,
use_pretrained,
groups=1,
expand=False,
exportable=True,
hooks=None,
use_vit_only=False,
use_readout='ignore',
):
if backbone == 'vitl16_384':
pretrained = _make_pretrained_vitl16_384(use_pretrained,
hooks=hooks,
use_readout=use_readout)
scratch = _make_scratch(
[256, 512, 1024, 1024], features, groups=groups,
expand=expand) # ViT-L/16 - 85.0% Top1 (backbone)
elif backbone == 'vitb_rn50_384':
pretrained = _make_pretrained_vitb_rn50_384(
use_pretrained,
hooks=hooks,
use_vit_only=use_vit_only,
use_readout=use_readout,
)
scratch = _make_scratch(
[256, 512, 768, 768], features, groups=groups,
expand=expand) # ViT-H/16 - 85.0% Top1 (backbone)
elif backbone == 'vitb16_384':
pretrained = _make_pretrained_vitb16_384(use_pretrained,
hooks=hooks,
use_readout=use_readout)
scratch = _make_scratch(
[96, 192, 384, 768], features, groups=groups,
expand=expand) # ViT-B/16 - 84.6% Top1 (backbone)
elif backbone == 'resnext101_wsl':
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
scratch = _make_scratch([256, 512, 1024, 2048],
features,
groups=groups,
expand=expand) # efficientnet_lite3
elif backbone == 'efficientnet_lite3':
pretrained = _make_pretrained_efficientnet_lite3(use_pretrained,
exportable=exportable)
scratch = _make_scratch([32, 48, 136, 384],
features,
groups=groups,
expand=expand) # efficientnet_lite3
else:
print(f"Backbone '{backbone}' not implemented")
assert False
return pretrained, scratch
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
scratch = nn.Module()
out_shape1 = out_shape
out_shape2 = out_shape
out_shape3 = out_shape
out_shape4 = out_shape
if expand is True:
out_shape1 = out_shape
out_shape2 = out_shape * 2
out_shape3 = out_shape * 4
out_shape4 = out_shape * 8
scratch.layer1_rn = nn.Conv2d(in_shape[0],
out_shape1,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=groups)
scratch.layer2_rn = nn.Conv2d(in_shape[1],
out_shape2,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=groups)
scratch.layer3_rn = nn.Conv2d(in_shape[2],
out_shape3,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=groups)
scratch.layer4_rn = nn.Conv2d(in_shape[3],
out_shape4,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=groups)
return scratch
def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
efficientnet = torch.hub.load('rwightman/gen-efficientnet-pytorch',
'tf_efficientnet_lite3',
pretrained=use_pretrained,
exportable=exportable)
return _make_efficientnet_backbone(efficientnet)
def _make_efficientnet_backbone(effnet):
pretrained = nn.Module()
pretrained.layer1 = nn.Sequential(effnet.conv_stem, effnet.bn1,
effnet.act1, *effnet.blocks[0:2])
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
return pretrained
def _make_resnet_backbone(resnet):
pretrained = nn.Module()
pretrained.layer1 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu,
resnet.maxpool, resnet.layer1)
pretrained.layer2 = resnet.layer2
pretrained.layer3 = resnet.layer3
pretrained.layer4 = resnet.layer4
return pretrained
def _make_pretrained_resnext101_wsl(use_pretrained):
resnet = torch.hub.load('facebookresearch/WSL-Images',
'resnext101_32x8d_wsl')
return _make_resnet_backbone(resnet)
class Interpolate(nn.Module):
"""Interpolation module.
"""
def __init__(self, scale_factor, mode, align_corners=False):
"""Init.
Args:
scale_factor (float): scaling
mode (str): interpolation mode
"""
super(Interpolate, self).__init__()
self.interp = nn.functional.interpolate
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: interpolated data
"""
x = self.interp(x,
scale_factor=self.scale_factor,
mode=self.mode,
align_corners=self.align_corners)
return x
class ResidualConvUnit(nn.Module):
"""Residual convolution module.
"""
def __init__(self, features):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.conv1 = nn.Conv2d(features,
features,
kernel_size=3,
stride=1,
padding=1,
bias=True)
self.conv2 = nn.Conv2d(features,
features,
kernel_size=3,
stride=1,
padding=1,
bias=True)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.relu(x)
out = self.conv1(out)
out = self.relu(out)
out = self.conv2(out)
return out + x
class FeatureFusionBlock(nn.Module):
"""Feature fusion block.
"""
def __init__(self, features):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock, self).__init__()
self.resConfUnit1 = ResidualConvUnit(features)
self.resConfUnit2 = ResidualConvUnit(features)
def forward(self, *xs):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
output += self.resConfUnit1(xs[1])
output = self.resConfUnit2(output)
output = nn.functional.interpolate(output,
scale_factor=2,
mode='bilinear',
align_corners=True)
return output
class ResidualConvUnit_custom(nn.Module):
"""Residual convolution module.
"""
def __init__(self, features, activation, bn):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.bn = bn
self.groups = 1
self.conv1 = nn.Conv2d(features,
features,
kernel_size=3,
stride=1,
padding=1,
bias=True,
groups=self.groups)
self.conv2 = nn.Conv2d(features,
features,
kernel_size=3,
stride=1,
padding=1,
bias=True,
groups=self.groups)
if self.bn is True:
self.bn1 = nn.BatchNorm2d(features)
self.bn2 = nn.BatchNorm2d(features)
self.activation = activation
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.activation(x)
out = self.conv1(out)
if self.bn is True:
out = self.bn1(out)
out = self.activation(out)
out = self.conv2(out)
if self.bn is True:
out = self.bn2(out)
if self.groups > 1:
out = self.conv_merge(out)
return self.skip_add.add(out, x)
# return out + x
class FeatureFusionBlock_custom(nn.Module):
"""Feature fusion block.
"""
def __init__(self,
features,
activation,
deconv=False,
bn=False,
expand=False,
align_corners=True):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock_custom, self).__init__()
self.deconv = deconv
self.align_corners = align_corners
self.groups = 1
self.expand = expand
out_features = features
if self.expand is True:
out_features = features // 2
self.out_conv = nn.Conv2d(features,
out_features,
kernel_size=1,
stride=1,
padding=0,
bias=True,
groups=1)
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, *xs):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
res = self.resConfUnit1(xs[1])
output = self.skip_add.add(output, res)
# output += res
output = self.resConfUnit2(output)
output = nn.functional.interpolate(output,
scale_factor=2,
mode='bilinear',
align_corners=self.align_corners)
output = self.out_conv(output)
return output

View File

@ -0,0 +1,84 @@
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
import torch
from einops import rearrange
from PIL import Image
import cv2
def convert_to_numpy(image):
if isinstance(image, Image.Image):
image = np.array(image)
elif isinstance(image, torch.Tensor):
image = image.detach().cpu().numpy()
elif isinstance(image, np.ndarray):
image = image.copy()
else:
raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.'
return image
def resize_image(input_image, resolution):
H, W, C = input_image.shape
H = float(H)
W = float(W)
k = float(resolution) / min(H, W)
H *= k
W *= k
H = int(np.round(H / 64.0)) * 64
W = int(np.round(W / 64.0)) * 64
img = cv2.resize(
input_image, (W, H),
interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
return img, k
def resize_image_ori(h, w, image, k):
img = cv2.resize(
image, (w, h),
interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
return img
class DepthAnnotator:
def __init__(self, cfg, device=None):
from .api import MiDaSInference
pretrained_model = cfg['PRETRAINED_MODEL']
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
self.model = MiDaSInference(model_type='dpt_hybrid', model_path=pretrained_model).to(self.device)
self.a = cfg.get('A', np.pi * 2.0)
self.bg_th = cfg.get('BG_TH', 0.1)
@torch.no_grad()
@torch.inference_mode()
@torch.autocast('cuda', enabled=False)
def forward(self, image):
image = convert_to_numpy(image)
image_depth = image
h, w, c = image.shape
image_depth, k = resize_image(image_depth,
1024 if min(h, w) > 1024 else min(h, w))
image_depth = torch.from_numpy(image_depth).float().to(self.device)
image_depth = image_depth / 127.5 - 1.0
image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
depth = self.model(image_depth)[0]
depth_pt = depth.clone()
depth_pt -= torch.min(depth_pt)
depth_pt /= torch.max(depth_pt)
depth_pt = depth_pt.cpu().numpy()
depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
depth_image = depth_image[..., None].repeat(3, 2)
depth_image = resize_image_ori(h, w, depth_image, k)
return depth_image
class DepthVideoAnnotator(DepthAnnotator):
def forward(self, frames):
ret_frames = []
for frame in frames:
anno_frame = super().forward(np.array(frame))
ret_frames.append(anno_frame)
return ret_frames

View File

@ -0,0 +1,107 @@
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torch.nn as nn
from .base_model import BaseModel
from .blocks import FeatureFusionBlock_custom, Interpolate, _make_encoder
from .vit import forward_vit
def _make_fusion_block(features, use_bn):
return FeatureFusionBlock_custom(
features,
nn.ReLU(False),
deconv=False,
bn=use_bn,
expand=False,
align_corners=True,
)
class DPT(BaseModel):
def __init__(
self,
head,
features=256,
backbone='vitb_rn50_384',
readout='project',
channels_last=False,
use_bn=False,
):
super(DPT, self).__init__()
self.channels_last = channels_last
hooks = {
'vitb_rn50_384': [0, 1, 8, 11],
'vitb16_384': [2, 5, 8, 11],
'vitl16_384': [5, 11, 17, 23],
}
# Instantiate backbone and reassemble blocks
self.pretrained, self.scratch = _make_encoder(
backbone,
features,
False, # Set to true of you want to train from scratch, uses ImageNet weights
groups=1,
expand=False,
exportable=False,
hooks=hooks[backbone],
use_readout=readout,
)
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
self.scratch.output_conv = head
def forward(self, x):
if self.channels_last is True:
x.contiguous(memory_format=torch.channels_last)
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn)
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
out = self.scratch.output_conv(path_1)
return out
class DPTDepthModel(DPT):
def __init__(self, path=None, non_negative=True, **kwargs):
features = kwargs['features'] if 'features' in kwargs else 256
head = nn.Sequential(
nn.Conv2d(features,
features // 2,
kernel_size=3,
stride=1,
padding=1),
Interpolate(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
nn.ReLU(True) if non_negative else nn.Identity(),
nn.Identity(),
)
super().__init__(head, **kwargs)
if path is not None:
self.load(path)
def forward(self, x):
return super().forward(x).squeeze(dim=1)

View File

@ -0,0 +1,80 @@
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
This file contains code that is adapted from
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
"""
import torch
import torch.nn as nn
from .base_model import BaseModel
from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
class MidasNet(BaseModel):
"""Network for monocular depth estimation.
"""
def __init__(self, path=None, features=256, non_negative=True):
"""Init.
Args:
path (str, optional): Path to saved model. Defaults to None.
features (int, optional): Number of features. Defaults to 256.
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
"""
print('Loading weights: ', path)
super(MidasNet, self).__init__()
use_pretrained = False if path is None else True
self.pretrained, self.scratch = _make_encoder(
backbone='resnext101_wsl',
features=features,
use_pretrained=use_pretrained)
self.scratch.refinenet4 = FeatureFusionBlock(features)
self.scratch.refinenet3 = FeatureFusionBlock(features)
self.scratch.refinenet2 = FeatureFusionBlock(features)
self.scratch.refinenet1 = FeatureFusionBlock(features)
self.scratch.output_conv = nn.Sequential(
nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
Interpolate(scale_factor=2, mode='bilinear'),
nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
nn.ReLU(True) if non_negative else nn.Identity(),
)
if path:
self.load(path)
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input data (image)
Returns:
tensor: depth
"""
layer_1 = self.pretrained.layer1(x)
layer_2 = self.pretrained.layer2(layer_1)
layer_3 = self.pretrained.layer3(layer_2)
layer_4 = self.pretrained.layer4(layer_3)
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn)
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
out = self.scratch.output_conv(path_1)
return torch.squeeze(out, dim=1)

View File

@ -0,0 +1,167 @@
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
This file contains code that is adapted from
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
"""
import torch
import torch.nn as nn
from .base_model import BaseModel
from .blocks import FeatureFusionBlock_custom, Interpolate, _make_encoder
class MidasNet_small(BaseModel):
"""Network for monocular depth estimation.
"""
def __init__(self,
path=None,
features=64,
backbone='efficientnet_lite3',
non_negative=True,
exportable=True,
channels_last=False,
align_corners=True,
blocks={'expand': True}):
"""Init.
Args:
path (str, optional): Path to saved model. Defaults to None.
features (int, optional): Number of features. Defaults to 256.
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
"""
print('Loading weights: ', path)
super(MidasNet_small, self).__init__()
use_pretrained = False if path else True
self.channels_last = channels_last
self.blocks = blocks
self.backbone = backbone
self.groups = 1
features1 = features
features2 = features
features3 = features
features4 = features
self.expand = False
if 'expand' in self.blocks and self.blocks['expand'] is True:
self.expand = True
features1 = features
features2 = features * 2
features3 = features * 4
features4 = features * 8
self.pretrained, self.scratch = _make_encoder(self.backbone,
features,
use_pretrained,
groups=self.groups,
expand=self.expand,
exportable=exportable)
self.scratch.activation = nn.ReLU(False)
self.scratch.refinenet4 = FeatureFusionBlock_custom(
features4,
self.scratch.activation,
deconv=False,
bn=False,
expand=self.expand,
align_corners=align_corners)
self.scratch.refinenet3 = FeatureFusionBlock_custom(
features3,
self.scratch.activation,
deconv=False,
bn=False,
expand=self.expand,
align_corners=align_corners)
self.scratch.refinenet2 = FeatureFusionBlock_custom(
features2,
self.scratch.activation,
deconv=False,
bn=False,
expand=self.expand,
align_corners=align_corners)
self.scratch.refinenet1 = FeatureFusionBlock_custom(
features1,
self.scratch.activation,
deconv=False,
bn=False,
align_corners=align_corners)
self.scratch.output_conv = nn.Sequential(
nn.Conv2d(features,
features // 2,
kernel_size=3,
stride=1,
padding=1,
groups=self.groups),
Interpolate(scale_factor=2, mode='bilinear'),
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
self.scratch.activation,
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
nn.ReLU(True) if non_negative else nn.Identity(),
nn.Identity(),
)
if path:
self.load(path)
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input data (image)
Returns:
tensor: depth
"""
if self.channels_last is True:
print('self.channels_last = ', self.channels_last)
x.contiguous(memory_format=torch.channels_last)
layer_1 = self.pretrained.layer1(x)
layer_2 = self.pretrained.layer2(layer_1)
layer_3 = self.pretrained.layer3(layer_2)
layer_4 = self.pretrained.layer4(layer_3)
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn)
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
out = self.scratch.output_conv(path_1)
return torch.squeeze(out, dim=1)
def fuse_model(m):
prev_previous_type = nn.Identity()
prev_previous_name = ''
previous_type = nn.Identity()
previous_name = ''
for name, module in m.named_modules():
if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(
module) == nn.ReLU:
# print("FUSED ", prev_previous_name, previous_name, name)
torch.quantization.fuse_modules(
m, [prev_previous_name, previous_name, name], inplace=True)
elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
# print("FUSED ", prev_previous_name, previous_name)
torch.quantization.fuse_modules(
m, [prev_previous_name, previous_name], inplace=True)
# elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
# print("FUSED ", previous_name, name)
# torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
prev_previous_type = previous_type
prev_previous_name = previous_name
previous_type = type(module)
previous_name = name

View File

@ -0,0 +1,231 @@
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import cv2
import numpy as np
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
Args:
sample (dict): sample
size (tuple): image size
Returns:
tuple: new size
"""
shape = list(sample['disparity'].shape)
if shape[0] >= size[0] and shape[1] >= size[1]:
return sample
scale = [0, 0]
scale[0] = size[0] / shape[0]
scale[1] = size[1] / shape[1]
scale = max(scale)
shape[0] = math.ceil(scale * shape[0])
shape[1] = math.ceil(scale * shape[1])
# resize
sample['image'] = cv2.resize(sample['image'],
tuple(shape[::-1]),
interpolation=image_interpolation_method)
sample['disparity'] = cv2.resize(sample['disparity'],
tuple(shape[::-1]),
interpolation=cv2.INTER_NEAREST)
sample['mask'] = cv2.resize(
sample['mask'].astype(np.float32),
tuple(shape[::-1]),
interpolation=cv2.INTER_NEAREST,
)
sample['mask'] = sample['mask'].astype(bool)
return tuple(shape)
class Resize(object):
"""Resize sample to given size (width, height).
"""
def __init__(
self,
width,
height,
resize_target=True,
keep_aspect_ratio=False,
ensure_multiple_of=1,
resize_method='lower_bound',
image_interpolation_method=cv2.INTER_AREA,
):
"""Init.
Args:
width (int): desired output width
height (int): desired output height
resize_target (bool, optional):
True: Resize the full sample (image, mask, target).
False: Resize image only.
Defaults to True.
keep_aspect_ratio (bool, optional):
True: Keep the aspect ratio of the input sample.
Output sample might not have the given width and height, and
resize behaviour depends on the parameter 'resize_method'.
Defaults to False.
ensure_multiple_of (int, optional):
Output width and height is constrained to be multiple of this parameter.
Defaults to 1.
resize_method (str, optional):
"lower_bound": Output will be at least as large as the given size.
"upper_bound": Output will be at max as large as the given size. "
"(Output size might be smaller than given size.)"
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
Defaults to "lower_bound".
"""
self.__width = width
self.__height = height
self.__resize_target = resize_target
self.__keep_aspect_ratio = keep_aspect_ratio
self.__multiple_of = ensure_multiple_of
self.__resize_method = resize_method
self.__image_interpolation_method = image_interpolation_method
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
if max_val is not None and y > max_val:
y = (np.floor(x / self.__multiple_of) *
self.__multiple_of).astype(int)
if y < min_val:
y = (np.ceil(x / self.__multiple_of) *
self.__multiple_of).astype(int)
return y
def get_size(self, width, height):
# determine new height and width
scale_height = self.__height / height
scale_width = self.__width / width
if self.__keep_aspect_ratio:
if self.__resize_method == 'lower_bound':
# scale such that output size is lower bound
if scale_width > scale_height:
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
elif self.__resize_method == 'upper_bound':
# scale such that output size is upper bound
if scale_width < scale_height:
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
elif self.__resize_method == 'minimal':
# scale as least as possbile
if abs(1 - scale_width) < abs(1 - scale_height):
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
else:
raise ValueError(
f'resize_method {self.__resize_method} not implemented')
if self.__resize_method == 'lower_bound':
new_height = self.constrain_to_multiple_of(scale_height * height,
min_val=self.__height)
new_width = self.constrain_to_multiple_of(scale_width * width,
min_val=self.__width)
elif self.__resize_method == 'upper_bound':
new_height = self.constrain_to_multiple_of(scale_height * height,
max_val=self.__height)
new_width = self.constrain_to_multiple_of(scale_width * width,
max_val=self.__width)
elif self.__resize_method == 'minimal':
new_height = self.constrain_to_multiple_of(scale_height * height)
new_width = self.constrain_to_multiple_of(scale_width * width)
else:
raise ValueError(
f'resize_method {self.__resize_method} not implemented')
return (new_width, new_height)
def __call__(self, sample):
width, height = self.get_size(sample['image'].shape[1],
sample['image'].shape[0])
# resize sample
sample['image'] = cv2.resize(
sample['image'],
(width, height),
interpolation=self.__image_interpolation_method,
)
if self.__resize_target:
if 'disparity' in sample:
sample['disparity'] = cv2.resize(
sample['disparity'],
(width, height),
interpolation=cv2.INTER_NEAREST,
)
if 'depth' in sample:
sample['depth'] = cv2.resize(sample['depth'], (width, height),
interpolation=cv2.INTER_NEAREST)
sample['mask'] = cv2.resize(
sample['mask'].astype(np.float32),
(width, height),
interpolation=cv2.INTER_NEAREST,
)
sample['mask'] = sample['mask'].astype(bool)
return sample
class NormalizeImage(object):
"""Normlize image by given mean and std.
"""
def __init__(self, mean, std):
self.__mean = mean
self.__std = std
def __call__(self, sample):
sample['image'] = (sample['image'] - self.__mean) / self.__std
return sample
class PrepareForNet(object):
"""Prepare sample for usage as network input.
"""
def __init__(self):
pass
def __call__(self, sample):
image = np.transpose(sample['image'], (2, 0, 1))
sample['image'] = np.ascontiguousarray(image).astype(np.float32)
if 'mask' in sample:
sample['mask'] = sample['mask'].astype(np.float32)
sample['mask'] = np.ascontiguousarray(sample['mask'])
if 'disparity' in sample:
disparity = sample['disparity'].astype(np.float32)
sample['disparity'] = np.ascontiguousarray(disparity)
if 'depth' in sample:
depth = sample['depth'].astype(np.float32)
sample['depth'] = np.ascontiguousarray(depth)
return sample

View File

@ -0,0 +1,193 @@
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
"""Utils for monoDepth."""
import re
import sys
import cv2
import numpy as np
import torch
def read_pfm(path):
"""Read pfm file.
Args:
path (str): path to file
Returns:
tuple: (data, scale)
"""
with open(path, 'rb') as file:
color = None
width = None
height = None
scale = None
endian = None
header = file.readline().rstrip()
if header.decode('ascii') == 'PF':
color = True
elif header.decode('ascii') == 'Pf':
color = False
else:
raise Exception('Not a PFM file: ' + path)
dim_match = re.match(r'^(\d+)\s(\d+)\s$',
file.readline().decode('ascii'))
if dim_match:
width, height = list(map(int, dim_match.groups()))
else:
raise Exception('Malformed PFM header.')
scale = float(file.readline().decode('ascii').rstrip())
if scale < 0:
# little-endian
endian = '<'
scale = -scale
else:
# big-endian
endian = '>'
data = np.fromfile(file, endian + 'f')
shape = (height, width, 3) if color else (height, width)
data = np.reshape(data, shape)
data = np.flipud(data)
return data, scale
def write_pfm(path, image, scale=1):
"""Write pfm file.
Args:
path (str): pathto file
image (array): data
scale (int, optional): Scale. Defaults to 1.
"""
with open(path, 'wb') as file:
color = None
if image.dtype.name != 'float32':
raise Exception('Image dtype must be float32.')
image = np.flipud(image)
if len(image.shape) == 3 and image.shape[2] == 3: # color image
color = True
elif (len(image.shape) == 2
or len(image.shape) == 3 and image.shape[2] == 1): # greyscale
color = False
else:
raise Exception(
'Image must have H x W x 3, H x W x 1 or H x W dimensions.')
file.write('PF\n' if color else 'Pf\n'.encode())
file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0]))
endian = image.dtype.byteorder
if endian == '<' or endian == '=' and sys.byteorder == 'little':
scale = -scale
file.write('%f\n'.encode() % scale)
image.tofile(file)
def read_image(path):
"""Read image and output RGB image (0-1).
Args:
path (str): path to file
Returns:
array: RGB image (0-1)
"""
img = cv2.imread(path)
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
return img
def resize_image(img):
"""Resize image and make it fit for network.
Args:
img (array): image
Returns:
tensor: data ready for network
"""
height_orig = img.shape[0]
width_orig = img.shape[1]
if width_orig > height_orig:
scale = width_orig / 384
else:
scale = height_orig / 384
height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
img_resized = cv2.resize(img, (width, height),
interpolation=cv2.INTER_AREA)
img_resized = (torch.from_numpy(np.transpose(
img_resized, (2, 0, 1))).contiguous().float())
img_resized = img_resized.unsqueeze(0)
return img_resized
def resize_depth(depth, width, height):
"""Resize depth map and bring to CPU (numpy).
Args:
depth (tensor): depth
width (int): image width
height (int): image height
Returns:
array: processed depth
"""
depth = torch.squeeze(depth[0, :, :, :]).to('cpu')
depth_resized = cv2.resize(depth.numpy(), (width, height),
interpolation=cv2.INTER_CUBIC)
return depth_resized
def write_depth(path, depth, bits=1):
"""Write depth map to pfm and png file.
Args:
path (str): filepath without extension
depth (array): depth
"""
write_pfm(path + '.pfm', depth.astype(np.float32))
depth_min = depth.min()
depth_max = depth.max()
max_val = (2**(8 * bits)) - 1
if depth_max - depth_min > np.finfo('float').eps:
out = max_val * (depth - depth_min) / (depth_max - depth_min)
else:
out = np.zeros(depth.shape, dtype=depth.type)
if bits == 1:
cv2.imwrite(path + '.png', out.astype('uint8'))
elif bits == 2:
cv2.imwrite(path + '.png', out.astype('uint16'))
return

510
preprocessing/midas/vit.py Normal file
View File

@ -0,0 +1,510 @@
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import types
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
class Slice(nn.Module):
def __init__(self, start_index=1):
super(Slice, self).__init__()
self.start_index = start_index
def forward(self, x):
return x[:, self.start_index:]
class AddReadout(nn.Module):
def __init__(self, start_index=1):
super(AddReadout, self).__init__()
self.start_index = start_index
def forward(self, x):
if self.start_index == 2:
readout = (x[:, 0] + x[:, 1]) / 2
else:
readout = x[:, 0]
return x[:, self.start_index:] + readout.unsqueeze(1)
class ProjectReadout(nn.Module):
def __init__(self, in_features, start_index=1):
super(ProjectReadout, self).__init__()
self.start_index = start_index
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features),
nn.GELU())
def forward(self, x):
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index:])
features = torch.cat((x[:, self.start_index:], readout), -1)
return self.project(features)
class Transpose(nn.Module):
def __init__(self, dim0, dim1):
super(Transpose, self).__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x):
x = x.transpose(self.dim0, self.dim1)
return x
def forward_vit(pretrained, x):
b, c, h, w = x.shape
_ = pretrained.model.forward_flex(x)
layer_1 = pretrained.activations['1']
layer_2 = pretrained.activations['2']
layer_3 = pretrained.activations['3']
layer_4 = pretrained.activations['4']
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
unflatten = nn.Sequential(
nn.Unflatten(
2,
torch.Size([
h // pretrained.model.patch_size[1],
w // pretrained.model.patch_size[0],
]),
))
if layer_1.ndim == 3:
layer_1 = unflatten(layer_1)
if layer_2.ndim == 3:
layer_2 = unflatten(layer_2)
if layer_3.ndim == 3:
layer_3 = unflatten(layer_3)
if layer_4.ndim == 3:
layer_4 = unflatten(layer_4)
layer_1 = pretrained.act_postprocess1[3:len(pretrained.act_postprocess1)](
layer_1)
layer_2 = pretrained.act_postprocess2[3:len(pretrained.act_postprocess2)](
layer_2)
layer_3 = pretrained.act_postprocess3[3:len(pretrained.act_postprocess3)](
layer_3)
layer_4 = pretrained.act_postprocess4[3:len(pretrained.act_postprocess4)](
layer_4)
return layer_1, layer_2, layer_3, layer_4
def _resize_pos_embed(self, posemb, gs_h, gs_w):
posemb_tok, posemb_grid = (
posemb[:, :self.start_index],
posemb[0, self.start_index:],
)
gs_old = int(math.sqrt(len(posemb_grid)))
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old,
-1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(posemb_grid,
size=(gs_h, gs_w),
mode='bilinear')
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
return posemb
def forward_flex(self, x):
b, c, h, w = x.shape
pos_embed = self._resize_pos_embed(self.pos_embed, h // self.patch_size[1],
w // self.patch_size[0])
B = x.shape[0]
if hasattr(self.patch_embed, 'backbone'):
x = self.patch_embed.backbone(x)
if isinstance(x, (list, tuple)):
x = x[
-1] # last feature if backbone outputs list/tuple of features
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
if getattr(self, 'dist_token', None) is not None:
cls_tokens = self.cls_token.expand(
B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
dist_token = self.dist_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, dist_token, x), dim=1)
else:
cls_tokens = self.cls_token.expand(
B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x
activations = {}
def get_activation(name):
def hook(model, input, output):
activations[name] = output
return hook
def get_readout_oper(vit_features, features, use_readout, start_index=1):
if use_readout == 'ignore':
readout_oper = [Slice(start_index)] * len(features)
elif use_readout == 'add':
readout_oper = [AddReadout(start_index)] * len(features)
elif use_readout == 'project':
readout_oper = [
ProjectReadout(vit_features, start_index) for out_feat in features
]
else:
assert (
False
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
return readout_oper
def _make_vit_b16_backbone(
model,
features=[96, 192, 384, 768],
size=[384, 384],
hooks=[2, 5, 8, 11],
vit_features=768,
use_readout='ignore',
start_index=1,
):
pretrained = nn.Module()
pretrained.model = model
pretrained.model.blocks[hooks[0]].register_forward_hook(
get_activation('1'))
pretrained.model.blocks[hooks[1]].register_forward_hook(
get_activation('2'))
pretrained.model.blocks[hooks[2]].register_forward_hook(
get_activation('3'))
pretrained.model.blocks[hooks[3]].register_forward_hook(
get_activation('4'))
pretrained.activations = activations
readout_oper = get_readout_oper(vit_features, features, use_readout,
start_index)
# 32, 48, 136, 384
pretrained.act_postprocess1 = nn.Sequential(
readout_oper[0],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[0],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[0],
out_channels=features[0],
kernel_size=4,
stride=4,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
pretrained.act_postprocess2 = nn.Sequential(
readout_oper[1],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[1],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[1],
out_channels=features[1],
kernel_size=2,
stride=2,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
pretrained.act_postprocess3 = nn.Sequential(
readout_oper[2],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[2],
kernel_size=1,
stride=1,
padding=0,
),
)
pretrained.act_postprocess4 = nn.Sequential(
readout_oper[3],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[3],
kernel_size=1,
stride=1,
padding=0,
),
nn.Conv2d(
in_channels=features[3],
out_channels=features[3],
kernel_size=3,
stride=2,
padding=1,
),
)
pretrained.model.start_index = start_index
pretrained.model.patch_size = [16, 16]
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model.forward_flex = types.MethodType(forward_flex,
pretrained.model)
pretrained.model._resize_pos_embed = types.MethodType(
_resize_pos_embed, pretrained.model)
return pretrained
def _make_pretrained_vitl16_384(pretrained, use_readout='ignore', hooks=None):
model = timm.create_model('vit_large_patch16_384', pretrained=pretrained)
hooks = [5, 11, 17, 23] if hooks is None else hooks
return _make_vit_b16_backbone(
model,
features=[256, 512, 1024, 1024],
hooks=hooks,
vit_features=1024,
use_readout=use_readout,
)
def _make_pretrained_vitb16_384(pretrained, use_readout='ignore', hooks=None):
model = timm.create_model('vit_base_patch16_384', pretrained=pretrained)
hooks = [2, 5, 8, 11] if hooks is None else hooks
return _make_vit_b16_backbone(model,
features=[96, 192, 384, 768],
hooks=hooks,
use_readout=use_readout)
def _make_pretrained_deitb16_384(pretrained, use_readout='ignore', hooks=None):
model = timm.create_model('vit_deit_base_patch16_384',
pretrained=pretrained)
hooks = [2, 5, 8, 11] if hooks is None else hooks
return _make_vit_b16_backbone(model,
features=[96, 192, 384, 768],
hooks=hooks,
use_readout=use_readout)
def _make_pretrained_deitb16_distil_384(pretrained,
use_readout='ignore',
hooks=None):
model = timm.create_model('vit_deit_base_distilled_patch16_384',
pretrained=pretrained)
hooks = [2, 5, 8, 11] if hooks is None else hooks
return _make_vit_b16_backbone(
model,
features=[96, 192, 384, 768],
hooks=hooks,
use_readout=use_readout,
start_index=2,
)
def _make_vit_b_rn50_backbone(
model,
features=[256, 512, 768, 768],
size=[384, 384],
hooks=[0, 1, 8, 11],
vit_features=768,
use_vit_only=False,
use_readout='ignore',
start_index=1,
):
pretrained = nn.Module()
pretrained.model = model
if use_vit_only is True:
pretrained.model.blocks[hooks[0]].register_forward_hook(
get_activation('1'))
pretrained.model.blocks[hooks[1]].register_forward_hook(
get_activation('2'))
else:
pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
get_activation('1'))
pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
get_activation('2'))
pretrained.model.blocks[hooks[2]].register_forward_hook(
get_activation('3'))
pretrained.model.blocks[hooks[3]].register_forward_hook(
get_activation('4'))
pretrained.activations = activations
readout_oper = get_readout_oper(vit_features, features, use_readout,
start_index)
if use_vit_only is True:
pretrained.act_postprocess1 = nn.Sequential(
readout_oper[0],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[0],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[0],
out_channels=features[0],
kernel_size=4,
stride=4,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
pretrained.act_postprocess2 = nn.Sequential(
readout_oper[1],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[1],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[1],
out_channels=features[1],
kernel_size=2,
stride=2,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
else:
pretrained.act_postprocess1 = nn.Sequential(nn.Identity(),
nn.Identity(),
nn.Identity())
pretrained.act_postprocess2 = nn.Sequential(nn.Identity(),
nn.Identity(),
nn.Identity())
pretrained.act_postprocess3 = nn.Sequential(
readout_oper[2],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[2],
kernel_size=1,
stride=1,
padding=0,
),
)
pretrained.act_postprocess4 = nn.Sequential(
readout_oper[3],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[3],
kernel_size=1,
stride=1,
padding=0,
),
nn.Conv2d(
in_channels=features[3],
out_channels=features[3],
kernel_size=3,
stride=2,
padding=1,
),
)
pretrained.model.start_index = start_index
pretrained.model.patch_size = [16, 16]
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model.forward_flex = types.MethodType(forward_flex,
pretrained.model)
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model._resize_pos_embed = types.MethodType(
_resize_pos_embed, pretrained.model)
return pretrained
def _make_pretrained_vitb_rn50_384(pretrained,
use_readout='ignore',
hooks=None,
use_vit_only=False):
model = timm.create_model('vit_base_resnet50_384', pretrained=pretrained)
hooks = [0, 1, 8, 11] if hooks is None else hooks
return _make_vit_b_rn50_backbone(
model,
features=[256, 512, 768, 768],
size=[384, 384],
hooks=hooks,
use_vit_only=use_vit_only,
use_readout=use_readout,
)

View File

@ -207,18 +207,19 @@ class WanT2V:
def vace_latent(self, z, m):
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device, trim_video= 0):
def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device, original_video = False, trim_video= 0):
image_sizes = []
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
if sub_src_mask is not None and sub_src_video is not None:
src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video)
# src_video is [-1, 1], 0 = inpainting area (in fact 127 in [0, 255])
# src_mask is [-1, 1], 0 = preserve original video (in fact 127 in [0, 255]) and 1 = Inpainting (in fact 255 in [0, 255])
src_video[i] = src_video[i].to(device)
src_mask[i] = src_mask[i].to(device)
src_video_shape = src_video[i].shape
if src_video_shape[1] != num_frames:
src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
image_sizes.append(src_video[i].shape[2:])
elif sub_src_video is None:
@ -228,10 +229,11 @@ class WanT2V:
else:
src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video)
src_video[i] = src_video[i].to(device)
src_mask[i] = torch.zeros_like(src_video[i], device=device) if original_video else torch.ones_like(src_video[i], device=device)
src_video_shape = src_video[i].shape
if src_video_shape[1] != num_frames:
src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
src_mask[i] = torch.ones_like(src_video[i], device=device)
src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
image_sizes.append(src_video[i].shape[2:])
for i, ref_images in enumerate(src_ref_images):

View File

@ -21,6 +21,30 @@ __all__ = ['cache_video', 'cache_image', 'str2bool']
from PIL import Image
def resample(video_fps, video_frames_count, max_frames, target_fps):
import math
video_frame_duration = 1 /video_fps
target_frame_duration = 1 / target_fps
cur_time = 0
target_time = 0
frame_no = 0
frame_ids =[]
while True:
if max_frames != 0 and len(frame_ids) >= max_frames:
break
add_frames_count = math.ceil( (target_time -cur_time) / video_frame_duration )
frame_no += add_frames_count
frame_ids.append(frame_no)
cur_time += add_frames_count * video_frame_duration
target_time += target_frame_duration
if frame_no >= video_frames_count -1:
break
frame_ids = frame_ids[:video_frames_count]
return frame_ids
def get_video_frame(file_name, frame_no):
decord.bridge.set_bridge('torch')
reader = decord.VideoReader(file_name)

View File

@ -180,26 +180,17 @@ class VaceVideoProcessor(object):
), axis=1).tolist()
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w, crop_box, rng, max_frames= 0):
import math
target_fps = self.max_fps
video_frames_count = len(frame_timestamps)
video_frame_duration = 1 /fps
target_frame_duration = 1 / target_fps
cur_time = 0
target_time = 0
frame_no = 0
frame_ids =[]
for i in range(max_frames):
add_frames_count = math.ceil( (target_time -cur_time) / video_frame_duration )
frame_no += add_frames_count
frame_ids.append(frame_no)
cur_time += add_frames_count * video_frame_duration
target_time += target_frame_duration
if frame_no >= video_frames_count -1:
break
frame_ids = frame_ids[:video_frames_count]
def _get_frameid_bbox_adjust_last(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0):
from wan.utils.utils import resample
target_fps = self.max_fps
# video_frames_count = len(frame_timestamps)
frame_ids= resample(fps, video_frames_count, max_frames, target_fps)
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
h, w = y2 - y1, x2 - x1
ratio = h / w
@ -235,11 +226,11 @@ class VaceVideoProcessor(object):
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng, max_frames= 0):
def _get_frameid_bbox(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0):
if self.keep_last:
return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h, w, crop_box, rng, max_frames= max_frames)
return self._get_frameid_bbox_adjust_last(fps, video_frames_count, h, w, crop_box, rng, max_frames= max_frames)
else:
return self._get_frameid_bbox_default(fps, frame_timestamps, h, w, crop_box, rng, max_frames= max_frames)
return self._get_frameid_bbox_default(fps, video_frames_count, h, w, crop_box, rng, max_frames= max_frames)
def load_video(self, data_key, crop_box=None, seed=2024, **kwargs):
return self.load_video_batch(data_key, crop_box=crop_box, seed=seed, **kwargs)
@ -253,23 +244,37 @@ class VaceVideoProcessor(object):
import decord
decord.bridge.set_bridge('torch')
readers = []
src_video = None
for data_k in data_key_batch:
reader = decord.VideoReader(data_k)
readers.append(reader)
if torch.is_tensor(data_k):
src_video = data_k
else:
reader = decord.VideoReader(data_k)
readers.append(reader)
fps = readers[0].get_avg_fps()
length = min([len(r) for r in readers])
frame_timestamps = [readers[0].get_frame_timestamp(i) for i in range(length)]
frame_timestamps = np.array(frame_timestamps, dtype=np.float32)
# # frame_timestamps = frame_timestamps[ :max_frames]
# if trim_video > 0:
# frame_timestamps = frame_timestamps[ :trim_video]
if src_video != None:
fps = 16
length = src_video.shape[1]
if len(readers) > 0:
min_readers = min([len(r) for r in readers])
length = min(length, min_readers )
else:
fps = readers[0].get_avg_fps()
length = min([len(r) for r in readers])
# frame_timestamps = [readers[0].get_frame_timestamp(i) for i in range(length)]
# frame_timestamps = np.array(frame_timestamps, dtype=np.float32)
max_frames = min(max_frames, trim_video) if trim_video > 0 else max_frames
h, w = readers[0].next().shape[:2]
frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, frame_timestamps, h, w, crop_box, rng, max_frames=max_frames)
if src_video != None:
src_video = src_video[:max_frames]
h, w = src_video.shape[1:3]
else:
h, w = readers[0].next().shape[:2]
frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, length, h, w, crop_box, rng, max_frames=max_frames)
# preprocess video
videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers]
if src_video != None:
videos = [src_video] + videos
videos = [self._video_preprocess(video, oh, ow) for video in videos]
return *videos, frame_ids, (oh, ow), fps
# return videos if len(videos) > 1 else videos[0]

263
wgp.py
View File

@ -141,12 +141,27 @@ def process_prompt_and_add_tasks(state, model_choice):
res = VACE_SIZE_CONFIGS.keys().join(" and ")
gr.Info(f"Video Resolution for Vace model is not supported. Only {res} resolutions are allowed.")
return
if not "I" in video_prompt_type:
if "I" in video_prompt_type:
if image_refs == None:
gr.Info("You must provide at one Refererence Image")
return
else:
image_refs = None
if not "V" in video_prompt_type:
if "V" in video_prompt_type:
if video_guide == None:
gr.Info("You must provide a Control Video")
return
else:
video_guide = None
if not "M" in video_prompt_type:
if "M" in video_prompt_type:
if video_mask == None:
gr.Info("You must provide a Video Mask ")
return
else:
video_mask = None
if "O" in video_prompt_type and inputs["max_frames"]==0:
gr.Info(f"In order to extend a video, you need to indicate how many frames you want to reuse in the source video.")
return
if isinstance(image_refs, list):
image_refs = [ convert_image(tup[0]) for tup in image_refs ]
@ -260,7 +275,7 @@ def add_video_task(**inputs):
queue = gen["queue"]
task_id += 1
current_task_id = task_id
inputs_to_query = ["image_start", "image_end", "image_refs", "video_guide", "video_mask"]
inputs_to_query = ["image_start", "image_end", "video_guide", "image_refs","video_mask"]
start_image_data = None
end_image_data = None
for name in inputs_to_query:
@ -718,7 +733,7 @@ if not Path(server_config_filename).is_file():
"transformer_types": [],
"transformer_quantization": "int8",
"text_encoder_filename" : text_encoder_choices[1],
"save_path": os.path.join(os.getcwd(), "gradio_outputs"),
"save_path": "outputs", #os.path.join(os.getcwd(),
"compile" : "",
"metadata_type": "metadata",
"default_ui": "t2v",
@ -726,7 +741,7 @@ if not Path(server_config_filename).is_file():
"clear_file_list" : 0,
"vae_config": 0,
"profile" : profile_type.LowRAM_LowVRAM,
"reload_model": 2 }
"preload_model_policy": [] }
with open(server_config_filename, "w", encoding="utf-8") as writer:
writer.write(json.dumps(server_config))
@ -860,7 +875,7 @@ if len(args.vae_config) > 0:
reload_needed = False
default_ui = server_config.get("default_ui", "t2v")
save_path = server_config.get("save_path", os.path.join(os.getcwd(), "gradio_outputs"))
reload_model = server_config.get("reload_model", 2)
preload_model_policy = server_config.get("preload_model_policy", [])
if args.t2v_14B or args.t2v:
@ -962,8 +977,8 @@ def download_models(transformer_filename, text_encoder_filename):
from huggingface_hub import hf_hub_download, snapshot_download
repoId = "DeepBeepMeep/Wan2.1"
sourceFolderList = ["xlm-roberta-large", "", ]
fileList = [ [], ["Wan2.1_VAE_bf16.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "flownet.pkl" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ]
sourceFolderList = ["xlm-roberta-large", "pose", "depth", "", ]
fileList = [ [], [],[], ["Wan2.1_VAE_bf16.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "flownet.pkl" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ]
targetRoot = "ckpts/"
for sourceFolder, files in zip(sourceFolderList,fileList ):
if len(files)==0:
@ -1166,7 +1181,7 @@ def load_models(model_filename):
return wan_model, offloadobj, pipe["transformer"]
if reload_model ==3 or reload_model ==4:
if not "P" in preload_model_policy:
wan_model, offloadobj, transformer = None, None, None
reload_needed = True
else:
@ -1254,7 +1269,7 @@ def apply_changes( state,
quantization_choice,
boost_choice = 1,
clear_file_list = 0,
reload_choice = 1,
preload_model_policy_choice = 1,
):
if args.lock_config:
return
@ -1272,7 +1287,7 @@ def apply_changes( state,
"transformer_quantization" : quantization_choice,
"boost" : boost_choice,
"clear_file_list" : clear_file_list,
"reload_model" : reload_choice,
"preload_model_policy" : preload_model_policy_choice,
}
if Path(server_config_filename).is_file():
@ -1295,14 +1310,14 @@ def apply_changes( state,
if v != v_old:
changes.append(k)
global attention_mode, profile, compile, transformer_filename, text_encoder_filename, vae_config, boost, lora_dir, reload_needed, reload_model, transformer_quantization, transformer_types
global attention_mode, profile, compile, transformer_filename, text_encoder_filename, vae_config, boost, lora_dir, reload_needed, preload_model_policy, transformer_quantization, transformer_types
attention_mode = server_config["attention_mode"]
profile = server_config["profile"]
compile = server_config["compile"]
text_encoder_filename = server_config["text_encoder_filename"]
vae_config = server_config["vae_config"]
boost = server_config["boost"]
reload_model = server_config["reload_model"]
preload_model_policy = server_config["preload_model_policy"]
transformer_quantization = server_config["transformer_quantization"]
transformer_types = server_config["transformer_types"]
transformer_type = get_model_type(transformer_filename)
@ -1381,7 +1396,8 @@ def abort_generation(state):
gen["abort"] = True
gen["extra_orders"] = 0
wan_model._interrupt= True
if wan_model != None:
wan_model._interrupt= True
msg = "Processing Request to abort Current Generation"
gr.Info(msg)
return msg, gr.Button(interactive= False)
@ -1480,24 +1496,68 @@ def expand_slist(slist, num_inference_steps ):
return new_slist
def convert_image(image):
from PIL import ExifTags, ImageOps
from PIL import ImageOps
from typing import cast
return cast(Image, ImageOps.exif_transpose(image))
# image = image.convert('RGB')
# for orientation in ExifTags.TAGS.keys():
# if ExifTags.TAGS[orientation]=='Orientation':
# break
# exif = image.getexif()
# return image
# if not orientation in exif:
# if exif[orientation] == 3:
# image=image.rotate(180, expand=True)
# elif exif[orientation] == 6:
# image=image.rotate(270, expand=True)
# elif exif[orientation] == 8:
# image=image.rotate(90, expand=True)
# return image
def preprocess_video(process_type, height, width, video_in, max_frames):
from wan.utils.utils import resample
import decord
decord.bridge.set_bridge('torch')
reader = decord.VideoReader(video_in)
fps = reader.get_avg_fps()
frame_nos = resample(fps, len(reader), max_frames= max_frames, target_fps=16)
frames_list = reader.get_batch(frame_nos)
frame_height, frame_width, _ = frames_list[0].shape
scale = ((height * width ) / (frame_height * frame_width))**(1/2)
# scale = min(height / frame_height, width / frame_width)
new_height = (int(frame_height * scale) // 16) * 16
new_width = (int(frame_width * scale) // 16) * 16
processed_frames_list = []
for frame in frames_list:
frame = Image.fromarray(np.clip(frame.cpu().numpy(), 0, 255).astype(np.uint8))
frame = frame.resize((new_width,new_height), resample=Image.Resampling.LANCZOS)
processed_frames_list.append(frame)
if process_type=="pose":
from preprocessing.dwpose.pose import PoseBodyFaceVideoAnnotator
cfg_dict = {
"DETECTION_MODEL": "ckpts/pose/yolox_l.onnx",
"POSE_MODEL": "ckpts/pose/dw-ll_ucoco_384.onnx",
"RESIZE_SIZE": 1024
}
anno_ins = PoseBodyFaceVideoAnnotator(cfg_dict)
elif process_type=="depth":
from preprocessing.midas.depth import DepthVideoAnnotator
cfg_dict = {
"PRETRAINED_MODEL": "ckpts/depth/dpt_hybrid-midas-501f0c75.pt"
}
anno_ins = DepthVideoAnnotator(cfg_dict)
else:
from preprocessing.gray import GrayVideoAnnotator
cfg_dict = {}
anno_ins = GrayVideoAnnotator(cfg_dict)
np_frames = anno_ins.forward(processed_frames_list)
# from preprocessing.dwpose.pose import save_one_video
# save_one_video("test.mp4", np_frames, fps=8, quality=8, macro_block_size=None)
torch_frames = []
for np_frame in np_frames:
torch_frame = torch.from_numpy(np_frame)
torch_frames.append(torch_frame)
return torch.stack(torch_frames)
def generate_video(
task_id,
@ -1551,7 +1611,7 @@ def generate_video(
# gr.Info("Unable to generate a Video while a new configuration is being applied.")
# return
if reload_model !=3 and reload_model !=4 :
if "P" in preload_model_policy:
while wan_model == None:
time.sleep(1)
@ -1681,10 +1741,32 @@ def generate_video(
raise gr.Error("Teacache not supported for this model")
if "Vace" in model_filename:
# video_prompt_type = video_prompt_type +"G"
if any(process in video_prompt_type for process in ("P", "D", "G")) :
prompts_max = gen["prompts_max"]
status = get_generation_status(prompt_no, prompts_max, 1, 1)
preprocess_type = None
if "P" in video_prompt_type :
progress_args = [0, status + " - Extracting Open Pose Information"]
preprocess_type = "pose"
elif "D" in video_prompt_type :
progress_args = [0, status + " - Extracting Depth Information"]
preprocess_type = "depth"
elif "G" in video_prompt_type :
progress_args = [0, status + " - Extracting Gray Level Information"]
preprocess_type = "gray"
if preprocess_type != None :
progress(*progress_args )
gen["progress_args"] = progress_args
video_guide = preprocess_video(preprocess_type, width=width, height=height,video_in=video_guide, max_frames= video_length)
src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide],
[video_mask],
[image_refs],
video_length, VACE_SIZE_CONFIGS[resolution_reformated], "cpu",
original_video= "O" in video_prompt_type,
trim_video=max_frames)
else:
src_video, src_mask, src_ref_images = None, None, None
@ -2539,9 +2621,9 @@ def fill_inputs(state):
return generate_video_tab(update_form = True, state_dict = state, ui_defaults = ui_defaults)
def preload_model(state):
def preload_model_when_switching(state):
global reload_needed, wan_model, offloadobj
if reload_model == 1:
if "S" in preload_model_policy:
model_filename = state["model_filename"]
if state["model_filename"] != transformer_filename:
wan_model = None
@ -2558,7 +2640,7 @@ def preload_model(state):
def unload_model_if_needed(state):
global reload_needed, wan_model, offloadobj
if reload_model == 4:
if "U" in preload_model_policy:
if wan_model != None:
wan_model = None
if offloadobj is not None:
@ -2567,7 +2649,39 @@ def unload_model_if_needed(state):
gc.collect()
reload_needed= True
def filter_letters(source_str, letters):
ret = ""
for letter in letters:
if letter in source_str:
ret += letter
return ret
def add_to_sequence(source_str, letters):
ret = source_str
for letter in letters:
if not letter in source_str:
ret += letter
return ret
def del_in_sequence(source_str, letters):
ret = source_str
for letter in letters:
if letter in source_str:
ret = ret.replace(letter, "")
return ret
def refresh_video_prompt_type_image_refs(video_prompt_type, video_prompt_type_image_refs):
video_prompt_type = add_to_sequence(video_prompt_type, "I") if video_prompt_type_image_refs else del_in_sequence(video_prompt_type, "I")
return video_prompt_type, gr.update(visible = video_prompt_type_image_refs),gr.update(visible = video_prompt_type_image_refs)
def refresh_video_prompt_type_video_guide(video_prompt_type, video_prompt_type_video_guide):
video_prompt_type = del_in_sequence(video_prompt_type, "ODPCMV")
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide)
visible = "V" in video_prompt_type
return video_prompt_type, gr.update(visible = visible), gr.update(visible = visible), gr.update(visible= "M" in video_prompt_type )
def generate_video_tab(update_form = False, state_dict = None, ui_defaults = None, model_choice = None, header = None):
global inputs_names #, advanced
@ -2676,19 +2790,36 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
image_end = gr.Image(label= "Last Image for a new video", type ="pil", visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None))
with gr.Column(visible= "Vace" in model_filename ) as video_prompt_column:
gr.Markdown("<B>Control conditions: Images References (custom Faces or Objects), Video (Open Pose, Depth maps), Mask (inpainting)")
video_prompt_type_value= ui_defaults.get("video_prompt_type","I")
video_prompt_type = gr.Radio( [("Images Ref", "I"),("a Video", "V"), ("Images Refs + a Video", "IV"), ("Video + Video Mask", "VM"), ("Images + Video + Mask", "IVM")], value =video_prompt_type_value, label="Location", show_label= False, scale= 3)
image_refs = gr.Gallery(
label="Images Referencse (Custom faces and Objects to be found in the Video)", type ="pil",
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in video_prompt_type_value, value= ui_defaults.get("image_refs", None) )
video_prompt_type_value= ui_defaults.get("video_prompt_type","")
video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False)
video_prompt_type_video_guide = gr.Dropdown(
choices=[
("None, use only the Text Prompt", ""),
("Extend the Control Video", "OV"),
("Transfer Human Motion from the Control Video", "PV"),
("Transfer Depth from the Control Video", "DV"),
("Recolorize the Control Video", "CV"),
("Control Video contains Open Pose, Depth or Black & White ", "V"),
("Inpainting of Control Video using Mask Video ", "MV"),
],
value=filter_letters(video_prompt_type_value, "ODPCMV"),
label="Video to Video"
)
video_prompt_type_image_refs = gr.Checkbox(value="I" in video_prompt_type_value , label= "Use References Images (Faces, Objects) to customize New Video", scale =1 )
video_guide = gr.Video(label= "Reference Video (an animated Video in the Open Pose format or Depth Map video)", visible= "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None) )
with gr.Row():
max_frames = gr.Slider(0, 100, value=ui_defaults.get("max_frames",0), step=1, label="Nb of frames in Ref. Video (0 = as many as possible)", visible= "V" in video_prompt_type_value, scale = 2 )
remove_background_image_ref = gr.Checkbox(value=ui_defaults.get("remove_background_image_ref",1), label= "Remove Images Ref. Background", visible= "I" in video_prompt_type_value, scale =1 )
video_guide = gr.Video(label= "Control Video", visible= "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None),)
max_frames = gr.Slider(0, 100, value=ui_defaults.get("max_frames",0), step=1, label="Nb of frames in Control Video to use (0 = max)", visible= "V" in video_prompt_type_value, scale = 2 )
video_mask = gr.Video(label= "Video Mask (for Inpainting or Outpaing, white pixels = Mask)", visible= "M" in video_prompt_type_value, value= ui_defaults.get("video_mask", None) )
image_refs = gr.Gallery( label ="Reference Images",
type ="pil", show_label= True,
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in video_prompt_type_value,
value= ui_defaults.get("image_refs", None) )
# with gr.Row():
remove_background_image_ref = gr.Checkbox(value=ui_defaults.get("remove_background_image_ref",1), label= "Remove Background of Images References", visible= "I" in video_prompt_type_value, scale =1 )
video_mask = gr.Video(label= "Video Mask (for Inpainting or Outpaing, white pixels = Mask)", visible= "M" in video_prompt_type_value, value= ui_defaults.get("video_mask", None))
advanced_prompt = advanced_ui
@ -2923,7 +3054,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
target_settings = gr.Text(value = "settings", interactive= False, visible= False)
image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start, image_end])
video_prompt_type.change(fn=refresh_video_prompt_type, inputs=[state, video_prompt_type], outputs=[image_refs, video_guide, video_mask, max_frames, remove_background_image_ref])
# video_prompt_type.change(fn=refresh_video_prompt_type, inputs=[state, video_prompt_type], outputs=[image_refs, video_guide, video_mask, max_frames, remove_background_image_ref])
video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_image_ref ])
video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [video_prompt_type, video_prompt_type_video_guide], outputs = [video_prompt_type, video_guide, max_frames, video_mask])
show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars])
queue_df.select( fn=handle_celll_selection, inputs=state, outputs=[queue_df, modal_image_display, modal_container])
@ -2967,7 +3101,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
).then(fn= fill_inputs,
inputs=[state],
outputs=gen_inputs + extra_inputs
).then(fn= preload_model,
).then(fn= preload_model_when_switching,
inputs=[state],
outputs=[gen_status])
@ -3149,14 +3283,8 @@ def generate_configuration_tab(header, model_choice):
value=server_config.get("metadata_type", "metadata"),
label="Metadata Handling"
)
reload_choice = gr.Dropdown(
choices=[
("Load Model When Starting the App and Changing Model if Model Changed", 1),
("Load Model When Starting the App and Pressing Generate if Model Changed", 2),
("Load Model When Pressing Generate if Model Changed", 3),
("Load Model When Pressing Generate and Unload Model when Finished", 4),
],
value=server_config.get("reload_model",2),
preload_model_policy_choice = gr.CheckboxGroup([("Preload Model while Launching the App","P"), ("Preload Model while Switching Model", "S"), ("Unload Model when Queue is Done", "U")],
value=server_config.get("preload_model_policy",[]),
label="RAM Loading / Unloading Model Policy (in any case VRAM will be freed once the queue has been processed)"
)
@ -3191,7 +3319,7 @@ def generate_configuration_tab(header, model_choice):
quantization_choice,
boost_choice,
clear_file_list_choice,
reload_choice,
preload_model_policy_choice,
],
outputs= [msg , header, model_choice]
)
@ -3201,10 +3329,16 @@ def generate_about_tab():
gr.Markdown("Original Wan 2.1 Model by <B>Alibaba</B> (<A HREF='https://github.com/Wan-Video/Wan2.1'>GitHub</A>)")
gr.Markdown("Many thanks to:")
gr.Markdown("- <B>Alibaba Wan team for the best open source video generator")
gr.Markdown("- <B>Alibaba Vace and Fun Teams for their incredible control net models")
gr.Markdown("- <B>Cocktail Peanuts</B> : QA and simple installation via Pinokio.computer")
gr.Markdown("- <B>Tophness</B> : created multi tabs and queuing frameworks")
gr.Markdown("- <B>AmericanPresidentJimmyCarter</B> : added original support for Skip Layer Guidance")
gr.Markdown("- <B>Remade_AI</B> : for creating their awesome Loras collection")
gr.Markdown("- <B>Remade_AI</B> : for their awesome Loras collection")
gr.Markdown("<BR>Huge acknowlegments to these great open source projects used in WanGP:")
gr.Markdown("- <B>Rife</B>: temporal upsampler (https://github.com/hzwer/ECCV2022-RIFE)")
gr.Markdown("- <B>DwPose</B>: Open Pose extractor (https://github.com/IDEA-Research/DWPose)")
gr.Markdown("- <B>Midas</B>: Depth extractor (https://github.com/isl-org/MiDaS")
def generate_info_tab():
gr.Markdown("<FONT SIZE=3>Welcome to WanGP a super fast and low VRAM AI Video Generator !</FONT>")
@ -3231,17 +3365,26 @@ def generate_dropdown_model_list():
choices= dropdown_choices,
value= current_model_type,
show_label= False,
scale= 2
scale= 2,
elem_id="model_list",
elem_classes="model_list_class",
)
def create_demo():
css = """
#model_list{
background-color:black;
padding:1px}
#model_list input {
font-size:25px}
.title-with-lines {
display: flex;
align-items: center;
margin: 30px 0;
margin: 25px 0;
}
.line {
flex-grow: 1;
@ -3462,7 +3605,7 @@ def create_demo():
pointer-events: none;
}
"""
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md"), title= "Wan2GP") as demo:
with gr.Blocks(css=css, theme=gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md"), title= "Wan2GP") as demo:
gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.0 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
global model_list