Vace powercharged

This commit is contained in:
DeepBeepMeep 2025-06-17 23:45:47 +02:00
parent 826cc3adb7
commit febeb95767
18 changed files with 1945 additions and 605 deletions

View File

@ -20,6 +20,18 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
## 🔥 Latest Updates
### June 17 2025: WanGP v6.1, Vace Powercharged
Lots of improvements for Vace the Mother of all Models:
- masks can now be combined with on the fly processing of a control video, for instance you can extract the motion of a specific person defined by a mask
- on the fly modification of masks : reversed masks (with the same mask you can modify the background instead of the people covered by the masks), enlarged masks (you can cover more area if for instance the person you are trying to inject is larger than the one in the mask), ...
- view these modified masks directly inside WanGP during the video generation to check they are really as expected
- multiple frames injections: multiples frames can be injected at any location of the video
- expand past videos in on click: just select one generated video to expand it
Of course all these new stuff work on all Vace finetunes (including Vace Fusionix).
Thanks also to Reevoy24 for adding a Notfication sound at the end of a generation and for fixing the background color of the current generation summary.
### June 12 2025: WanGP v6.0
👋 *Finetune models*: You find the 20 models supported by WanGP not sufficient ? Too impatient to wait for the next release to get the support for a newly released model ? Your prayers have been answered: if a new model is compatible with a model architecture supported by WanGP, you can add by yourself the support for this model in WanGP by just creating a finetune model definition. You can then store this model in the cloud (for instance in Huggingface) and the very light finetune definition file can be easily shared with other users. WanGP will download automatically the finetuned model for them.

View File

@ -75,10 +75,10 @@ If you launch the app with the *--save-quantized* switch, WanGP will create a qu
2) Launch WanGP *python wgp.py --save-quantized*
3) In the configuration menu *Transformer Data Type* property choose either *BF16* of *FP16*
4) Launch a video generation (settings used do not matter). As soon as the model is loaded, a new quantized model will be created in the **ckpts** subfolder if it doesn't already exist.
5) To test that this works properly set the local path in the "URLs" key of the finetune definition file. For instance *URLs = ["ckpts/finetune_quanto_fp16_int8.safetensors"]*
5) WanGP will update automatically the finetune definition file with the local path of the newly created quantized file (the list "URLs" will have an extra value such as *"ckpts/finetune_quanto_fp16_int8.safetensors"*
6) Remove *--save-quantized*, restart WanGP and select *Scaled Int8 Quantization* in the *Transformer Model Quantization* property
7) Launch a new generation and verify in the terminal window that the right quantized model is loaded
8) In order to share the finetune definition file you will need to store the fine model weights in the cloud. You can upload them for instance on *Huggingface*. You can now replace in the definition file the local path by a URL (on Huggingface to get the URL of the model file click *Copy download link* when accessing the model properties)
8) In order to share the finetune definition file you will need to store the fine model weights in the cloud. You can upload them for instance on *Huggingface*. You can now replace in the finetune definition file the local path by a URL (on Huggingface to get the URL of the model file click *Copy download link* when accessing the model properties)
You need to create a quantized model specifically for *bf16* or *fp16* as they can not converted on the fly. However there is no need for a non quantized model as they can be converted on the fly while being loaded.

View File

@ -1,17 +1,19 @@
{
"model":
{
"name": "Vace FusioniX 14B",
"architecture" : "vace_14B",
"modules" : ["vace_14B"],
"description": "Vace control model enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_fp16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_fp16_int8.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_bf16_int8.safetensors"
],
"auto_quantize": true
},
"model": {
"name": "Vace FusioniX 14B",
"architecture": "vace_14B",
"modules": [
"vace_14B"
],
"description": "Vace control model enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_fp16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_bf16_int8.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_fp16_int8.safetensors",
"ckpts/Wan14BT2VFusioniX_quanto_fp16_int8.safetensors"
],
"auto_quantize": true
},
"negative_prompt": "",
"prompt": "",
"resolution": "832x480",

View File

@ -315,7 +315,7 @@ class Inference(object):
@classmethod
def from_pretrained(cls, model_filepath, base_model_type, text_encoder_filepath, dtype = torch.bfloat16, VAE_dtype = torch.float16, mixed_precision_transformer =torch.bfloat16 , quantizeTransformer = False, save_quantized = False, **kwargs):
def from_pretrained(cls, model_filepath, model_type, base_model_type, text_encoder_filepath, dtype = torch.bfloat16, VAE_dtype = torch.float16, mixed_precision_transformer =torch.bfloat16 , quantizeTransformer = False, save_quantized = False, **kwargs):
device = "cuda"
@ -392,8 +392,8 @@ class Inference(object):
# offload.save_model(model, "hunyuan_video_avatar_edit_720_bf16.safetensors")
# offload.save_model(model, "hunyuan_video_avatar_edit_720_quanto_bf16_int8.safetensors", do_quantize= True)
if save_quantized:
from wan.utils.utils import save_quantized_model
save_quantized_model(model, filepath, dtype, None)
from wgp import save_quantized_model
save_quantized_model(model, model_type, filepath, dtype, None)
model.mixed_precision = mixed_precision_transformer

View File

@ -287,13 +287,11 @@ class LTXV:
height, width = input_video.shape[-2:]
else:
if image_start != None:
image_start = image_start[0]
frame_width, frame_height = image_start.size
height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas, 32)
conditioning_media_paths.append(image_start)
conditioning_start_frames.append(0)
if image_end != None:
image_end = image_end[0]
conditioning_media_paths.append(image_end)
conditioning_start_frames.append(frame_num-1)

View File

@ -4,7 +4,7 @@ import math
import numpy as np
import matplotlib
import cv2
matplotlib.use('TkAgg')
# matplotlib.use('TkAgg')
eps = 0.01

View File

@ -354,21 +354,29 @@ def video_matting(video_state, end_slider, matting_type, interactive_state, mask
foreground, alpha = matanyone(matanyone_processor, following_frames, template_mask*255, r_erode=erode_kernel_size, r_dilate=dilate_kernel_size)
output_frames = []
foreground_mat = matting_type == "Foreground"
for frame_origin, frame_alpha in zip(following_frames, alpha):
if foreground_mat:
frame_alpha[frame_alpha > 127] = 255
frame_alpha[frame_alpha <= 127] = 0
else:
if not foreground_mat:
new_alpha = []
for frame_alpha in alpha:
frame_temp = frame_alpha.copy()
frame_alpha[frame_temp > 127] = 0
frame_alpha[frame_temp <= 127] = 255
new_alpha.append(frame_alpha)
alpha = new_alpha
# for frame_origin, frame_alpha in zip(following_frames, alpha):
# if foreground_mat:
# frame_alpha[frame_alpha > 127] = 255
# frame_alpha[frame_alpha <= 127] = 0
# else:
# frame_temp = frame_alpha.copy()
# frame_alpha[frame_temp > 127] = 0
# frame_alpha[frame_temp <= 127] = 255
output_frame = np.bitwise_and(frame_origin, 255-frame_alpha)
frame_grey = frame_alpha.copy()
frame_grey[frame_alpha == 255] = 127
output_frame += frame_grey
output_frames.append(output_frame)
foreground = output_frames
# output_frame = np.bitwise_and(frame_origin, 255-frame_alpha)
# frame_grey = frame_alpha.copy()
# frame_grey[frame_alpha == 255] = 127
# output_frame += frame_grey
# output_frames.append(output_frame)
foreground = following_frames
if not os.path.exists("mask_outputs"):
os.makedirs("mask_outputs")
@ -465,6 +473,7 @@ def load_unload_models(selected):
global model
global matanyone_model
if selected:
# print("Matanyone Tab Selected")
if model_loaded:
model.samcontroler.sam_controler.model.to(arg_device)
matanyone_model.to(arg_device)
@ -494,6 +503,7 @@ def load_unload_models(selected):
matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg)
model_loaded = True
else:
# print("Matanyone Tab UnSelected")
import gc
model.samcontroler.sam_controler.model.to("cpu")
matanyone_model.to("cpu")
@ -520,7 +530,7 @@ def export_image(image_refs, image_output):
def export_to_current_video_engine(model_type, foreground_video_output, alpha_video_output):
gr.Info("Masked Video Input and Full Mask transferred to Current Video Engine For Inpainting")
# return "MV#" + str(time.time()), foreground_video_output, alpha_video_output
if "custom_edit" in model_type:
if "custom_edit" in model_type and False:
return gr.update(), alpha_video_output
else:
return foreground_video_output, alpha_video_output

View File

@ -29,7 +29,7 @@ timm
segment-anything
omegaconf
hydra-core
librosa
librosa==0.11.0
loguru
sentencepiece
av

View File

@ -29,6 +29,7 @@ class DTT2V:
checkpoint_dir,
rank=0,
model_filename = None,
model_type = None,
base_model_type = None,
save_quantized = False,
text_encoder_filename = None,
@ -77,8 +78,8 @@ class DTT2V:
self.model.eval().requires_grad_(False)
if save_quantized:
from wan.utils.utils import save_quantized_model
save_quantized_model(self.model, model_filename[0], dtype, base_config_file)
from wgp import save_quantized_model
save_quantized_model(self.model, model_type, model_filename[0], dtype, base_config_file)
self.scheduler = FlowUniPCMultistepScheduler()

View File

@ -49,6 +49,7 @@ class WanI2V:
config,
checkpoint_dir,
model_filename = None,
model_type = None,
base_model_type= None,
text_encoder_filename= None,
quantizeTransformer = False,
@ -115,8 +116,8 @@ class WanI2V:
# offload.save_model(self.model, "wan2.1_Fun_InP_1.3B_bf16_bis.safetensors")
self.model.eval().requires_grad_(False)
if save_quantized:
from wan.utils.utils import save_quantized_model
save_quantized_model(self.model, model_filename[0], dtype, base_config_file)
from wgp import save_quantized_model
save_quantized_model(self.model, model_type, model_filename[0], dtype, base_config_file)
self.sample_neg_prompt = config.sample_neg_prompt

150
wan/modules/motion_patch.py Normal file
View File

@ -0,0 +1,150 @@
# Copyright (c) 2024-2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
# Refer to https://github.com/Angtian/VoGE/blob/main/VoGE/Utils.py
def ind_sel(target: torch.Tensor, ind: torch.Tensor, dim: int = 1):
"""
:param target: [... (can be k or 1), n > M, ...]
:param ind: [... (k), M]
:param dim: dim to apply index on
:return: sel_target [... (k), M, ...]
"""
assert (
len(ind.shape) > dim
), "Index must have the target dim, but get dim: %d, ind shape: %s" % (dim, str(ind.shape))
target = target.expand(
*tuple(
[ind.shape[k] if target.shape[k] == 1 else -1 for k in range(dim)]
+ [
-1,
]
* (len(target.shape) - dim)
)
)
ind_pad = ind
if len(target.shape) > dim + 1:
for _ in range(len(target.shape) - (dim + 1)):
ind_pad = ind_pad.unsqueeze(-1)
ind_pad = ind_pad.expand(*(-1,) * (dim + 1), *target.shape[(dim + 1) : :])
return torch.gather(target, dim=dim, index=ind_pad)
def merge_final(vert_attr: torch.Tensor, weight: torch.Tensor, vert_assign: torch.Tensor):
"""
:param vert_attr: [n, d] or [b, n, d] color or feature of each vertex
:param weight: [b(optional), w, h, M] weight of selected vertices
:param vert_assign: [b(optional), w, h, M] selective index
:return:
"""
target_dim = len(vert_assign.shape) - 1
if len(vert_attr.shape) == 2:
assert vert_attr.shape[0] > vert_assign.max()
# [n, d] ind: [b(optional), w, h, M]-> [b(optional), w, h, M, d]
sel_attr = ind_sel(
vert_attr[(None,) * target_dim], vert_assign.type(torch.long), dim=target_dim
)
else:
assert vert_attr.shape[1] > vert_assign.max()
sel_attr = ind_sel(
vert_attr[(slice(None),) + (None,)*(target_dim-1)], vert_assign.type(torch.long), dim=target_dim
)
# [b(optional), w, h, M]
final_attr = torch.sum(sel_attr * weight.unsqueeze(-1), dim=-2)
return final_attr
def patch_motion(
tracks: torch.FloatTensor, # (B, T, N, 4)
vid: torch.FloatTensor, # (C, T, H, W)
temperature: float = 220.0,
training: bool = True,
tail_dropout: float = 0.2,
vae_divide: tuple = (4, 16),
topk: int = 2,
):
with torch.no_grad():
_, T, H, W = vid.shape
N = tracks.shape[2]
_, tracks, visible = torch.split(
tracks, [1, 2, 1], dim=-1
) # (B, T, N, 2) | (B, T, N, 1)
tracks_n = tracks / torch.tensor([W / min(H, W), H / min(H, W)], device=tracks.device)
tracks_n = tracks_n.clamp(-1, 1)
visible = visible.clamp(0, 1)
if tail_dropout > 0 and training:
TT = visible.shape[1]
rrange = torch.arange(TT, device=visible.device, dtype=visible.dtype)[
None, :, None, None
]
rand_nn = torch.rand_like(visible[:, :1])
rand_rr = torch.rand_like(visible[:, :1]) * (TT - 1)
visible = visible * (
(rand_nn > tail_dropout).type_as(visible)
+ (rrange < rand_rr).type_as(visible)
).clamp(0, 1)
xx = torch.linspace(-W / min(H, W), W / min(H, W), W)
yy = torch.linspace(-H / min(H, W), H / min(H, W), H)
grid = torch.stack(torch.meshgrid(yy, xx, indexing="ij")[::-1], dim=-1).to(
tracks.device
)
tracks_pad = tracks[:, 1:]
visible_pad = visible[:, 1:]
visible_align = visible_pad.view(T - 1, 4, *visible_pad.shape[2:]).sum(1)
tracks_align = (tracks_pad * visible_pad).view(T - 1, 4, *tracks_pad.shape[2:]).sum(
1
) / (visible_align + 1e-5)
dist_ = (
(tracks_align[:, None, None] - grid[None, :, :, None]).pow(2).sum(-1)
) # T, H, W, N
weight = torch.exp(-dist_ * temperature) * visible_align.clamp(0, 1).view(
T - 1, 1, 1, N
)
vert_weight, vert_index = torch.topk(
weight, k=min(topk, weight.shape[-1]), dim=-1
)
grid_mode = "bilinear"
point_feature = torch.nn.functional.grid_sample(
vid[vae_divide[0]:].permute(1, 0, 2, 3)[:1],
tracks_n[:, :1].type(vid.dtype),
mode=grid_mode,
padding_mode="zeros",
align_corners=None,
)
point_feature = point_feature.squeeze(0).squeeze(1).permute(1, 0) # N, C=16
out_feature = merge_final(point_feature, vert_weight, vert_index).permute(3, 0, 1, 2) # T - 1, H, W, C => C, T - 1, H, W
out_weight = vert_weight.sum(-1) # T - 1, H, W
# out feature -> already soft weighted
mix_feature = out_feature + vid[vae_divide[0]:, 1:] * (1 - out_weight.clamp(0, 1))
out_feature_full = torch.cat([vid[vae_divide[0]:, :1], mix_feature], dim=1) # C, T, H, W
out_mask_full = torch.cat([torch.ones_like(out_weight[:1]), out_weight], dim=0) # T, H, W
return torch.cat([out_mask_full[None].expand(vae_divide[0], -1, -1, -1), out_feature_full], dim=0)

View File

@ -50,6 +50,7 @@ class WanT2V:
checkpoint_dir,
rank=0,
model_filename = None,
model_type = None,
base_model_type = None,
text_encoder_filename = None,
quantizeTransformer = False,
@ -99,9 +100,9 @@ class WanT2V:
# offload.save_model(self.model, "wan2.1_text2video_14B_mbf16.safetensors", config_file_path=base_config_file)
# offload.save_model(self.model, "wan2.1_text2video_14B_quanto_mfp16_int8.safetensors", do_quantize=True, config_file_path=base_config_file)
self.model.eval().requires_grad_(False)
if save_quantized:
from wan.utils.utils import save_quantized_model
save_quantized_model(self.model, model_filename[1 if base_model_type=="fantasy" else 0], dtype, base_config_file)
if save_quantized:
from wgp import save_quantized_model
save_quantized_model(self.model, model_type, model_filename[1 if base_model_type=="fantasy" else 0], dtype, base_config_file)
self.sample_neg_prompt = config.sample_neg_prompt
@ -186,7 +187,25 @@ class WanT2V:
def vace_latent(self, z, m):
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, original_video = False, keep_frames= [], start_frame = 0, fit_into_canvas = True, pre_src_video = None):
def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device):
ref_width, ref_height = ref_img.size
if (ref_height, ref_width) == image_size:
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
else:
canvas_height, canvas_width = image_size
scale = min(canvas_height / ref_height, canvas_width / ref_width)
new_height = int(ref_height * scale)
new_width = int(ref_width * scale)
white_canvas = torch.full((3, 1, canvas_height, canvas_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1]
ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
top = (canvas_height - new_height) // 2
left = (canvas_width - new_width) // 2
white_canvas[:, :, top:top + new_height, left:left + new_width] = ref_img
ref_img = white_canvas
return ref_img.to(device)
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, original_video = False, keep_frames= [], start_frame = 0, fit_into_canvas = None, pre_src_video = None, inject_frames = []):
image_sizes = []
trim_video = len(keep_frames)
canvas_height, canvas_width = image_size
@ -234,25 +253,18 @@ class WanT2V:
src_video[i][:, k:k+1] = 0
src_mask[i][:, k:k+1] = 1
for k, frame in enumerate(inject_frames):
if frame != None:
src_video[i][:, k:k+1] = self.fit_image_into_canvas(frame, image_size, 0, device)
src_mask[i][:, k:k+1] = 0
for i, ref_images in enumerate(src_ref_images):
if ref_images is not None:
image_size = image_sizes[i]
for j, ref_img in enumerate(ref_images):
if ref_img is not None:
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
if ref_img.shape[-2:] != image_size:
canvas_height, canvas_width = image_size
ref_height, ref_width = ref_img.shape[-2:]
white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1]
scale = min(canvas_height / ref_height, canvas_width / ref_width)
new_height = int(ref_height * scale)
new_width = int(ref_width * scale)
resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1)
top = (canvas_height - new_height) // 2
left = (canvas_width - new_width) // 2
white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image
ref_img = white_canvas
src_ref_images[i][j] = ref_img.to(device)
src_ref_images[i][j] = self.fit_image_into_canvas(ref_img, image_size, 1, device)
return src_video, src_mask, src_ref_images
def decode_latent(self, zs, ref_images=None, tile_size= 0 ):

View File

@ -0,0 +1,209 @@
# Copyright (c) 2024-2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import matplotlib.pyplot as plt
from flask import Flask, request, jsonify, render_template
import os
import io
import numpy as np
import torch
import yaml
import matplotlib
import argparse
matplotlib.use('Agg')
app = Flask(__name__, static_folder='static', template_folder='templates')
# ——— Arguments ———————————————————————————————————
parser = argparse.ArgumentParser()
parser.add_argument('--save_dir', type=str, default='videos_example')
args = parser.parse_args()
# ——— Configuration —————————————————————————————
BASE_DIR = args.save_dir
STATIC_BASE = os.path.join('static', BASE_DIR)
IMAGES_DIR = os.path.join(STATIC_BASE, 'images')
OVERLAY_DIR = os.path.join(STATIC_BASE, 'images_tracks')
TRACKS_DIR = os.path.join(BASE_DIR, 'tracks')
YAML_PATH = os.path.join(BASE_DIR, 'test.yaml')
IMAGES_DIR_OUT = os.path.join(BASE_DIR, 'images')
FIXED_LENGTH = 121
COLOR_CYCLE = ['r', 'g', 'b', 'c', 'm', 'y', 'k']
QUANT_MULTI = 8
for d in (IMAGES_DIR, TRACKS_DIR, OVERLAY_DIR, IMAGES_DIR_OUT):
os.makedirs(d, exist_ok=True)
# ——— Helpers ———————————————————————————————————————
def array_to_npz_bytes(arr, path, compressed=True, quant_multi=QUANT_MULTI):
# pack into uint16 as before
arr_q = (quant_multi * arr).astype(np.float32)
bio = io.BytesIO()
if compressed:
np.savez_compressed(bio, array=arr_q)
else:
np.savez(bio, array=arr_q)
torch.save(bio.getvalue(), path)
def load_existing_tracks(path):
raw = torch.load(path)
bio = io.BytesIO(raw)
with np.load(bio) as npz:
return npz['array']
# ——— Routes ———————————————————————————————————————
@app.route('/')
def index():
return render_template('index.html')
@app.route('/upload_image', methods=['POST'])
def upload_image():
f = request.files['image']
from PIL import Image
img = Image.open(f.stream)
orig_w, orig_h = img.size
idx = len(os.listdir(IMAGES_DIR)) + 1
ext = f.filename.rsplit('.', 1)[-1]
fname = f"{idx:02d}.{ext}"
img.save(os.path.join(IMAGES_DIR, fname))
img.save(os.path.join(IMAGES_DIR_OUT, fname))
return jsonify({
'image_url': f"{STATIC_BASE}/images/{fname}",
'image_id': idx,
'ext': ext,
'orig_width': orig_w,
'orig_height': orig_h
})
@app.route('/store_tracks', methods=['POST'])
def store_tracks():
data = request.get_json()
image_id = data['image_id']
ext = data['ext']
free_tracks = data.get('tracks', [])
circ_trajs = data.get('circle_trajectories', [])
# Debug lengths
for i, tr in enumerate(free_tracks, 1):
print(f"Freehand Track {i}: {len(tr)} points")
for i, tr in enumerate(circ_trajs, 1):
print(f"Circle/Static Traj {i}: {len(tr)} points")
def pad_pts(tr):
"""Convert list of {x,y} to (FIXED_LENGTH,1,3) array, padding/truncating."""
pts = np.array([[p['x'], p['y'], 1] for p in tr], dtype=np.float32)
n = pts.shape[0]
if n < FIXED_LENGTH:
pad = np.zeros((FIXED_LENGTH - n, 3), dtype=np.float32)
pts = np.vstack((pts, pad))
else:
pts = pts[:FIXED_LENGTH]
return pts.reshape(FIXED_LENGTH, 1, 3)
arrs = []
# 1) Freehand tracks
for i, tr in enumerate(free_tracks):
pts = pad_pts(tr)
arrs.append(pts,)
# 2) Circle + Static combined
for i, tr in enumerate(circ_trajs):
pts = pad_pts(tr)
arrs.append(pts)
print(arrs)
# Nothing to save?
if not arrs:
overlay_file = f"{image_id:02d}.png"
return jsonify({
'status': 'ok',
'overlay_url': f"{STATIC_BASE}/images_tracks/{overlay_file}"
})
new_tracks = np.stack(arrs, axis=0) # (T_new, FIXED_LENGTH,1,4)
# Load existing .pth and pad old channels to 4 if needed
track_path = os.path.join(TRACKS_DIR, f"{image_id:02d}.pth")
if os.path.exists(track_path):
# shape (T_old, FIXED_LENGTH,1,3) or (...,4)
old = load_existing_tracks(track_path)
if old.ndim == 4 and old.shape[-1] == 3:
pad = np.zeros(
(old.shape[0], old.shape[1], old.shape[2], 1), dtype=np.float32)
old = np.concatenate((old, pad), axis=-1)
all_tracks = np.concatenate([old, new_tracks], axis=0)
else:
all_tracks = new_tracks
# Save updated track file
array_to_npz_bytes(all_tracks, track_path, compressed=True)
# Build overlay PNG
img_path = os.path.join(IMAGES_DIR, f"{image_id:02d}.{ext}")
img = plt.imread(img_path)
fig, ax = plt.subplots(figsize=(12, 8))
ax.imshow(img)
for t in all_tracks:
coords = t[:, 0, :] # (FIXED_LENGTH,4)
ax.plot(coords[:, 0][coords[:, 2] > 0.5], coords[:, 1]
[coords[:, 2] > 0.5], marker='o', color=COLOR_CYCLE[0])
ax.axis('off')
overlay_file = f"{image_id:02d}.png"
fig.savefig(os.path.join(OVERLAY_DIR, overlay_file),
bbox_inches='tight', pad_inches=0)
plt.close(fig)
# Update YAML (unchanged)
entry = {
"image": os.path.join(f"tools/trajectory_editor/{BASE_DIR}/images/{image_id:02d}.{ext}"),
"text": None,
"track": os.path.join(f"tools/trajectory_editor/{BASE_DIR}/tracks/{image_id:02d}.pth")
}
if os.path.exists(YAML_PATH):
with open(YAML_PATH) as yf:
docs = yaml.safe_load(yf) or []
else:
docs = []
for e in docs:
if e.get("image", "").endswith(f"{image_id:02d}.{ext}"):
e.update(entry)
break
else:
docs.append(entry)
with open(YAML_PATH, 'w') as yf:
yaml.dump(docs, yf, default_flow_style=False)
return jsonify({
'status': 'ok',
'overlay_url': f"{STATIC_BASE}/images_tracks/{overlay_file}"
})
if __name__ == '__main__':
app.run(debug=True)

View File

@ -0,0 +1,571 @@
<!-- Copyright (c) 2024-2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<title>Track Point Editor</title>
<style>
.btn-row {
display: flex;
align-items: center;
margin: 8px 0;
}
.btn-row > * { margin-right: 12px; }
body { font-family: sans-serif; margin: 16px; }
#topControls, #bottomControls { margin-bottom: 12px; }
button, input, select, label { margin: 4px; }
#canvas { border:1px solid #ccc; display: block; margin: auto; }
#canvas { cursor: crosshair; }
#trajProgress { width: 200px; height: 16px; margin-left:12px; }
</style>
</head>
<body>
<h2>Track Point Editor</h2>
<!-- Top controls -->
<div id="topControls" class="btn-row">
<input type="file" id="fileInput" accept="image/*">
<button id="storeBtn">Store Tracks</button>
</div>
<!-- Main drawing canvas -->
<canvas id="canvas"></canvas>
<!-- Track controls -->
<div id="bottomControls">
<div class="btn-row">
<button id="addTrackBtn">Add Freehand Track</button>
<button id="deleteLastBtn">Delete Last Track</button>
<progress id="trajProgress" max="121" value="0" style="display:none;"></progress>
</div>
<div class="btn-row">
<button id="placeCircleBtn">Place Circle</button>
<button id="addCirclePointBtn">Add Circle Point</button>
<label>Radius:
<input type="range" id="radiusSlider" min="10" max="800" value="50" style="display:none;">
</label>
</div>
<div class="btn-row">
<button id="addStaticBtn">Add Static Point</button>
<label>Static Frames:
<input type="number" id="staticFramesInput" value="121" min="1" style="width:60px">
</label>
</div>
<div class="btn-row">
<select id="trackSelect" style="min-width:160px;"></select>
<div id="colorIndicator"
style="
width:16px;
height:16px;
border:1px solid #444;
display:inline-block;
vertical-align:middle;
margin-left:8px;
pointer-events:none;
visibility:hidden;
">
</div>
<button id="deleteTrackBtn">Delete Selected</button>
<button id="editTrackBtn">Edit Track</button>
<button id="duplicateTrackBtn">Duplicate Track</button>
</div>
<!-- Global motion offset -->
<div class="btn-row">
<label>Motion X (px/frame):
<input type="number" id="motionXInput" value="0" style="width:60px">
</label>
<label>Motion Y (px/frame):
<input type="number" id="motionYInput" value="0" style="width:60px">
</label>
<button id="applySelectedMotionBtn">Add to Selected</button>
<button id="applyAllMotionBtn">Add to All</button>
</div>
</div>
<script>
// ——— DOM refs —————————————————————————————————————————
const canvas = document.getElementById('canvas'),
ctx = canvas.getContext('2d'),
fileIn = document.getElementById('fileInput'),
storeBtn = document.getElementById('storeBtn'),
addTrackBtn = document.getElementById('addTrackBtn'),
deleteLastBtn = document.getElementById('deleteLastBtn'),
placeCircleBtn = document.getElementById('placeCircleBtn'),
addCirclePointBtn = document.getElementById('addCirclePointBtn'),
addStaticBtn = document.getElementById('addStaticBtn'),
staticFramesInput = document.getElementById('staticFramesInput'),
radiusSlider = document.getElementById('radiusSlider'),
trackSelect = document.getElementById('trackSelect'),
deleteTrackBtn = document.getElementById('deleteTrackBtn'),
editTrackBtn = document.getElementById('editTrackBtn'),
duplicateTrackBtn = document.getElementById('duplicateTrackBtn'),
trajProg = document.getElementById('trajProgress'),
colorIndicator = document.getElementById('colorIndicator'),
motionXInput = document.getElementById('motionXInput'),
motionYInput = document.getElementById('motionYInput'),
applySelectedMotionBtn = document.getElementById('applySelectedMotionBtn'),
applyAllMotionBtn = document.getElementById('applyAllMotionBtn');
let img, image_id, ext, origW, origH,
scaleX=1, scaleY=1;
// track data
let free_tracks = [], current_track = [], drawing=false, motionCounter=0;
let circle=null, static_trajs=[];
let mode='', selectedTrack=null, editMode=false, editInfo=null, duplicateBuffer=null;
const COLORS=['red','green','blue','cyan','magenta','yellow','black'],
FIXED_LENGTH=121,
editSigma = 5/Math.sqrt(2*Math.log(2));
// ——— Upload & scale image ————————————————————————————
fileIn.addEventListener('change', async e => {
const f = e.target.files[0]; if (!f) return;
const fd = new FormData(); fd.append('image',f);
const res = await fetch('/upload_image',{method:'POST',body:fd});
const js = await res.json();
image_id=js.image_id; ext=js.ext;
origW=js.orig_width; origH=js.orig_height;
if(origW>=origH){
canvas.width=800; canvas.height=Math.round(origH*800/origW);
} else {
canvas.height=800; canvas.width=Math.round(origW*800/origH);
}
scaleX=origW/canvas.width; scaleY=origH/canvas.height;
img=new Image(); img.src=js.image_url;
img.onload=()=>{
free_tracks=[]; current_track=[];
circle=null; static_trajs=[];
mode=selectedTrack=''; editMode=false; editInfo=null; duplicateBuffer=null;
trajProg.style.display='none';
radiusSlider.style.display='none';
trackSelect.innerHTML='';
redraw();
};
});
// ——— Store tracks + depth —————————————————————————
storeBtn.onclick = async () => {
if(!image_id) return alert('Load an image first');
const fh = free_tracks.map(tr=>tr.map(p=>({x:p.x*scaleX,y:p.y*scaleY}))),
ct = (circle?.trajectories||[]).map(tr=>tr.map(p=>({x:p.x*scaleX,y:p.y*scaleY}))),
st = static_trajs.map(tr=>tr.map(p=>({x:p.x*scaleX,y:p.y*scaleY})));
const payload = {
image_id, ext,
tracks: fh,
circle_trajectories: ct.concat(st)
};
const res = await fetch('/store_tracks',{
method:'POST',
headers:{'Content-Type':'application/json'},
body: JSON.stringify(payload)
});
const js = await res.json();
img.src=js.overlay_url;
img.onload=()=>ctx.drawImage(img,0,0,canvas.width,canvas.height);
// reset UI
free_tracks=[]; circle=null; static_trajs=[];
mode=selectedTrack=''; editMode=false; editInfo=null; duplicateBuffer=null;
trajProg.style.display='none';
radiusSlider.style.display='none';
trackSelect.innerHTML='';
redraw();
};
// ——— Control buttons —————————————————————————————
addTrackBtn.onclick = ()=>{
mode='free'; drawing=true; current_track=[]; motionCounter=0;
trajProg.max=FIXED_LENGTH; trajProg.value=0;
trajProg.style.display='inline-block';
};
deleteLastBtn.onclick = ()=>{
if(drawing){
drawing=false; current_track=[]; trajProg.style.display='none';
} else if(free_tracks.length){
free_tracks.pop(); updateTrackSelect(); redraw();
}
updateColorIndicator();
};
placeCircleBtn.onclick = ()=>{ mode='placeCircle'; drawing=false; };
addCirclePointBtn.onclick = ()=>{ if(!circle) alert('Place circle first'); else mode='addCirclePt'; };
addStaticBtn.onclick = ()=>{ mode='placeStatic'; };
duplicateTrackBtn.onclick = ()=>{
if(!selectedTrack) return alert('Select a track first');
const arr = selectedTrack.type==='free'
? free_tracks[selectedTrack.idx]
: selectedTrack.type==='circle'
? circle.trajectories[selectedTrack.idx]
: static_trajs[selectedTrack.idx];
duplicateBuffer = arr.map(p=>({x:p.x,y:p.y}));
mode='duplicate'; canvas.style.cursor='copy';
};
radiusSlider.oninput = ()=>{
if(!circle) return;
circle.radius = +radiusSlider.value;
circle.trajectories.forEach((traj,i)=>{
const θ = circle.angles[i];
traj.push({
x: circle.cx + Math.cos(θ)*circle.radius,
y: circle.cy + Math.sin(θ)*circle.radius
});
});
if(selectedTrack?.type==='circle')
trajProg.value = circle.trajectories[selectedTrack.idx].length;
redraw();
};
deleteTrackBtn.onclick = ()=>{
if(!selectedTrack) return;
const {type,idx} = selectedTrack;
if(type==='free') free_tracks.splice(idx,1);
else if(type==='circle'){
circle.trajectories.splice(idx,1);
circle.angles.splice(idx,1);
} else {
static_trajs.splice(idx,1);
}
selectedTrack=null;
trajProg.style.display='none';
updateTrackSelect();
redraw();
updateColorIndicator();
};
editTrackBtn.onclick = ()=>{
if(!selectedTrack) return alert('Select a track first');
editMode=!editMode;
editTrackBtn.textContent = editMode?'Stop Editing':'Edit Track';
};
// ——— Track select & depth init —————————————————————
function updateTrackSelect(){
trackSelect.innerHTML='';
free_tracks.forEach((_,i)=>{
const o=document.createElement('option');
o.value=JSON.stringify({type:'free',idx:i});
o.textContent=`Point ${i+1}`;
trackSelect.appendChild(o);
});
if(circle){
circle.trajectories.forEach((_,i)=>{
const o=document.createElement('option');
o.value=JSON.stringify({type:'circle',idx:i});
o.textContent=`CirclePt ${i+1}`;
trackSelect.appendChild(o);
});
}
static_trajs.forEach((_,i)=>{
const o=document.createElement('option');
o.value=JSON.stringify({type:'static',idx:i});
o.textContent=`StaticPt ${i+1}`;
trackSelect.appendChild(o);
});
if(trackSelect.options.length){
trackSelect.selectedIndex=0;
trackSelect.onchange();
}
updateColorIndicator();
}
function applyMotionToTrajectory(traj, dx, dy) {
traj.forEach((pt, frameIdx) => {
pt.x += dx * frameIdx;
pt.y += dy * frameIdx;
});
}
applySelectedMotionBtn.onclick = () => {
if (!selectedTrack) {
return alert('Please select a track first');
}
const dx = parseFloat(motionXInput.value) || 0;
const dy = parseFloat(motionYInput.value) || 0;
// pick the underlying array
let arr = null;
if (selectedTrack.type === 'free') {
arr = free_tracks[selectedTrack.idx];
} else if (selectedTrack.type === 'circle') {
arr = circle.trajectories[selectedTrack.idx];
} else { // 'static'
arr = static_trajs[selectedTrack.idx];
}
applyMotionToTrajectory(arr, dx, dy);
redraw();
};
// 2) Add motion to every track on the canvas
applyAllMotionBtn.onclick = () => {
const dx = parseFloat(motionXInput.value) || 0;
const dy = parseFloat(motionYInput.value) || 0;
// freehand tracks
free_tracks.forEach(tr => applyMotionToTrajectory(tr, dx, dy));
// circlebased tracks
if (circle) {
circle.trajectories.forEach(tr => applyMotionToTrajectory(tr, dx, dy));
}
// static points (now will move over frames)
static_trajs.forEach(tr => applyMotionToTrajectory(tr, dx, dy));
redraw();
};
trackSelect.onchange = ()=>{
if(!trackSelect.value){
selectedTrack=null;
trajProg.style.display='none';
return;
}
selectedTrack = JSON.parse(trackSelect.value);
if(selectedTrack.type==='circle'){
trajProg.style.display='inline-block';
trajProg.max=FIXED_LENGTH;
trajProg.value=circle.trajectories[selectedTrack.idx].length;
} else if(selectedTrack.type==='free'){
trajProg.style.display='inline-block';
trajProg.max=FIXED_LENGTH;
trajProg.value=free_tracks[selectedTrack.idx].length;
} else {
trajProg.style.display='none';
}
updateColorIndicator();
};
// ——— Canvas drawing ————————————————————————————————
canvas.addEventListener('mousedown', e=>{
const r=canvas.getBoundingClientRect(),
x=e.clientX-r.left, y=e.clientY-r.top;
// place circle
if(mode==='placeCircle'){
circle={cx:x,cy:y,radius:50,angles:[],trajectories:[]};
radiusSlider.max=Math.min(canvas.width,canvas.height)|0;
radiusSlider.value=50; radiusSlider.style.display='inline';
mode=''; updateTrackSelect(); redraw(); return;
}
// add circle point
if(mode==='addCirclePt'){
const dx=x-circle.cx, dy=y-circle.cy;
const θ=Math.atan2(dy,dx);
const px=circle.cx+Math.cos(θ)*circle.radius;
const py=circle.cy+Math.sin(θ)*circle.radius;
circle.angles.push(θ);
circle.trajectories.push([{x:px,y:py}]);
mode=''; updateTrackSelect(); redraw(); return;
}
// add static
if (mode === 'placeStatic') {
// how many frames to “hold” the point
const len = parseInt(staticFramesInput.value, 10) || FIXED_LENGTH;
// duplicate the clickpoint len times
const traj = Array.from({ length: len }, () => ({ x, y }));
// push into free_tracks so it's drawn & edited just like any freehand curve
free_tracks.push(traj);
// reset state
mode = '';
updateTrackSelect();
redraw();
return;
}
// duplicate
if(mode==='duplicate' && duplicateBuffer){
const orig = duplicateBuffer;
// click defines translation by first point
const dx = x - orig[0].x, dy = y - orig[0].y;
const newTr = orig.map(p=>({x:p.x+dx, y:p.y+dy}));
free_tracks.push(newTr);
mode=''; duplicateBuffer=null; canvas.style.cursor='crosshair';
updateTrackSelect(); redraw(); return;
}
// editing
if(editMode && selectedTrack){
const arr = selectedTrack.type==='free'
? free_tracks[selectedTrack.idx]
: selectedTrack.type==='circle'
? circle.trajectories[selectedTrack.idx]
: static_trajs[selectedTrack.idx];
let best=0,bd=Infinity;
arr.forEach((p,i)=>{
const d=(p.x-x)**2+(p.y-y)**2;
if(d<bd){ bd=d; best=i; }
});
editInfo={ trackType:selectedTrack.type,
trackIdx:selectedTrack.idx,
ptIdx:best,
startX:x, startY:y };
return;
}
// freehand start
if(mode==='free'){
drawing=true; motionCounter=0;
current_track=[{x,y}];
redraw();
}
});
canvas.addEventListener('mousemove', e=>{
const r=canvas.getBoundingClientRect(),
x=e.clientX-r.left, y=e.clientY-r.top;
// edit mode
if(editMode && editInfo){
const dx=x-editInfo.startX,
dy=y-editInfo.startY;
const {trackType,trackIdx,ptIdx} = editInfo;
const arr = trackType==='free'
? free_tracks[trackIdx]
: trackType==='circle'
? circle.trajectories[trackIdx]
: static_trajs[trackIdx];
arr.forEach((p,i)=>{
const d=i-ptIdx;
const w=Math.exp(-0.5*(d*d)/(editSigma*editSigma));
p.x+=dx*w; p.y+=dy*w;
});
editInfo.startX=x; editInfo.startY=y;
if(selectedTrack?.type==='circle')
trajProg.value=circle.trajectories[selectedTrack.idx].length;
redraw(); return;
}
// freehand draw
if(drawing && (e.buttons&1)){
motionCounter++;
if(motionCounter%2===0){
current_track.push({x,y});
trajProg.value = Math.min(current_track.length, trajProg.max);
redraw();
}
}
});
canvas.addEventListener('mouseup', ()=>{
if(editMode && editInfo){ editInfo=null; return; }
if(drawing){
free_tracks.push(current_track.slice());
drawing=false; current_track=[];
updateTrackSelect(); redraw();
}
});
function updateColorIndicator() {
const idx = trackSelect.selectedIndex;
if (idx < 0) {
colorIndicator.style.visibility = 'hidden';
return;
}
// Pick the color by index
const col = COLORS[idx % COLORS.length];
colorIndicator.style.backgroundColor = col;
colorIndicator.style.visibility = 'visible';
}
// ——— redraw ———
function redraw(){
ctx.clearRect(0, 0, canvas.width, canvas.height);
if (img.complete) ctx.drawImage(img, 0, 0, canvas.width, canvas.height);
// set a fatter line for all strokes
ctx.lineWidth = 2;
// — freehand (and staticturnedfreehand) tracks —
free_tracks.forEach((tr, i) => {
const col = COLORS[i % COLORS.length];
ctx.strokeStyle = col;
ctx.fillStyle = col;
if (tr.length === 0) return;
// check if every point equals the first
const allSame = tr.every(p => p.x === tr[0].x && p.y === tr[0].y);
if (allSame) {
// draw a filled circle for a “static” dot
ctx.beginPath();
ctx.arc(tr[0].x, tr[0].y, 4, 0, 2 * Math.PI);
ctx.fill();
} else {
// normal polyline
ctx.beginPath();
tr.forEach((p, j) =>
j ? ctx.lineTo(p.x, p.y) : ctx.moveTo(p.x, p.y)
);
ctx.stroke();
}
});
if(drawing && current_track.length){
ctx.strokeStyle='black';
ctx.beginPath();
current_track.forEach((p,j)=>
j? ctx.lineTo(p.x,p.y): ctx.moveTo(p.x,p.y));
ctx.stroke();
}
// — circle trajectories —
if (circle) {
// circle outline
ctx.strokeStyle = 'white';
ctx.lineWidth = 1;
ctx.beginPath();
ctx.arc(circle.cx, circle.cy, circle.radius, 0, 2 * Math.PI);
ctx.stroke();
circle.trajectories.forEach((tr, i) => {
const col = COLORS[(free_tracks.length + i) % COLORS.length];
ctx.strokeStyle = col;
ctx.fillStyle = col;
ctx.lineWidth = 2;
if (tr.length <= 1) {
// singlepoint circle trajectory → dot
ctx.beginPath();
ctx.arc(tr[0].x, tr[0].y, 4, 0, 2 * Math.PI);
ctx.fill();
} else {
// normal circle track
ctx.beginPath();
tr.forEach((p, j) =>
j ? ctx.lineTo(p.x, p.y) : ctx.moveTo(p.x, p.y)
);
ctx.stroke();
// white handle at last point
const lp = tr[tr.length - 1];
ctx.fillStyle = 'white';
ctx.beginPath();
ctx.arc(lp.x, lp.y, 4, 0, 2 * Math.PI);
ctx.fill();
}
});
}
// — static_trajs (if you still use them separately) —
static_trajs.forEach((tr, i) => {
const p = tr[0];
ctx.fillStyle = 'orange';
ctx.beginPath();
ctx.arc(p.x, p.y, 5, 0, 2 * Math.PI);
ctx.fill();
});
}
</script>
</body>
</html>

74
wan/utils/motion.py Normal file
View File

@ -0,0 +1,74 @@
# Copyright (c) 2024-2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import json
import os, io
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
def get_tracks_inference(tracks, height, width, quant_multi: Optional[int] = 8, **kwargs):
if isinstance(tracks, str):
tracks = torch.load(tracks)
tracks_np = unzip_to_array(tracks)
tracks = process_tracks(
tracks_np, (width, height), quant_multi=quant_multi, **kwargs
)
return tracks
def unzip_to_array(
data: bytes, key: Union[str, List[str]] = "array"
) -> Union[np.ndarray, Dict[str, np.ndarray]]:
bytes_io = io.BytesIO(data)
if isinstance(key, str):
# Load the NPZ data from the BytesIO object
with np.load(bytes_io) as data:
return data[key]
else:
get = {}
with np.load(bytes_io) as data:
for k in key:
get[k] = data[k]
return get
def process_tracks(tracks_np: np.ndarray, frame_size: Tuple[int, int], quant_multi: int = 8, **kwargs):
# tracks: shape [t, h, w, 3] => samples align with 24 fps, model trained with 16 fps.
# frame_size: tuple (W, H)
tracks = torch.from_numpy(tracks_np).float() / quant_multi
if tracks.shape[1] == 121:
tracks = torch.permute(tracks, (1, 0, 2, 3))
tracks, visibles = tracks[..., :2], tracks[..., 2:3]
short_edge = min(*frame_size)
tracks = tracks - torch.tensor([*frame_size]).type_as(tracks) / 2
tracks = tracks / short_edge * 2
visibles = visibles * 2 - 1
trange = torch.linspace(-1, 1, tracks.shape[0]).view(-1, 1, 1, 1).expand(*visibles.shape)
out_ = torch.cat([trange, tracks, visibles], dim=-1).view(121, -1, 4)
out_0 = out_[:1]
out_l = out_[1:] # 121 => 120 | 1
out_l = torch.repeat_interleave(out_l, 2, dim=0)[1::3] # 120 => 240 => 80
return torch.cat([out_0, out_l], dim=0)

View File

@ -78,8 +78,15 @@ def remove_background(img, session=None):
img = remove(img, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0)
def save_image(tensor_image, name):
import numpy as np
tensor_image = tensor_image.clone()
tensor_image= tensor_image.add_(1).mul(127.5).squeeze(1).permute(1,2,0)
Image.fromarray(tensor_image.cpu().numpy().astype(np.uint8)).save(name)
def calculate_new_dimensions(canvas_height, canvas_width, height, width, fit_into_canvas, block_size = 16):
if fit_into_canvas == None:
return height, width
if fit_into_canvas:
scale1 = min(canvas_height / height, canvas_width / width)
scale2 = min(canvas_width / height, canvas_height / width)
@ -337,22 +344,3 @@ def create_progress_hook(filename):
return progress_hook(block_num, block_size, total_size, filename)
return hook
def save_quantized_model(model, model_filename, dtype, config_file):
if "quanto" in model_filename:
return
from mmgp import offload
if dtype == torch.bfloat16:
model_filename = model_filename.replace("fp16", "bf16").replace("FP16", "bf16")
elif dtype == torch.float16:
model_filename = model_filename.replace("bf16", "fp16").replace("BF16", "bf16")
for rep in ["mfp16", "fp16", "mbf16", "bf16"]:
if "_" + rep in model_filename:
model_filename = model_filename.replace("_" + rep, "_quanto_" + rep + "_int8")
break
if not "quanto" in model_filename:
pos = model_filename.rfind(".")
model_filename = model_filename[:pos] + "_quanto_int8" + model_filename[pos+1:]
if not os.path.isfile(model_filename):
offload.save_model(model, model_filename, do_quantize= True, config_file_path=config_file)

View File

@ -196,7 +196,7 @@ class VaceVideoProcessor(object):
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
def _get_frameid_bbox(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0, start_frame= 0, canvas_height = 0, canvas_width = 0, fit_into_canvas= True):
def _get_frameid_bbox(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0, start_frame= 0, canvas_height = 0, canvas_width = 0, fit_into_canvas= None):
if self.keep_last:
return self._get_frameid_bbox_adjust_last(fps, video_frames_count, canvas_height, canvas_width, h, w, fit_into_canvas, crop_box, rng, max_frames= max_frames, start_frame= start_frame)
else:
@ -208,23 +208,23 @@ class VaceVideoProcessor(object):
def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs):
return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs)
def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, max_frames= 0, trim_video =0, start_frame = 0, canvas_height = 0, canvas_width = 0, fit_into_canvas = False, **kwargs):
def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, max_frames= 0, trim_video =0, start_frame = 0, canvas_height = 0, canvas_width = 0, fit_into_canvas = None, **kwargs):
rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000)
# read video
import decord
decord.bridge.set_bridge('torch')
readers = []
src_video = None
src_videos = []
for data_k in data_key_batch:
if torch.is_tensor(data_k):
src_video = data_k
src_videos.append(data_k)
else:
reader = decord.VideoReader(data_k)
readers.append(reader)
if src_video != None:
if len(src_videos) >0:
fps = 16
length = src_video.shape[0] + start_frame
length = src_videos[0].shape[0] + start_frame
if len(readers) > 0:
min_readers = min([len(r) for r in readers])
length = min(length, min_readers )
@ -234,17 +234,17 @@ class VaceVideoProcessor(object):
# frame_timestamps = [readers[0].get_frame_timestamp(i) for i in range(length)]
# frame_timestamps = np.array(frame_timestamps, dtype=np.float32)
max_frames = min(max_frames, trim_video) if trim_video > 0 else max_frames
if src_video != None:
src_video = src_video[:max_frames]
h, w = src_video.shape[1:3]
if len(src_videos) >0:
src_videos = [ src_video[:max_frames] for src_video in src_videos]
h, w = src_videos[0].shape[1:3]
else:
h, w = readers[0].next().shape[:2]
frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, length, h, w, crop_box, rng, canvas_height = canvas_height, canvas_width = canvas_width, fit_into_canvas = fit_into_canvas, max_frames=max_frames, start_frame = start_frame )
# preprocess video
videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers]
if src_video != None:
videos = [src_video] + videos
if len(src_videos) >0:
videos = src_videos + videos
videos = [self._video_preprocess(video, oh, ow) for video in videos]
return *videos, frame_ids, (oh, ow), fps
# return videos if len(videos) > 1 else videos[0]

1350
wgp.py

File diff suppressed because it is too large Load Diff