mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
Simplified Vace, added auto open pose and depth extrators
This commit is contained in:
parent
fea835f21f
commit
9ac1674615
17
README.md
17
README.md
@ -14,7 +14,7 @@
|
||||
|
||||
|
||||
## 🔥 Latest News!!
|
||||
* April 4 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you !
|
||||
* April 9 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you !
|
||||
- A new queuing system that lets you stack in a queue as many text2video and imag2video tasks as you want. Each task can rely on complete different generation parameters (different number of frames, steps, loras, ...).
|
||||
- Temporal upsampling (Rife) and spatial upsampling (Lanczos) for a smoother video (32 fps or 64 fps) and to enlarge you video by x2 or x4. Check these new advanced options.
|
||||
- Wan Vace Control Net support : with Vace you can inject in the scene people or objects, animate a person, perform inpainting or outpainting, continue a video, ... I have provided an introduction guide below.
|
||||
@ -272,23 +272,24 @@ You can define multiple lines of macros. If there is only one macro line, the ap
|
||||
|
||||
### VACE ControlNet introduction
|
||||
|
||||
Vace is a ControlNet 1.3B text2video model that allows you on top of a text prompt to provide visual hints to guide the generation. It can do more things than image2video although it is not as good for just starting a video with an image because it only a 1.3B model (in fact 3B) versus 14B and (it is not specialized for start frames). However, with Vace you can inject in the scene people or objects, animate a person, perform inpainting or outpainting, continue a video, ...
|
||||
Vace is a ControlNet 1.3B text2video model that allows you to do Video to Video and Reference to Video (inject your own images into the output video). So with Vace you can inject in the scene people or objects of your choice, animate a person, perform inpainting or outpainting, continue a video, ...
|
||||
|
||||
First you need to select the Vace 1.3B model in the Drop Down box at the top. Please note that Vace works well for the moment only with videos up to 5s (81 frames).
|
||||
|
||||
Beside the usual Text Prompt, three new types of visual hints can be provided (and combined !):
|
||||
- reference Images: use this to inject people or objects in the video. You can select multiple reference Images. The integration of the image is more efficient if the background is replaced by the full white color. You can do that with your preferred background remover or use the built in background remover by checking the box *Remove background*
|
||||
- a Control Video: Based on your choice, you can decide to transfer the motion, the depth in a new Video. You can tell WanGP to use only the first n frames of Control Video and to extrapolate the rest. You can also do inpainting ). If the video contains area of grey color 127, they will be considered as masks and will be filled based on the Text prompt of the reference Images.
|
||||
|
||||
- a Video: this can be a video that contains a body pose (an animated wireframe that indicates the positions of limbs of a person), a greyed depth map video, a normal video combined with a masked video (see below),... The Vace model will detect automatically what to do depending on the video content. You can tell WanGP to use only the n first frames of this Video. All the frames beyond and up the number of requested frames will be generated by following the Text prompt and the other visual hints (for instance reference images). If the video contains area of grey color 127, they will be considered as masks and will be filled based on the Text prompt of the reference Images. There
|
||||
- reference Images: Use this to inject people or objects of your choice in the video. You can select multiple reference Images. The integration of the image is more efficient if the background is replaced by the full white color. You can do that with your preferred background remover or use the built in background remover by checking the box *Remove background*
|
||||
|
||||
- a Video Mask
|
||||
This offers a stronger mechanism to tell Vace which parts should be kept (black) or replaced (white). You can do as well inpainting / outpainting, fill the missing part of a video more efficientlty with just the video hint.
|
||||
This offers a stronger mechanism to tell Vace which parts should be kept (black) or replaced (white). You can do as well inpainting / outpainting, fill the missing part of a video more efficientlty with just the video hint. If a video mask is white, it will be generated so with black frames at the beginning and at the end and the rest white, you could generate the missing frames in between.
|
||||
|
||||
|
||||
Examples:
|
||||
- Inject people and / objects into a scene describe by a text promtp: Ref. Images + text Prompt
|
||||
- Animate a character described in a text prompt: Body Pose Video + text Prompt
|
||||
- Animate a character of your choice : Ref Images + Body Pose Video + text Prompt
|
||||
- Inject people and / objects into a scene describe by a text prompt: Ref. Images + text Prompt
|
||||
- Animate a character described in a text prompt: a Video of person moving + text Prompt
|
||||
- Animate a character of your choice (pose transfer) : Ref Images + a Video of person moving + text Prompt
|
||||
- Change the style of a scene (depth transfer): a Video that contains objects / person at differen depths + text Prompt
|
||||
|
||||
|
||||
There are lots of possible combinations. Some of them require to prepare some materials (masks on top of video, full masks, etc...).
|
||||
|
||||
2
preprocessing/dwpose/__init__.py
Normal file
2
preprocessing/dwpose/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
127
preprocessing/dwpose/onnxdet.py
Normal file
127
preprocessing/dwpose/onnxdet.py
Normal file
@ -0,0 +1,127 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
import onnxruntime
|
||||
|
||||
def nms(boxes, scores, nms_thr):
|
||||
"""Single class NMS implemented in Numpy."""
|
||||
x1 = boxes[:, 0]
|
||||
y1 = boxes[:, 1]
|
||||
x2 = boxes[:, 2]
|
||||
y2 = boxes[:, 3]
|
||||
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
order = scores.argsort()[::-1]
|
||||
|
||||
keep = []
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
keep.append(i)
|
||||
xx1 = np.maximum(x1[i], x1[order[1:]])
|
||||
yy1 = np.maximum(y1[i], y1[order[1:]])
|
||||
xx2 = np.minimum(x2[i], x2[order[1:]])
|
||||
yy2 = np.minimum(y2[i], y2[order[1:]])
|
||||
|
||||
w = np.maximum(0.0, xx2 - xx1 + 1)
|
||||
h = np.maximum(0.0, yy2 - yy1 + 1)
|
||||
inter = w * h
|
||||
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
||||
|
||||
inds = np.where(ovr <= nms_thr)[0]
|
||||
order = order[inds + 1]
|
||||
|
||||
return keep
|
||||
|
||||
def multiclass_nms(boxes, scores, nms_thr, score_thr):
|
||||
"""Multiclass NMS implemented in Numpy. Class-aware version."""
|
||||
final_dets = []
|
||||
num_classes = scores.shape[1]
|
||||
for cls_ind in range(num_classes):
|
||||
cls_scores = scores[:, cls_ind]
|
||||
valid_score_mask = cls_scores > score_thr
|
||||
if valid_score_mask.sum() == 0:
|
||||
continue
|
||||
else:
|
||||
valid_scores = cls_scores[valid_score_mask]
|
||||
valid_boxes = boxes[valid_score_mask]
|
||||
keep = nms(valid_boxes, valid_scores, nms_thr)
|
||||
if len(keep) > 0:
|
||||
cls_inds = np.ones((len(keep), 1)) * cls_ind
|
||||
dets = np.concatenate(
|
||||
[valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
|
||||
)
|
||||
final_dets.append(dets)
|
||||
if len(final_dets) == 0:
|
||||
return None
|
||||
return np.concatenate(final_dets, 0)
|
||||
|
||||
def demo_postprocess(outputs, img_size, p6=False):
|
||||
grids = []
|
||||
expanded_strides = []
|
||||
strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
|
||||
|
||||
hsizes = [img_size[0] // stride for stride in strides]
|
||||
wsizes = [img_size[1] // stride for stride in strides]
|
||||
|
||||
for hsize, wsize, stride in zip(hsizes, wsizes, strides):
|
||||
xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
|
||||
grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
|
||||
grids.append(grid)
|
||||
shape = grid.shape[:2]
|
||||
expanded_strides.append(np.full((*shape, 1), stride))
|
||||
|
||||
grids = np.concatenate(grids, 1)
|
||||
expanded_strides = np.concatenate(expanded_strides, 1)
|
||||
outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
|
||||
outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
|
||||
|
||||
return outputs
|
||||
|
||||
def preprocess(img, input_size, swap=(2, 0, 1)):
|
||||
if len(img.shape) == 3:
|
||||
padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
|
||||
else:
|
||||
padded_img = np.ones(input_size, dtype=np.uint8) * 114
|
||||
|
||||
r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
|
||||
resized_img = cv2.resize(
|
||||
img,
|
||||
(int(img.shape[1] * r), int(img.shape[0] * r)),
|
||||
interpolation=cv2.INTER_LINEAR,
|
||||
).astype(np.uint8)
|
||||
padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
|
||||
|
||||
padded_img = padded_img.transpose(swap)
|
||||
padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
|
||||
return padded_img, r
|
||||
|
||||
def inference_detector(session, oriImg):
|
||||
input_shape = (640,640)
|
||||
img, ratio = preprocess(oriImg, input_shape)
|
||||
|
||||
ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
|
||||
output = session.run(None, ort_inputs)
|
||||
predictions = demo_postprocess(output[0], input_shape)[0]
|
||||
|
||||
boxes = predictions[:, :4]
|
||||
scores = predictions[:, 4:5] * predictions[:, 5:]
|
||||
|
||||
boxes_xyxy = np.ones_like(boxes)
|
||||
boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.
|
||||
boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.
|
||||
boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.
|
||||
boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
|
||||
boxes_xyxy /= ratio
|
||||
dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
|
||||
if dets is not None:
|
||||
final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
|
||||
isscore = final_scores>0.3
|
||||
iscat = final_cls_inds == 0
|
||||
isbbox = [ i and j for (i, j) in zip(isscore, iscat)]
|
||||
final_boxes = final_boxes[isbbox]
|
||||
else:
|
||||
final_boxes = np.array([])
|
||||
|
||||
return final_boxes
|
||||
362
preprocessing/dwpose/onnxpose.py
Normal file
362
preprocessing/dwpose/onnxpose.py
Normal file
@ -0,0 +1,362 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import List, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
def preprocess(
|
||||
img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256)
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""Do preprocessing for RTMPose model inference.
|
||||
|
||||
Args:
|
||||
img (np.ndarray): Input image in shape.
|
||||
input_size (tuple): Input image size in shape (w, h).
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
- resized_img (np.ndarray): Preprocessed image.
|
||||
- center (np.ndarray): Center of image.
|
||||
- scale (np.ndarray): Scale of image.
|
||||
"""
|
||||
# get shape of image
|
||||
img_shape = img.shape[:2]
|
||||
out_img, out_center, out_scale = [], [], []
|
||||
if len(out_bbox) == 0:
|
||||
out_bbox = [[0, 0, img_shape[1], img_shape[0]]]
|
||||
for i in range(len(out_bbox)):
|
||||
x0 = out_bbox[i][0]
|
||||
y0 = out_bbox[i][1]
|
||||
x1 = out_bbox[i][2]
|
||||
y1 = out_bbox[i][3]
|
||||
bbox = np.array([x0, y0, x1, y1])
|
||||
|
||||
# get center and scale
|
||||
center, scale = bbox_xyxy2cs(bbox, padding=1.25)
|
||||
|
||||
# do affine transformation
|
||||
resized_img, scale = top_down_affine(input_size, scale, center, img)
|
||||
|
||||
# normalize image
|
||||
mean = np.array([123.675, 116.28, 103.53])
|
||||
std = np.array([58.395, 57.12, 57.375])
|
||||
resized_img = (resized_img - mean) / std
|
||||
|
||||
out_img.append(resized_img)
|
||||
out_center.append(center)
|
||||
out_scale.append(scale)
|
||||
|
||||
return out_img, out_center, out_scale
|
||||
|
||||
|
||||
def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray:
|
||||
"""Inference RTMPose model.
|
||||
|
||||
Args:
|
||||
sess (ort.InferenceSession): ONNXRuntime session.
|
||||
img (np.ndarray): Input image in shape.
|
||||
|
||||
Returns:
|
||||
outputs (np.ndarray): Output of RTMPose model.
|
||||
"""
|
||||
all_out = []
|
||||
# build input
|
||||
for i in range(len(img)):
|
||||
input = [img[i].transpose(2, 0, 1)]
|
||||
|
||||
# build output
|
||||
sess_input = {sess.get_inputs()[0].name: input}
|
||||
sess_output = []
|
||||
for out in sess.get_outputs():
|
||||
sess_output.append(out.name)
|
||||
|
||||
# run model
|
||||
outputs = sess.run(sess_output, sess_input)
|
||||
all_out.append(outputs)
|
||||
|
||||
return all_out
|
||||
|
||||
|
||||
def postprocess(outputs: List[np.ndarray],
|
||||
model_input_size: Tuple[int, int],
|
||||
center: Tuple[int, int],
|
||||
scale: Tuple[int, int],
|
||||
simcc_split_ratio: float = 2.0
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Postprocess for RTMPose model output.
|
||||
|
||||
Args:
|
||||
outputs (np.ndarray): Output of RTMPose model.
|
||||
model_input_size (tuple): RTMPose model Input image size.
|
||||
center (tuple): Center of bbox in shape (x, y).
|
||||
scale (tuple): Scale of bbox in shape (w, h).
|
||||
simcc_split_ratio (float): Split ratio of simcc.
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
- keypoints (np.ndarray): Rescaled keypoints.
|
||||
- scores (np.ndarray): Model predict scores.
|
||||
"""
|
||||
all_key = []
|
||||
all_score = []
|
||||
for i in range(len(outputs)):
|
||||
# use simcc to decode
|
||||
simcc_x, simcc_y = outputs[i]
|
||||
keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio)
|
||||
|
||||
# rescale keypoints
|
||||
keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2
|
||||
all_key.append(keypoints[0])
|
||||
all_score.append(scores[0])
|
||||
|
||||
return np.array(all_key), np.array(all_score)
|
||||
|
||||
|
||||
def bbox_xyxy2cs(bbox: np.ndarray,
|
||||
padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Transform the bbox format from (x,y,w,h) into (center, scale)
|
||||
|
||||
Args:
|
||||
bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
|
||||
as (left, top, right, bottom)
|
||||
padding (float): BBox padding factor that will be multilied to scale.
|
||||
Default: 1.0
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing center and scale.
|
||||
- np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
|
||||
(n, 2)
|
||||
- np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
|
||||
(n, 2)
|
||||
"""
|
||||
# convert single bbox from (4, ) to (1, 4)
|
||||
dim = bbox.ndim
|
||||
if dim == 1:
|
||||
bbox = bbox[None, :]
|
||||
|
||||
# get bbox center and scale
|
||||
x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
|
||||
center = np.hstack([x1 + x2, y1 + y2]) * 0.5
|
||||
scale = np.hstack([x2 - x1, y2 - y1]) * padding
|
||||
|
||||
if dim == 1:
|
||||
center = center[0]
|
||||
scale = scale[0]
|
||||
|
||||
return center, scale
|
||||
|
||||
|
||||
def _fix_aspect_ratio(bbox_scale: np.ndarray,
|
||||
aspect_ratio: float) -> np.ndarray:
|
||||
"""Extend the scale to match the given aspect ratio.
|
||||
|
||||
Args:
|
||||
scale (np.ndarray): The image scale (w, h) in shape (2, )
|
||||
aspect_ratio (float): The ratio of ``w/h``
|
||||
|
||||
Returns:
|
||||
np.ndarray: The reshaped image scale in (2, )
|
||||
"""
|
||||
w, h = np.hsplit(bbox_scale, [1])
|
||||
bbox_scale = np.where(w > h * aspect_ratio,
|
||||
np.hstack([w, w / aspect_ratio]),
|
||||
np.hstack([h * aspect_ratio, h]))
|
||||
return bbox_scale
|
||||
|
||||
|
||||
def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
|
||||
"""Rotate a point by an angle.
|
||||
|
||||
Args:
|
||||
pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
|
||||
angle_rad (float): rotation angle in radian
|
||||
|
||||
Returns:
|
||||
np.ndarray: Rotated point in shape (2, )
|
||||
"""
|
||||
sn, cs = np.sin(angle_rad), np.cos(angle_rad)
|
||||
rot_mat = np.array([[cs, -sn], [sn, cs]])
|
||||
return rot_mat @ pt
|
||||
|
||||
|
||||
def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
||||
"""To calculate the affine matrix, three pairs of points are required. This
|
||||
function is used to get the 3rd point, given 2D points a & b.
|
||||
|
||||
The 3rd point is defined by rotating vector `a - b` by 90 degrees
|
||||
anticlockwise, using b as the rotation center.
|
||||
|
||||
Args:
|
||||
a (np.ndarray): The 1st point (x,y) in shape (2, )
|
||||
b (np.ndarray): The 2nd point (x,y) in shape (2, )
|
||||
|
||||
Returns:
|
||||
np.ndarray: The 3rd point.
|
||||
"""
|
||||
direction = a - b
|
||||
c = b + np.r_[-direction[1], direction[0]]
|
||||
return c
|
||||
|
||||
|
||||
def get_warp_matrix(center: np.ndarray,
|
||||
scale: np.ndarray,
|
||||
rot: float,
|
||||
output_size: Tuple[int, int],
|
||||
shift: Tuple[float, float] = (0., 0.),
|
||||
inv: bool = False) -> np.ndarray:
|
||||
"""Calculate the affine transformation matrix that can warp the bbox area
|
||||
in the input image to the output size.
|
||||
|
||||
Args:
|
||||
center (np.ndarray[2, ]): Center of the bounding box (x, y).
|
||||
scale (np.ndarray[2, ]): Scale of the bounding box
|
||||
wrt [width, height].
|
||||
rot (float): Rotation angle (degree).
|
||||
output_size (np.ndarray[2, ] | list(2,)): Size of the
|
||||
destination heatmaps.
|
||||
shift (0-100%): Shift translation ratio wrt the width/height.
|
||||
Default (0., 0.).
|
||||
inv (bool): Option to inverse the affine transform direction.
|
||||
(inv=False: src->dst or inv=True: dst->src)
|
||||
|
||||
Returns:
|
||||
np.ndarray: A 2x3 transformation matrix
|
||||
"""
|
||||
shift = np.array(shift)
|
||||
src_w = scale[0]
|
||||
dst_w = output_size[0]
|
||||
dst_h = output_size[1]
|
||||
|
||||
# compute transformation matrix
|
||||
rot_rad = np.deg2rad(rot)
|
||||
src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad)
|
||||
dst_dir = np.array([0., dst_w * -0.5])
|
||||
|
||||
# get four corners of the src rectangle in the original image
|
||||
src = np.zeros((3, 2), dtype=np.float32)
|
||||
src[0, :] = center + scale * shift
|
||||
src[1, :] = center + src_dir + scale * shift
|
||||
src[2, :] = _get_3rd_point(src[0, :], src[1, :])
|
||||
|
||||
# get four corners of the dst rectangle in the input image
|
||||
dst = np.zeros((3, 2), dtype=np.float32)
|
||||
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
|
||||
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
|
||||
dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
|
||||
|
||||
if inv:
|
||||
warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
|
||||
else:
|
||||
warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
|
||||
|
||||
return warp_mat
|
||||
|
||||
|
||||
def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict,
|
||||
img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Get the bbox image as the model input by affine transform.
|
||||
|
||||
Args:
|
||||
input_size (dict): The input size of the model.
|
||||
bbox_scale (dict): The bbox scale of the img.
|
||||
bbox_center (dict): The bbox center of the img.
|
||||
img (np.ndarray): The original image.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing center and scale.
|
||||
- np.ndarray[float32]: img after affine transform.
|
||||
- np.ndarray[float32]: bbox scale after affine transform.
|
||||
"""
|
||||
w, h = input_size
|
||||
warp_size = (int(w), int(h))
|
||||
|
||||
# reshape bbox to fixed aspect ratio
|
||||
bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)
|
||||
|
||||
# get the affine matrix
|
||||
center = bbox_center
|
||||
scale = bbox_scale
|
||||
rot = 0
|
||||
warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
|
||||
|
||||
# do affine transform
|
||||
img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR)
|
||||
|
||||
return img, bbox_scale
|
||||
|
||||
|
||||
def get_simcc_maximum(simcc_x: np.ndarray,
|
||||
simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Get maximum response location and value from simcc representations.
|
||||
|
||||
Note:
|
||||
instance number: N
|
||||
num_keypoints: K
|
||||
heatmap height: H
|
||||
heatmap width: W
|
||||
|
||||
Args:
|
||||
simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
|
||||
simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
- locs (np.ndarray): locations of maximum heatmap responses in shape
|
||||
(K, 2) or (N, K, 2)
|
||||
- vals (np.ndarray): values of maximum heatmap responses in shape
|
||||
(K,) or (N, K)
|
||||
"""
|
||||
N, K, Wx = simcc_x.shape
|
||||
simcc_x = simcc_x.reshape(N * K, -1)
|
||||
simcc_y = simcc_y.reshape(N * K, -1)
|
||||
|
||||
# get maximum value locations
|
||||
x_locs = np.argmax(simcc_x, axis=1)
|
||||
y_locs = np.argmax(simcc_y, axis=1)
|
||||
locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
|
||||
max_val_x = np.amax(simcc_x, axis=1)
|
||||
max_val_y = np.amax(simcc_y, axis=1)
|
||||
|
||||
# get maximum value across x and y axis
|
||||
mask = max_val_x > max_val_y
|
||||
max_val_x[mask] = max_val_y[mask]
|
||||
vals = max_val_x
|
||||
locs[vals <= 0.] = -1
|
||||
|
||||
# reshape
|
||||
locs = locs.reshape(N, K, 2)
|
||||
vals = vals.reshape(N, K)
|
||||
|
||||
return locs, vals
|
||||
|
||||
|
||||
def decode(simcc_x: np.ndarray, simcc_y: np.ndarray,
|
||||
simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Modulate simcc distribution with Gaussian.
|
||||
|
||||
Args:
|
||||
simcc_x (np.ndarray[K, Wx]): model predicted simcc in x.
|
||||
simcc_y (np.ndarray[K, Wy]): model predicted simcc in y.
|
||||
simcc_split_ratio (int): The split ratio of simcc.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing center and scale.
|
||||
- np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2)
|
||||
- np.ndarray[float32]: scores in shape (K,) or (n, K)
|
||||
"""
|
||||
keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
|
||||
keypoints /= simcc_split_ratio
|
||||
|
||||
return keypoints, scores
|
||||
|
||||
|
||||
def inference_pose(session, out_bbox, oriImg):
|
||||
h, w = session.get_inputs()[0].shape[2:]
|
||||
model_input_size = (w, h)
|
||||
resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size)
|
||||
outputs = inference(session, resized_img)
|
||||
keypoints, scores = postprocess(outputs, model_input_size, center, scale)
|
||||
|
||||
return keypoints, scores
|
||||
183
preprocessing/dwpose/pose.py
Normal file
183
preprocessing/dwpose/pose.py
Normal file
@ -0,0 +1,183 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
from . import util
|
||||
from .wholebody import Wholebody, HWC3, resize_image
|
||||
from PIL import Image
|
||||
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||
|
||||
def convert_to_numpy(image):
|
||||
if isinstance(image, Image.Image):
|
||||
image = np.array(image)
|
||||
elif isinstance(image, torch.Tensor):
|
||||
image = image.detach().cpu().numpy()
|
||||
elif isinstance(image, np.ndarray):
|
||||
image = image.copy()
|
||||
else:
|
||||
raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.'
|
||||
return image
|
||||
|
||||
|
||||
|
||||
def draw_pose(pose, H, W, use_hand=False, use_body=False, use_face=False):
|
||||
bodies = pose['bodies']
|
||||
faces = pose['faces']
|
||||
hands = pose['hands']
|
||||
candidate = bodies['candidate']
|
||||
subset = bodies['subset']
|
||||
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
|
||||
|
||||
if use_body:
|
||||
canvas = util.draw_bodypose(canvas, candidate, subset)
|
||||
if use_hand:
|
||||
canvas = util.draw_handpose(canvas, hands)
|
||||
if use_face:
|
||||
canvas = util.draw_facepose(canvas, faces)
|
||||
|
||||
return canvas
|
||||
|
||||
|
||||
class PoseAnnotator:
|
||||
def __init__(self, cfg, device=None):
|
||||
onnx_det = cfg['DETECTION_MODEL']
|
||||
onnx_pose = cfg['POSE_MODEL']
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
||||
self.pose_estimation = Wholebody(onnx_det, onnx_pose, device=self.device)
|
||||
self.resize_size = cfg.get("RESIZE_SIZE", 1024)
|
||||
self.use_body = cfg.get('USE_BODY', True)
|
||||
self.use_face = cfg.get('USE_FACE', True)
|
||||
self.use_hand = cfg.get('USE_HAND', True)
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode
|
||||
def forward(self, image):
|
||||
image = convert_to_numpy(image)
|
||||
input_image = HWC3(image[..., ::-1])
|
||||
return self.process(resize_image(input_image, self.resize_size), image.shape[:2])
|
||||
|
||||
def process(self, ori_img, ori_shape):
|
||||
ori_h, ori_w = ori_shape
|
||||
ori_img = ori_img.copy()
|
||||
H, W, C = ori_img.shape
|
||||
with torch.no_grad():
|
||||
candidate, subset, det_result = self.pose_estimation(ori_img)
|
||||
nums, keys, locs = candidate.shape
|
||||
candidate[..., 0] /= float(W)
|
||||
candidate[..., 1] /= float(H)
|
||||
body = candidate[:, :18].copy()
|
||||
body = body.reshape(nums * 18, locs)
|
||||
score = subset[:, :18]
|
||||
for i in range(len(score)):
|
||||
for j in range(len(score[i])):
|
||||
if score[i][j] > 0.3:
|
||||
score[i][j] = int(18 * i + j)
|
||||
else:
|
||||
score[i][j] = -1
|
||||
|
||||
un_visible = subset < 0.3
|
||||
candidate[un_visible] = -1
|
||||
|
||||
foot = candidate[:, 18:24]
|
||||
|
||||
faces = candidate[:, 24:92]
|
||||
|
||||
hands = candidate[:, 92:113]
|
||||
hands = np.vstack([hands, candidate[:, 113:]])
|
||||
|
||||
bodies = dict(candidate=body, subset=score)
|
||||
pose = dict(bodies=bodies, hands=hands, faces=faces)
|
||||
|
||||
ret_data = {}
|
||||
if self.use_body:
|
||||
detected_map_body = draw_pose(pose, H, W, use_body=True)
|
||||
detected_map_body = cv2.resize(detected_map_body[..., ::-1], (ori_w, ori_h),
|
||||
interpolation=cv2.INTER_LANCZOS4 if ori_h * ori_w > H * W else cv2.INTER_AREA)
|
||||
ret_data["detected_map_body"] = detected_map_body
|
||||
|
||||
if self.use_face:
|
||||
detected_map_face = draw_pose(pose, H, W, use_face=True)
|
||||
detected_map_face = cv2.resize(detected_map_face[..., ::-1], (ori_w, ori_h),
|
||||
interpolation=cv2.INTER_LANCZOS4 if ori_h * ori_w > H * W else cv2.INTER_AREA)
|
||||
ret_data["detected_map_face"] = detected_map_face
|
||||
|
||||
if self.use_body and self.use_face:
|
||||
detected_map_bodyface = draw_pose(pose, H, W, use_body=True, use_face=True)
|
||||
detected_map_bodyface = cv2.resize(detected_map_bodyface[..., ::-1], (ori_w, ori_h),
|
||||
interpolation=cv2.INTER_LANCZOS4 if ori_h * ori_w > H * W else cv2.INTER_AREA)
|
||||
ret_data["detected_map_bodyface"] = detected_map_bodyface
|
||||
|
||||
if self.use_hand and self.use_body and self.use_face:
|
||||
detected_map_handbodyface = draw_pose(pose, H, W, use_hand=True, use_body=True, use_face=True)
|
||||
detected_map_handbodyface = cv2.resize(detected_map_handbodyface[..., ::-1], (ori_w, ori_h),
|
||||
interpolation=cv2.INTER_LANCZOS4 if ori_h * ori_w > H * W else cv2.INTER_AREA)
|
||||
ret_data["detected_map_handbodyface"] = detected_map_handbodyface
|
||||
|
||||
# convert_size
|
||||
if det_result.shape[0] > 0:
|
||||
w_ratio, h_ratio = ori_w / W, ori_h / H
|
||||
det_result[..., ::2] *= h_ratio
|
||||
det_result[..., 1::2] *= w_ratio
|
||||
det_result = det_result.astype(np.int32)
|
||||
return ret_data, det_result
|
||||
|
||||
|
||||
class PoseBodyFaceAnnotator(PoseAnnotator):
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.use_body, self.use_face, self.use_hand = True, True, False
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode
|
||||
def forward(self, image):
|
||||
ret_data, det_result = super().forward(image)
|
||||
return ret_data['detected_map_bodyface']
|
||||
|
||||
|
||||
class PoseBodyFaceVideoAnnotator(PoseBodyFaceAnnotator):
|
||||
def forward(self, frames):
|
||||
ret_frames = []
|
||||
for frame in frames:
|
||||
anno_frame = super().forward(np.array(frame))
|
||||
ret_frames.append(anno_frame)
|
||||
return ret_frames
|
||||
|
||||
import imageio
|
||||
|
||||
def save_one_video(file_path, videos, fps=8, quality=8, macro_block_size=None):
|
||||
try:
|
||||
video_writer = imageio.get_writer(file_path, fps=fps, codec='libx264', quality=quality, macro_block_size=macro_block_size)
|
||||
for frame in videos:
|
||||
video_writer.append_data(frame)
|
||||
video_writer.close()
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Video save error: {e}")
|
||||
return False
|
||||
|
||||
def get_frames(video_path):
|
||||
frames = []
|
||||
|
||||
|
||||
# Opens the Video file with CV2
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
print("video fps: " + str(fps))
|
||||
i = 0
|
||||
while cap.isOpened():
|
||||
ret, frame = cap.read()
|
||||
if ret == False:
|
||||
break
|
||||
frames.append(frame)
|
||||
i += 1
|
||||
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
return frames, fps
|
||||
|
||||
299
preprocessing/dwpose/util.py
Normal file
299
preprocessing/dwpose/util.py
Normal file
@ -0,0 +1,299 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import math
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
import cv2
|
||||
|
||||
|
||||
eps = 0.01
|
||||
|
||||
|
||||
def smart_resize(x, s):
|
||||
Ht, Wt = s
|
||||
if x.ndim == 2:
|
||||
Ho, Wo = x.shape
|
||||
Co = 1
|
||||
else:
|
||||
Ho, Wo, Co = x.shape
|
||||
if Co == 3 or Co == 1:
|
||||
k = float(Ht + Wt) / float(Ho + Wo)
|
||||
return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
|
||||
else:
|
||||
return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2)
|
||||
|
||||
|
||||
def smart_resize_k(x, fx, fy):
|
||||
if x.ndim == 2:
|
||||
Ho, Wo = x.shape
|
||||
Co = 1
|
||||
else:
|
||||
Ho, Wo, Co = x.shape
|
||||
Ht, Wt = Ho * fy, Wo * fx
|
||||
if Co == 3 or Co == 1:
|
||||
k = float(Ht + Wt) / float(Ho + Wo)
|
||||
return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
|
||||
else:
|
||||
return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2)
|
||||
|
||||
|
||||
def padRightDownCorner(img, stride, padValue):
|
||||
h = img.shape[0]
|
||||
w = img.shape[1]
|
||||
|
||||
pad = 4 * [None]
|
||||
pad[0] = 0 # up
|
||||
pad[1] = 0 # left
|
||||
pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
|
||||
pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
|
||||
|
||||
img_padded = img
|
||||
pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1))
|
||||
img_padded = np.concatenate((pad_up, img_padded), axis=0)
|
||||
pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1))
|
||||
img_padded = np.concatenate((pad_left, img_padded), axis=1)
|
||||
pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1))
|
||||
img_padded = np.concatenate((img_padded, pad_down), axis=0)
|
||||
pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1))
|
||||
img_padded = np.concatenate((img_padded, pad_right), axis=1)
|
||||
|
||||
return img_padded, pad
|
||||
|
||||
|
||||
def transfer(model, model_weights):
|
||||
transfered_model_weights = {}
|
||||
for weights_name in model.state_dict().keys():
|
||||
transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])]
|
||||
return transfered_model_weights
|
||||
|
||||
|
||||
def draw_bodypose(canvas, candidate, subset):
|
||||
H, W, C = canvas.shape
|
||||
candidate = np.array(candidate)
|
||||
subset = np.array(subset)
|
||||
|
||||
stickwidth = 4
|
||||
|
||||
limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
|
||||
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
|
||||
[1, 16], [16, 18], [3, 17], [6, 18]]
|
||||
|
||||
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
|
||||
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
|
||||
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
|
||||
|
||||
for i in range(17):
|
||||
for n in range(len(subset)):
|
||||
index = subset[n][np.array(limbSeq[i]) - 1]
|
||||
if -1 in index:
|
||||
continue
|
||||
Y = candidate[index.astype(int), 0] * float(W)
|
||||
X = candidate[index.astype(int), 1] * float(H)
|
||||
mX = np.mean(X)
|
||||
mY = np.mean(Y)
|
||||
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
|
||||
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
||||
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
|
||||
cv2.fillConvexPoly(canvas, polygon, colors[i])
|
||||
|
||||
canvas = (canvas * 0.6).astype(np.uint8)
|
||||
|
||||
for i in range(18):
|
||||
for n in range(len(subset)):
|
||||
index = int(subset[n][i])
|
||||
if index == -1:
|
||||
continue
|
||||
x, y = candidate[index][0:2]
|
||||
x = int(x * W)
|
||||
y = int(y * H)
|
||||
cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
|
||||
|
||||
return canvas
|
||||
|
||||
|
||||
def draw_handpose(canvas, all_hand_peaks):
|
||||
H, W, C = canvas.shape
|
||||
|
||||
edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
|
||||
[10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
|
||||
|
||||
for peaks in all_hand_peaks:
|
||||
peaks = np.array(peaks)
|
||||
|
||||
for ie, e in enumerate(edges):
|
||||
x1, y1 = peaks[e[0]]
|
||||
x2, y2 = peaks[e[1]]
|
||||
x1 = int(x1 * W)
|
||||
y1 = int(y1 * H)
|
||||
x2 = int(x2 * W)
|
||||
y2 = int(y2 * H)
|
||||
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
|
||||
cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2)
|
||||
|
||||
for i, keyponit in enumerate(peaks):
|
||||
x, y = keyponit
|
||||
x = int(x * W)
|
||||
y = int(y * H)
|
||||
if x > eps and y > eps:
|
||||
cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
|
||||
return canvas
|
||||
|
||||
|
||||
def draw_facepose(canvas, all_lmks):
|
||||
H, W, C = canvas.shape
|
||||
for lmks in all_lmks:
|
||||
lmks = np.array(lmks)
|
||||
for lmk in lmks:
|
||||
x, y = lmk
|
||||
x = int(x * W)
|
||||
y = int(y * H)
|
||||
if x > eps and y > eps:
|
||||
cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1)
|
||||
return canvas
|
||||
|
||||
|
||||
# detect hand according to body pose keypoints
|
||||
# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
|
||||
def handDetect(candidate, subset, oriImg):
|
||||
# right hand: wrist 4, elbow 3, shoulder 2
|
||||
# left hand: wrist 7, elbow 6, shoulder 5
|
||||
ratioWristElbow = 0.33
|
||||
detect_result = []
|
||||
image_height, image_width = oriImg.shape[0:2]
|
||||
for person in subset.astype(int):
|
||||
# if any of three not detected
|
||||
has_left = np.sum(person[[5, 6, 7]] == -1) == 0
|
||||
has_right = np.sum(person[[2, 3, 4]] == -1) == 0
|
||||
if not (has_left or has_right):
|
||||
continue
|
||||
hands = []
|
||||
#left hand
|
||||
if has_left:
|
||||
left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]]
|
||||
x1, y1 = candidate[left_shoulder_index][:2]
|
||||
x2, y2 = candidate[left_elbow_index][:2]
|
||||
x3, y3 = candidate[left_wrist_index][:2]
|
||||
hands.append([x1, y1, x2, y2, x3, y3, True])
|
||||
# right hand
|
||||
if has_right:
|
||||
right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]]
|
||||
x1, y1 = candidate[right_shoulder_index][:2]
|
||||
x2, y2 = candidate[right_elbow_index][:2]
|
||||
x3, y3 = candidate[right_wrist_index][:2]
|
||||
hands.append([x1, y1, x2, y2, x3, y3, False])
|
||||
|
||||
for x1, y1, x2, y2, x3, y3, is_left in hands:
|
||||
# pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
|
||||
# handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
|
||||
# handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
|
||||
# const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
|
||||
# const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
|
||||
# handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
|
||||
x = x3 + ratioWristElbow * (x3 - x2)
|
||||
y = y3 + ratioWristElbow * (y3 - y2)
|
||||
distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
|
||||
distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
|
||||
width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
|
||||
# x-y refers to the center --> offset to topLeft point
|
||||
# handRectangle.x -= handRectangle.width / 2.f;
|
||||
# handRectangle.y -= handRectangle.height / 2.f;
|
||||
x -= width / 2
|
||||
y -= width / 2 # width = height
|
||||
# overflow the image
|
||||
if x < 0: x = 0
|
||||
if y < 0: y = 0
|
||||
width1 = width
|
||||
width2 = width
|
||||
if x + width > image_width: width1 = image_width - x
|
||||
if y + width > image_height: width2 = image_height - y
|
||||
width = min(width1, width2)
|
||||
# the max hand box value is 20 pixels
|
||||
if width >= 20:
|
||||
detect_result.append([int(x), int(y), int(width), is_left])
|
||||
|
||||
'''
|
||||
return value: [[x, y, w, True if left hand else False]].
|
||||
width=height since the network require squared input.
|
||||
x, y is the coordinate of top left
|
||||
'''
|
||||
return detect_result
|
||||
|
||||
|
||||
# Written by Lvmin
|
||||
def faceDetect(candidate, subset, oriImg):
|
||||
# left right eye ear 14 15 16 17
|
||||
detect_result = []
|
||||
image_height, image_width = oriImg.shape[0:2]
|
||||
for person in subset.astype(int):
|
||||
has_head = person[0] > -1
|
||||
if not has_head:
|
||||
continue
|
||||
|
||||
has_left_eye = person[14] > -1
|
||||
has_right_eye = person[15] > -1
|
||||
has_left_ear = person[16] > -1
|
||||
has_right_ear = person[17] > -1
|
||||
|
||||
if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear):
|
||||
continue
|
||||
|
||||
head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]]
|
||||
|
||||
width = 0.0
|
||||
x0, y0 = candidate[head][:2]
|
||||
|
||||
if has_left_eye:
|
||||
x1, y1 = candidate[left_eye][:2]
|
||||
d = max(abs(x0 - x1), abs(y0 - y1))
|
||||
width = max(width, d * 3.0)
|
||||
|
||||
if has_right_eye:
|
||||
x1, y1 = candidate[right_eye][:2]
|
||||
d = max(abs(x0 - x1), abs(y0 - y1))
|
||||
width = max(width, d * 3.0)
|
||||
|
||||
if has_left_ear:
|
||||
x1, y1 = candidate[left_ear][:2]
|
||||
d = max(abs(x0 - x1), abs(y0 - y1))
|
||||
width = max(width, d * 1.5)
|
||||
|
||||
if has_right_ear:
|
||||
x1, y1 = candidate[right_ear][:2]
|
||||
d = max(abs(x0 - x1), abs(y0 - y1))
|
||||
width = max(width, d * 1.5)
|
||||
|
||||
x, y = x0, y0
|
||||
|
||||
x -= width
|
||||
y -= width
|
||||
|
||||
if x < 0:
|
||||
x = 0
|
||||
|
||||
if y < 0:
|
||||
y = 0
|
||||
|
||||
width1 = width * 2
|
||||
width2 = width * 2
|
||||
|
||||
if x + width > image_width:
|
||||
width1 = image_width - x
|
||||
|
||||
if y + width > image_height:
|
||||
width2 = image_height - y
|
||||
|
||||
width = min(width1, width2)
|
||||
|
||||
if width >= 20:
|
||||
detect_result.append([int(x), int(y), int(width)])
|
||||
|
||||
return detect_result
|
||||
|
||||
|
||||
# get max index of 2d array
|
||||
def npmax(array):
|
||||
arrayindex = array.argmax(1)
|
||||
arrayvalue = array.max(1)
|
||||
i = arrayvalue.argmax()
|
||||
j = arrayindex[i]
|
||||
return i, j
|
||||
80
preprocessing/dwpose/wholebody.py
Normal file
80
preprocessing/dwpose/wholebody.py
Normal file
@ -0,0 +1,80 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import cv2
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
from .onnxdet import inference_detector
|
||||
from .onnxpose import inference_pose
|
||||
|
||||
def HWC3(x):
|
||||
assert x.dtype == np.uint8
|
||||
if x.ndim == 2:
|
||||
x = x[:, :, None]
|
||||
assert x.ndim == 3
|
||||
H, W, C = x.shape
|
||||
assert C == 1 or C == 3 or C == 4
|
||||
if C == 3:
|
||||
return x
|
||||
if C == 1:
|
||||
return np.concatenate([x, x, x], axis=2)
|
||||
if C == 4:
|
||||
color = x[:, :, 0:3].astype(np.float32)
|
||||
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
|
||||
y = color * alpha + 255.0 * (1.0 - alpha)
|
||||
y = y.clip(0, 255).astype(np.uint8)
|
||||
return y
|
||||
|
||||
|
||||
def resize_image(input_image, resolution):
|
||||
H, W, C = input_image.shape
|
||||
H = float(H)
|
||||
W = float(W)
|
||||
k = float(resolution) / min(H, W)
|
||||
H *= k
|
||||
W *= k
|
||||
H = int(np.round(H / 64.0)) * 64
|
||||
W = int(np.round(W / 64.0)) * 64
|
||||
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
|
||||
return img
|
||||
|
||||
class Wholebody:
|
||||
def __init__(self, onnx_det, onnx_pose, device = 'cuda:0'):
|
||||
|
||||
providers = ['CPUExecutionProvider'
|
||||
] if device == 'cpu' else ['CUDAExecutionProvider']
|
||||
# onnx_det = 'annotator/ckpts/yolox_l.onnx'
|
||||
# onnx_pose = 'annotator/ckpts/dw-ll_ucoco_384.onnx'
|
||||
|
||||
self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
|
||||
self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
|
||||
|
||||
def __call__(self, ori_img):
|
||||
det_result = inference_detector(self.session_det, ori_img)
|
||||
keypoints, scores = inference_pose(self.session_pose, det_result, ori_img)
|
||||
|
||||
keypoints_info = np.concatenate(
|
||||
(keypoints, scores[..., None]), axis=-1)
|
||||
# compute neck joint
|
||||
neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
|
||||
# neck score when visualizing pred
|
||||
neck[:, 2:4] = np.logical_and(
|
||||
keypoints_info[:, 5, 2:4] > 0.3,
|
||||
keypoints_info[:, 6, 2:4] > 0.3).astype(int)
|
||||
new_keypoints_info = np.insert(
|
||||
keypoints_info, 17, neck, axis=1)
|
||||
mmpose_idx = [
|
||||
17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3
|
||||
]
|
||||
openpose_idx = [
|
||||
1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17
|
||||
]
|
||||
new_keypoints_info[:, openpose_idx] = \
|
||||
new_keypoints_info[:, mmpose_idx]
|
||||
keypoints_info = new_keypoints_info
|
||||
|
||||
keypoints, scores = keypoints_info[
|
||||
..., :2], keypoints_info[..., 2]
|
||||
|
||||
return keypoints, scores, det_result
|
||||
|
||||
|
||||
35
preprocessing/gray.py
Normal file
35
preprocessing/gray.py
Normal file
@ -0,0 +1,35 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
def convert_to_numpy(image):
|
||||
if isinstance(image, Image.Image):
|
||||
image = np.array(image)
|
||||
elif isinstance(image, torch.Tensor):
|
||||
image = image.detach().cpu().numpy()
|
||||
elif isinstance(image, np.ndarray):
|
||||
image = image.copy()
|
||||
else:
|
||||
raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.'
|
||||
return image
|
||||
|
||||
class GrayAnnotator:
|
||||
def __init__(self, cfg):
|
||||
pass
|
||||
def forward(self, image):
|
||||
image = convert_to_numpy(image)
|
||||
gray_map = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
return gray_map[..., None].repeat(3, axis=2)
|
||||
|
||||
|
||||
class GrayVideoAnnotator(GrayAnnotator):
|
||||
def forward(self, frames):
|
||||
ret_frames = []
|
||||
for frame in frames:
|
||||
anno_frame = super().forward(np.array(frame))
|
||||
ret_frames.append(anno_frame)
|
||||
return ret_frames
|
||||
2
preprocessing/midas/__init__.py
Normal file
2
preprocessing/midas/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
166
preprocessing/midas/api.py
Normal file
166
preprocessing/midas/api.py
Normal file
@ -0,0 +1,166 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# based on https://github.com/isl-org/MiDaS
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from .dpt_depth import DPTDepthModel
|
||||
from .midas_net import MidasNet
|
||||
from .midas_net_custom import MidasNet_small
|
||||
from .transforms import NormalizeImage, PrepareForNet, Resize
|
||||
|
||||
# ISL_PATHS = {
|
||||
# "dpt_large": "dpt_large-midas-2f21e586.pt",
|
||||
# "dpt_hybrid": "dpt_hybrid-midas-501f0c75.pt",
|
||||
# "midas_v21": "",
|
||||
# "midas_v21_small": "",
|
||||
# }
|
||||
|
||||
# remote_model_path =
|
||||
# "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt"
|
||||
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
return self
|
||||
|
||||
|
||||
def load_midas_transform(model_type):
|
||||
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
||||
# load transform only
|
||||
if model_type == 'dpt_large': # DPT-Large
|
||||
net_w, net_h = 384, 384
|
||||
resize_mode = 'minimal'
|
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
|
||||
std=[0.5, 0.5, 0.5])
|
||||
|
||||
elif model_type == 'dpt_hybrid': # DPT-Hybrid
|
||||
net_w, net_h = 384, 384
|
||||
resize_mode = 'minimal'
|
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
|
||||
std=[0.5, 0.5, 0.5])
|
||||
|
||||
elif model_type == 'midas_v21':
|
||||
net_w, net_h = 384, 384
|
||||
resize_mode = 'upper_bound'
|
||||
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
|
||||
elif model_type == 'midas_v21_small':
|
||||
net_w, net_h = 256, 256
|
||||
resize_mode = 'upper_bound'
|
||||
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
|
||||
else:
|
||||
assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
|
||||
|
||||
transform = Compose([
|
||||
Resize(
|
||||
net_w,
|
||||
net_h,
|
||||
resize_target=None,
|
||||
keep_aspect_ratio=True,
|
||||
ensure_multiple_of=32,
|
||||
resize_method=resize_mode,
|
||||
image_interpolation_method=cv2.INTER_CUBIC,
|
||||
),
|
||||
normalization,
|
||||
PrepareForNet(),
|
||||
])
|
||||
|
||||
return transform
|
||||
|
||||
|
||||
def load_model(model_type, model_path):
|
||||
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
||||
# load network
|
||||
# model_path = ISL_PATHS[model_type]
|
||||
if model_type == 'dpt_large': # DPT-Large
|
||||
model = DPTDepthModel(
|
||||
path=model_path,
|
||||
backbone='vitl16_384',
|
||||
non_negative=True,
|
||||
)
|
||||
net_w, net_h = 384, 384
|
||||
resize_mode = 'minimal'
|
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
|
||||
std=[0.5, 0.5, 0.5])
|
||||
|
||||
elif model_type == 'dpt_hybrid': # DPT-Hybrid
|
||||
model = DPTDepthModel(
|
||||
path=model_path,
|
||||
backbone='vitb_rn50_384',
|
||||
non_negative=True,
|
||||
)
|
||||
net_w, net_h = 384, 384
|
||||
resize_mode = 'minimal'
|
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
|
||||
std=[0.5, 0.5, 0.5])
|
||||
|
||||
elif model_type == 'midas_v21':
|
||||
model = MidasNet(model_path, non_negative=True)
|
||||
net_w, net_h = 384, 384
|
||||
resize_mode = 'upper_bound'
|
||||
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
|
||||
elif model_type == 'midas_v21_small':
|
||||
model = MidasNet_small(model_path,
|
||||
features=64,
|
||||
backbone='efficientnet_lite3',
|
||||
exportable=True,
|
||||
non_negative=True,
|
||||
blocks={'expand': True})
|
||||
net_w, net_h = 256, 256
|
||||
resize_mode = 'upper_bound'
|
||||
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
|
||||
else:
|
||||
print(
|
||||
f"model_type '{model_type}' not implemented, use: --model_type large"
|
||||
)
|
||||
assert False
|
||||
|
||||
transform = Compose([
|
||||
Resize(
|
||||
net_w,
|
||||
net_h,
|
||||
resize_target=None,
|
||||
keep_aspect_ratio=True,
|
||||
ensure_multiple_of=32,
|
||||
resize_method=resize_mode,
|
||||
image_interpolation_method=cv2.INTER_CUBIC,
|
||||
),
|
||||
normalization,
|
||||
PrepareForNet(),
|
||||
])
|
||||
|
||||
return model.eval(), transform
|
||||
|
||||
|
||||
class MiDaSInference(nn.Module):
|
||||
MODEL_TYPES_TORCH_HUB = ['DPT_Large', 'DPT_Hybrid', 'MiDaS_small']
|
||||
MODEL_TYPES_ISL = [
|
||||
'dpt_large',
|
||||
'dpt_hybrid',
|
||||
'midas_v21',
|
||||
'midas_v21_small',
|
||||
]
|
||||
|
||||
def __init__(self, model_type, model_path):
|
||||
super().__init__()
|
||||
assert (model_type in self.MODEL_TYPES_ISL)
|
||||
model, _ = load_model(model_type, model_path)
|
||||
self.model = model
|
||||
self.model.train = disabled_train
|
||||
|
||||
def forward(self, x):
|
||||
with torch.no_grad():
|
||||
prediction = self.model(x)
|
||||
return prediction
|
||||
18
preprocessing/midas/base_model.py
Normal file
18
preprocessing/midas/base_model.py
Normal file
@ -0,0 +1,18 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import torch
|
||||
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
def load(self, path):
|
||||
"""Load model from file.
|
||||
|
||||
Args:
|
||||
path (str): file path
|
||||
"""
|
||||
parameters = torch.load(path, map_location=torch.device('cpu'), weights_only=True)
|
||||
|
||||
if 'optimizer' in parameters:
|
||||
parameters = parameters['model']
|
||||
|
||||
self.load_state_dict(parameters)
|
||||
391
preprocessing/midas/blocks.py
Normal file
391
preprocessing/midas/blocks.py
Normal file
@ -0,0 +1,391 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .vit import (_make_pretrained_vitb16_384, _make_pretrained_vitb_rn50_384,
|
||||
_make_pretrained_vitl16_384)
|
||||
|
||||
|
||||
def _make_encoder(
|
||||
backbone,
|
||||
features,
|
||||
use_pretrained,
|
||||
groups=1,
|
||||
expand=False,
|
||||
exportable=True,
|
||||
hooks=None,
|
||||
use_vit_only=False,
|
||||
use_readout='ignore',
|
||||
):
|
||||
if backbone == 'vitl16_384':
|
||||
pretrained = _make_pretrained_vitl16_384(use_pretrained,
|
||||
hooks=hooks,
|
||||
use_readout=use_readout)
|
||||
scratch = _make_scratch(
|
||||
[256, 512, 1024, 1024], features, groups=groups,
|
||||
expand=expand) # ViT-L/16 - 85.0% Top1 (backbone)
|
||||
elif backbone == 'vitb_rn50_384':
|
||||
pretrained = _make_pretrained_vitb_rn50_384(
|
||||
use_pretrained,
|
||||
hooks=hooks,
|
||||
use_vit_only=use_vit_only,
|
||||
use_readout=use_readout,
|
||||
)
|
||||
scratch = _make_scratch(
|
||||
[256, 512, 768, 768], features, groups=groups,
|
||||
expand=expand) # ViT-H/16 - 85.0% Top1 (backbone)
|
||||
elif backbone == 'vitb16_384':
|
||||
pretrained = _make_pretrained_vitb16_384(use_pretrained,
|
||||
hooks=hooks,
|
||||
use_readout=use_readout)
|
||||
scratch = _make_scratch(
|
||||
[96, 192, 384, 768], features, groups=groups,
|
||||
expand=expand) # ViT-B/16 - 84.6% Top1 (backbone)
|
||||
elif backbone == 'resnext101_wsl':
|
||||
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
|
||||
scratch = _make_scratch([256, 512, 1024, 2048],
|
||||
features,
|
||||
groups=groups,
|
||||
expand=expand) # efficientnet_lite3
|
||||
elif backbone == 'efficientnet_lite3':
|
||||
pretrained = _make_pretrained_efficientnet_lite3(use_pretrained,
|
||||
exportable=exportable)
|
||||
scratch = _make_scratch([32, 48, 136, 384],
|
||||
features,
|
||||
groups=groups,
|
||||
expand=expand) # efficientnet_lite3
|
||||
else:
|
||||
print(f"Backbone '{backbone}' not implemented")
|
||||
assert False
|
||||
|
||||
return pretrained, scratch
|
||||
|
||||
|
||||
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
||||
scratch = nn.Module()
|
||||
|
||||
out_shape1 = out_shape
|
||||
out_shape2 = out_shape
|
||||
out_shape3 = out_shape
|
||||
out_shape4 = out_shape
|
||||
if expand is True:
|
||||
out_shape1 = out_shape
|
||||
out_shape2 = out_shape * 2
|
||||
out_shape3 = out_shape * 4
|
||||
out_shape4 = out_shape * 8
|
||||
|
||||
scratch.layer1_rn = nn.Conv2d(in_shape[0],
|
||||
out_shape1,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False,
|
||||
groups=groups)
|
||||
scratch.layer2_rn = nn.Conv2d(in_shape[1],
|
||||
out_shape2,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False,
|
||||
groups=groups)
|
||||
scratch.layer3_rn = nn.Conv2d(in_shape[2],
|
||||
out_shape3,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False,
|
||||
groups=groups)
|
||||
scratch.layer4_rn = nn.Conv2d(in_shape[3],
|
||||
out_shape4,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False,
|
||||
groups=groups)
|
||||
|
||||
return scratch
|
||||
|
||||
|
||||
def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
|
||||
efficientnet = torch.hub.load('rwightman/gen-efficientnet-pytorch',
|
||||
'tf_efficientnet_lite3',
|
||||
pretrained=use_pretrained,
|
||||
exportable=exportable)
|
||||
return _make_efficientnet_backbone(efficientnet)
|
||||
|
||||
|
||||
def _make_efficientnet_backbone(effnet):
|
||||
pretrained = nn.Module()
|
||||
|
||||
pretrained.layer1 = nn.Sequential(effnet.conv_stem, effnet.bn1,
|
||||
effnet.act1, *effnet.blocks[0:2])
|
||||
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
|
||||
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
|
||||
pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
|
||||
|
||||
return pretrained
|
||||
|
||||
|
||||
def _make_resnet_backbone(resnet):
|
||||
pretrained = nn.Module()
|
||||
pretrained.layer1 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu,
|
||||
resnet.maxpool, resnet.layer1)
|
||||
|
||||
pretrained.layer2 = resnet.layer2
|
||||
pretrained.layer3 = resnet.layer3
|
||||
pretrained.layer4 = resnet.layer4
|
||||
|
||||
return pretrained
|
||||
|
||||
|
||||
def _make_pretrained_resnext101_wsl(use_pretrained):
|
||||
resnet = torch.hub.load('facebookresearch/WSL-Images',
|
||||
'resnext101_32x8d_wsl')
|
||||
return _make_resnet_backbone(resnet)
|
||||
|
||||
|
||||
class Interpolate(nn.Module):
|
||||
"""Interpolation module.
|
||||
"""
|
||||
def __init__(self, scale_factor, mode, align_corners=False):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
scale_factor (float): scaling
|
||||
mode (str): interpolation mode
|
||||
"""
|
||||
super(Interpolate, self).__init__()
|
||||
|
||||
self.interp = nn.functional.interpolate
|
||||
self.scale_factor = scale_factor
|
||||
self.mode = mode
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (tensor): input
|
||||
|
||||
Returns:
|
||||
tensor: interpolated data
|
||||
"""
|
||||
|
||||
x = self.interp(x,
|
||||
scale_factor=self.scale_factor,
|
||||
mode=self.mode,
|
||||
align_corners=self.align_corners)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ResidualConvUnit(nn.Module):
|
||||
"""Residual convolution module.
|
||||
"""
|
||||
def __init__(self, features):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(features,
|
||||
features,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=True)
|
||||
|
||||
self.conv2 = nn.Conv2d(features,
|
||||
features,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=True)
|
||||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (tensor): input
|
||||
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
out = self.relu(x)
|
||||
out = self.conv1(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv2(out)
|
||||
|
||||
return out + x
|
||||
|
||||
|
||||
class FeatureFusionBlock(nn.Module):
|
||||
"""Feature fusion block.
|
||||
"""
|
||||
def __init__(self, features):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super(FeatureFusionBlock, self).__init__()
|
||||
|
||||
self.resConfUnit1 = ResidualConvUnit(features)
|
||||
self.resConfUnit2 = ResidualConvUnit(features)
|
||||
|
||||
def forward(self, *xs):
|
||||
"""Forward pass.
|
||||
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
output = xs[0]
|
||||
|
||||
if len(xs) == 2:
|
||||
output += self.resConfUnit1(xs[1])
|
||||
|
||||
output = self.resConfUnit2(output)
|
||||
|
||||
output = nn.functional.interpolate(output,
|
||||
scale_factor=2,
|
||||
mode='bilinear',
|
||||
align_corners=True)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ResidualConvUnit_custom(nn.Module):
|
||||
"""Residual convolution module.
|
||||
"""
|
||||
def __init__(self, features, activation, bn):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.bn = bn
|
||||
|
||||
self.groups = 1
|
||||
|
||||
self.conv1 = nn.Conv2d(features,
|
||||
features,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=True,
|
||||
groups=self.groups)
|
||||
|
||||
self.conv2 = nn.Conv2d(features,
|
||||
features,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=True,
|
||||
groups=self.groups)
|
||||
|
||||
if self.bn is True:
|
||||
self.bn1 = nn.BatchNorm2d(features)
|
||||
self.bn2 = nn.BatchNorm2d(features)
|
||||
|
||||
self.activation = activation
|
||||
|
||||
self.skip_add = nn.quantized.FloatFunctional()
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (tensor): input
|
||||
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
|
||||
out = self.activation(x)
|
||||
out = self.conv1(out)
|
||||
if self.bn is True:
|
||||
out = self.bn1(out)
|
||||
|
||||
out = self.activation(out)
|
||||
out = self.conv2(out)
|
||||
if self.bn is True:
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.groups > 1:
|
||||
out = self.conv_merge(out)
|
||||
|
||||
return self.skip_add.add(out, x)
|
||||
|
||||
# return out + x
|
||||
|
||||
|
||||
class FeatureFusionBlock_custom(nn.Module):
|
||||
"""Feature fusion block.
|
||||
"""
|
||||
def __init__(self,
|
||||
features,
|
||||
activation,
|
||||
deconv=False,
|
||||
bn=False,
|
||||
expand=False,
|
||||
align_corners=True):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super(FeatureFusionBlock_custom, self).__init__()
|
||||
|
||||
self.deconv = deconv
|
||||
self.align_corners = align_corners
|
||||
|
||||
self.groups = 1
|
||||
|
||||
self.expand = expand
|
||||
out_features = features
|
||||
if self.expand is True:
|
||||
out_features = features // 2
|
||||
|
||||
self.out_conv = nn.Conv2d(features,
|
||||
out_features,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=True,
|
||||
groups=1)
|
||||
|
||||
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
|
||||
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
|
||||
|
||||
self.skip_add = nn.quantized.FloatFunctional()
|
||||
|
||||
def forward(self, *xs):
|
||||
"""Forward pass.
|
||||
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
output = xs[0]
|
||||
|
||||
if len(xs) == 2:
|
||||
res = self.resConfUnit1(xs[1])
|
||||
output = self.skip_add.add(output, res)
|
||||
# output += res
|
||||
|
||||
output = self.resConfUnit2(output)
|
||||
|
||||
output = nn.functional.interpolate(output,
|
||||
scale_factor=2,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
|
||||
output = self.out_conv(output)
|
||||
|
||||
return output
|
||||
84
preprocessing/midas/depth.py
Normal file
84
preprocessing/midas/depth.py
Normal file
@ -0,0 +1,84 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
import cv2
|
||||
|
||||
|
||||
|
||||
def convert_to_numpy(image):
|
||||
if isinstance(image, Image.Image):
|
||||
image = np.array(image)
|
||||
elif isinstance(image, torch.Tensor):
|
||||
image = image.detach().cpu().numpy()
|
||||
elif isinstance(image, np.ndarray):
|
||||
image = image.copy()
|
||||
else:
|
||||
raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.'
|
||||
return image
|
||||
|
||||
def resize_image(input_image, resolution):
|
||||
H, W, C = input_image.shape
|
||||
H = float(H)
|
||||
W = float(W)
|
||||
k = float(resolution) / min(H, W)
|
||||
H *= k
|
||||
W *= k
|
||||
H = int(np.round(H / 64.0)) * 64
|
||||
W = int(np.round(W / 64.0)) * 64
|
||||
img = cv2.resize(
|
||||
input_image, (W, H),
|
||||
interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
|
||||
return img, k
|
||||
|
||||
|
||||
def resize_image_ori(h, w, image, k):
|
||||
img = cv2.resize(
|
||||
image, (w, h),
|
||||
interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
|
||||
return img
|
||||
|
||||
class DepthAnnotator:
|
||||
def __init__(self, cfg, device=None):
|
||||
from .api import MiDaSInference
|
||||
pretrained_model = cfg['PRETRAINED_MODEL']
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
||||
self.model = MiDaSInference(model_type='dpt_hybrid', model_path=pretrained_model).to(self.device)
|
||||
self.a = cfg.get('A', np.pi * 2.0)
|
||||
self.bg_th = cfg.get('BG_TH', 0.1)
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
@torch.autocast('cuda', enabled=False)
|
||||
def forward(self, image):
|
||||
image = convert_to_numpy(image)
|
||||
image_depth = image
|
||||
h, w, c = image.shape
|
||||
image_depth, k = resize_image(image_depth,
|
||||
1024 if min(h, w) > 1024 else min(h, w))
|
||||
image_depth = torch.from_numpy(image_depth).float().to(self.device)
|
||||
image_depth = image_depth / 127.5 - 1.0
|
||||
image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
|
||||
depth = self.model(image_depth)[0]
|
||||
|
||||
depth_pt = depth.clone()
|
||||
depth_pt -= torch.min(depth_pt)
|
||||
depth_pt /= torch.max(depth_pt)
|
||||
depth_pt = depth_pt.cpu().numpy()
|
||||
depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
|
||||
depth_image = depth_image[..., None].repeat(3, 2)
|
||||
|
||||
depth_image = resize_image_ori(h, w, depth_image, k)
|
||||
return depth_image
|
||||
|
||||
|
||||
class DepthVideoAnnotator(DepthAnnotator):
|
||||
def forward(self, frames):
|
||||
ret_frames = []
|
||||
for frame in frames:
|
||||
anno_frame = super().forward(np.array(frame))
|
||||
ret_frames.append(anno_frame)
|
||||
return ret_frames
|
||||
107
preprocessing/midas/dpt_depth.py
Normal file
107
preprocessing/midas/dpt_depth.py
Normal file
@ -0,0 +1,107 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .base_model import BaseModel
|
||||
from .blocks import FeatureFusionBlock_custom, Interpolate, _make_encoder
|
||||
from .vit import forward_vit
|
||||
|
||||
|
||||
def _make_fusion_block(features, use_bn):
|
||||
return FeatureFusionBlock_custom(
|
||||
features,
|
||||
nn.ReLU(False),
|
||||
deconv=False,
|
||||
bn=use_bn,
|
||||
expand=False,
|
||||
align_corners=True,
|
||||
)
|
||||
|
||||
|
||||
class DPT(BaseModel):
|
||||
def __init__(
|
||||
self,
|
||||
head,
|
||||
features=256,
|
||||
backbone='vitb_rn50_384',
|
||||
readout='project',
|
||||
channels_last=False,
|
||||
use_bn=False,
|
||||
):
|
||||
|
||||
super(DPT, self).__init__()
|
||||
|
||||
self.channels_last = channels_last
|
||||
|
||||
hooks = {
|
||||
'vitb_rn50_384': [0, 1, 8, 11],
|
||||
'vitb16_384': [2, 5, 8, 11],
|
||||
'vitl16_384': [5, 11, 17, 23],
|
||||
}
|
||||
|
||||
# Instantiate backbone and reassemble blocks
|
||||
self.pretrained, self.scratch = _make_encoder(
|
||||
backbone,
|
||||
features,
|
||||
False, # Set to true of you want to train from scratch, uses ImageNet weights
|
||||
groups=1,
|
||||
expand=False,
|
||||
exportable=False,
|
||||
hooks=hooks[backbone],
|
||||
use_readout=readout,
|
||||
)
|
||||
|
||||
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
||||
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
||||
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
||||
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
||||
|
||||
self.scratch.output_conv = head
|
||||
|
||||
def forward(self, x):
|
||||
if self.channels_last is True:
|
||||
x.contiguous(memory_format=torch.channels_last)
|
||||
|
||||
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
|
||||
|
||||
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
||||
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
||||
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
||||
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
||||
|
||||
path_4 = self.scratch.refinenet4(layer_4_rn)
|
||||
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
||||
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
||||
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
||||
|
||||
out = self.scratch.output_conv(path_1)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class DPTDepthModel(DPT):
|
||||
def __init__(self, path=None, non_negative=True, **kwargs):
|
||||
features = kwargs['features'] if 'features' in kwargs else 256
|
||||
|
||||
head = nn.Sequential(
|
||||
nn.Conv2d(features,
|
||||
features // 2,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1),
|
||||
Interpolate(scale_factor=2, mode='bilinear', align_corners=True),
|
||||
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
||||
nn.ReLU(True) if non_negative else nn.Identity(),
|
||||
nn.Identity(),
|
||||
)
|
||||
|
||||
super().__init__(head, **kwargs)
|
||||
|
||||
if path is not None:
|
||||
self.load(path)
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x).squeeze(dim=1)
|
||||
80
preprocessing/midas/midas_net.py
Normal file
80
preprocessing/midas/midas_net.py
Normal file
@ -0,0 +1,80 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
||||
This file contains code that is adapted from
|
||||
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .base_model import BaseModel
|
||||
from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
|
||||
|
||||
|
||||
class MidasNet(BaseModel):
|
||||
"""Network for monocular depth estimation.
|
||||
"""
|
||||
def __init__(self, path=None, features=256, non_negative=True):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
path (str, optional): Path to saved model. Defaults to None.
|
||||
features (int, optional): Number of features. Defaults to 256.
|
||||
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
||||
"""
|
||||
print('Loading weights: ', path)
|
||||
|
||||
super(MidasNet, self).__init__()
|
||||
|
||||
use_pretrained = False if path is None else True
|
||||
|
||||
self.pretrained, self.scratch = _make_encoder(
|
||||
backbone='resnext101_wsl',
|
||||
features=features,
|
||||
use_pretrained=use_pretrained)
|
||||
|
||||
self.scratch.refinenet4 = FeatureFusionBlock(features)
|
||||
self.scratch.refinenet3 = FeatureFusionBlock(features)
|
||||
self.scratch.refinenet2 = FeatureFusionBlock(features)
|
||||
self.scratch.refinenet1 = FeatureFusionBlock(features)
|
||||
|
||||
self.scratch.output_conv = nn.Sequential(
|
||||
nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
|
||||
Interpolate(scale_factor=2, mode='bilinear'),
|
||||
nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
||||
nn.ReLU(True) if non_negative else nn.Identity(),
|
||||
)
|
||||
|
||||
if path:
|
||||
self.load(path)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (tensor): input data (image)
|
||||
|
||||
Returns:
|
||||
tensor: depth
|
||||
"""
|
||||
|
||||
layer_1 = self.pretrained.layer1(x)
|
||||
layer_2 = self.pretrained.layer2(layer_1)
|
||||
layer_3 = self.pretrained.layer3(layer_2)
|
||||
layer_4 = self.pretrained.layer4(layer_3)
|
||||
|
||||
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
||||
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
||||
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
||||
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
||||
|
||||
path_4 = self.scratch.refinenet4(layer_4_rn)
|
||||
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
||||
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
||||
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
||||
|
||||
out = self.scratch.output_conv(path_1)
|
||||
|
||||
return torch.squeeze(out, dim=1)
|
||||
167
preprocessing/midas/midas_net_custom.py
Normal file
167
preprocessing/midas/midas_net_custom.py
Normal file
@ -0,0 +1,167 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
||||
This file contains code that is adapted from
|
||||
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .base_model import BaseModel
|
||||
from .blocks import FeatureFusionBlock_custom, Interpolate, _make_encoder
|
||||
|
||||
|
||||
class MidasNet_small(BaseModel):
|
||||
"""Network for monocular depth estimation.
|
||||
"""
|
||||
def __init__(self,
|
||||
path=None,
|
||||
features=64,
|
||||
backbone='efficientnet_lite3',
|
||||
non_negative=True,
|
||||
exportable=True,
|
||||
channels_last=False,
|
||||
align_corners=True,
|
||||
blocks={'expand': True}):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
path (str, optional): Path to saved model. Defaults to None.
|
||||
features (int, optional): Number of features. Defaults to 256.
|
||||
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
||||
"""
|
||||
print('Loading weights: ', path)
|
||||
|
||||
super(MidasNet_small, self).__init__()
|
||||
|
||||
use_pretrained = False if path else True
|
||||
|
||||
self.channels_last = channels_last
|
||||
self.blocks = blocks
|
||||
self.backbone = backbone
|
||||
|
||||
self.groups = 1
|
||||
|
||||
features1 = features
|
||||
features2 = features
|
||||
features3 = features
|
||||
features4 = features
|
||||
self.expand = False
|
||||
if 'expand' in self.blocks and self.blocks['expand'] is True:
|
||||
self.expand = True
|
||||
features1 = features
|
||||
features2 = features * 2
|
||||
features3 = features * 4
|
||||
features4 = features * 8
|
||||
|
||||
self.pretrained, self.scratch = _make_encoder(self.backbone,
|
||||
features,
|
||||
use_pretrained,
|
||||
groups=self.groups,
|
||||
expand=self.expand,
|
||||
exportable=exportable)
|
||||
|
||||
self.scratch.activation = nn.ReLU(False)
|
||||
|
||||
self.scratch.refinenet4 = FeatureFusionBlock_custom(
|
||||
features4,
|
||||
self.scratch.activation,
|
||||
deconv=False,
|
||||
bn=False,
|
||||
expand=self.expand,
|
||||
align_corners=align_corners)
|
||||
self.scratch.refinenet3 = FeatureFusionBlock_custom(
|
||||
features3,
|
||||
self.scratch.activation,
|
||||
deconv=False,
|
||||
bn=False,
|
||||
expand=self.expand,
|
||||
align_corners=align_corners)
|
||||
self.scratch.refinenet2 = FeatureFusionBlock_custom(
|
||||
features2,
|
||||
self.scratch.activation,
|
||||
deconv=False,
|
||||
bn=False,
|
||||
expand=self.expand,
|
||||
align_corners=align_corners)
|
||||
self.scratch.refinenet1 = FeatureFusionBlock_custom(
|
||||
features1,
|
||||
self.scratch.activation,
|
||||
deconv=False,
|
||||
bn=False,
|
||||
align_corners=align_corners)
|
||||
|
||||
self.scratch.output_conv = nn.Sequential(
|
||||
nn.Conv2d(features,
|
||||
features // 2,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=self.groups),
|
||||
Interpolate(scale_factor=2, mode='bilinear'),
|
||||
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
||||
self.scratch.activation,
|
||||
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
||||
nn.ReLU(True) if non_negative else nn.Identity(),
|
||||
nn.Identity(),
|
||||
)
|
||||
|
||||
if path:
|
||||
self.load(path)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (tensor): input data (image)
|
||||
|
||||
Returns:
|
||||
tensor: depth
|
||||
"""
|
||||
if self.channels_last is True:
|
||||
print('self.channels_last = ', self.channels_last)
|
||||
x.contiguous(memory_format=torch.channels_last)
|
||||
|
||||
layer_1 = self.pretrained.layer1(x)
|
||||
layer_2 = self.pretrained.layer2(layer_1)
|
||||
layer_3 = self.pretrained.layer3(layer_2)
|
||||
layer_4 = self.pretrained.layer4(layer_3)
|
||||
|
||||
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
||||
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
||||
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
||||
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
||||
|
||||
path_4 = self.scratch.refinenet4(layer_4_rn)
|
||||
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
||||
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
||||
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
||||
|
||||
out = self.scratch.output_conv(path_1)
|
||||
|
||||
return torch.squeeze(out, dim=1)
|
||||
|
||||
|
||||
def fuse_model(m):
|
||||
prev_previous_type = nn.Identity()
|
||||
prev_previous_name = ''
|
||||
previous_type = nn.Identity()
|
||||
previous_name = ''
|
||||
for name, module in m.named_modules():
|
||||
if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(
|
||||
module) == nn.ReLU:
|
||||
# print("FUSED ", prev_previous_name, previous_name, name)
|
||||
torch.quantization.fuse_modules(
|
||||
m, [prev_previous_name, previous_name, name], inplace=True)
|
||||
elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
|
||||
# print("FUSED ", prev_previous_name, previous_name)
|
||||
torch.quantization.fuse_modules(
|
||||
m, [prev_previous_name, previous_name], inplace=True)
|
||||
# elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
|
||||
# print("FUSED ", previous_name, name)
|
||||
# torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
|
||||
|
||||
prev_previous_type = previous_type
|
||||
prev_previous_name = previous_name
|
||||
previous_type = type(module)
|
||||
previous_name = name
|
||||
231
preprocessing/midas/transforms.py
Normal file
231
preprocessing/midas/transforms.py
Normal file
@ -0,0 +1,231 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import math
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
|
||||
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
|
||||
|
||||
Args:
|
||||
sample (dict): sample
|
||||
size (tuple): image size
|
||||
|
||||
Returns:
|
||||
tuple: new size
|
||||
"""
|
||||
shape = list(sample['disparity'].shape)
|
||||
|
||||
if shape[0] >= size[0] and shape[1] >= size[1]:
|
||||
return sample
|
||||
|
||||
scale = [0, 0]
|
||||
scale[0] = size[0] / shape[0]
|
||||
scale[1] = size[1] / shape[1]
|
||||
|
||||
scale = max(scale)
|
||||
|
||||
shape[0] = math.ceil(scale * shape[0])
|
||||
shape[1] = math.ceil(scale * shape[1])
|
||||
|
||||
# resize
|
||||
sample['image'] = cv2.resize(sample['image'],
|
||||
tuple(shape[::-1]),
|
||||
interpolation=image_interpolation_method)
|
||||
|
||||
sample['disparity'] = cv2.resize(sample['disparity'],
|
||||
tuple(shape[::-1]),
|
||||
interpolation=cv2.INTER_NEAREST)
|
||||
sample['mask'] = cv2.resize(
|
||||
sample['mask'].astype(np.float32),
|
||||
tuple(shape[::-1]),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
sample['mask'] = sample['mask'].astype(bool)
|
||||
|
||||
return tuple(shape)
|
||||
|
||||
|
||||
class Resize(object):
|
||||
"""Resize sample to given size (width, height).
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
width,
|
||||
height,
|
||||
resize_target=True,
|
||||
keep_aspect_ratio=False,
|
||||
ensure_multiple_of=1,
|
||||
resize_method='lower_bound',
|
||||
image_interpolation_method=cv2.INTER_AREA,
|
||||
):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
width (int): desired output width
|
||||
height (int): desired output height
|
||||
resize_target (bool, optional):
|
||||
True: Resize the full sample (image, mask, target).
|
||||
False: Resize image only.
|
||||
Defaults to True.
|
||||
keep_aspect_ratio (bool, optional):
|
||||
True: Keep the aspect ratio of the input sample.
|
||||
Output sample might not have the given width and height, and
|
||||
resize behaviour depends on the parameter 'resize_method'.
|
||||
Defaults to False.
|
||||
ensure_multiple_of (int, optional):
|
||||
Output width and height is constrained to be multiple of this parameter.
|
||||
Defaults to 1.
|
||||
resize_method (str, optional):
|
||||
"lower_bound": Output will be at least as large as the given size.
|
||||
"upper_bound": Output will be at max as large as the given size. "
|
||||
"(Output size might be smaller than given size.)"
|
||||
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
||||
Defaults to "lower_bound".
|
||||
"""
|
||||
self.__width = width
|
||||
self.__height = height
|
||||
|
||||
self.__resize_target = resize_target
|
||||
self.__keep_aspect_ratio = keep_aspect_ratio
|
||||
self.__multiple_of = ensure_multiple_of
|
||||
self.__resize_method = resize_method
|
||||
self.__image_interpolation_method = image_interpolation_method
|
||||
|
||||
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
||||
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
||||
|
||||
if max_val is not None and y > max_val:
|
||||
y = (np.floor(x / self.__multiple_of) *
|
||||
self.__multiple_of).astype(int)
|
||||
|
||||
if y < min_val:
|
||||
y = (np.ceil(x / self.__multiple_of) *
|
||||
self.__multiple_of).astype(int)
|
||||
|
||||
return y
|
||||
|
||||
def get_size(self, width, height):
|
||||
# determine new height and width
|
||||
scale_height = self.__height / height
|
||||
scale_width = self.__width / width
|
||||
|
||||
if self.__keep_aspect_ratio:
|
||||
if self.__resize_method == 'lower_bound':
|
||||
# scale such that output size is lower bound
|
||||
if scale_width > scale_height:
|
||||
# fit width
|
||||
scale_height = scale_width
|
||||
else:
|
||||
# fit height
|
||||
scale_width = scale_height
|
||||
elif self.__resize_method == 'upper_bound':
|
||||
# scale such that output size is upper bound
|
||||
if scale_width < scale_height:
|
||||
# fit width
|
||||
scale_height = scale_width
|
||||
else:
|
||||
# fit height
|
||||
scale_width = scale_height
|
||||
elif self.__resize_method == 'minimal':
|
||||
# scale as least as possbile
|
||||
if abs(1 - scale_width) < abs(1 - scale_height):
|
||||
# fit width
|
||||
scale_height = scale_width
|
||||
else:
|
||||
# fit height
|
||||
scale_width = scale_height
|
||||
else:
|
||||
raise ValueError(
|
||||
f'resize_method {self.__resize_method} not implemented')
|
||||
|
||||
if self.__resize_method == 'lower_bound':
|
||||
new_height = self.constrain_to_multiple_of(scale_height * height,
|
||||
min_val=self.__height)
|
||||
new_width = self.constrain_to_multiple_of(scale_width * width,
|
||||
min_val=self.__width)
|
||||
elif self.__resize_method == 'upper_bound':
|
||||
new_height = self.constrain_to_multiple_of(scale_height * height,
|
||||
max_val=self.__height)
|
||||
new_width = self.constrain_to_multiple_of(scale_width * width,
|
||||
max_val=self.__width)
|
||||
elif self.__resize_method == 'minimal':
|
||||
new_height = self.constrain_to_multiple_of(scale_height * height)
|
||||
new_width = self.constrain_to_multiple_of(scale_width * width)
|
||||
else:
|
||||
raise ValueError(
|
||||
f'resize_method {self.__resize_method} not implemented')
|
||||
|
||||
return (new_width, new_height)
|
||||
|
||||
def __call__(self, sample):
|
||||
width, height = self.get_size(sample['image'].shape[1],
|
||||
sample['image'].shape[0])
|
||||
|
||||
# resize sample
|
||||
sample['image'] = cv2.resize(
|
||||
sample['image'],
|
||||
(width, height),
|
||||
interpolation=self.__image_interpolation_method,
|
||||
)
|
||||
|
||||
if self.__resize_target:
|
||||
if 'disparity' in sample:
|
||||
sample['disparity'] = cv2.resize(
|
||||
sample['disparity'],
|
||||
(width, height),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
|
||||
if 'depth' in sample:
|
||||
sample['depth'] = cv2.resize(sample['depth'], (width, height),
|
||||
interpolation=cv2.INTER_NEAREST)
|
||||
|
||||
sample['mask'] = cv2.resize(
|
||||
sample['mask'].astype(np.float32),
|
||||
(width, height),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
sample['mask'] = sample['mask'].astype(bool)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class NormalizeImage(object):
|
||||
"""Normlize image by given mean and std.
|
||||
"""
|
||||
def __init__(self, mean, std):
|
||||
self.__mean = mean
|
||||
self.__std = std
|
||||
|
||||
def __call__(self, sample):
|
||||
sample['image'] = (sample['image'] - self.__mean) / self.__std
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class PrepareForNet(object):
|
||||
"""Prepare sample for usage as network input.
|
||||
"""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, sample):
|
||||
image = np.transpose(sample['image'], (2, 0, 1))
|
||||
sample['image'] = np.ascontiguousarray(image).astype(np.float32)
|
||||
|
||||
if 'mask' in sample:
|
||||
sample['mask'] = sample['mask'].astype(np.float32)
|
||||
sample['mask'] = np.ascontiguousarray(sample['mask'])
|
||||
|
||||
if 'disparity' in sample:
|
||||
disparity = sample['disparity'].astype(np.float32)
|
||||
sample['disparity'] = np.ascontiguousarray(disparity)
|
||||
|
||||
if 'depth' in sample:
|
||||
depth = sample['depth'].astype(np.float32)
|
||||
sample['depth'] = np.ascontiguousarray(depth)
|
||||
|
||||
return sample
|
||||
193
preprocessing/midas/utils.py
Normal file
193
preprocessing/midas/utils.py
Normal file
@ -0,0 +1,193 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
"""Utils for monoDepth."""
|
||||
import re
|
||||
import sys
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def read_pfm(path):
|
||||
"""Read pfm file.
|
||||
|
||||
Args:
|
||||
path (str): path to file
|
||||
|
||||
Returns:
|
||||
tuple: (data, scale)
|
||||
"""
|
||||
with open(path, 'rb') as file:
|
||||
|
||||
color = None
|
||||
width = None
|
||||
height = None
|
||||
scale = None
|
||||
endian = None
|
||||
|
||||
header = file.readline().rstrip()
|
||||
if header.decode('ascii') == 'PF':
|
||||
color = True
|
||||
elif header.decode('ascii') == 'Pf':
|
||||
color = False
|
||||
else:
|
||||
raise Exception('Not a PFM file: ' + path)
|
||||
|
||||
dim_match = re.match(r'^(\d+)\s(\d+)\s$',
|
||||
file.readline().decode('ascii'))
|
||||
if dim_match:
|
||||
width, height = list(map(int, dim_match.groups()))
|
||||
else:
|
||||
raise Exception('Malformed PFM header.')
|
||||
|
||||
scale = float(file.readline().decode('ascii').rstrip())
|
||||
if scale < 0:
|
||||
# little-endian
|
||||
endian = '<'
|
||||
scale = -scale
|
||||
else:
|
||||
# big-endian
|
||||
endian = '>'
|
||||
|
||||
data = np.fromfile(file, endian + 'f')
|
||||
shape = (height, width, 3) if color else (height, width)
|
||||
|
||||
data = np.reshape(data, shape)
|
||||
data = np.flipud(data)
|
||||
|
||||
return data, scale
|
||||
|
||||
|
||||
def write_pfm(path, image, scale=1):
|
||||
"""Write pfm file.
|
||||
|
||||
Args:
|
||||
path (str): pathto file
|
||||
image (array): data
|
||||
scale (int, optional): Scale. Defaults to 1.
|
||||
"""
|
||||
|
||||
with open(path, 'wb') as file:
|
||||
color = None
|
||||
|
||||
if image.dtype.name != 'float32':
|
||||
raise Exception('Image dtype must be float32.')
|
||||
|
||||
image = np.flipud(image)
|
||||
|
||||
if len(image.shape) == 3 and image.shape[2] == 3: # color image
|
||||
color = True
|
||||
elif (len(image.shape) == 2
|
||||
or len(image.shape) == 3 and image.shape[2] == 1): # greyscale
|
||||
color = False
|
||||
else:
|
||||
raise Exception(
|
||||
'Image must have H x W x 3, H x W x 1 or H x W dimensions.')
|
||||
|
||||
file.write('PF\n' if color else 'Pf\n'.encode())
|
||||
file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0]))
|
||||
|
||||
endian = image.dtype.byteorder
|
||||
|
||||
if endian == '<' or endian == '=' and sys.byteorder == 'little':
|
||||
scale = -scale
|
||||
|
||||
file.write('%f\n'.encode() % scale)
|
||||
|
||||
image.tofile(file)
|
||||
|
||||
|
||||
def read_image(path):
|
||||
"""Read image and output RGB image (0-1).
|
||||
|
||||
Args:
|
||||
path (str): path to file
|
||||
|
||||
Returns:
|
||||
array: RGB image (0-1)
|
||||
"""
|
||||
img = cv2.imread(path)
|
||||
|
||||
if img.ndim == 2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def resize_image(img):
|
||||
"""Resize image and make it fit for network.
|
||||
|
||||
Args:
|
||||
img (array): image
|
||||
|
||||
Returns:
|
||||
tensor: data ready for network
|
||||
"""
|
||||
height_orig = img.shape[0]
|
||||
width_orig = img.shape[1]
|
||||
|
||||
if width_orig > height_orig:
|
||||
scale = width_orig / 384
|
||||
else:
|
||||
scale = height_orig / 384
|
||||
|
||||
height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
|
||||
width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
|
||||
|
||||
img_resized = cv2.resize(img, (width, height),
|
||||
interpolation=cv2.INTER_AREA)
|
||||
|
||||
img_resized = (torch.from_numpy(np.transpose(
|
||||
img_resized, (2, 0, 1))).contiguous().float())
|
||||
img_resized = img_resized.unsqueeze(0)
|
||||
|
||||
return img_resized
|
||||
|
||||
|
||||
def resize_depth(depth, width, height):
|
||||
"""Resize depth map and bring to CPU (numpy).
|
||||
|
||||
Args:
|
||||
depth (tensor): depth
|
||||
width (int): image width
|
||||
height (int): image height
|
||||
|
||||
Returns:
|
||||
array: processed depth
|
||||
"""
|
||||
depth = torch.squeeze(depth[0, :, :, :]).to('cpu')
|
||||
|
||||
depth_resized = cv2.resize(depth.numpy(), (width, height),
|
||||
interpolation=cv2.INTER_CUBIC)
|
||||
|
||||
return depth_resized
|
||||
|
||||
|
||||
def write_depth(path, depth, bits=1):
|
||||
"""Write depth map to pfm and png file.
|
||||
|
||||
Args:
|
||||
path (str): filepath without extension
|
||||
depth (array): depth
|
||||
"""
|
||||
write_pfm(path + '.pfm', depth.astype(np.float32))
|
||||
|
||||
depth_min = depth.min()
|
||||
depth_max = depth.max()
|
||||
|
||||
max_val = (2**(8 * bits)) - 1
|
||||
|
||||
if depth_max - depth_min > np.finfo('float').eps:
|
||||
out = max_val * (depth - depth_min) / (depth_max - depth_min)
|
||||
else:
|
||||
out = np.zeros(depth.shape, dtype=depth.type)
|
||||
|
||||
if bits == 1:
|
||||
cv2.imwrite(path + '.png', out.astype('uint8'))
|
||||
elif bits == 2:
|
||||
cv2.imwrite(path + '.png', out.astype('uint16'))
|
||||
|
||||
return
|
||||
510
preprocessing/midas/vit.py
Normal file
510
preprocessing/midas/vit.py
Normal file
@ -0,0 +1,510 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import math
|
||||
import types
|
||||
|
||||
import timm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Slice(nn.Module):
|
||||
def __init__(self, start_index=1):
|
||||
super(Slice, self).__init__()
|
||||
self.start_index = start_index
|
||||
|
||||
def forward(self, x):
|
||||
return x[:, self.start_index:]
|
||||
|
||||
|
||||
class AddReadout(nn.Module):
|
||||
def __init__(self, start_index=1):
|
||||
super(AddReadout, self).__init__()
|
||||
self.start_index = start_index
|
||||
|
||||
def forward(self, x):
|
||||
if self.start_index == 2:
|
||||
readout = (x[:, 0] + x[:, 1]) / 2
|
||||
else:
|
||||
readout = x[:, 0]
|
||||
return x[:, self.start_index:] + readout.unsqueeze(1)
|
||||
|
||||
|
||||
class ProjectReadout(nn.Module):
|
||||
def __init__(self, in_features, start_index=1):
|
||||
super(ProjectReadout, self).__init__()
|
||||
self.start_index = start_index
|
||||
|
||||
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features),
|
||||
nn.GELU())
|
||||
|
||||
def forward(self, x):
|
||||
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index:])
|
||||
features = torch.cat((x[:, self.start_index:], readout), -1)
|
||||
|
||||
return self.project(features)
|
||||
|
||||
|
||||
class Transpose(nn.Module):
|
||||
def __init__(self, dim0, dim1):
|
||||
super(Transpose, self).__init__()
|
||||
self.dim0 = dim0
|
||||
self.dim1 = dim1
|
||||
|
||||
def forward(self, x):
|
||||
x = x.transpose(self.dim0, self.dim1)
|
||||
return x
|
||||
|
||||
|
||||
def forward_vit(pretrained, x):
|
||||
b, c, h, w = x.shape
|
||||
|
||||
_ = pretrained.model.forward_flex(x)
|
||||
|
||||
layer_1 = pretrained.activations['1']
|
||||
layer_2 = pretrained.activations['2']
|
||||
layer_3 = pretrained.activations['3']
|
||||
layer_4 = pretrained.activations['4']
|
||||
|
||||
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
|
||||
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
|
||||
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
|
||||
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
|
||||
|
||||
unflatten = nn.Sequential(
|
||||
nn.Unflatten(
|
||||
2,
|
||||
torch.Size([
|
||||
h // pretrained.model.patch_size[1],
|
||||
w // pretrained.model.patch_size[0],
|
||||
]),
|
||||
))
|
||||
|
||||
if layer_1.ndim == 3:
|
||||
layer_1 = unflatten(layer_1)
|
||||
if layer_2.ndim == 3:
|
||||
layer_2 = unflatten(layer_2)
|
||||
if layer_3.ndim == 3:
|
||||
layer_3 = unflatten(layer_3)
|
||||
if layer_4.ndim == 3:
|
||||
layer_4 = unflatten(layer_4)
|
||||
|
||||
layer_1 = pretrained.act_postprocess1[3:len(pretrained.act_postprocess1)](
|
||||
layer_1)
|
||||
layer_2 = pretrained.act_postprocess2[3:len(pretrained.act_postprocess2)](
|
||||
layer_2)
|
||||
layer_3 = pretrained.act_postprocess3[3:len(pretrained.act_postprocess3)](
|
||||
layer_3)
|
||||
layer_4 = pretrained.act_postprocess4[3:len(pretrained.act_postprocess4)](
|
||||
layer_4)
|
||||
|
||||
return layer_1, layer_2, layer_3, layer_4
|
||||
|
||||
|
||||
def _resize_pos_embed(self, posemb, gs_h, gs_w):
|
||||
posemb_tok, posemb_grid = (
|
||||
posemb[:, :self.start_index],
|
||||
posemb[0, self.start_index:],
|
||||
)
|
||||
|
||||
gs_old = int(math.sqrt(len(posemb_grid)))
|
||||
|
||||
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old,
|
||||
-1).permute(0, 3, 1, 2)
|
||||
posemb_grid = F.interpolate(posemb_grid,
|
||||
size=(gs_h, gs_w),
|
||||
mode='bilinear')
|
||||
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
|
||||
|
||||
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
||||
|
||||
return posemb
|
||||
|
||||
|
||||
def forward_flex(self, x):
|
||||
b, c, h, w = x.shape
|
||||
|
||||
pos_embed = self._resize_pos_embed(self.pos_embed, h // self.patch_size[1],
|
||||
w // self.patch_size[0])
|
||||
|
||||
B = x.shape[0]
|
||||
|
||||
if hasattr(self.patch_embed, 'backbone'):
|
||||
x = self.patch_embed.backbone(x)
|
||||
if isinstance(x, (list, tuple)):
|
||||
x = x[
|
||||
-1] # last feature if backbone outputs list/tuple of features
|
||||
|
||||
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
|
||||
|
||||
if getattr(self, 'dist_token', None) is not None:
|
||||
cls_tokens = self.cls_token.expand(
|
||||
B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
||||
dist_token = self.dist_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
||||
else:
|
||||
cls_tokens = self.cls_token.expand(
|
||||
B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
x = x + pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
activations = {}
|
||||
|
||||
|
||||
def get_activation(name):
|
||||
def hook(model, input, output):
|
||||
activations[name] = output
|
||||
|
||||
return hook
|
||||
|
||||
|
||||
def get_readout_oper(vit_features, features, use_readout, start_index=1):
|
||||
if use_readout == 'ignore':
|
||||
readout_oper = [Slice(start_index)] * len(features)
|
||||
elif use_readout == 'add':
|
||||
readout_oper = [AddReadout(start_index)] * len(features)
|
||||
elif use_readout == 'project':
|
||||
readout_oper = [
|
||||
ProjectReadout(vit_features, start_index) for out_feat in features
|
||||
]
|
||||
else:
|
||||
assert (
|
||||
False
|
||||
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
|
||||
|
||||
return readout_oper
|
||||
|
||||
|
||||
def _make_vit_b16_backbone(
|
||||
model,
|
||||
features=[96, 192, 384, 768],
|
||||
size=[384, 384],
|
||||
hooks=[2, 5, 8, 11],
|
||||
vit_features=768,
|
||||
use_readout='ignore',
|
||||
start_index=1,
|
||||
):
|
||||
pretrained = nn.Module()
|
||||
|
||||
pretrained.model = model
|
||||
pretrained.model.blocks[hooks[0]].register_forward_hook(
|
||||
get_activation('1'))
|
||||
pretrained.model.blocks[hooks[1]].register_forward_hook(
|
||||
get_activation('2'))
|
||||
pretrained.model.blocks[hooks[2]].register_forward_hook(
|
||||
get_activation('3'))
|
||||
pretrained.model.blocks[hooks[3]].register_forward_hook(
|
||||
get_activation('4'))
|
||||
|
||||
pretrained.activations = activations
|
||||
|
||||
readout_oper = get_readout_oper(vit_features, features, use_readout,
|
||||
start_index)
|
||||
|
||||
# 32, 48, 136, 384
|
||||
pretrained.act_postprocess1 = nn.Sequential(
|
||||
readout_oper[0],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[0],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=features[0],
|
||||
out_channels=features[0],
|
||||
kernel_size=4,
|
||||
stride=4,
|
||||
padding=0,
|
||||
bias=True,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
),
|
||||
)
|
||||
|
||||
pretrained.act_postprocess2 = nn.Sequential(
|
||||
readout_oper[1],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[1],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=features[1],
|
||||
out_channels=features[1],
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0,
|
||||
bias=True,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
),
|
||||
)
|
||||
|
||||
pretrained.act_postprocess3 = nn.Sequential(
|
||||
readout_oper[2],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[2],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
)
|
||||
|
||||
pretrained.act_postprocess4 = nn.Sequential(
|
||||
readout_oper[3],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[3],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
nn.Conv2d(
|
||||
in_channels=features[3],
|
||||
out_channels=features[3],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
),
|
||||
)
|
||||
|
||||
pretrained.model.start_index = start_index
|
||||
pretrained.model.patch_size = [16, 16]
|
||||
|
||||
# We inject this function into the VisionTransformer instances so that
|
||||
# we can use it with interpolated position embeddings without modifying the library source.
|
||||
pretrained.model.forward_flex = types.MethodType(forward_flex,
|
||||
pretrained.model)
|
||||
pretrained.model._resize_pos_embed = types.MethodType(
|
||||
_resize_pos_embed, pretrained.model)
|
||||
|
||||
return pretrained
|
||||
|
||||
|
||||
def _make_pretrained_vitl16_384(pretrained, use_readout='ignore', hooks=None):
|
||||
model = timm.create_model('vit_large_patch16_384', pretrained=pretrained)
|
||||
|
||||
hooks = [5, 11, 17, 23] if hooks is None else hooks
|
||||
return _make_vit_b16_backbone(
|
||||
model,
|
||||
features=[256, 512, 1024, 1024],
|
||||
hooks=hooks,
|
||||
vit_features=1024,
|
||||
use_readout=use_readout,
|
||||
)
|
||||
|
||||
|
||||
def _make_pretrained_vitb16_384(pretrained, use_readout='ignore', hooks=None):
|
||||
model = timm.create_model('vit_base_patch16_384', pretrained=pretrained)
|
||||
|
||||
hooks = [2, 5, 8, 11] if hooks is None else hooks
|
||||
return _make_vit_b16_backbone(model,
|
||||
features=[96, 192, 384, 768],
|
||||
hooks=hooks,
|
||||
use_readout=use_readout)
|
||||
|
||||
|
||||
def _make_pretrained_deitb16_384(pretrained, use_readout='ignore', hooks=None):
|
||||
model = timm.create_model('vit_deit_base_patch16_384',
|
||||
pretrained=pretrained)
|
||||
|
||||
hooks = [2, 5, 8, 11] if hooks is None else hooks
|
||||
return _make_vit_b16_backbone(model,
|
||||
features=[96, 192, 384, 768],
|
||||
hooks=hooks,
|
||||
use_readout=use_readout)
|
||||
|
||||
|
||||
def _make_pretrained_deitb16_distil_384(pretrained,
|
||||
use_readout='ignore',
|
||||
hooks=None):
|
||||
model = timm.create_model('vit_deit_base_distilled_patch16_384',
|
||||
pretrained=pretrained)
|
||||
|
||||
hooks = [2, 5, 8, 11] if hooks is None else hooks
|
||||
return _make_vit_b16_backbone(
|
||||
model,
|
||||
features=[96, 192, 384, 768],
|
||||
hooks=hooks,
|
||||
use_readout=use_readout,
|
||||
start_index=2,
|
||||
)
|
||||
|
||||
|
||||
def _make_vit_b_rn50_backbone(
|
||||
model,
|
||||
features=[256, 512, 768, 768],
|
||||
size=[384, 384],
|
||||
hooks=[0, 1, 8, 11],
|
||||
vit_features=768,
|
||||
use_vit_only=False,
|
||||
use_readout='ignore',
|
||||
start_index=1,
|
||||
):
|
||||
pretrained = nn.Module()
|
||||
|
||||
pretrained.model = model
|
||||
|
||||
if use_vit_only is True:
|
||||
pretrained.model.blocks[hooks[0]].register_forward_hook(
|
||||
get_activation('1'))
|
||||
pretrained.model.blocks[hooks[1]].register_forward_hook(
|
||||
get_activation('2'))
|
||||
else:
|
||||
pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
|
||||
get_activation('1'))
|
||||
pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
|
||||
get_activation('2'))
|
||||
|
||||
pretrained.model.blocks[hooks[2]].register_forward_hook(
|
||||
get_activation('3'))
|
||||
pretrained.model.blocks[hooks[3]].register_forward_hook(
|
||||
get_activation('4'))
|
||||
|
||||
pretrained.activations = activations
|
||||
|
||||
readout_oper = get_readout_oper(vit_features, features, use_readout,
|
||||
start_index)
|
||||
|
||||
if use_vit_only is True:
|
||||
pretrained.act_postprocess1 = nn.Sequential(
|
||||
readout_oper[0],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[0],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=features[0],
|
||||
out_channels=features[0],
|
||||
kernel_size=4,
|
||||
stride=4,
|
||||
padding=0,
|
||||
bias=True,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
),
|
||||
)
|
||||
|
||||
pretrained.act_postprocess2 = nn.Sequential(
|
||||
readout_oper[1],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[1],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=features[1],
|
||||
out_channels=features[1],
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0,
|
||||
bias=True,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
),
|
||||
)
|
||||
else:
|
||||
pretrained.act_postprocess1 = nn.Sequential(nn.Identity(),
|
||||
nn.Identity(),
|
||||
nn.Identity())
|
||||
pretrained.act_postprocess2 = nn.Sequential(nn.Identity(),
|
||||
nn.Identity(),
|
||||
nn.Identity())
|
||||
|
||||
pretrained.act_postprocess3 = nn.Sequential(
|
||||
readout_oper[2],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[2],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
)
|
||||
|
||||
pretrained.act_postprocess4 = nn.Sequential(
|
||||
readout_oper[3],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[3],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
nn.Conv2d(
|
||||
in_channels=features[3],
|
||||
out_channels=features[3],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
),
|
||||
)
|
||||
|
||||
pretrained.model.start_index = start_index
|
||||
pretrained.model.patch_size = [16, 16]
|
||||
|
||||
# We inject this function into the VisionTransformer instances so that
|
||||
# we can use it with interpolated position embeddings without modifying the library source.
|
||||
pretrained.model.forward_flex = types.MethodType(forward_flex,
|
||||
pretrained.model)
|
||||
|
||||
# We inject this function into the VisionTransformer instances so that
|
||||
# we can use it with interpolated position embeddings without modifying the library source.
|
||||
pretrained.model._resize_pos_embed = types.MethodType(
|
||||
_resize_pos_embed, pretrained.model)
|
||||
|
||||
return pretrained
|
||||
|
||||
|
||||
def _make_pretrained_vitb_rn50_384(pretrained,
|
||||
use_readout='ignore',
|
||||
hooks=None,
|
||||
use_vit_only=False):
|
||||
model = timm.create_model('vit_base_resnet50_384', pretrained=pretrained)
|
||||
|
||||
hooks = [0, 1, 8, 11] if hooks is None else hooks
|
||||
return _make_vit_b_rn50_backbone(
|
||||
model,
|
||||
features=[256, 512, 768, 768],
|
||||
size=[384, 384],
|
||||
hooks=hooks,
|
||||
use_vit_only=use_vit_only,
|
||||
use_readout=use_readout,
|
||||
)
|
||||
@ -207,18 +207,19 @@ class WanT2V:
|
||||
def vace_latent(self, z, m):
|
||||
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
|
||||
|
||||
def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device, trim_video= 0):
|
||||
def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device, original_video = False, trim_video= 0):
|
||||
image_sizes = []
|
||||
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
|
||||
if sub_src_mask is not None and sub_src_video is not None:
|
||||
src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video)
|
||||
# src_video is [-1, 1], 0 = inpainting area (in fact 127 in [0, 255])
|
||||
# src_mask is [-1, 1], 0 = preserve original video (in fact 127 in [0, 255]) and 1 = Inpainting (in fact 255 in [0, 255])
|
||||
src_video[i] = src_video[i].to(device)
|
||||
src_mask[i] = src_mask[i].to(device)
|
||||
src_video_shape = src_video[i].shape
|
||||
if src_video_shape[1] != num_frames:
|
||||
src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
|
||||
src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
|
||||
|
||||
src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
|
||||
image_sizes.append(src_video[i].shape[2:])
|
||||
elif sub_src_video is None:
|
||||
@ -228,10 +229,11 @@ class WanT2V:
|
||||
else:
|
||||
src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video)
|
||||
src_video[i] = src_video[i].to(device)
|
||||
src_mask[i] = torch.zeros_like(src_video[i], device=device) if original_video else torch.ones_like(src_video[i], device=device)
|
||||
src_video_shape = src_video[i].shape
|
||||
if src_video_shape[1] != num_frames:
|
||||
src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
|
||||
src_mask[i] = torch.ones_like(src_video[i], device=device)
|
||||
src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
|
||||
image_sizes.append(src_video[i].shape[2:])
|
||||
|
||||
for i, ref_images in enumerate(src_ref_images):
|
||||
|
||||
@ -21,6 +21,30 @@ __all__ = ['cache_video', 'cache_image', 'str2bool']
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def resample(video_fps, video_frames_count, max_frames, target_fps):
|
||||
import math
|
||||
|
||||
video_frame_duration = 1 /video_fps
|
||||
target_frame_duration = 1 / target_fps
|
||||
|
||||
cur_time = 0
|
||||
target_time = 0
|
||||
frame_no = 0
|
||||
frame_ids =[]
|
||||
while True:
|
||||
if max_frames != 0 and len(frame_ids) >= max_frames:
|
||||
break
|
||||
add_frames_count = math.ceil( (target_time -cur_time) / video_frame_duration )
|
||||
frame_no += add_frames_count
|
||||
frame_ids.append(frame_no)
|
||||
cur_time += add_frames_count * video_frame_duration
|
||||
target_time += target_frame_duration
|
||||
if frame_no >= video_frames_count -1:
|
||||
break
|
||||
frame_ids = frame_ids[:video_frames_count]
|
||||
return frame_ids
|
||||
|
||||
def get_video_frame(file_name, frame_no):
|
||||
decord.bridge.set_bridge('torch')
|
||||
reader = decord.VideoReader(file_name)
|
||||
|
||||
@ -180,26 +180,17 @@ class VaceVideoProcessor(object):
|
||||
), axis=1).tolist()
|
||||
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
|
||||
|
||||
def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w, crop_box, rng, max_frames= 0):
|
||||
import math
|
||||
target_fps = self.max_fps
|
||||
video_frames_count = len(frame_timestamps)
|
||||
video_frame_duration = 1 /fps
|
||||
target_frame_duration = 1 / target_fps
|
||||
|
||||
cur_time = 0
|
||||
target_time = 0
|
||||
frame_no = 0
|
||||
frame_ids =[]
|
||||
for i in range(max_frames):
|
||||
add_frames_count = math.ceil( (target_time -cur_time) / video_frame_duration )
|
||||
frame_no += add_frames_count
|
||||
frame_ids.append(frame_no)
|
||||
cur_time += add_frames_count * video_frame_duration
|
||||
target_time += target_frame_duration
|
||||
if frame_no >= video_frames_count -1:
|
||||
break
|
||||
frame_ids = frame_ids[:video_frames_count]
|
||||
|
||||
def _get_frameid_bbox_adjust_last(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0):
|
||||
from wan.utils.utils import resample
|
||||
|
||||
target_fps = self.max_fps
|
||||
|
||||
# video_frames_count = len(frame_timestamps)
|
||||
|
||||
frame_ids= resample(fps, video_frames_count, max_frames, target_fps)
|
||||
|
||||
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
|
||||
h, w = y2 - y1, x2 - x1
|
||||
ratio = h / w
|
||||
@ -235,11 +226,11 @@ class VaceVideoProcessor(object):
|
||||
|
||||
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
|
||||
|
||||
def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng, max_frames= 0):
|
||||
def _get_frameid_bbox(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0):
|
||||
if self.keep_last:
|
||||
return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h, w, crop_box, rng, max_frames= max_frames)
|
||||
return self._get_frameid_bbox_adjust_last(fps, video_frames_count, h, w, crop_box, rng, max_frames= max_frames)
|
||||
else:
|
||||
return self._get_frameid_bbox_default(fps, frame_timestamps, h, w, crop_box, rng, max_frames= max_frames)
|
||||
return self._get_frameid_bbox_default(fps, video_frames_count, h, w, crop_box, rng, max_frames= max_frames)
|
||||
|
||||
def load_video(self, data_key, crop_box=None, seed=2024, **kwargs):
|
||||
return self.load_video_batch(data_key, crop_box=crop_box, seed=seed, **kwargs)
|
||||
@ -253,23 +244,37 @@ class VaceVideoProcessor(object):
|
||||
import decord
|
||||
decord.bridge.set_bridge('torch')
|
||||
readers = []
|
||||
src_video = None
|
||||
for data_k in data_key_batch:
|
||||
reader = decord.VideoReader(data_k)
|
||||
readers.append(reader)
|
||||
if torch.is_tensor(data_k):
|
||||
src_video = data_k
|
||||
else:
|
||||
reader = decord.VideoReader(data_k)
|
||||
readers.append(reader)
|
||||
|
||||
fps = readers[0].get_avg_fps()
|
||||
length = min([len(r) for r in readers])
|
||||
frame_timestamps = [readers[0].get_frame_timestamp(i) for i in range(length)]
|
||||
frame_timestamps = np.array(frame_timestamps, dtype=np.float32)
|
||||
# # frame_timestamps = frame_timestamps[ :max_frames]
|
||||
# if trim_video > 0:
|
||||
# frame_timestamps = frame_timestamps[ :trim_video]
|
||||
if src_video != None:
|
||||
fps = 16
|
||||
length = src_video.shape[1]
|
||||
if len(readers) > 0:
|
||||
min_readers = min([len(r) for r in readers])
|
||||
length = min(length, min_readers )
|
||||
else:
|
||||
fps = readers[0].get_avg_fps()
|
||||
length = min([len(r) for r in readers])
|
||||
# frame_timestamps = [readers[0].get_frame_timestamp(i) for i in range(length)]
|
||||
# frame_timestamps = np.array(frame_timestamps, dtype=np.float32)
|
||||
max_frames = min(max_frames, trim_video) if trim_video > 0 else max_frames
|
||||
h, w = readers[0].next().shape[:2]
|
||||
frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, frame_timestamps, h, w, crop_box, rng, max_frames=max_frames)
|
||||
if src_video != None:
|
||||
src_video = src_video[:max_frames]
|
||||
h, w = src_video.shape[1:3]
|
||||
else:
|
||||
h, w = readers[0].next().shape[:2]
|
||||
frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, length, h, w, crop_box, rng, max_frames=max_frames)
|
||||
|
||||
# preprocess video
|
||||
videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers]
|
||||
if src_video != None:
|
||||
videos = [src_video] + videos
|
||||
videos = [self._video_preprocess(video, oh, ow) for video in videos]
|
||||
return *videos, frame_ids, (oh, ow), fps
|
||||
# return videos if len(videos) > 1 else videos[0]
|
||||
|
||||
263
wgp.py
263
wgp.py
@ -141,12 +141,27 @@ def process_prompt_and_add_tasks(state, model_choice):
|
||||
res = VACE_SIZE_CONFIGS.keys().join(" and ")
|
||||
gr.Info(f"Video Resolution for Vace model is not supported. Only {res} resolutions are allowed.")
|
||||
return
|
||||
if not "I" in video_prompt_type:
|
||||
if "I" in video_prompt_type:
|
||||
if image_refs == None:
|
||||
gr.Info("You must provide at one Refererence Image")
|
||||
return
|
||||
else:
|
||||
image_refs = None
|
||||
if not "V" in video_prompt_type:
|
||||
if "V" in video_prompt_type:
|
||||
if video_guide == None:
|
||||
gr.Info("You must provide a Control Video")
|
||||
return
|
||||
else:
|
||||
video_guide = None
|
||||
if not "M" in video_prompt_type:
|
||||
if "M" in video_prompt_type:
|
||||
if video_mask == None:
|
||||
gr.Info("You must provide a Video Mask ")
|
||||
return
|
||||
else:
|
||||
video_mask = None
|
||||
if "O" in video_prompt_type and inputs["max_frames"]==0:
|
||||
gr.Info(f"In order to extend a video, you need to indicate how many frames you want to reuse in the source video.")
|
||||
return
|
||||
|
||||
if isinstance(image_refs, list):
|
||||
image_refs = [ convert_image(tup[0]) for tup in image_refs ]
|
||||
@ -260,7 +275,7 @@ def add_video_task(**inputs):
|
||||
queue = gen["queue"]
|
||||
task_id += 1
|
||||
current_task_id = task_id
|
||||
inputs_to_query = ["image_start", "image_end", "image_refs", "video_guide", "video_mask"]
|
||||
inputs_to_query = ["image_start", "image_end", "video_guide", "image_refs","video_mask"]
|
||||
start_image_data = None
|
||||
end_image_data = None
|
||||
for name in inputs_to_query:
|
||||
@ -718,7 +733,7 @@ if not Path(server_config_filename).is_file():
|
||||
"transformer_types": [],
|
||||
"transformer_quantization": "int8",
|
||||
"text_encoder_filename" : text_encoder_choices[1],
|
||||
"save_path": os.path.join(os.getcwd(), "gradio_outputs"),
|
||||
"save_path": "outputs", #os.path.join(os.getcwd(),
|
||||
"compile" : "",
|
||||
"metadata_type": "metadata",
|
||||
"default_ui": "t2v",
|
||||
@ -726,7 +741,7 @@ if not Path(server_config_filename).is_file():
|
||||
"clear_file_list" : 0,
|
||||
"vae_config": 0,
|
||||
"profile" : profile_type.LowRAM_LowVRAM,
|
||||
"reload_model": 2 }
|
||||
"preload_model_policy": [] }
|
||||
|
||||
with open(server_config_filename, "w", encoding="utf-8") as writer:
|
||||
writer.write(json.dumps(server_config))
|
||||
@ -860,7 +875,7 @@ if len(args.vae_config) > 0:
|
||||
reload_needed = False
|
||||
default_ui = server_config.get("default_ui", "t2v")
|
||||
save_path = server_config.get("save_path", os.path.join(os.getcwd(), "gradio_outputs"))
|
||||
reload_model = server_config.get("reload_model", 2)
|
||||
preload_model_policy = server_config.get("preload_model_policy", [])
|
||||
|
||||
|
||||
if args.t2v_14B or args.t2v:
|
||||
@ -962,8 +977,8 @@ def download_models(transformer_filename, text_encoder_filename):
|
||||
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
repoId = "DeepBeepMeep/Wan2.1"
|
||||
sourceFolderList = ["xlm-roberta-large", "", ]
|
||||
fileList = [ [], ["Wan2.1_VAE_bf16.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "flownet.pkl" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ]
|
||||
sourceFolderList = ["xlm-roberta-large", "pose", "depth", "", ]
|
||||
fileList = [ [], [],[], ["Wan2.1_VAE_bf16.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "flownet.pkl" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ]
|
||||
targetRoot = "ckpts/"
|
||||
for sourceFolder, files in zip(sourceFolderList,fileList ):
|
||||
if len(files)==0:
|
||||
@ -1166,7 +1181,7 @@ def load_models(model_filename):
|
||||
|
||||
return wan_model, offloadobj, pipe["transformer"]
|
||||
|
||||
if reload_model ==3 or reload_model ==4:
|
||||
if not "P" in preload_model_policy:
|
||||
wan_model, offloadobj, transformer = None, None, None
|
||||
reload_needed = True
|
||||
else:
|
||||
@ -1254,7 +1269,7 @@ def apply_changes( state,
|
||||
quantization_choice,
|
||||
boost_choice = 1,
|
||||
clear_file_list = 0,
|
||||
reload_choice = 1,
|
||||
preload_model_policy_choice = 1,
|
||||
):
|
||||
if args.lock_config:
|
||||
return
|
||||
@ -1272,7 +1287,7 @@ def apply_changes( state,
|
||||
"transformer_quantization" : quantization_choice,
|
||||
"boost" : boost_choice,
|
||||
"clear_file_list" : clear_file_list,
|
||||
"reload_model" : reload_choice,
|
||||
"preload_model_policy" : preload_model_policy_choice,
|
||||
}
|
||||
|
||||
if Path(server_config_filename).is_file():
|
||||
@ -1295,14 +1310,14 @@ def apply_changes( state,
|
||||
if v != v_old:
|
||||
changes.append(k)
|
||||
|
||||
global attention_mode, profile, compile, transformer_filename, text_encoder_filename, vae_config, boost, lora_dir, reload_needed, reload_model, transformer_quantization, transformer_types
|
||||
global attention_mode, profile, compile, transformer_filename, text_encoder_filename, vae_config, boost, lora_dir, reload_needed, preload_model_policy, transformer_quantization, transformer_types
|
||||
attention_mode = server_config["attention_mode"]
|
||||
profile = server_config["profile"]
|
||||
compile = server_config["compile"]
|
||||
text_encoder_filename = server_config["text_encoder_filename"]
|
||||
vae_config = server_config["vae_config"]
|
||||
boost = server_config["boost"]
|
||||
reload_model = server_config["reload_model"]
|
||||
preload_model_policy = server_config["preload_model_policy"]
|
||||
transformer_quantization = server_config["transformer_quantization"]
|
||||
transformer_types = server_config["transformer_types"]
|
||||
transformer_type = get_model_type(transformer_filename)
|
||||
@ -1381,7 +1396,8 @@ def abort_generation(state):
|
||||
|
||||
gen["abort"] = True
|
||||
gen["extra_orders"] = 0
|
||||
wan_model._interrupt= True
|
||||
if wan_model != None:
|
||||
wan_model._interrupt= True
|
||||
msg = "Processing Request to abort Current Generation"
|
||||
gr.Info(msg)
|
||||
return msg, gr.Button(interactive= False)
|
||||
@ -1480,24 +1496,68 @@ def expand_slist(slist, num_inference_steps ):
|
||||
return new_slist
|
||||
def convert_image(image):
|
||||
|
||||
from PIL import ExifTags, ImageOps
|
||||
from PIL import ImageOps
|
||||
from typing import cast
|
||||
|
||||
return cast(Image, ImageOps.exif_transpose(image))
|
||||
# image = image.convert('RGB')
|
||||
# for orientation in ExifTags.TAGS.keys():
|
||||
# if ExifTags.TAGS[orientation]=='Orientation':
|
||||
# break
|
||||
# exif = image.getexif()
|
||||
# return image
|
||||
# if not orientation in exif:
|
||||
# if exif[orientation] == 3:
|
||||
# image=image.rotate(180, expand=True)
|
||||
# elif exif[orientation] == 6:
|
||||
# image=image.rotate(270, expand=True)
|
||||
# elif exif[orientation] == 8:
|
||||
# image=image.rotate(90, expand=True)
|
||||
# return image
|
||||
|
||||
|
||||
def preprocess_video(process_type, height, width, video_in, max_frames):
|
||||
|
||||
from wan.utils.utils import resample
|
||||
|
||||
import decord
|
||||
decord.bridge.set_bridge('torch')
|
||||
reader = decord.VideoReader(video_in)
|
||||
|
||||
fps = reader.get_avg_fps()
|
||||
|
||||
frame_nos = resample(fps, len(reader), max_frames= max_frames, target_fps=16)
|
||||
frames_list = reader.get_batch(frame_nos)
|
||||
frame_height, frame_width, _ = frames_list[0].shape
|
||||
|
||||
scale = ((height * width ) / (frame_height * frame_width))**(1/2)
|
||||
# scale = min(height / frame_height, width / frame_width)
|
||||
|
||||
new_height = (int(frame_height * scale) // 16) * 16
|
||||
new_width = (int(frame_width * scale) // 16) * 16
|
||||
|
||||
processed_frames_list = []
|
||||
for frame in frames_list:
|
||||
frame = Image.fromarray(np.clip(frame.cpu().numpy(), 0, 255).astype(np.uint8))
|
||||
frame = frame.resize((new_width,new_height), resample=Image.Resampling.LANCZOS)
|
||||
processed_frames_list.append(frame)
|
||||
|
||||
if process_type=="pose":
|
||||
from preprocessing.dwpose.pose import PoseBodyFaceVideoAnnotator
|
||||
cfg_dict = {
|
||||
"DETECTION_MODEL": "ckpts/pose/yolox_l.onnx",
|
||||
"POSE_MODEL": "ckpts/pose/dw-ll_ucoco_384.onnx",
|
||||
"RESIZE_SIZE": 1024
|
||||
}
|
||||
anno_ins = PoseBodyFaceVideoAnnotator(cfg_dict)
|
||||
elif process_type=="depth":
|
||||
from preprocessing.midas.depth import DepthVideoAnnotator
|
||||
cfg_dict = {
|
||||
"PRETRAINED_MODEL": "ckpts/depth/dpt_hybrid-midas-501f0c75.pt"
|
||||
}
|
||||
anno_ins = DepthVideoAnnotator(cfg_dict)
|
||||
else:
|
||||
from preprocessing.gray import GrayVideoAnnotator
|
||||
cfg_dict = {}
|
||||
anno_ins = GrayVideoAnnotator(cfg_dict)
|
||||
|
||||
np_frames = anno_ins.forward(processed_frames_list)
|
||||
|
||||
# from preprocessing.dwpose.pose import save_one_video
|
||||
# save_one_video("test.mp4", np_frames, fps=8, quality=8, macro_block_size=None)
|
||||
|
||||
torch_frames = []
|
||||
for np_frame in np_frames:
|
||||
torch_frame = torch.from_numpy(np_frame)
|
||||
torch_frames.append(torch_frame)
|
||||
|
||||
return torch.stack(torch_frames)
|
||||
|
||||
def generate_video(
|
||||
task_id,
|
||||
@ -1551,7 +1611,7 @@ def generate_video(
|
||||
# gr.Info("Unable to generate a Video while a new configuration is being applied.")
|
||||
# return
|
||||
|
||||
if reload_model !=3 and reload_model !=4 :
|
||||
if "P" in preload_model_policy:
|
||||
while wan_model == None:
|
||||
time.sleep(1)
|
||||
|
||||
@ -1681,10 +1741,32 @@ def generate_video(
|
||||
raise gr.Error("Teacache not supported for this model")
|
||||
|
||||
if "Vace" in model_filename:
|
||||
# video_prompt_type = video_prompt_type +"G"
|
||||
if any(process in video_prompt_type for process in ("P", "D", "G")) :
|
||||
prompts_max = gen["prompts_max"]
|
||||
|
||||
status = get_generation_status(prompt_no, prompts_max, 1, 1)
|
||||
preprocess_type = None
|
||||
if "P" in video_prompt_type :
|
||||
progress_args = [0, status + " - Extracting Open Pose Information"]
|
||||
preprocess_type = "pose"
|
||||
elif "D" in video_prompt_type :
|
||||
progress_args = [0, status + " - Extracting Depth Information"]
|
||||
preprocess_type = "depth"
|
||||
elif "G" in video_prompt_type :
|
||||
progress_args = [0, status + " - Extracting Gray Level Information"]
|
||||
preprocess_type = "gray"
|
||||
|
||||
if preprocess_type != None :
|
||||
progress(*progress_args )
|
||||
gen["progress_args"] = progress_args
|
||||
video_guide = preprocess_video(preprocess_type, width=width, height=height,video_in=video_guide, max_frames= video_length)
|
||||
|
||||
src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide],
|
||||
[video_mask],
|
||||
[image_refs],
|
||||
video_length, VACE_SIZE_CONFIGS[resolution_reformated], "cpu",
|
||||
original_video= "O" in video_prompt_type,
|
||||
trim_video=max_frames)
|
||||
else:
|
||||
src_video, src_mask, src_ref_images = None, None, None
|
||||
@ -2539,9 +2621,9 @@ def fill_inputs(state):
|
||||
|
||||
return generate_video_tab(update_form = True, state_dict = state, ui_defaults = ui_defaults)
|
||||
|
||||
def preload_model(state):
|
||||
def preload_model_when_switching(state):
|
||||
global reload_needed, wan_model, offloadobj
|
||||
if reload_model == 1:
|
||||
if "S" in preload_model_policy:
|
||||
model_filename = state["model_filename"]
|
||||
if state["model_filename"] != transformer_filename:
|
||||
wan_model = None
|
||||
@ -2558,7 +2640,7 @@ def preload_model(state):
|
||||
|
||||
def unload_model_if_needed(state):
|
||||
global reload_needed, wan_model, offloadobj
|
||||
if reload_model == 4:
|
||||
if "U" in preload_model_policy:
|
||||
if wan_model != None:
|
||||
wan_model = None
|
||||
if offloadobj is not None:
|
||||
@ -2567,6 +2649,38 @@ def unload_model_if_needed(state):
|
||||
gc.collect()
|
||||
reload_needed= True
|
||||
|
||||
def filter_letters(source_str, letters):
|
||||
ret = ""
|
||||
for letter in letters:
|
||||
if letter in source_str:
|
||||
ret += letter
|
||||
return ret
|
||||
|
||||
def add_to_sequence(source_str, letters):
|
||||
ret = source_str
|
||||
for letter in letters:
|
||||
if not letter in source_str:
|
||||
ret += letter
|
||||
return ret
|
||||
|
||||
def del_in_sequence(source_str, letters):
|
||||
ret = source_str
|
||||
for letter in letters:
|
||||
if letter in source_str:
|
||||
ret = ret.replace(letter, "")
|
||||
return ret
|
||||
|
||||
|
||||
def refresh_video_prompt_type_image_refs(video_prompt_type, video_prompt_type_image_refs):
|
||||
video_prompt_type = add_to_sequence(video_prompt_type, "I") if video_prompt_type_image_refs else del_in_sequence(video_prompt_type, "I")
|
||||
return video_prompt_type, gr.update(visible = video_prompt_type_image_refs),gr.update(visible = video_prompt_type_image_refs)
|
||||
|
||||
def refresh_video_prompt_type_video_guide(video_prompt_type, video_prompt_type_video_guide):
|
||||
video_prompt_type = del_in_sequence(video_prompt_type, "ODPCMV")
|
||||
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide)
|
||||
visible = "V" in video_prompt_type
|
||||
return video_prompt_type, gr.update(visible = visible), gr.update(visible = visible), gr.update(visible= "M" in video_prompt_type )
|
||||
|
||||
|
||||
def generate_video_tab(update_form = False, state_dict = None, ui_defaults = None, model_choice = None, header = None):
|
||||
global inputs_names #, advanced
|
||||
@ -2676,19 +2790,36 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
image_end = gr.Image(label= "Last Image for a new video", type ="pil", visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None))
|
||||
|
||||
with gr.Column(visible= "Vace" in model_filename ) as video_prompt_column:
|
||||
gr.Markdown("<B>Control conditions: Images References (custom Faces or Objects), Video (Open Pose, Depth maps), Mask (inpainting)")
|
||||
video_prompt_type_value= ui_defaults.get("video_prompt_type","I")
|
||||
video_prompt_type = gr.Radio( [("Images Ref", "I"),("a Video", "V"), ("Images Refs + a Video", "IV"), ("Video + Video Mask", "VM"), ("Images + Video + Mask", "IVM")], value =video_prompt_type_value, label="Location", show_label= False, scale= 3)
|
||||
image_refs = gr.Gallery(
|
||||
label="Images Referencse (Custom faces and Objects to be found in the Video)", type ="pil",
|
||||
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in video_prompt_type_value, value= ui_defaults.get("image_refs", None) )
|
||||
video_prompt_type_value= ui_defaults.get("video_prompt_type","")
|
||||
video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False)
|
||||
video_prompt_type_video_guide = gr.Dropdown(
|
||||
choices=[
|
||||
("None, use only the Text Prompt", ""),
|
||||
("Extend the Control Video", "OV"),
|
||||
("Transfer Human Motion from the Control Video", "PV"),
|
||||
("Transfer Depth from the Control Video", "DV"),
|
||||
("Recolorize the Control Video", "CV"),
|
||||
("Control Video contains Open Pose, Depth or Black & White ", "V"),
|
||||
("Inpainting of Control Video using Mask Video ", "MV"),
|
||||
],
|
||||
value=filter_letters(video_prompt_type_value, "ODPCMV"),
|
||||
label="Video to Video"
|
||||
)
|
||||
video_prompt_type_image_refs = gr.Checkbox(value="I" in video_prompt_type_value , label= "Use References Images (Faces, Objects) to customize New Video", scale =1 )
|
||||
|
||||
video_guide = gr.Video(label= "Reference Video (an animated Video in the Open Pose format or Depth Map video)", visible= "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None) )
|
||||
with gr.Row():
|
||||
max_frames = gr.Slider(0, 100, value=ui_defaults.get("max_frames",0), step=1, label="Nb of frames in Ref. Video (0 = as many as possible)", visible= "V" in video_prompt_type_value, scale = 2 )
|
||||
remove_background_image_ref = gr.Checkbox(value=ui_defaults.get("remove_background_image_ref",1), label= "Remove Images Ref. Background", visible= "I" in video_prompt_type_value, scale =1 )
|
||||
video_guide = gr.Video(label= "Control Video", visible= "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None),)
|
||||
max_frames = gr.Slider(0, 100, value=ui_defaults.get("max_frames",0), step=1, label="Nb of frames in Control Video to use (0 = max)", visible= "V" in video_prompt_type_value, scale = 2 )
|
||||
|
||||
video_mask = gr.Video(label= "Video Mask (for Inpainting or Outpaing, white pixels = Mask)", visible= "M" in video_prompt_type_value, value= ui_defaults.get("video_mask", None) )
|
||||
image_refs = gr.Gallery( label ="Reference Images",
|
||||
type ="pil", show_label= True,
|
||||
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in video_prompt_type_value,
|
||||
value= ui_defaults.get("image_refs", None) )
|
||||
|
||||
# with gr.Row():
|
||||
remove_background_image_ref = gr.Checkbox(value=ui_defaults.get("remove_background_image_ref",1), label= "Remove Background of Images References", visible= "I" in video_prompt_type_value, scale =1 )
|
||||
|
||||
|
||||
video_mask = gr.Video(label= "Video Mask (for Inpainting or Outpaing, white pixels = Mask)", visible= "M" in video_prompt_type_value, value= ui_defaults.get("video_mask", None))
|
||||
|
||||
|
||||
advanced_prompt = advanced_ui
|
||||
@ -2923,7 +3054,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
target_settings = gr.Text(value = "settings", interactive= False, visible= False)
|
||||
|
||||
image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start, image_end])
|
||||
video_prompt_type.change(fn=refresh_video_prompt_type, inputs=[state, video_prompt_type], outputs=[image_refs, video_guide, video_mask, max_frames, remove_background_image_ref])
|
||||
# video_prompt_type.change(fn=refresh_video_prompt_type, inputs=[state, video_prompt_type], outputs=[image_refs, video_guide, video_mask, max_frames, remove_background_image_ref])
|
||||
video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_image_ref ])
|
||||
video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [video_prompt_type, video_prompt_type_video_guide], outputs = [video_prompt_type, video_guide, max_frames, video_mask])
|
||||
|
||||
show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
|
||||
fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars])
|
||||
queue_df.select( fn=handle_celll_selection, inputs=state, outputs=[queue_df, modal_image_display, modal_container])
|
||||
@ -2967,7 +3101,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
).then(fn= fill_inputs,
|
||||
inputs=[state],
|
||||
outputs=gen_inputs + extra_inputs
|
||||
).then(fn= preload_model,
|
||||
).then(fn= preload_model_when_switching,
|
||||
inputs=[state],
|
||||
outputs=[gen_status])
|
||||
|
||||
@ -3149,14 +3283,8 @@ def generate_configuration_tab(header, model_choice):
|
||||
value=server_config.get("metadata_type", "metadata"),
|
||||
label="Metadata Handling"
|
||||
)
|
||||
reload_choice = gr.Dropdown(
|
||||
choices=[
|
||||
("Load Model When Starting the App and Changing Model if Model Changed", 1),
|
||||
("Load Model When Starting the App and Pressing Generate if Model Changed", 2),
|
||||
("Load Model When Pressing Generate if Model Changed", 3),
|
||||
("Load Model When Pressing Generate and Unload Model when Finished", 4),
|
||||
],
|
||||
value=server_config.get("reload_model",2),
|
||||
preload_model_policy_choice = gr.CheckboxGroup([("Preload Model while Launching the App","P"), ("Preload Model while Switching Model", "S"), ("Unload Model when Queue is Done", "U")],
|
||||
value=server_config.get("preload_model_policy",[]),
|
||||
label="RAM Loading / Unloading Model Policy (in any case VRAM will be freed once the queue has been processed)"
|
||||
)
|
||||
|
||||
@ -3191,7 +3319,7 @@ def generate_configuration_tab(header, model_choice):
|
||||
quantization_choice,
|
||||
boost_choice,
|
||||
clear_file_list_choice,
|
||||
reload_choice,
|
||||
preload_model_policy_choice,
|
||||
],
|
||||
outputs= [msg , header, model_choice]
|
||||
)
|
||||
@ -3201,10 +3329,16 @@ def generate_about_tab():
|
||||
gr.Markdown("Original Wan 2.1 Model by <B>Alibaba</B> (<A HREF='https://github.com/Wan-Video/Wan2.1'>GitHub</A>)")
|
||||
gr.Markdown("Many thanks to:")
|
||||
gr.Markdown("- <B>Alibaba Wan team for the best open source video generator")
|
||||
gr.Markdown("- <B>Alibaba Vace and Fun Teams for their incredible control net models")
|
||||
gr.Markdown("- <B>Cocktail Peanuts</B> : QA and simple installation via Pinokio.computer")
|
||||
gr.Markdown("- <B>Tophness</B> : created multi tabs and queuing frameworks")
|
||||
gr.Markdown("- <B>AmericanPresidentJimmyCarter</B> : added original support for Skip Layer Guidance")
|
||||
gr.Markdown("- <B>Remade_AI</B> : for creating their awesome Loras collection")
|
||||
gr.Markdown("- <B>Remade_AI</B> : for their awesome Loras collection")
|
||||
gr.Markdown("<BR>Huge acknowlegments to these great open source projects used in WanGP:")
|
||||
gr.Markdown("- <B>Rife</B>: temporal upsampler (https://github.com/hzwer/ECCV2022-RIFE)")
|
||||
gr.Markdown("- <B>DwPose</B>: Open Pose extractor (https://github.com/IDEA-Research/DWPose)")
|
||||
gr.Markdown("- <B>Midas</B>: Depth extractor (https://github.com/isl-org/MiDaS")
|
||||
|
||||
|
||||
def generate_info_tab():
|
||||
gr.Markdown("<FONT SIZE=3>Welcome to WanGP a super fast and low VRAM AI Video Generator !</FONT>")
|
||||
@ -3231,17 +3365,26 @@ def generate_dropdown_model_list():
|
||||
choices= dropdown_choices,
|
||||
value= current_model_type,
|
||||
show_label= False,
|
||||
scale= 2
|
||||
scale= 2,
|
||||
elem_id="model_list",
|
||||
elem_classes="model_list_class",
|
||||
)
|
||||
|
||||
|
||||
|
||||
def create_demo():
|
||||
css = """
|
||||
#model_list{
|
||||
background-color:black;
|
||||
padding:1px}
|
||||
|
||||
#model_list input {
|
||||
font-size:25px}
|
||||
|
||||
.title-with-lines {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin: 30px 0;
|
||||
margin: 25px 0;
|
||||
}
|
||||
.line {
|
||||
flex-grow: 1;
|
||||
@ -3462,7 +3605,7 @@ def create_demo():
|
||||
pointer-events: none;
|
||||
}
|
||||
"""
|
||||
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md"), title= "Wan2GP") as demo:
|
||||
with gr.Blocks(css=css, theme=gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md"), title= "Wan2GP") as demo:
|
||||
gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.0 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
|
||||
global model_list
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user