mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
Vace powercharged
This commit is contained in:
parent
826cc3adb7
commit
febeb95767
12
README.md
12
README.md
@ -20,6 +20,18 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
|
||||
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
|
||||
|
||||
## 🔥 Latest Updates
|
||||
### June 17 2025: WanGP v6.1, Vace Powercharged
|
||||
Lots of improvements for Vace the Mother of all Models:
|
||||
- masks can now be combined with on the fly processing of a control video, for instance you can extract the motion of a specific person defined by a mask
|
||||
- on the fly modification of masks : reversed masks (with the same mask you can modify the background instead of the people covered by the masks), enlarged masks (you can cover more area if for instance the person you are trying to inject is larger than the one in the mask), ...
|
||||
- view these modified masks directly inside WanGP during the video generation to check they are really as expected
|
||||
- multiple frames injections: multiples frames can be injected at any location of the video
|
||||
- expand past videos in on click: just select one generated video to expand it
|
||||
|
||||
Of course all these new stuff work on all Vace finetunes (including Vace Fusionix).
|
||||
|
||||
Thanks also to Reevoy24 for adding a Notfication sound at the end of a generation and for fixing the background color of the current generation summary.
|
||||
|
||||
### June 12 2025: WanGP v6.0
|
||||
👋 *Finetune models*: You find the 20 models supported by WanGP not sufficient ? Too impatient to wait for the next release to get the support for a newly released model ? Your prayers have been answered: if a new model is compatible with a model architecture supported by WanGP, you can add by yourself the support for this model in WanGP by just creating a finetune model definition. You can then store this model in the cloud (for instance in Huggingface) and the very light finetune definition file can be easily shared with other users. WanGP will download automatically the finetuned model for them.
|
||||
|
||||
|
||||
@ -75,10 +75,10 @@ If you launch the app with the *--save-quantized* switch, WanGP will create a qu
|
||||
2) Launch WanGP *python wgp.py --save-quantized*
|
||||
3) In the configuration menu *Transformer Data Type* property choose either *BF16* of *FP16*
|
||||
4) Launch a video generation (settings used do not matter). As soon as the model is loaded, a new quantized model will be created in the **ckpts** subfolder if it doesn't already exist.
|
||||
5) To test that this works properly set the local path in the "URLs" key of the finetune definition file. For instance *URLs = ["ckpts/finetune_quanto_fp16_int8.safetensors"]*
|
||||
5) WanGP will update automatically the finetune definition file with the local path of the newly created quantized file (the list "URLs" will have an extra value such as *"ckpts/finetune_quanto_fp16_int8.safetensors"*
|
||||
6) Remove *--save-quantized*, restart WanGP and select *Scaled Int8 Quantization* in the *Transformer Model Quantization* property
|
||||
7) Launch a new generation and verify in the terminal window that the right quantized model is loaded
|
||||
8) In order to share the finetune definition file you will need to store the fine model weights in the cloud. You can upload them for instance on *Huggingface*. You can now replace in the definition file the local path by a URL (on Huggingface to get the URL of the model file click *Copy download link* when accessing the model properties)
|
||||
8) In order to share the finetune definition file you will need to store the fine model weights in the cloud. You can upload them for instance on *Huggingface*. You can now replace in the finetune definition file the local path by a URL (on Huggingface to get the URL of the model file click *Copy download link* when accessing the model properties)
|
||||
|
||||
You need to create a quantized model specifically for *bf16* or *fp16* as they can not converted on the fly. However there is no need for a non quantized model as they can be converted on the fly while being loaded.
|
||||
|
||||
|
||||
@ -1,14 +1,16 @@
|
||||
{
|
||||
"model":
|
||||
{
|
||||
"model": {
|
||||
"name": "Vace FusioniX 14B",
|
||||
"architecture" : "vace_14B",
|
||||
"modules" : ["vace_14B"],
|
||||
"architecture": "vace_14B",
|
||||
"modules": [
|
||||
"vace_14B"
|
||||
],
|
||||
"description": "Vace control model enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail.",
|
||||
"URLs": [
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_fp16.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_bf16_int8.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_fp16_int8.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_bf16_int8.safetensors"
|
||||
"ckpts/Wan14BT2VFusioniX_quanto_fp16_int8.safetensors"
|
||||
],
|
||||
"auto_quantize": true
|
||||
},
|
||||
|
||||
@ -315,7 +315,7 @@ class Inference(object):
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_filepath, base_model_type, text_encoder_filepath, dtype = torch.bfloat16, VAE_dtype = torch.float16, mixed_precision_transformer =torch.bfloat16 , quantizeTransformer = False, save_quantized = False, **kwargs):
|
||||
def from_pretrained(cls, model_filepath, model_type, base_model_type, text_encoder_filepath, dtype = torch.bfloat16, VAE_dtype = torch.float16, mixed_precision_transformer =torch.bfloat16 , quantizeTransformer = False, save_quantized = False, **kwargs):
|
||||
|
||||
device = "cuda"
|
||||
|
||||
@ -392,8 +392,8 @@ class Inference(object):
|
||||
# offload.save_model(model, "hunyuan_video_avatar_edit_720_bf16.safetensors")
|
||||
# offload.save_model(model, "hunyuan_video_avatar_edit_720_quanto_bf16_int8.safetensors", do_quantize= True)
|
||||
if save_quantized:
|
||||
from wan.utils.utils import save_quantized_model
|
||||
save_quantized_model(model, filepath, dtype, None)
|
||||
from wgp import save_quantized_model
|
||||
save_quantized_model(model, model_type, filepath, dtype, None)
|
||||
|
||||
model.mixed_precision = mixed_precision_transformer
|
||||
|
||||
|
||||
@ -287,13 +287,11 @@ class LTXV:
|
||||
height, width = input_video.shape[-2:]
|
||||
else:
|
||||
if image_start != None:
|
||||
image_start = image_start[0]
|
||||
frame_width, frame_height = image_start.size
|
||||
height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas, 32)
|
||||
conditioning_media_paths.append(image_start)
|
||||
conditioning_start_frames.append(0)
|
||||
if image_end != None:
|
||||
image_end = image_end[0]
|
||||
conditioning_media_paths.append(image_end)
|
||||
conditioning_start_frames.append(frame_num-1)
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ import math
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
import cv2
|
||||
matplotlib.use('TkAgg')
|
||||
# matplotlib.use('TkAgg')
|
||||
|
||||
eps = 0.01
|
||||
|
||||
|
||||
@ -354,21 +354,29 @@ def video_matting(video_state, end_slider, matting_type, interactive_state, mask
|
||||
foreground, alpha = matanyone(matanyone_processor, following_frames, template_mask*255, r_erode=erode_kernel_size, r_dilate=dilate_kernel_size)
|
||||
output_frames = []
|
||||
foreground_mat = matting_type == "Foreground"
|
||||
for frame_origin, frame_alpha in zip(following_frames, alpha):
|
||||
if foreground_mat:
|
||||
frame_alpha[frame_alpha > 127] = 255
|
||||
frame_alpha[frame_alpha <= 127] = 0
|
||||
else:
|
||||
if not foreground_mat:
|
||||
new_alpha = []
|
||||
for frame_alpha in alpha:
|
||||
frame_temp = frame_alpha.copy()
|
||||
frame_alpha[frame_temp > 127] = 0
|
||||
frame_alpha[frame_temp <= 127] = 255
|
||||
new_alpha.append(frame_alpha)
|
||||
alpha = new_alpha
|
||||
# for frame_origin, frame_alpha in zip(following_frames, alpha):
|
||||
# if foreground_mat:
|
||||
# frame_alpha[frame_alpha > 127] = 255
|
||||
# frame_alpha[frame_alpha <= 127] = 0
|
||||
# else:
|
||||
# frame_temp = frame_alpha.copy()
|
||||
# frame_alpha[frame_temp > 127] = 0
|
||||
# frame_alpha[frame_temp <= 127] = 255
|
||||
|
||||
output_frame = np.bitwise_and(frame_origin, 255-frame_alpha)
|
||||
frame_grey = frame_alpha.copy()
|
||||
frame_grey[frame_alpha == 255] = 127
|
||||
output_frame += frame_grey
|
||||
output_frames.append(output_frame)
|
||||
foreground = output_frames
|
||||
# output_frame = np.bitwise_and(frame_origin, 255-frame_alpha)
|
||||
# frame_grey = frame_alpha.copy()
|
||||
# frame_grey[frame_alpha == 255] = 127
|
||||
# output_frame += frame_grey
|
||||
# output_frames.append(output_frame)
|
||||
foreground = following_frames
|
||||
|
||||
if not os.path.exists("mask_outputs"):
|
||||
os.makedirs("mask_outputs")
|
||||
@ -465,6 +473,7 @@ def load_unload_models(selected):
|
||||
global model
|
||||
global matanyone_model
|
||||
if selected:
|
||||
# print("Matanyone Tab Selected")
|
||||
if model_loaded:
|
||||
model.samcontroler.sam_controler.model.to(arg_device)
|
||||
matanyone_model.to(arg_device)
|
||||
@ -494,6 +503,7 @@ def load_unload_models(selected):
|
||||
matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg)
|
||||
model_loaded = True
|
||||
else:
|
||||
# print("Matanyone Tab UnSelected")
|
||||
import gc
|
||||
model.samcontroler.sam_controler.model.to("cpu")
|
||||
matanyone_model.to("cpu")
|
||||
@ -520,7 +530,7 @@ def export_image(image_refs, image_output):
|
||||
def export_to_current_video_engine(model_type, foreground_video_output, alpha_video_output):
|
||||
gr.Info("Masked Video Input and Full Mask transferred to Current Video Engine For Inpainting")
|
||||
# return "MV#" + str(time.time()), foreground_video_output, alpha_video_output
|
||||
if "custom_edit" in model_type:
|
||||
if "custom_edit" in model_type and False:
|
||||
return gr.update(), alpha_video_output
|
||||
else:
|
||||
return foreground_video_output, alpha_video_output
|
||||
|
||||
@ -29,7 +29,7 @@ timm
|
||||
segment-anything
|
||||
omegaconf
|
||||
hydra-core
|
||||
librosa
|
||||
librosa==0.11.0
|
||||
loguru
|
||||
sentencepiece
|
||||
av
|
||||
|
||||
@ -29,6 +29,7 @@ class DTT2V:
|
||||
checkpoint_dir,
|
||||
rank=0,
|
||||
model_filename = None,
|
||||
model_type = None,
|
||||
base_model_type = None,
|
||||
save_quantized = False,
|
||||
text_encoder_filename = None,
|
||||
@ -77,8 +78,8 @@ class DTT2V:
|
||||
|
||||
self.model.eval().requires_grad_(False)
|
||||
if save_quantized:
|
||||
from wan.utils.utils import save_quantized_model
|
||||
save_quantized_model(self.model, model_filename[0], dtype, base_config_file)
|
||||
from wgp import save_quantized_model
|
||||
save_quantized_model(self.model, model_type, model_filename[0], dtype, base_config_file)
|
||||
|
||||
self.scheduler = FlowUniPCMultistepScheduler()
|
||||
|
||||
|
||||
@ -49,6 +49,7 @@ class WanI2V:
|
||||
config,
|
||||
checkpoint_dir,
|
||||
model_filename = None,
|
||||
model_type = None,
|
||||
base_model_type= None,
|
||||
text_encoder_filename= None,
|
||||
quantizeTransformer = False,
|
||||
@ -115,8 +116,8 @@ class WanI2V:
|
||||
# offload.save_model(self.model, "wan2.1_Fun_InP_1.3B_bf16_bis.safetensors")
|
||||
self.model.eval().requires_grad_(False)
|
||||
if save_quantized:
|
||||
from wan.utils.utils import save_quantized_model
|
||||
save_quantized_model(self.model, model_filename[0], dtype, base_config_file)
|
||||
from wgp import save_quantized_model
|
||||
save_quantized_model(self.model, model_type, model_filename[0], dtype, base_config_file)
|
||||
|
||||
|
||||
self.sample_neg_prompt = config.sample_neg_prompt
|
||||
|
||||
150
wan/modules/motion_patch.py
Normal file
150
wan/modules/motion_patch.py
Normal file
@ -0,0 +1,150 @@
|
||||
# Copyright (c) 2024-2025 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
|
||||
|
||||
# Refer to https://github.com/Angtian/VoGE/blob/main/VoGE/Utils.py
|
||||
def ind_sel(target: torch.Tensor, ind: torch.Tensor, dim: int = 1):
|
||||
"""
|
||||
:param target: [... (can be k or 1), n > M, ...]
|
||||
:param ind: [... (k), M]
|
||||
:param dim: dim to apply index on
|
||||
:return: sel_target [... (k), M, ...]
|
||||
"""
|
||||
assert (
|
||||
len(ind.shape) > dim
|
||||
), "Index must have the target dim, but get dim: %d, ind shape: %s" % (dim, str(ind.shape))
|
||||
|
||||
target = target.expand(
|
||||
*tuple(
|
||||
[ind.shape[k] if target.shape[k] == 1 else -1 for k in range(dim)]
|
||||
+ [
|
||||
-1,
|
||||
]
|
||||
* (len(target.shape) - dim)
|
||||
)
|
||||
)
|
||||
|
||||
ind_pad = ind
|
||||
|
||||
if len(target.shape) > dim + 1:
|
||||
for _ in range(len(target.shape) - (dim + 1)):
|
||||
ind_pad = ind_pad.unsqueeze(-1)
|
||||
ind_pad = ind_pad.expand(*(-1,) * (dim + 1), *target.shape[(dim + 1) : :])
|
||||
|
||||
return torch.gather(target, dim=dim, index=ind_pad)
|
||||
|
||||
|
||||
def merge_final(vert_attr: torch.Tensor, weight: torch.Tensor, vert_assign: torch.Tensor):
|
||||
"""
|
||||
|
||||
:param vert_attr: [n, d] or [b, n, d] color or feature of each vertex
|
||||
:param weight: [b(optional), w, h, M] weight of selected vertices
|
||||
:param vert_assign: [b(optional), w, h, M] selective index
|
||||
:return:
|
||||
"""
|
||||
target_dim = len(vert_assign.shape) - 1
|
||||
if len(vert_attr.shape) == 2:
|
||||
assert vert_attr.shape[0] > vert_assign.max()
|
||||
# [n, d] ind: [b(optional), w, h, M]-> [b(optional), w, h, M, d]
|
||||
sel_attr = ind_sel(
|
||||
vert_attr[(None,) * target_dim], vert_assign.type(torch.long), dim=target_dim
|
||||
)
|
||||
else:
|
||||
assert vert_attr.shape[1] > vert_assign.max()
|
||||
sel_attr = ind_sel(
|
||||
vert_attr[(slice(None),) + (None,)*(target_dim-1)], vert_assign.type(torch.long), dim=target_dim
|
||||
)
|
||||
|
||||
# [b(optional), w, h, M]
|
||||
final_attr = torch.sum(sel_attr * weight.unsqueeze(-1), dim=-2)
|
||||
return final_attr
|
||||
|
||||
|
||||
def patch_motion(
|
||||
tracks: torch.FloatTensor, # (B, T, N, 4)
|
||||
vid: torch.FloatTensor, # (C, T, H, W)
|
||||
temperature: float = 220.0,
|
||||
training: bool = True,
|
||||
tail_dropout: float = 0.2,
|
||||
vae_divide: tuple = (4, 16),
|
||||
topk: int = 2,
|
||||
):
|
||||
with torch.no_grad():
|
||||
_, T, H, W = vid.shape
|
||||
N = tracks.shape[2]
|
||||
_, tracks, visible = torch.split(
|
||||
tracks, [1, 2, 1], dim=-1
|
||||
) # (B, T, N, 2) | (B, T, N, 1)
|
||||
tracks_n = tracks / torch.tensor([W / min(H, W), H / min(H, W)], device=tracks.device)
|
||||
tracks_n = tracks_n.clamp(-1, 1)
|
||||
visible = visible.clamp(0, 1)
|
||||
|
||||
if tail_dropout > 0 and training:
|
||||
TT = visible.shape[1]
|
||||
rrange = torch.arange(TT, device=visible.device, dtype=visible.dtype)[
|
||||
None, :, None, None
|
||||
]
|
||||
rand_nn = torch.rand_like(visible[:, :1])
|
||||
rand_rr = torch.rand_like(visible[:, :1]) * (TT - 1)
|
||||
visible = visible * (
|
||||
(rand_nn > tail_dropout).type_as(visible)
|
||||
+ (rrange < rand_rr).type_as(visible)
|
||||
).clamp(0, 1)
|
||||
|
||||
xx = torch.linspace(-W / min(H, W), W / min(H, W), W)
|
||||
yy = torch.linspace(-H / min(H, W), H / min(H, W), H)
|
||||
|
||||
grid = torch.stack(torch.meshgrid(yy, xx, indexing="ij")[::-1], dim=-1).to(
|
||||
tracks.device
|
||||
)
|
||||
|
||||
tracks_pad = tracks[:, 1:]
|
||||
visible_pad = visible[:, 1:]
|
||||
|
||||
visible_align = visible_pad.view(T - 1, 4, *visible_pad.shape[2:]).sum(1)
|
||||
tracks_align = (tracks_pad * visible_pad).view(T - 1, 4, *tracks_pad.shape[2:]).sum(
|
||||
1
|
||||
) / (visible_align + 1e-5)
|
||||
dist_ = (
|
||||
(tracks_align[:, None, None] - grid[None, :, :, None]).pow(2).sum(-1)
|
||||
) # T, H, W, N
|
||||
weight = torch.exp(-dist_ * temperature) * visible_align.clamp(0, 1).view(
|
||||
T - 1, 1, 1, N
|
||||
)
|
||||
vert_weight, vert_index = torch.topk(
|
||||
weight, k=min(topk, weight.shape[-1]), dim=-1
|
||||
)
|
||||
|
||||
grid_mode = "bilinear"
|
||||
point_feature = torch.nn.functional.grid_sample(
|
||||
vid[vae_divide[0]:].permute(1, 0, 2, 3)[:1],
|
||||
tracks_n[:, :1].type(vid.dtype),
|
||||
mode=grid_mode,
|
||||
padding_mode="zeros",
|
||||
align_corners=None,
|
||||
)
|
||||
point_feature = point_feature.squeeze(0).squeeze(1).permute(1, 0) # N, C=16
|
||||
|
||||
out_feature = merge_final(point_feature, vert_weight, vert_index).permute(3, 0, 1, 2) # T - 1, H, W, C => C, T - 1, H, W
|
||||
out_weight = vert_weight.sum(-1) # T - 1, H, W
|
||||
|
||||
# out feature -> already soft weighted
|
||||
mix_feature = out_feature + vid[vae_divide[0]:, 1:] * (1 - out_weight.clamp(0, 1))
|
||||
|
||||
out_feature_full = torch.cat([vid[vae_divide[0]:, :1], mix_feature], dim=1) # C, T, H, W
|
||||
out_mask_full = torch.cat([torch.ones_like(out_weight[:1]), out_weight], dim=0) # T, H, W
|
||||
return torch.cat([out_mask_full[None].expand(vae_divide[0], -1, -1, -1), out_feature_full], dim=0)
|
||||
@ -50,6 +50,7 @@ class WanT2V:
|
||||
checkpoint_dir,
|
||||
rank=0,
|
||||
model_filename = None,
|
||||
model_type = None,
|
||||
base_model_type = None,
|
||||
text_encoder_filename = None,
|
||||
quantizeTransformer = False,
|
||||
@ -100,8 +101,8 @@ class WanT2V:
|
||||
# offload.save_model(self.model, "wan2.1_text2video_14B_quanto_mfp16_int8.safetensors", do_quantize=True, config_file_path=base_config_file)
|
||||
self.model.eval().requires_grad_(False)
|
||||
if save_quantized:
|
||||
from wan.utils.utils import save_quantized_model
|
||||
save_quantized_model(self.model, model_filename[1 if base_model_type=="fantasy" else 0], dtype, base_config_file)
|
||||
from wgp import save_quantized_model
|
||||
save_quantized_model(self.model, model_type, model_filename[1 if base_model_type=="fantasy" else 0], dtype, base_config_file)
|
||||
|
||||
self.sample_neg_prompt = config.sample_neg_prompt
|
||||
|
||||
@ -186,7 +187,25 @@ class WanT2V:
|
||||
def vace_latent(self, z, m):
|
||||
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
|
||||
|
||||
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, original_video = False, keep_frames= [], start_frame = 0, fit_into_canvas = True, pre_src_video = None):
|
||||
def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device):
|
||||
ref_width, ref_height = ref_img.size
|
||||
if (ref_height, ref_width) == image_size:
|
||||
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
|
||||
else:
|
||||
canvas_height, canvas_width = image_size
|
||||
scale = min(canvas_height / ref_height, canvas_width / ref_width)
|
||||
new_height = int(ref_height * scale)
|
||||
new_width = int(ref_width * scale)
|
||||
white_canvas = torch.full((3, 1, canvas_height, canvas_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1]
|
||||
ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
|
||||
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
|
||||
top = (canvas_height - new_height) // 2
|
||||
left = (canvas_width - new_width) // 2
|
||||
white_canvas[:, :, top:top + new_height, left:left + new_width] = ref_img
|
||||
ref_img = white_canvas
|
||||
return ref_img.to(device)
|
||||
|
||||
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, original_video = False, keep_frames= [], start_frame = 0, fit_into_canvas = None, pre_src_video = None, inject_frames = []):
|
||||
image_sizes = []
|
||||
trim_video = len(keep_frames)
|
||||
canvas_height, canvas_width = image_size
|
||||
@ -234,25 +253,18 @@ class WanT2V:
|
||||
src_video[i][:, k:k+1] = 0
|
||||
src_mask[i][:, k:k+1] = 1
|
||||
|
||||
for k, frame in enumerate(inject_frames):
|
||||
if frame != None:
|
||||
src_video[i][:, k:k+1] = self.fit_image_into_canvas(frame, image_size, 0, device)
|
||||
src_mask[i][:, k:k+1] = 0
|
||||
|
||||
|
||||
for i, ref_images in enumerate(src_ref_images):
|
||||
if ref_images is not None:
|
||||
image_size = image_sizes[i]
|
||||
for j, ref_img in enumerate(ref_images):
|
||||
if ref_img is not None:
|
||||
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
|
||||
if ref_img.shape[-2:] != image_size:
|
||||
canvas_height, canvas_width = image_size
|
||||
ref_height, ref_width = ref_img.shape[-2:]
|
||||
white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1]
|
||||
scale = min(canvas_height / ref_height, canvas_width / ref_width)
|
||||
new_height = int(ref_height * scale)
|
||||
new_width = int(ref_width * scale)
|
||||
resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1)
|
||||
top = (canvas_height - new_height) // 2
|
||||
left = (canvas_width - new_width) // 2
|
||||
white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image
|
||||
ref_img = white_canvas
|
||||
src_ref_images[i][j] = ref_img.to(device)
|
||||
src_ref_images[i][j] = self.fit_image_into_canvas(ref_img, image_size, 1, device)
|
||||
return src_video, src_mask, src_ref_images
|
||||
|
||||
def decode_latent(self, zs, ref_images=None, tile_size= 0 ):
|
||||
|
||||
209
wan/trajectory_editor/app.py
Normal file
209
wan/trajectory_editor/app.py
Normal file
@ -0,0 +1,209 @@
|
||||
# Copyright (c) 2024-2025 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from flask import Flask, request, jsonify, render_template
|
||||
import os
|
||||
import io
|
||||
import numpy as np
|
||||
import torch
|
||||
import yaml
|
||||
import matplotlib
|
||||
import argparse
|
||||
matplotlib.use('Agg')
|
||||
|
||||
app = Flask(__name__, static_folder='static', template_folder='templates')
|
||||
|
||||
|
||||
# ——— Arguments ———————————————————————————————————
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--save_dir', type=str, default='videos_example')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# ——— Configuration —————————————————————————————
|
||||
BASE_DIR = args.save_dir
|
||||
STATIC_BASE = os.path.join('static', BASE_DIR)
|
||||
IMAGES_DIR = os.path.join(STATIC_BASE, 'images')
|
||||
OVERLAY_DIR = os.path.join(STATIC_BASE, 'images_tracks')
|
||||
TRACKS_DIR = os.path.join(BASE_DIR, 'tracks')
|
||||
YAML_PATH = os.path.join(BASE_DIR, 'test.yaml')
|
||||
IMAGES_DIR_OUT = os.path.join(BASE_DIR, 'images')
|
||||
|
||||
FIXED_LENGTH = 121
|
||||
COLOR_CYCLE = ['r', 'g', 'b', 'c', 'm', 'y', 'k']
|
||||
QUANT_MULTI = 8
|
||||
|
||||
for d in (IMAGES_DIR, TRACKS_DIR, OVERLAY_DIR, IMAGES_DIR_OUT):
|
||||
os.makedirs(d, exist_ok=True)
|
||||
|
||||
# ——— Helpers ———————————————————————————————————————
|
||||
|
||||
|
||||
def array_to_npz_bytes(arr, path, compressed=True, quant_multi=QUANT_MULTI):
|
||||
# pack into uint16 as before
|
||||
arr_q = (quant_multi * arr).astype(np.float32)
|
||||
bio = io.BytesIO()
|
||||
if compressed:
|
||||
np.savez_compressed(bio, array=arr_q)
|
||||
else:
|
||||
np.savez(bio, array=arr_q)
|
||||
torch.save(bio.getvalue(), path)
|
||||
|
||||
|
||||
def load_existing_tracks(path):
|
||||
raw = torch.load(path)
|
||||
bio = io.BytesIO(raw)
|
||||
with np.load(bio) as npz:
|
||||
return npz['array']
|
||||
|
||||
# ——— Routes ———————————————————————————————————————
|
||||
|
||||
|
||||
@app.route('/')
|
||||
def index():
|
||||
return render_template('index.html')
|
||||
|
||||
|
||||
@app.route('/upload_image', methods=['POST'])
|
||||
def upload_image():
|
||||
f = request.files['image']
|
||||
from PIL import Image
|
||||
img = Image.open(f.stream)
|
||||
orig_w, orig_h = img.size
|
||||
|
||||
idx = len(os.listdir(IMAGES_DIR)) + 1
|
||||
ext = f.filename.rsplit('.', 1)[-1]
|
||||
fname = f"{idx:02d}.{ext}"
|
||||
img.save(os.path.join(IMAGES_DIR, fname))
|
||||
img.save(os.path.join(IMAGES_DIR_OUT, fname))
|
||||
|
||||
return jsonify({
|
||||
'image_url': f"{STATIC_BASE}/images/{fname}",
|
||||
'image_id': idx,
|
||||
'ext': ext,
|
||||
'orig_width': orig_w,
|
||||
'orig_height': orig_h
|
||||
})
|
||||
|
||||
|
||||
@app.route('/store_tracks', methods=['POST'])
|
||||
def store_tracks():
|
||||
data = request.get_json()
|
||||
image_id = data['image_id']
|
||||
ext = data['ext']
|
||||
free_tracks = data.get('tracks', [])
|
||||
circ_trajs = data.get('circle_trajectories', [])
|
||||
|
||||
# Debug lengths
|
||||
for i, tr in enumerate(free_tracks, 1):
|
||||
print(f"Freehand Track {i}: {len(tr)} points")
|
||||
for i, tr in enumerate(circ_trajs, 1):
|
||||
print(f"Circle/Static Traj {i}: {len(tr)} points")
|
||||
|
||||
def pad_pts(tr):
|
||||
"""Convert list of {x,y} to (FIXED_LENGTH,1,3) array, padding/truncating."""
|
||||
pts = np.array([[p['x'], p['y'], 1] for p in tr], dtype=np.float32)
|
||||
n = pts.shape[0]
|
||||
if n < FIXED_LENGTH:
|
||||
pad = np.zeros((FIXED_LENGTH - n, 3), dtype=np.float32)
|
||||
pts = np.vstack((pts, pad))
|
||||
else:
|
||||
pts = pts[:FIXED_LENGTH]
|
||||
return pts.reshape(FIXED_LENGTH, 1, 3)
|
||||
|
||||
arrs = []
|
||||
|
||||
# 1) Freehand tracks
|
||||
for i, tr in enumerate(free_tracks):
|
||||
pts = pad_pts(tr)
|
||||
arrs.append(pts,)
|
||||
|
||||
# 2) Circle + Static combined
|
||||
for i, tr in enumerate(circ_trajs):
|
||||
pts = pad_pts(tr)
|
||||
|
||||
arrs.append(pts)
|
||||
print(arrs)
|
||||
# Nothing to save?
|
||||
if not arrs:
|
||||
overlay_file = f"{image_id:02d}.png"
|
||||
return jsonify({
|
||||
'status': 'ok',
|
||||
'overlay_url': f"{STATIC_BASE}/images_tracks/{overlay_file}"
|
||||
})
|
||||
|
||||
new_tracks = np.stack(arrs, axis=0) # (T_new, FIXED_LENGTH,1,4)
|
||||
|
||||
# Load existing .pth and pad old channels to 4 if needed
|
||||
track_path = os.path.join(TRACKS_DIR, f"{image_id:02d}.pth")
|
||||
if os.path.exists(track_path):
|
||||
# shape (T_old, FIXED_LENGTH,1,3) or (...,4)
|
||||
old = load_existing_tracks(track_path)
|
||||
if old.ndim == 4 and old.shape[-1] == 3:
|
||||
pad = np.zeros(
|
||||
(old.shape[0], old.shape[1], old.shape[2], 1), dtype=np.float32)
|
||||
old = np.concatenate((old, pad), axis=-1)
|
||||
all_tracks = np.concatenate([old, new_tracks], axis=0)
|
||||
else:
|
||||
all_tracks = new_tracks
|
||||
|
||||
# Save updated track file
|
||||
array_to_npz_bytes(all_tracks, track_path, compressed=True)
|
||||
|
||||
# Build overlay PNG
|
||||
img_path = os.path.join(IMAGES_DIR, f"{image_id:02d}.{ext}")
|
||||
img = plt.imread(img_path)
|
||||
fig, ax = plt.subplots(figsize=(12, 8))
|
||||
ax.imshow(img)
|
||||
for t in all_tracks:
|
||||
coords = t[:, 0, :] # (FIXED_LENGTH,4)
|
||||
ax.plot(coords[:, 0][coords[:, 2] > 0.5], coords[:, 1]
|
||||
[coords[:, 2] > 0.5], marker='o', color=COLOR_CYCLE[0])
|
||||
ax.axis('off')
|
||||
overlay_file = f"{image_id:02d}.png"
|
||||
fig.savefig(os.path.join(OVERLAY_DIR, overlay_file),
|
||||
bbox_inches='tight', pad_inches=0)
|
||||
plt.close(fig)
|
||||
|
||||
# Update YAML (unchanged)
|
||||
entry = {
|
||||
"image": os.path.join(f"tools/trajectory_editor/{BASE_DIR}/images/{image_id:02d}.{ext}"),
|
||||
"text": None,
|
||||
"track": os.path.join(f"tools/trajectory_editor/{BASE_DIR}/tracks/{image_id:02d}.pth")
|
||||
}
|
||||
if os.path.exists(YAML_PATH):
|
||||
with open(YAML_PATH) as yf:
|
||||
docs = yaml.safe_load(yf) or []
|
||||
else:
|
||||
docs = []
|
||||
|
||||
for e in docs:
|
||||
if e.get("image", "").endswith(f"{image_id:02d}.{ext}"):
|
||||
e.update(entry)
|
||||
break
|
||||
else:
|
||||
docs.append(entry)
|
||||
|
||||
with open(YAML_PATH, 'w') as yf:
|
||||
yaml.dump(docs, yf, default_flow_style=False)
|
||||
|
||||
return jsonify({
|
||||
'status': 'ok',
|
||||
'overlay_url': f"{STATIC_BASE}/images_tracks/{overlay_file}"
|
||||
})
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(debug=True)
|
||||
571
wan/trajectory_editor/templates/index.html
Normal file
571
wan/trajectory_editor/templates/index.html
Normal file
@ -0,0 +1,571 @@
|
||||
<!-- Copyright (c) 2024-2025 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<title>Track Point Editor</title>
|
||||
<style>
|
||||
.btn-row {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin: 8px 0;
|
||||
}
|
||||
.btn-row > * { margin-right: 12px; }
|
||||
body { font-family: sans-serif; margin: 16px; }
|
||||
#topControls, #bottomControls { margin-bottom: 12px; }
|
||||
button, input, select, label { margin: 4px; }
|
||||
#canvas { border:1px solid #ccc; display: block; margin: auto; }
|
||||
#canvas { cursor: crosshair; }
|
||||
#trajProgress { width: 200px; height: 16px; margin-left:12px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h2>Track Point Editor</h2>
|
||||
|
||||
<!-- Top controls -->
|
||||
<div id="topControls" class="btn-row">
|
||||
<input type="file" id="fileInput" accept="image/*">
|
||||
<button id="storeBtn">Store Tracks</button>
|
||||
</div>
|
||||
|
||||
<!-- Main drawing canvas -->
|
||||
<canvas id="canvas"></canvas>
|
||||
|
||||
<!-- Track controls -->
|
||||
<div id="bottomControls">
|
||||
<div class="btn-row">
|
||||
<button id="addTrackBtn">Add Freehand Track</button>
|
||||
<button id="deleteLastBtn">Delete Last Track</button>
|
||||
<progress id="trajProgress" max="121" value="0" style="display:none;"></progress>
|
||||
</div>
|
||||
<div class="btn-row">
|
||||
<button id="placeCircleBtn">Place Circle</button>
|
||||
<button id="addCirclePointBtn">Add Circle Point</button>
|
||||
<label>Radius:
|
||||
<input type="range" id="radiusSlider" min="10" max="800" value="50" style="display:none;">
|
||||
</label>
|
||||
</div>
|
||||
<div class="btn-row">
|
||||
<button id="addStaticBtn">Add Static Point</button>
|
||||
<label>Static Frames:
|
||||
<input type="number" id="staticFramesInput" value="121" min="1" style="width:60px">
|
||||
</label>
|
||||
</div>
|
||||
<div class="btn-row">
|
||||
<select id="trackSelect" style="min-width:160px;"></select>
|
||||
<div id="colorIndicator"
|
||||
style="
|
||||
width:16px;
|
||||
height:16px;
|
||||
border:1px solid #444;
|
||||
display:inline-block;
|
||||
vertical-align:middle;
|
||||
margin-left:8px;
|
||||
pointer-events:none;
|
||||
visibility:hidden;
|
||||
">
|
||||
</div>
|
||||
<button id="deleteTrackBtn">Delete Selected</button>
|
||||
<button id="editTrackBtn">Edit Track</button>
|
||||
<button id="duplicateTrackBtn">Duplicate Track</button>
|
||||
</div>
|
||||
<!-- Global motion offset -->
|
||||
<div class="btn-row">
|
||||
<label>Motion X (px/frame):
|
||||
<input type="number" id="motionXInput" value="0" style="width:60px">
|
||||
</label>
|
||||
<label>Motion Y (px/frame):
|
||||
<input type="number" id="motionYInput" value="0" style="width:60px">
|
||||
</label>
|
||||
<button id="applySelectedMotionBtn">Add to Selected</button>
|
||||
<button id="applyAllMotionBtn">Add to All</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
<script>
|
||||
// ——— DOM refs —————————————————————————————————————————
|
||||
const canvas = document.getElementById('canvas'),
|
||||
ctx = canvas.getContext('2d'),
|
||||
fileIn = document.getElementById('fileInput'),
|
||||
storeBtn = document.getElementById('storeBtn'),
|
||||
addTrackBtn = document.getElementById('addTrackBtn'),
|
||||
deleteLastBtn = document.getElementById('deleteLastBtn'),
|
||||
placeCircleBtn = document.getElementById('placeCircleBtn'),
|
||||
addCirclePointBtn = document.getElementById('addCirclePointBtn'),
|
||||
addStaticBtn = document.getElementById('addStaticBtn'),
|
||||
staticFramesInput = document.getElementById('staticFramesInput'),
|
||||
radiusSlider = document.getElementById('radiusSlider'),
|
||||
trackSelect = document.getElementById('trackSelect'),
|
||||
deleteTrackBtn = document.getElementById('deleteTrackBtn'),
|
||||
editTrackBtn = document.getElementById('editTrackBtn'),
|
||||
duplicateTrackBtn = document.getElementById('duplicateTrackBtn'),
|
||||
trajProg = document.getElementById('trajProgress'),
|
||||
colorIndicator = document.getElementById('colorIndicator'),
|
||||
motionXInput = document.getElementById('motionXInput'),
|
||||
motionYInput = document.getElementById('motionYInput'),
|
||||
applySelectedMotionBtn = document.getElementById('applySelectedMotionBtn'),
|
||||
applyAllMotionBtn = document.getElementById('applyAllMotionBtn');
|
||||
|
||||
let img, image_id, ext, origW, origH,
|
||||
scaleX=1, scaleY=1;
|
||||
|
||||
// track data
|
||||
let free_tracks = [], current_track = [], drawing=false, motionCounter=0;
|
||||
let circle=null, static_trajs=[];
|
||||
let mode='', selectedTrack=null, editMode=false, editInfo=null, duplicateBuffer=null;
|
||||
const COLORS=['red','green','blue','cyan','magenta','yellow','black'],
|
||||
FIXED_LENGTH=121,
|
||||
editSigma = 5/Math.sqrt(2*Math.log(2));
|
||||
|
||||
// ——— Upload & scale image ————————————————————————————
|
||||
fileIn.addEventListener('change', async e => {
|
||||
const f = e.target.files[0]; if (!f) return;
|
||||
const fd = new FormData(); fd.append('image',f);
|
||||
const res = await fetch('/upload_image',{method:'POST',body:fd});
|
||||
const js = await res.json();
|
||||
image_id=js.image_id; ext=js.ext;
|
||||
origW=js.orig_width; origH=js.orig_height;
|
||||
if(origW>=origH){
|
||||
canvas.width=800; canvas.height=Math.round(origH*800/origW);
|
||||
} else {
|
||||
canvas.height=800; canvas.width=Math.round(origW*800/origH);
|
||||
}
|
||||
scaleX=origW/canvas.width; scaleY=origH/canvas.height;
|
||||
img=new Image(); img.src=js.image_url;
|
||||
img.onload=()=>{
|
||||
free_tracks=[]; current_track=[];
|
||||
circle=null; static_trajs=[];
|
||||
mode=selectedTrack=''; editMode=false; editInfo=null; duplicateBuffer=null;
|
||||
trajProg.style.display='none';
|
||||
radiusSlider.style.display='none';
|
||||
trackSelect.innerHTML='';
|
||||
redraw();
|
||||
};
|
||||
});
|
||||
|
||||
// ——— Store tracks + depth —————————————————————————
|
||||
storeBtn.onclick = async () => {
|
||||
if(!image_id) return alert('Load an image first');
|
||||
const fh = free_tracks.map(tr=>tr.map(p=>({x:p.x*scaleX,y:p.y*scaleY}))),
|
||||
ct = (circle?.trajectories||[]).map(tr=>tr.map(p=>({x:p.x*scaleX,y:p.y*scaleY}))),
|
||||
st = static_trajs.map(tr=>tr.map(p=>({x:p.x*scaleX,y:p.y*scaleY})));
|
||||
const payload = {
|
||||
image_id, ext,
|
||||
tracks: fh,
|
||||
circle_trajectories: ct.concat(st)
|
||||
};
|
||||
const res = await fetch('/store_tracks',{
|
||||
method:'POST',
|
||||
headers:{'Content-Type':'application/json'},
|
||||
body: JSON.stringify(payload)
|
||||
});
|
||||
const js = await res.json();
|
||||
img.src=js.overlay_url;
|
||||
img.onload=()=>ctx.drawImage(img,0,0,canvas.width,canvas.height);
|
||||
|
||||
// reset UI
|
||||
free_tracks=[]; circle=null; static_trajs=[];
|
||||
mode=selectedTrack=''; editMode=false; editInfo=null; duplicateBuffer=null;
|
||||
trajProg.style.display='none';
|
||||
radiusSlider.style.display='none';
|
||||
trackSelect.innerHTML='';
|
||||
redraw();
|
||||
};
|
||||
|
||||
// ——— Control buttons —————————————————————————————
|
||||
addTrackBtn.onclick = ()=>{
|
||||
mode='free'; drawing=true; current_track=[]; motionCounter=0;
|
||||
trajProg.max=FIXED_LENGTH; trajProg.value=0;
|
||||
trajProg.style.display='inline-block';
|
||||
};
|
||||
deleteLastBtn.onclick = ()=>{
|
||||
if(drawing){
|
||||
drawing=false; current_track=[]; trajProg.style.display='none';
|
||||
} else if(free_tracks.length){
|
||||
free_tracks.pop(); updateTrackSelect(); redraw();
|
||||
}
|
||||
updateColorIndicator();
|
||||
};
|
||||
placeCircleBtn.onclick = ()=>{ mode='placeCircle'; drawing=false; };
|
||||
addCirclePointBtn.onclick = ()=>{ if(!circle) alert('Place circle first'); else mode='addCirclePt'; };
|
||||
addStaticBtn.onclick = ()=>{ mode='placeStatic'; };
|
||||
duplicateTrackBtn.onclick = ()=>{
|
||||
if(!selectedTrack) return alert('Select a track first');
|
||||
const arr = selectedTrack.type==='free'
|
||||
? free_tracks[selectedTrack.idx]
|
||||
: selectedTrack.type==='circle'
|
||||
? circle.trajectories[selectedTrack.idx]
|
||||
: static_trajs[selectedTrack.idx];
|
||||
duplicateBuffer = arr.map(p=>({x:p.x,y:p.y}));
|
||||
mode='duplicate'; canvas.style.cursor='copy';
|
||||
};
|
||||
|
||||
radiusSlider.oninput = ()=>{
|
||||
if(!circle) return;
|
||||
circle.radius = +radiusSlider.value;
|
||||
circle.trajectories.forEach((traj,i)=>{
|
||||
const θ = circle.angles[i];
|
||||
traj.push({
|
||||
x: circle.cx + Math.cos(θ)*circle.radius,
|
||||
y: circle.cy + Math.sin(θ)*circle.radius
|
||||
});
|
||||
});
|
||||
if(selectedTrack?.type==='circle')
|
||||
trajProg.value = circle.trajectories[selectedTrack.idx].length;
|
||||
redraw();
|
||||
};
|
||||
|
||||
deleteTrackBtn.onclick = ()=>{
|
||||
if(!selectedTrack) return;
|
||||
const {type,idx} = selectedTrack;
|
||||
if(type==='free') free_tracks.splice(idx,1);
|
||||
else if(type==='circle'){
|
||||
circle.trajectories.splice(idx,1);
|
||||
circle.angles.splice(idx,1);
|
||||
} else {
|
||||
static_trajs.splice(idx,1);
|
||||
}
|
||||
selectedTrack=null;
|
||||
trajProg.style.display='none';
|
||||
updateTrackSelect();
|
||||
redraw();
|
||||
updateColorIndicator();
|
||||
};
|
||||
|
||||
editTrackBtn.onclick = ()=>{
|
||||
if(!selectedTrack) return alert('Select a track first');
|
||||
editMode=!editMode;
|
||||
editTrackBtn.textContent = editMode?'Stop Editing':'Edit Track';
|
||||
};
|
||||
|
||||
// ——— Track select & depth init —————————————————————
|
||||
function updateTrackSelect(){
|
||||
trackSelect.innerHTML='';
|
||||
free_tracks.forEach((_,i)=>{
|
||||
const o=document.createElement('option');
|
||||
o.value=JSON.stringify({type:'free',idx:i});
|
||||
o.textContent=`Point ${i+1}`;
|
||||
trackSelect.appendChild(o);
|
||||
});
|
||||
if(circle){
|
||||
circle.trajectories.forEach((_,i)=>{
|
||||
const o=document.createElement('option');
|
||||
o.value=JSON.stringify({type:'circle',idx:i});
|
||||
o.textContent=`CirclePt ${i+1}`;
|
||||
trackSelect.appendChild(o);
|
||||
});
|
||||
}
|
||||
static_trajs.forEach((_,i)=>{
|
||||
const o=document.createElement('option');
|
||||
o.value=JSON.stringify({type:'static',idx:i});
|
||||
o.textContent=`StaticPt ${i+1}`;
|
||||
trackSelect.appendChild(o);
|
||||
});
|
||||
if(trackSelect.options.length){
|
||||
trackSelect.selectedIndex=0;
|
||||
trackSelect.onchange();
|
||||
}
|
||||
updateColorIndicator();
|
||||
}
|
||||
|
||||
function applyMotionToTrajectory(traj, dx, dy) {
|
||||
traj.forEach((pt, frameIdx) => {
|
||||
pt.x += dx * frameIdx;
|
||||
pt.y += dy * frameIdx;
|
||||
});
|
||||
}
|
||||
|
||||
applySelectedMotionBtn.onclick = () => {
|
||||
if (!selectedTrack) {
|
||||
return alert('Please select a track first');
|
||||
}
|
||||
const dx = parseFloat(motionXInput.value) || 0;
|
||||
const dy = parseFloat(motionYInput.value) || 0;
|
||||
|
||||
// pick the underlying array
|
||||
let arr = null;
|
||||
if (selectedTrack.type === 'free') {
|
||||
arr = free_tracks[selectedTrack.idx];
|
||||
} else if (selectedTrack.type === 'circle') {
|
||||
arr = circle.trajectories[selectedTrack.idx];
|
||||
} else { // 'static'
|
||||
arr = static_trajs[selectedTrack.idx];
|
||||
}
|
||||
|
||||
applyMotionToTrajectory(arr, dx, dy);
|
||||
redraw();
|
||||
};
|
||||
|
||||
// 2) Add motion to every track on the canvas
|
||||
applyAllMotionBtn.onclick = () => {
|
||||
const dx = parseFloat(motionXInput.value) || 0;
|
||||
const dy = parseFloat(motionYInput.value) || 0;
|
||||
|
||||
// freehand tracks
|
||||
free_tracks.forEach(tr => applyMotionToTrajectory(tr, dx, dy));
|
||||
// circle‑based tracks
|
||||
if (circle) {
|
||||
circle.trajectories.forEach(tr => applyMotionToTrajectory(tr, dx, dy));
|
||||
}
|
||||
// static points (now will move over frames)
|
||||
static_trajs.forEach(tr => applyMotionToTrajectory(tr, dx, dy));
|
||||
|
||||
redraw();
|
||||
};
|
||||
|
||||
trackSelect.onchange = ()=>{
|
||||
if(!trackSelect.value){
|
||||
selectedTrack=null;
|
||||
trajProg.style.display='none';
|
||||
return;
|
||||
}
|
||||
selectedTrack = JSON.parse(trackSelect.value);
|
||||
|
||||
if(selectedTrack.type==='circle'){
|
||||
trajProg.style.display='inline-block';
|
||||
trajProg.max=FIXED_LENGTH;
|
||||
trajProg.value=circle.trajectories[selectedTrack.idx].length;
|
||||
} else if(selectedTrack.type==='free'){
|
||||
trajProg.style.display='inline-block';
|
||||
trajProg.max=FIXED_LENGTH;
|
||||
trajProg.value=free_tracks[selectedTrack.idx].length;
|
||||
} else {
|
||||
trajProg.style.display='none';
|
||||
}
|
||||
updateColorIndicator();
|
||||
};
|
||||
|
||||
// ——— Canvas drawing ————————————————————————————————
|
||||
canvas.addEventListener('mousedown', e=>{
|
||||
const r=canvas.getBoundingClientRect(),
|
||||
x=e.clientX-r.left, y=e.clientY-r.top;
|
||||
|
||||
// place circle
|
||||
if(mode==='placeCircle'){
|
||||
circle={cx:x,cy:y,radius:50,angles:[],trajectories:[]};
|
||||
radiusSlider.max=Math.min(canvas.width,canvas.height)|0;
|
||||
radiusSlider.value=50; radiusSlider.style.display='inline';
|
||||
mode=''; updateTrackSelect(); redraw(); return;
|
||||
}
|
||||
// add circle point
|
||||
if(mode==='addCirclePt'){
|
||||
const dx=x-circle.cx, dy=y-circle.cy;
|
||||
const θ=Math.atan2(dy,dx);
|
||||
const px=circle.cx+Math.cos(θ)*circle.radius;
|
||||
const py=circle.cy+Math.sin(θ)*circle.radius;
|
||||
circle.angles.push(θ);
|
||||
circle.trajectories.push([{x:px,y:py}]);
|
||||
mode=''; updateTrackSelect(); redraw(); return;
|
||||
}
|
||||
// add static
|
||||
if (mode === 'placeStatic') {
|
||||
// how many frames to “hold” the point
|
||||
const len = parseInt(staticFramesInput.value, 10) || FIXED_LENGTH;
|
||||
// duplicate the click‐point len times
|
||||
const traj = Array.from({ length: len }, () => ({ x, y }));
|
||||
// push into free_tracks so it's drawn & edited just like any freehand curve
|
||||
free_tracks.push(traj);
|
||||
|
||||
// reset state
|
||||
mode = '';
|
||||
updateTrackSelect();
|
||||
redraw();
|
||||
return;
|
||||
}
|
||||
// duplicate
|
||||
if(mode==='duplicate' && duplicateBuffer){
|
||||
const orig = duplicateBuffer;
|
||||
// click defines translation by first point
|
||||
const dx = x - orig[0].x, dy = y - orig[0].y;
|
||||
const newTr = orig.map(p=>({x:p.x+dx, y:p.y+dy}));
|
||||
free_tracks.push(newTr);
|
||||
mode=''; duplicateBuffer=null; canvas.style.cursor='crosshair';
|
||||
updateTrackSelect(); redraw(); return;
|
||||
}
|
||||
// editing
|
||||
if(editMode && selectedTrack){
|
||||
const arr = selectedTrack.type==='free'
|
||||
? free_tracks[selectedTrack.idx]
|
||||
: selectedTrack.type==='circle'
|
||||
? circle.trajectories[selectedTrack.idx]
|
||||
: static_trajs[selectedTrack.idx];
|
||||
let best=0,bd=Infinity;
|
||||
arr.forEach((p,i)=>{
|
||||
const d=(p.x-x)**2+(p.y-y)**2;
|
||||
if(d<bd){ bd=d; best=i; }
|
||||
});
|
||||
editInfo={ trackType:selectedTrack.type,
|
||||
trackIdx:selectedTrack.idx,
|
||||
ptIdx:best,
|
||||
startX:x, startY:y };
|
||||
return;
|
||||
}
|
||||
// freehand start
|
||||
if(mode==='free'){
|
||||
drawing=true; motionCounter=0;
|
||||
current_track=[{x,y}];
|
||||
redraw();
|
||||
}
|
||||
});
|
||||
|
||||
canvas.addEventListener('mousemove', e=>{
|
||||
const r=canvas.getBoundingClientRect(),
|
||||
x=e.clientX-r.left, y=e.clientY-r.top;
|
||||
// edit mode
|
||||
if(editMode && editInfo){
|
||||
const dx=x-editInfo.startX,
|
||||
dy=y-editInfo.startY;
|
||||
const {trackType,trackIdx,ptIdx} = editInfo;
|
||||
const arr = trackType==='free'
|
||||
? free_tracks[trackIdx]
|
||||
: trackType==='circle'
|
||||
? circle.trajectories[trackIdx]
|
||||
: static_trajs[trackIdx];
|
||||
arr.forEach((p,i)=>{
|
||||
const d=i-ptIdx;
|
||||
const w=Math.exp(-0.5*(d*d)/(editSigma*editSigma));
|
||||
p.x+=dx*w; p.y+=dy*w;
|
||||
});
|
||||
editInfo.startX=x; editInfo.startY=y;
|
||||
if(selectedTrack?.type==='circle')
|
||||
trajProg.value=circle.trajectories[selectedTrack.idx].length;
|
||||
redraw(); return;
|
||||
}
|
||||
// freehand draw
|
||||
if(drawing && (e.buttons&1)){
|
||||
motionCounter++;
|
||||
if(motionCounter%2===0){
|
||||
current_track.push({x,y});
|
||||
trajProg.value = Math.min(current_track.length, trajProg.max);
|
||||
redraw();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
canvas.addEventListener('mouseup', ()=>{
|
||||
if(editMode && editInfo){ editInfo=null; return; }
|
||||
if(drawing){
|
||||
free_tracks.push(current_track.slice());
|
||||
drawing=false; current_track=[];
|
||||
updateTrackSelect(); redraw();
|
||||
}
|
||||
});
|
||||
|
||||
function updateColorIndicator() {
|
||||
const idx = trackSelect.selectedIndex;
|
||||
if (idx < 0) {
|
||||
colorIndicator.style.visibility = 'hidden';
|
||||
return;
|
||||
}
|
||||
// Pick the color by index
|
||||
const col = COLORS[idx % COLORS.length];
|
||||
colorIndicator.style.backgroundColor = col;
|
||||
colorIndicator.style.visibility = 'visible';
|
||||
}
|
||||
|
||||
// ——— redraw ———
|
||||
function redraw(){
|
||||
ctx.clearRect(0, 0, canvas.width, canvas.height);
|
||||
if (img.complete) ctx.drawImage(img, 0, 0, canvas.width, canvas.height);
|
||||
|
||||
// set a fatter line for all strokes
|
||||
ctx.lineWidth = 2;
|
||||
|
||||
// — freehand (and static‑turned‑freehand) tracks —
|
||||
free_tracks.forEach((tr, i) => {
|
||||
const col = COLORS[i % COLORS.length];
|
||||
ctx.strokeStyle = col;
|
||||
ctx.fillStyle = col;
|
||||
|
||||
if (tr.length === 0) return;
|
||||
|
||||
// check if every point equals the first
|
||||
const allSame = tr.every(p => p.x === tr[0].x && p.y === tr[0].y);
|
||||
|
||||
if (allSame) {
|
||||
// draw a filled circle for a “static” dot
|
||||
ctx.beginPath();
|
||||
ctx.arc(tr[0].x, tr[0].y, 4, 0, 2 * Math.PI);
|
||||
ctx.fill();
|
||||
} else {
|
||||
// normal polyline
|
||||
ctx.beginPath();
|
||||
tr.forEach((p, j) =>
|
||||
j ? ctx.lineTo(p.x, p.y) : ctx.moveTo(p.x, p.y)
|
||||
);
|
||||
ctx.stroke();
|
||||
}
|
||||
});
|
||||
|
||||
if(drawing && current_track.length){
|
||||
ctx.strokeStyle='black';
|
||||
ctx.beginPath();
|
||||
current_track.forEach((p,j)=>
|
||||
j? ctx.lineTo(p.x,p.y): ctx.moveTo(p.x,p.y));
|
||||
ctx.stroke();
|
||||
}
|
||||
|
||||
// — circle trajectories —
|
||||
if (circle) {
|
||||
// circle outline
|
||||
ctx.strokeStyle = 'white';
|
||||
ctx.lineWidth = 1;
|
||||
ctx.beginPath();
|
||||
ctx.arc(circle.cx, circle.cy, circle.radius, 0, 2 * Math.PI);
|
||||
ctx.stroke();
|
||||
|
||||
circle.trajectories.forEach((tr, i) => {
|
||||
const col = COLORS[(free_tracks.length + i) % COLORS.length];
|
||||
ctx.strokeStyle = col;
|
||||
ctx.fillStyle = col;
|
||||
ctx.lineWidth = 2;
|
||||
|
||||
if (tr.length <= 1) {
|
||||
// single‑point circle trajectory → dot
|
||||
ctx.beginPath();
|
||||
ctx.arc(tr[0].x, tr[0].y, 4, 0, 2 * Math.PI);
|
||||
ctx.fill();
|
||||
} else {
|
||||
// normal circle track
|
||||
ctx.beginPath();
|
||||
tr.forEach((p, j) =>
|
||||
j ? ctx.lineTo(p.x, p.y) : ctx.moveTo(p.x, p.y)
|
||||
);
|
||||
ctx.stroke();
|
||||
|
||||
// white handle at last point
|
||||
const lp = tr[tr.length - 1];
|
||||
ctx.fillStyle = 'white';
|
||||
ctx.beginPath();
|
||||
ctx.arc(lp.x, lp.y, 4, 0, 2 * Math.PI);
|
||||
ctx.fill();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// — static_trajs (if you still use them separately) —
|
||||
static_trajs.forEach((tr, i) => {
|
||||
const p = tr[0];
|
||||
ctx.fillStyle = 'orange';
|
||||
ctx.beginPath();
|
||||
ctx.arc(p.x, p.y, 5, 0, 2 * Math.PI);
|
||||
ctx.fill();
|
||||
});
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
74
wan/utils/motion.py
Normal file
74
wan/utils/motion.py
Normal file
@ -0,0 +1,74 @@
|
||||
# Copyright (c) 2024-2025 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os, io
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def get_tracks_inference(tracks, height, width, quant_multi: Optional[int] = 8, **kwargs):
|
||||
if isinstance(tracks, str):
|
||||
tracks = torch.load(tracks)
|
||||
|
||||
tracks_np = unzip_to_array(tracks)
|
||||
|
||||
tracks = process_tracks(
|
||||
tracks_np, (width, height), quant_multi=quant_multi, **kwargs
|
||||
)
|
||||
|
||||
return tracks
|
||||
|
||||
|
||||
def unzip_to_array(
|
||||
data: bytes, key: Union[str, List[str]] = "array"
|
||||
) -> Union[np.ndarray, Dict[str, np.ndarray]]:
|
||||
bytes_io = io.BytesIO(data)
|
||||
|
||||
if isinstance(key, str):
|
||||
# Load the NPZ data from the BytesIO object
|
||||
with np.load(bytes_io) as data:
|
||||
return data[key]
|
||||
else:
|
||||
get = {}
|
||||
with np.load(bytes_io) as data:
|
||||
for k in key:
|
||||
get[k] = data[k]
|
||||
return get
|
||||
|
||||
|
||||
def process_tracks(tracks_np: np.ndarray, frame_size: Tuple[int, int], quant_multi: int = 8, **kwargs):
|
||||
# tracks: shape [t, h, w, 3] => samples align with 24 fps, model trained with 16 fps.
|
||||
# frame_size: tuple (W, H)
|
||||
|
||||
tracks = torch.from_numpy(tracks_np).float() / quant_multi
|
||||
if tracks.shape[1] == 121:
|
||||
tracks = torch.permute(tracks, (1, 0, 2, 3))
|
||||
tracks, visibles = tracks[..., :2], tracks[..., 2:3]
|
||||
short_edge = min(*frame_size)
|
||||
|
||||
tracks = tracks - torch.tensor([*frame_size]).type_as(tracks) / 2
|
||||
tracks = tracks / short_edge * 2
|
||||
|
||||
visibles = visibles * 2 - 1
|
||||
|
||||
trange = torch.linspace(-1, 1, tracks.shape[0]).view(-1, 1, 1, 1).expand(*visibles.shape)
|
||||
|
||||
out_ = torch.cat([trange, tracks, visibles], dim=-1).view(121, -1, 4)
|
||||
out_0 = out_[:1]
|
||||
out_l = out_[1:] # 121 => 120 | 1
|
||||
out_l = torch.repeat_interleave(out_l, 2, dim=0)[1::3] # 120 => 240 => 80
|
||||
return torch.cat([out_0, out_l], dim=0)
|
||||
@ -78,8 +78,15 @@ def remove_background(img, session=None):
|
||||
img = remove(img, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
|
||||
return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0)
|
||||
|
||||
def save_image(tensor_image, name):
|
||||
import numpy as np
|
||||
tensor_image = tensor_image.clone()
|
||||
tensor_image= tensor_image.add_(1).mul(127.5).squeeze(1).permute(1,2,0)
|
||||
Image.fromarray(tensor_image.cpu().numpy().astype(np.uint8)).save(name)
|
||||
|
||||
def calculate_new_dimensions(canvas_height, canvas_width, height, width, fit_into_canvas, block_size = 16):
|
||||
if fit_into_canvas == None:
|
||||
return height, width
|
||||
if fit_into_canvas:
|
||||
scale1 = min(canvas_height / height, canvas_width / width)
|
||||
scale2 = min(canvas_width / height, canvas_height / width)
|
||||
@ -337,22 +344,3 @@ def create_progress_hook(filename):
|
||||
return progress_hook(block_num, block_size, total_size, filename)
|
||||
return hook
|
||||
|
||||
def save_quantized_model(model, model_filename, dtype, config_file):
|
||||
if "quanto" in model_filename:
|
||||
return
|
||||
from mmgp import offload
|
||||
if dtype == torch.bfloat16:
|
||||
model_filename = model_filename.replace("fp16", "bf16").replace("FP16", "bf16")
|
||||
elif dtype == torch.float16:
|
||||
model_filename = model_filename.replace("bf16", "fp16").replace("BF16", "bf16")
|
||||
|
||||
for rep in ["mfp16", "fp16", "mbf16", "bf16"]:
|
||||
if "_" + rep in model_filename:
|
||||
model_filename = model_filename.replace("_" + rep, "_quanto_" + rep + "_int8")
|
||||
break
|
||||
if not "quanto" in model_filename:
|
||||
pos = model_filename.rfind(".")
|
||||
model_filename = model_filename[:pos] + "_quanto_int8" + model_filename[pos+1:]
|
||||
|
||||
if not os.path.isfile(model_filename):
|
||||
offload.save_model(model, model_filename, do_quantize= True, config_file_path=config_file)
|
||||
|
||||
@ -196,7 +196,7 @@ class VaceVideoProcessor(object):
|
||||
|
||||
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
|
||||
|
||||
def _get_frameid_bbox(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0, start_frame= 0, canvas_height = 0, canvas_width = 0, fit_into_canvas= True):
|
||||
def _get_frameid_bbox(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0, start_frame= 0, canvas_height = 0, canvas_width = 0, fit_into_canvas= None):
|
||||
if self.keep_last:
|
||||
return self._get_frameid_bbox_adjust_last(fps, video_frames_count, canvas_height, canvas_width, h, w, fit_into_canvas, crop_box, rng, max_frames= max_frames, start_frame= start_frame)
|
||||
else:
|
||||
@ -208,23 +208,23 @@ class VaceVideoProcessor(object):
|
||||
def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs):
|
||||
return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs)
|
||||
|
||||
def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, max_frames= 0, trim_video =0, start_frame = 0, canvas_height = 0, canvas_width = 0, fit_into_canvas = False, **kwargs):
|
||||
def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, max_frames= 0, trim_video =0, start_frame = 0, canvas_height = 0, canvas_width = 0, fit_into_canvas = None, **kwargs):
|
||||
rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000)
|
||||
# read video
|
||||
import decord
|
||||
decord.bridge.set_bridge('torch')
|
||||
readers = []
|
||||
src_video = None
|
||||
src_videos = []
|
||||
for data_k in data_key_batch:
|
||||
if torch.is_tensor(data_k):
|
||||
src_video = data_k
|
||||
src_videos.append(data_k)
|
||||
else:
|
||||
reader = decord.VideoReader(data_k)
|
||||
readers.append(reader)
|
||||
|
||||
if src_video != None:
|
||||
if len(src_videos) >0:
|
||||
fps = 16
|
||||
length = src_video.shape[0] + start_frame
|
||||
length = src_videos[0].shape[0] + start_frame
|
||||
if len(readers) > 0:
|
||||
min_readers = min([len(r) for r in readers])
|
||||
length = min(length, min_readers )
|
||||
@ -234,17 +234,17 @@ class VaceVideoProcessor(object):
|
||||
# frame_timestamps = [readers[0].get_frame_timestamp(i) for i in range(length)]
|
||||
# frame_timestamps = np.array(frame_timestamps, dtype=np.float32)
|
||||
max_frames = min(max_frames, trim_video) if trim_video > 0 else max_frames
|
||||
if src_video != None:
|
||||
src_video = src_video[:max_frames]
|
||||
h, w = src_video.shape[1:3]
|
||||
if len(src_videos) >0:
|
||||
src_videos = [ src_video[:max_frames] for src_video in src_videos]
|
||||
h, w = src_videos[0].shape[1:3]
|
||||
else:
|
||||
h, w = readers[0].next().shape[:2]
|
||||
frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, length, h, w, crop_box, rng, canvas_height = canvas_height, canvas_width = canvas_width, fit_into_canvas = fit_into_canvas, max_frames=max_frames, start_frame = start_frame )
|
||||
|
||||
# preprocess video
|
||||
videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers]
|
||||
if src_video != None:
|
||||
videos = [src_video] + videos
|
||||
if len(src_videos) >0:
|
||||
videos = src_videos + videos
|
||||
videos = [self._video_preprocess(video, oh, ow) for video in videos]
|
||||
return *videos, frame_ids, (oh, ow), fps
|
||||
# return videos if len(videos) > 1 else videos[0]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user