Added Recammaster support

This commit is contained in:
DeepBeepMeep 2025-04-17 12:59:54 +02:00
parent 19de2b1989
commit 30f2889a35
7 changed files with 228 additions and 258 deletions

View File

@ -1,9 +1,4 @@
# Wan2.1 GP
<p align="center">
💜 <a href=""><b>Wan</b></a> &nbsp&nbsp &nbsp&nbsp 🖥️ <a href="https://github.com/Wan-Video/Wan2.1">GitHub</a> &nbsp&nbsp | &nbsp&nbsp🤗 <a href="https://huggingface.co/Wan-AI/">Hugging Face</a>&nbsp&nbsp | &nbsp&nbsp🤖 <a href="https://modelscope.cn/organization/Wan-AI">ModelScope</a>&nbsp&nbsp | &nbsp&nbsp 📑 <a href="">Paper (Coming soon)</a> &nbsp&nbsp | &nbsp&nbsp 📑 <a href="https://wanxai.com">Blog</a> &nbsp&nbsp | &nbsp&nbsp💬 <a href="https://gw.alicdn.com/imgextra/i2/O1CN01tqjWFi1ByuyehkTSB_!!6000000000015-0-tps-611-1279.jpg">WeChat Group</a>&nbsp&nbsp | &nbsp&nbsp 📖 <a href="https://discord.gg/g7efUW9jGV">Discord</a>&nbsp&nbsp
<br>
# WanGP
-----
<p align="center">
@ -15,6 +10,7 @@
## 🔥 Latest News!!
* April 17 2025: 👋 Wan 2.1GP v4.1: Recam Master model support, view a video from a different angle. The video to process must be at least 81 frames long and you should set at least 15 steps denoising to get good results.
* April 13 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you !
- A new UI, tabs were replaced by a Dropdown box to easily switch models
- A new queuing system that lets you stack in a queue as many text2video, imag2video tasks, ... as you want. Each task can rely on complete different generation parameters (different number of frames, steps, loras, ...). Many thanks to **Tophness** for being a big contributor on this new feature
@ -254,7 +250,7 @@ Each preset, is a file with ".lset" extension stored in the loras directory and
Last but not least you can pre activate Loras corresponding and prefill a prompt (comments only or full prompt) by specifying a preset when launching the gradio server:
```bash
python wgp.py.py --lora-preset mylorapreset.lset # where 'mylorapreset.lset' is a preset stored in the 'loras' folder
python wgp.py --lora-preset mylorapreset.lset # where 'mylorapreset.lset' is a preset stored in the 'loras' folder
```
You will find prebuilt Loras on https://civitai.com/ or you will be able to build them with tools such as kohya or onetrainer.

View File

@ -319,7 +319,6 @@ class WanI2V:
arg_c = {
'context': [context[0]],
'clip_fea': clip_context,
'seq_len': max_seq_len,
'y': [y],
'freqs' : freqs,
'pipeline' : self,
@ -329,7 +328,6 @@ class WanI2V:
arg_null = {
'context': context_null,
'clip_fea': clip_context,
'seq_len': max_seq_len,
'y': [y],
'freqs' : freqs,
'pipeline' : self,
@ -340,7 +338,6 @@ class WanI2V:
'context': [context[0]],
'context2': context_null,
'clip_fea': clip_context,
'seq_len': max_seq_len,
'y': [y],
'freqs' : freqs,
'pipeline' : self,

View File

@ -152,91 +152,52 @@ __all__ = [
def pay_attention(
qkv_list,
# q,
# k,
# v,
q_lens=None,
k_lens=None,
dropout_p=0.,
softmax_scale=None,
q_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
dtype=torch.bfloat16,
version=None,
force_attention= None,
cross_attn= False
):
"""
q: [B, Lq, Nq, C1].
k: [B, Lk, Nk, C1].
v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
q_lens: [B].
k_lens: [B].
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
causal: bool. Whether to apply causal attention mask.
window_size: (left right). If not (-1, -1), apply sliding window local attention.
deterministic: bool. If True, slightly slower and uses more memory.
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
"""
attn = offload.shared_state["_attention"] if force_attention== None else force_attention
q,k,v = qkv_list
qkv_list.clear()
half_dtypes = (torch.float16, torch.bfloat16)
assert dtype in half_dtypes
assert q.device.type == 'cuda' and q.size(-1) <= 256
# params
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
assert b==1
q = q.squeeze(0)
k = k.squeeze(0)
v = v.squeeze(0)
def half(x):
return x if x.dtype in half_dtypes else x.to(dtype)
# preprocess query
if q_lens is None:
q = half(q.flatten(0, 1))
q_lens = torch.tensor(
[lq] * b, dtype=torch.int32).to(
device=q.device, non_blocking=True)
else:
q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
# preprocess key, value
if k_lens is None:
k = half(k.flatten(0, 1))
v = half(v.flatten(0, 1))
k_lens = torch.tensor(
[lk] * b, dtype=torch.int32).to(
device=k.device, non_blocking=True)
else:
k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
q = q.to(v.dtype)
k = k.to(v.dtype)
if q_scale is not None:
q = q * q_scale
# if q_scale is not None:
# q = q * q_scale
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
warnings.warn(
'Flash attention 3 is not available, use flash attention 2 instead.'
)
if attn=="sage" or attn=="flash":
cu_seqlens_q = torch.tensor([0, lq], dtype=torch.int32, device="cuda")
cu_seqlens_k = torch.tensor([0, lk], dtype=torch.int32, device="cuda")
# apply attention
if attn=="sage":
x = sageattn_varlen_wrapper(
q=q,
k=k,
v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
cu_seqlens_kv=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
cu_seqlens_q= cu_seqlens_q,
cu_seqlens_kv= cu_seqlens_k,
max_seqlen_q=lq,
max_seqlen_kv=lk,
).unflatten(0, (b, lq))
@ -314,10 +275,8 @@ def pay_attention(
q=q,
k=k,
v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
cu_seqlens_q= cu_seqlens_q,
cu_seqlens_kv= cu_seqlens_k,
seqused_q=None,
seqused_k=None,
max_seqlen_q=lq,
@ -330,10 +289,8 @@ def pay_attention(
q=q,
k=k,
v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
cu_seqlens_q= [0, lq],
cu_seqlens_kv=[0, lk],
max_seqlen_q=lq,
max_seqlen_k=lk,
dropout_p=dropout_p,

View File

@ -1,6 +1,6 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import math
from einops import rearrange
import torch
import torch.cuda.amp as amp
import torch.nn as nn
@ -167,11 +167,10 @@ class WanSelfAttention(nn.Module):
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
def forward(self, xlist, seq_lens, grid_sizes, freqs):
def forward(self, xlist, grid_sizes, freqs):
r"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
seq_lens(Tensor): Shape [B]
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
@ -196,10 +195,6 @@ class WanSelfAttention(nn.Module):
del q,k,v
x = pay_attention(
qkv_list,
# q=q,
# k=k,
# v=v,
# k_lens=seq_lens,
window_size=self.window_size)
# output
x = x.flatten(2)
@ -209,12 +204,11 @@ class WanSelfAttention(nn.Module):
class WanT2VCrossAttention(WanSelfAttention):
def forward(self, xlist, context, context_lens):
def forward(self, xlist, context):
r"""
Args:
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
context_lens(Tensor): Shape [B]
"""
x = xlist[0]
xlist.clear()
@ -233,7 +227,7 @@ class WanT2VCrossAttention(WanSelfAttention):
# compute attention
qvl_list=[q, k, v]
del q, k, v
x = pay_attention(qvl_list, k_lens=context_lens, cross_attn= True)
x = pay_attention(qvl_list, cross_attn= True)
# output
x = x.flatten(2)
@ -256,12 +250,11 @@ class WanI2VCrossAttention(WanSelfAttention):
# self.alpha = nn.Parameter(torch.zeros((1, )))
self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
def forward(self, xlist, context, context_lens):
def forward(self, xlist, context):
r"""
Args:
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
context_lens(Tensor): Shape [B]
"""
##### Enjoy this spagheti VRAM optimizations done by DeepBeepMeep !
@ -287,7 +280,7 @@ class WanI2VCrossAttention(WanSelfAttention):
qkv_list = [q, k, v]
del k,v
x = pay_attention(qkv_list, k_lens=context_lens)
x = pay_attention(qkv_list)
k_img = self.k_img(context_img)
self.norm_k_img(k_img)
@ -295,7 +288,7 @@ class WanI2VCrossAttention(WanSelfAttention):
v_img = self.v_img(context_img).view(b, -1, n, d)
qkv_list = [q, k_img, v_img]
del q, k_img, v_img
img_x = pay_attention(qkv_list, k_lens=None)
img_x = pay_attention(qkv_list)
# compute attention
@ -362,30 +355,26 @@ class WanAttentionBlock(nn.Module):
self,
x,
e,
seq_lens,
grid_sizes,
freqs,
context,
context_lens,
hints= None,
context_scale=1.0,
cam_emb= None
):
r"""
Args:
x(Tensor): Shape [B, L, C]
e(Tensor): Shape [B, 6, C]
seq_lens(Tensor): Shape [B], length of each sequence in batch
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
hint = None
if self.block_id is not None and hints is not None:
kwargs = {
"seq_lens" : seq_lens,
"grid_sizes" : grid_sizes,
"freqs" :freqs,
"context" : context,
"context_lens" : context_lens,
"e" : e,
}
if self.block_id == 0:
@ -394,20 +383,31 @@ class WanAttentionBlock(nn.Module):
hint = self.vace(hints, None, **kwargs)
e = (self.modulation + e).chunk(6, dim=1)
# self-attention
x_mod = self.norm1(x)
x_mod *= 1 + e[1]
x_mod += e[0]
if cam_emb != None:
cam_emb = self.cam_encoder(cam_emb)
cam_emb = cam_emb.repeat(1, 2, 1)
cam_emb = cam_emb.unsqueeze(2).unsqueeze(3).repeat(1, 1, grid_sizes[0][1], grid_sizes[0][2], 1)
cam_emb = rearrange(cam_emb, 'b f h w d -> b (f h w) d')
x_mod += cam_emb
xlist = [x_mod]
del x_mod
y = self.self_attn( xlist, seq_lens, grid_sizes,freqs)
y = self.self_attn( xlist, grid_sizes, freqs)
if cam_emb != None:
y = self.projector(y)
# x = x + gate_msa * self.projector(self.self_attn(input_x, freqs))
x.addcmul_(y, e[2])
del y
y = self.norm3(x)
ylist= [y]
del y
x += self.cross_attn(ylist, context, context_lens)
x += self.cross_attn(ylist, context)
y = self.norm2(x)
y *= 1 + e[4]
@ -552,6 +552,7 @@ class WanModel(ModelMixin, ConfigMixin):
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
recammaster = False
):
r"""
Initialize the diffusion model backbone.
@ -666,6 +667,15 @@ class WanModel(ModelMixin, ConfigMixin):
self.vace_patch_embedding = nn.Conv3d(
self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size
)
if recammaster :
dim=self.blocks[0].self_attn.q.weight.shape[0]
for block in self.blocks:
block.cam_encoder = nn.Linear(12, dim)
block.projector = nn.Linear(dim, dim)
block.cam_encoder.weight.data.zero_()
block.cam_encoder.bias.data.zero_()
block.projector.weight = nn.Parameter(torch.eye(dim))
block.projector.bias = nn.Parameter(torch.zeros(dim))
def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
@ -716,7 +726,6 @@ class WanModel(ModelMixin, ConfigMixin):
x,
t,
context,
seq_len,
vace_context = None,
vace_context_scale=1.0,
clip_fea=None,
@ -729,28 +738,9 @@ class WanModel(ModelMixin, ConfigMixin):
max_steps = 0,
slg_layers=None,
callback = None,
cam_emb: torch.Tensor = None,
):
r"""
Forward pass through the diffusion model
Args:
x (List[Tensor]):
List of input video tensors, each with shape [C_in, F, H, W]
t (Tensor):
Diffusion timesteps tensor of shape [B]
context (List[Tensor]):
List of text embeddings each with shape [L, C]
seq_len (`int`):
Maximum sequence length for positional encoding
clip_fea (Tensor, *optional*):
CLIP image features for image-to-video mode
y (List[Tensor], *optional*):
Conditional video inputs for image-to-video mode, same shape as x
Returns:
List[Tensor]:
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
"""
if self.model_type == 'i2v':
assert clip_fea is not None and y is not None
# params
@ -775,15 +765,7 @@ class WanModel(ModelMixin, ConfigMixin):
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
assert seq_lens.max() <= seq_len
if len(x)==1 and seq_len == x[0].size(1):
x = x[0]
else:
x = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in x
])
x = x[0]
# time embeddings
e = self.time_embedding(
@ -791,7 +773,6 @@ class WanModel(ModelMixin, ConfigMixin):
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype)
# context
context_lens = None
context = self.text_embedding(
torch.stack([
torch.cat(
@ -825,10 +806,9 @@ class WanModel(ModelMixin, ConfigMixin):
# arguments
kwargs = dict(
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=freqs,
context_lens=context_lens,
cam_emb = cam_emb
)
if vace_context == None:
@ -837,13 +817,7 @@ class WanModel(ModelMixin, ConfigMixin):
# embeddings
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c]
if (len(c) == 1 and seq_len == c[0].size(1)):
c = c[0]
else:
c = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in c
])
c = c[0]
kwargs['context_scale'] = vace_context_scale
hints_list = [ [c] for _ in range(len(x_list)) ]

View File

@ -10,6 +10,7 @@ from contextlib import contextmanager
from functools import partial
from mmgp import offload
import torch
import torch.nn as nn
import torch.cuda.amp as amp
import torch.distributed as dist
from tqdm import tqdm
@ -106,6 +107,9 @@ class WanT2V:
from mmgp import offload
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False)
# offload.load_model_data(self.model, "recam.ckpt")
# self.model.cpu()
# offload.save_model(self.model, "recam.safetensors")
if self.dtype == torch.float16 and not "fp16" in model_filename:
self.model.to(self.dtype)
# offload.save_model(self.model, "t2v_fp16.safetensors",do_quantize=True)
@ -278,7 +282,9 @@ class WanT2V:
input_prompt,
input_frames= None,
input_masks = None,
input_ref_images = None,
input_ref_images = None,
source_video=None,
target_camera=None,
context_scale=1.0,
size=(1280, 720),
frame_num=81,
@ -340,18 +346,19 @@ class WanT2V:
seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed)
if not self.t5_cpu:
# self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
if offload_model:
self.text_encoder.model.cpu()
else:
context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null]
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
if target_camera != None:
size = (source_video.shape[2], source_video.shape[1])
source_video = source_video.to(dtype=self.dtype , device=self.device)
source_video = source_video.permute(3, 0, 1, 2).div_(127.5).sub_(1.)
source_latents = self.vae.encode([source_video]) #.to(dtype=self.dtype, device=self.device)
del source_video
# Process target camera (recammaster)
from wan.utils.cammmaster_tools import get_camera_embedding
cam_emb = get_camera_embedding(target_camera)
cam_emb = cam_emb.to(dtype=self.dtype, device=self.device)
if input_frames != None:
# vace context encode
input_frames = [u.to(self.device) for u in input_frames]
@ -377,22 +384,7 @@ class WanT2V:
context = [u.to(self.dtype) for u in context]
context_null = [u.to(self.dtype) for u in context_null]
noise = [
torch.randn(
target_shape[0],
target_shape[1],
target_shape[2],
target_shape[3],
dtype=torch.float32,
device=self.device,
generator=seed_g)
]
@contextmanager
def noop_no_sync():
yield
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
noise = [ torch.randn( *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) ]
# evaluation mode
@ -419,11 +411,24 @@ class WanT2V:
# sample videos
latents = noise
del noise
batch_size =len(latents)
freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
if target_camera != None:
shape = list(latents[0].shape[1:])
shape[0] *= 2
freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
else:
freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
arg_c = {'context': context, 'freqs': freqs, 'pipeline': self}
arg_null = {'context': context_null, 'freqs': freqs, 'pipeline': self}
arg_both = {'context': context, 'context2': context_null, 'freqs': freqs, 'pipeline': self}
if target_camera != None:
recam_dict = {'cam_emb': cam_emb}
arg_c.update(recam_dict)
arg_null.update(recam_dict)
arg_both.update(recam_dict)
if input_frames != None:
vace_dict = {'vace_context' : z, 'vace_context_scale' : context_scale}
arg_c.update(vace_dict)
@ -435,7 +440,10 @@ class WanT2V:
if callback != None:
callback(-1, True)
for i, t in enumerate(tqdm(timesteps)):
latent_model_input = latents
if target_camera != None:
latent_model_input = [torch.cat([u,v], dim=1) for u,v in zip(latents,source_latents )]
else:
latent_model_input = latents
slg_layers_local = None
if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps):
slg_layers_local = slg_layers
@ -443,7 +451,6 @@ class WanT2V:
offload.set_step_no_for_lora(self.model, i)
timestep = torch.stack(timestep)
# self.model.to(self.device)
if joint_pass:
noise_pred_cond, noise_pred_uncond = self.model(
latent_model_input, t=timestep, current_step=i, slg_layers=slg_layers_local, **arg_both)
@ -459,7 +466,7 @@ class WanT2V:
if self._interrupt:
return None
del latent_model_input
# del latent_model_input
# CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
noise_pred_text = noise_pred_cond
@ -478,7 +485,7 @@ class WanT2V:
del noise_pred_uncond
temp_x0 = sample_scheduler.step(
noise_pred.unsqueeze(0),
noise_pred[:, :target_shape[1]].unsqueeze(0),
t,
latents[0].unsqueeze(0),
return_dict=False,
@ -490,24 +497,14 @@ class WanT2V:
callback(i, False)
x0 = latents
if offload_model:
self.model.cpu()
torch.cuda.empty_cache()
if self.rank == 0:
if input_frames == None:
videos = self.vae.decode(x0, VAE_tile_size)
else:
videos = self.decode_latent(x0, input_ref_images, VAE_tile_size)
if input_frames == None:
videos = self.vae.decode(x0, VAE_tile_size)
else:
videos = self.decode_latent(x0, input_ref_images, VAE_tile_size)
del noise, latents
del latents
del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
if dist.is_initialized():
dist.barrier()
return videos[0] if self.rank == 0 else None

View File

@ -69,25 +69,21 @@ def remove_background(img, session=None):
def resize_and_remove_background(img_list, canvas_width, canvas_height, rm_background ):
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background ):
if rm_background:
session = new_session()
output_list =[]
for img in img_list:
width, height = img.size
white_canvas = np.full( (canvas_height, canvas_width, 3), 255, dtype= np.uint8 )
scale = min(canvas_height / height, canvas_width / width)
new_height = int(height * scale)
new_width = int(width * scale)
scale = (budget_height * budget_width / (height * width))**(1/2)
new_height = int( round(height * scale / 16) * 16)
new_width = int( round(width * scale / 16) * 16)
resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS)
if rm_background:
resized_image = remove(resized_image, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
top = (canvas_height - new_height) // 2
left = (canvas_width - new_width) // 2
white_canvas[top:top + new_height, left:left + new_width, :] = np.array(resized_image)
img = Image.fromarray(white_canvas)
output_list.append(img)
output_list.append(resized_image)
return output_list

183
wgp.py
View File

@ -174,7 +174,25 @@ def process_prompt_and_add_tasks(state, model_choice):
sliding_window_repeat = inputs["sliding_window_repeat"]
sliding_window = sliding_window_repeat > 0
if "Vace" in model_filename:
if "recam" in model_filename:
video_source = inputs["video_source"]
if video_source == None:
gr.Info("You must provide a Source Video")
return
frames = get_resampled_video(video_source, 0, 81)
if len(frames)<81:
gr.Info("Recammaster source video should be at least 81 frames one the resampling at 16 fps has been done")
return
for single_prompt in prompts:
extra_inputs = {
"prompt" : single_prompt,
"video_source" : video_source,
}
inputs.update(extra_inputs)
add_video_task(**inputs)
elif "Vace" in model_filename:
video_prompt_type = inputs["video_prompt_type"]
image_refs = inputs["image_refs"]
video_guide = inputs["video_guide"]
@ -334,14 +352,8 @@ def process_prompt_and_add_tasks(state, model_choice):
queue= gen.get("queue", [])
return update_queue_data(queue)
def add_video_task(**inputs):
global task_id
state = inputs["state"]
gen = get_gen_info(state)
queue = gen["queue"]
task_id += 1
current_task_id = task_id
inputs_to_query = ["image_start", "image_end", "video_guide", "image_refs","video_mask"]
def get_preview_images(inputs):
inputs_to_query = ["image_start", "image_end", "video_guide", "image_refs","video_mask", "video_source"]
start_image_data = None
end_image_data = None
for name in inputs_to_query:
@ -353,6 +365,17 @@ def add_video_task(**inputs):
else:
end_image_data = image
break
return start_image_data, end_image_data
def add_video_task(**inputs):
global task_id
state = inputs["state"]
gen = get_gen_info(state)
queue = gen["queue"]
task_id += 1
current_task_id = task_id
start_image_data, end_image_data = get_preview_images(inputs)
queue.append({
"id": current_task_id,
@ -434,7 +457,7 @@ def save_queue_action(state):
task_id_s = task.get('id', f"task_{task_index}")
image_keys = ["image_start", "image_end", "image_refs"]
video_keys = ["video_guide", "video_mask"]
video_keys = ["video_guide", "video_mask", "video_source"]
for key in image_keys:
images_pil = params_copy.get(key)
@ -595,7 +618,7 @@ def load_queue_action(filepath, state):
params['state'] = state
image_keys = ["image_start", "image_end", "image_refs"]
video_keys = ["video_guide", "video_mask"]
video_keys = ["video_guide", "video_mask", "video_source"]
loaded_pil_images = {}
loaded_video_paths = {}
@ -652,20 +675,10 @@ def load_queue_action(filepath, state):
print(f"[load_queue_action] Error copying video {video_filename_in_zip} to cache: {vid_e}")
params.pop(key, None)
primary_preview_pil_list, secondary_preview_pil_list = get_preview_images(params)
primary_preview_pil_list = loaded_pil_images.get("image_start") or loaded_pil_images.get("image_refs")
secondary_preview_pil_list = loaded_pil_images.get("image_end")
primary_preview_pil = None
if primary_preview_pil_list:
primary_preview_pil = primary_preview_pil_list[0] if isinstance(primary_preview_pil_list, list) else primary_preview_pil_list
secondary_preview_pil = None
if secondary_preview_pil_list:
secondary_preview_pil = secondary_preview_pil_list[0] if isinstance(secondary_preview_pil_list, list) else secondary_preview_pil_list
start_b64 = [pil_to_base64_uri(primary_preview_pil, format="jpeg", quality=70)] if primary_preview_pil else None
end_b64 = [pil_to_base64_uri(secondary_preview_pil, format="jpeg", quality=70)] if secondary_preview_pil else None
start_b64 = [pil_to_base64_uri(primary_preview_pil_list[0], format="jpeg", quality=70)] if primary_preview_pil_list[0] else None
end_b64 = [pil_to_base64_uri(secondary_preview_pil_list[0], format="jpeg", quality=70)] if secondary_preview_pil_list[0] else None
top_level_start_image = params.get("image_start") or params.get("image_refs")
top_level_end_image = params.get("image_end")
@ -818,7 +831,7 @@ def autosave_queue():
task_id_s = task.get('id', f"task_{task_index}")
image_keys = ["image_start", "image_end", "image_refs"]
video_keys = ["video_guide", "video_mask"]
video_keys = ["video_guide", "video_mask", "video_source"]
for key in image_keys:
images_pil = params_copy.get(key)
@ -1352,13 +1365,13 @@ quantizeTransformer = args.quantize_transformer
check_loras = args.check_loras ==1
advanced = args.advanced
transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", "ckpts/wan2.1_Vace_1.3B_preview_bf16.safetensors"]
transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", "ckpts/wan2.1_Vace_1.3B_preview_bf16.safetensors", "ckpts/wan2.1_recammaster_1.3B_bf16.safetensors"]
transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", ]
transformer_choices = transformer_choices_t2v + transformer_choices_i2v
text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"]
server_config_filename = "wgp_config.json"
if not os.path.isdir("settings"):
os.mkdir("settings")
os.mkdir("settings")
if os.path.isfile("t2v_settings.json"):
for f in glob.glob(os.path.join(".", "*_settings.json*")):
target_file = os.path.join("settings", Path(f).parts[-1] )
@ -1391,30 +1404,16 @@ else:
server_config = json.loads(text)
model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "fun_inp"]
model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "fun_inp", "recam_1.3B"]
model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B",
"i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B" }
"i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B", "recam_1.3B": "recammaster_1.3B" }
def get_model_type(model_filename):
if "text2video" in model_filename and "14B" in model_filename:
return "t2v"
elif "text2video" in model_filename and "1.3B" in model_filename:
return "t2v_1.3B"
elif "Fun_InP" in model_filename and "1.3B" in model_filename:
return "fun_inp_1.3B"
elif "Fun_InP" in model_filename and "14B" in model_filename:
return "fun_inp"
elif "image2video_480p" in model_filename :
return "i2v"
elif "image2video_720p" in model_filename :
return "i2v_720p"
elif "Vace" in model_filename and "1.3B" in model_filename:
return "vace_1.3B"
elif "Vace" in model_filename and "14B" in model_filename:
return "vace"
else:
raise Exception("Unknown model:" + model_filename)
for model_type, signature in model_signatures.items():
if signature in model_filename:
return model_type
raise Exception("Unknown model:" + model_filename)
def test_class_i2v(model_filename):
return "image2video" in model_filename or "Fun_InP" in model_filename
@ -1862,6 +1861,9 @@ def get_model_name(model_filename):
elif "image" in model_filename:
model_name = "Wan2.1 image2video"
model_name += " 720p" if "720p" in model_filename else " 480p"
elif "recam" in model_filename:
model_name = "ReCamMaster"
model_name += " 14B" if "14B" in model_filename else " 1.3B"
else:
model_name = "Wan2.1 text2video"
model_name += " 14B" if "14B" in model_filename else " 1.3B"
@ -2145,9 +2147,7 @@ def convert_image(image):
return cast(Image, ImageOps.exif_transpose(image))
def preprocess_video(process_type, height, width, video_in, max_frames, start_frame=0):
def get_resampled_video(video_in, start_frame, max_frames):
from wan.utils.utils import resample
import decord
@ -2158,15 +2158,26 @@ def preprocess_video(process_type, height, width, video_in, max_frames, start_fr
frame_nos = resample(fps, len(reader), max_target_frames_count= max_frames, target_fps=16, start_target_frame= start_frame)
frames_list = reader.get_batch(frame_nos)
return frames_list
def preprocess_video(process_type, height, width, video_in, max_frames, start_frame=0, fit_canvas = False):
frames_list = get_resampled_video(video_in, start_frame, max_frames)
if len(frames_list) == 0:
return None
frame_height, frame_width, _ = frames_list[0].shape
scale = ((height * width ) / (frame_height * frame_width))**(1/2)
# scale = min(height / frame_height, width / frame_width)
if fit_canvas :
scale = min(height / frame_height, width / frame_width)
else:
scale = ((height * width ) / (frame_height * frame_width))**(1/2)
new_height = (int(frame_height * scale) // 16) * 16
new_width = (int(frame_width * scale) // 16) * 16
# if fit_canvas :
# new_height = height
# new_width = width
processed_frames_list = []
for frame in frames_list:
@ -2188,12 +2199,17 @@ def preprocess_video(process_type, height, width, video_in, max_frames, start_fr
"PRETRAINED_MODEL": "ckpts/depth/dpt_hybrid-midas-501f0c75.pt"
}
anno_ins = DepthVideoAnnotator(cfg_dict)
else:
elif process_type=="gray":
from preprocessing.gray import GrayVideoAnnotator
cfg_dict = {}
anno_ins = GrayVideoAnnotator(cfg_dict)
np_frames = anno_ins.forward(processed_frames_list)
else:
anno_ins = None
if anno_ins == None:
np_frames = [np.array(frame) for frame in processed_frames_list]
else:
np_frames = anno_ins.forward(processed_frames_list)
# from preprocessing.dwpose.pose import save_one_video
# save_one_video("test.mp4", np_frames, fps=8, quality=8, macro_block_size=None)
@ -2281,6 +2297,8 @@ def generate_video(
image_refs,
video_guide,
video_mask,
camera_type,
video_source,
keep_frames,
sliding_window_repeat,
sliding_window_overlap,
@ -2439,8 +2457,11 @@ def generate_video(
trans.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
else:
raise gr.Error("Teacache not supported for this model")
source_video = None
target_camera = None
if "recam" in model_filename:
source_video = preprocess_video("", width=width, height=height,video_in=video_source, max_frames= video_length, start_frame = 0, fit_canvas= True)
target_camera = camera_type
import random
if seed == None or seed <0:
seed = random.randint(0, 999999999)
@ -2532,7 +2553,6 @@ def generate_video(
prompts_max = gen["prompts_max"]
status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation, sliding_window)
yield status
gen["progress_status"] = status
gen["progress_phase"] = (" - Encoding Prompt", -1 )
@ -2582,6 +2602,8 @@ def generate_video(
input_frames = src_video,
input_ref_images= src_ref_images,
input_masks = src_mask,
source_video= source_video,
target_camera= target_camera,
frame_num=(video_length // 4)* 4 + 1,
size=(width, height),
shift=flow_shift,
@ -2755,7 +2777,8 @@ def generate_video(
state['update_gallery'] = True
if not sliding_window:
seed += 1
yield status
if temp_filename!= None and os.path.isfile(temp_filename):
os.remove(temp_filename)
offload.unload_loras_from_model(trans)
@ -2858,7 +2881,7 @@ def update_status(state):
gen = get_gen_info(state)
prompt_no = gen["prompt_no"]
prompts_max = gen.get("prompts_max",0)
total_generation = gen["total_generation"]
total_generation = gen.get("total_generation", 1)
repeat_no = gen["repeat_no"]
sliding_window = gen["sliding_window"]
status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation, sliding_window)
@ -3186,7 +3209,7 @@ def prepare_inputs_dict(target, inputs ):
if target == "state":
return inputs
unsaved_params = ["image_start", "image_end", "image_refs", "video_guide", "video_mask"]
unsaved_params = ["image_start", "image_end", "image_refs", "video_guide", "video_source", "video_mask"]
for k in unsaved_params:
inputs.pop(k)
@ -3200,6 +3223,9 @@ def prepare_inputs_dict(target, inputs ):
inputs.pop("image_prompt_type")
if not "recam" in model_filename:
inputs.pop("camera_type")
if not "Vace" in model_filename:
unsaved_params = ["video_prompt_type", "keep_frames", "remove_background_image_ref", "sliding_window_repeat", "sliding_window_overlap", "sliding_window_discard_last_frames"]
for k in unsaved_params:
@ -3243,6 +3269,8 @@ def save_inputs(
image_refs,
video_guide,
video_mask,
camera_type,
video_source,
keep_frames,
sliding_window_repeat,
sliding_window_overlap,
@ -3551,6 +3579,26 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
else:
image_end = gr.Image(label= "Last Image for a new video", type ="pil", visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None))
with gr.Column(visible= "recam" in model_filename ) as recam_column:
camera_type = gr.Dropdown(
choices=[
("Pan Right", 1),
("Pan Left", 2),
("Tilt Up", 3),
("Tilt Down", 4),
("Zoom In", 5),
("Zoom Out", 6),
("Translate Up (with rotation)", 7),
("Translate Down (with rotation)", 8),
("Arc Left (with rotation)", 9),
("Arc Right (with rotation)", 10),
],
value=ui_defaults.get("camera_type", 1),
label="Camera Movement Type", scale = 3
)
video_source = gr.Video(label= "Video Source", value= ui_defaults.get("video_source", None),)
with gr.Column(visible= "Vace" in model_filename ) as video_prompt_column:
video_prompt_type_value= ui_defaults.get("video_prompt_type","")
video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False)
@ -3657,7 +3705,11 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
label="Resolution"
)
with gr.Row():
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)")
if "recam" in model_filename:
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s), locked", interactive= False)
else:
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True)
num_inference_steps = gr.Slider(1, 100, value=ui_defaults.get("num_inference_steps",30), step=1, label="Number of Inference Steps")
@ -3794,7 +3846,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
sliding_window_overlap = gr.Slider(1, 32, value=ui_defaults.get("sliding_window_overlap",16), step=1, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
sliding_window_discard_last_frames = gr.Slider(1, 10, value=ui_defaults.get("sliding_window_discard_last_frames", 4), step=1, label="Discard Last Frames of a Window (that may have bad quality)")
with gr.Tab("Miscellaneous"):
with gr.Tab("Miscellaneous", visible= not "recam" in model_filename):
gr.Markdown("<B>With Riflex you can generate videos longer than 5s which is the default duration of videos used to train the model</B>")
RIFLEx_setting = gr.Dropdown(
choices=[
@ -4007,7 +4059,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
)
extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column,
prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, advanced_row, sliding_window_tab] # show_advanced presets_column,
prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, advanced_row, sliding_window_tab,
video_prompt_type_video_guide, video_prompt_type_image_refs, recam_column] # show_advanced presets_column,
if update_form:
locals_dict = locals()
gen_inputs = [state_dict if k=="state" else locals_dict[k] for k in inputs_names] + [state_dict] + extra_inputs
@ -4615,7 +4668,7 @@ def create_demo():
theme = gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md")
with gr.Blocks(css=css, theme=theme, title= "Wan2GP") as demo:
gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.0 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.1 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
global model_list
tab_state = gr.State({ "tab_no":0 })