Added Recammaster support

This commit is contained in:
DeepBeepMeep 2025-04-17 12:59:54 +02:00
parent 19de2b1989
commit 30f2889a35
7 changed files with 228 additions and 258 deletions

View File

@ -1,9 +1,4 @@
# Wan2.1 GP # WanGP
<p align="center">
💜 <a href=""><b>Wan</b></a> &nbsp&nbsp &nbsp&nbsp 🖥️ <a href="https://github.com/Wan-Video/Wan2.1">GitHub</a> &nbsp&nbsp | &nbsp&nbsp🤗 <a href="https://huggingface.co/Wan-AI/">Hugging Face</a>&nbsp&nbsp | &nbsp&nbsp🤖 <a href="https://modelscope.cn/organization/Wan-AI">ModelScope</a>&nbsp&nbsp | &nbsp&nbsp 📑 <a href="">Paper (Coming soon)</a> &nbsp&nbsp | &nbsp&nbsp 📑 <a href="https://wanxai.com">Blog</a> &nbsp&nbsp | &nbsp&nbsp💬 <a href="https://gw.alicdn.com/imgextra/i2/O1CN01tqjWFi1ByuyehkTSB_!!6000000000015-0-tps-611-1279.jpg">WeChat Group</a>&nbsp&nbsp | &nbsp&nbsp 📖 <a href="https://discord.gg/g7efUW9jGV">Discord</a>&nbsp&nbsp
<br>
----- -----
<p align="center"> <p align="center">
@ -15,6 +10,7 @@
## 🔥 Latest News!! ## 🔥 Latest News!!
* April 17 2025: 👋 Wan 2.1GP v4.1: Recam Master model support, view a video from a different angle. The video to process must be at least 81 frames long and you should set at least 15 steps denoising to get good results.
* April 13 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you ! * April 13 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you !
- A new UI, tabs were replaced by a Dropdown box to easily switch models - A new UI, tabs were replaced by a Dropdown box to easily switch models
- A new queuing system that lets you stack in a queue as many text2video, imag2video tasks, ... as you want. Each task can rely on complete different generation parameters (different number of frames, steps, loras, ...). Many thanks to **Tophness** for being a big contributor on this new feature - A new queuing system that lets you stack in a queue as many text2video, imag2video tasks, ... as you want. Each task can rely on complete different generation parameters (different number of frames, steps, loras, ...). Many thanks to **Tophness** for being a big contributor on this new feature
@ -254,7 +250,7 @@ Each preset, is a file with ".lset" extension stored in the loras directory and
Last but not least you can pre activate Loras corresponding and prefill a prompt (comments only or full prompt) by specifying a preset when launching the gradio server: Last but not least you can pre activate Loras corresponding and prefill a prompt (comments only or full prompt) by specifying a preset when launching the gradio server:
```bash ```bash
python wgp.py.py --lora-preset mylorapreset.lset # where 'mylorapreset.lset' is a preset stored in the 'loras' folder python wgp.py --lora-preset mylorapreset.lset # where 'mylorapreset.lset' is a preset stored in the 'loras' folder
``` ```
You will find prebuilt Loras on https://civitai.com/ or you will be able to build them with tools such as kohya or onetrainer. You will find prebuilt Loras on https://civitai.com/ or you will be able to build them with tools such as kohya or onetrainer.

View File

@ -319,7 +319,6 @@ class WanI2V:
arg_c = { arg_c = {
'context': [context[0]], 'context': [context[0]],
'clip_fea': clip_context, 'clip_fea': clip_context,
'seq_len': max_seq_len,
'y': [y], 'y': [y],
'freqs' : freqs, 'freqs' : freqs,
'pipeline' : self, 'pipeline' : self,
@ -329,7 +328,6 @@ class WanI2V:
arg_null = { arg_null = {
'context': context_null, 'context': context_null,
'clip_fea': clip_context, 'clip_fea': clip_context,
'seq_len': max_seq_len,
'y': [y], 'y': [y],
'freqs' : freqs, 'freqs' : freqs,
'pipeline' : self, 'pipeline' : self,
@ -340,7 +338,6 @@ class WanI2V:
'context': [context[0]], 'context': [context[0]],
'context2': context_null, 'context2': context_null,
'clip_fea': clip_context, 'clip_fea': clip_context,
'seq_len': max_seq_len,
'y': [y], 'y': [y],
'freqs' : freqs, 'freqs' : freqs,
'pipeline' : self, 'pipeline' : self,

View File

@ -152,91 +152,52 @@ __all__ = [
def pay_attention( def pay_attention(
qkv_list, qkv_list,
# q,
# k,
# v,
q_lens=None,
k_lens=None,
dropout_p=0., dropout_p=0.,
softmax_scale=None, softmax_scale=None,
q_scale=None,
causal=False, causal=False,
window_size=(-1, -1), window_size=(-1, -1),
deterministic=False, deterministic=False,
dtype=torch.bfloat16,
version=None, version=None,
force_attention= None, force_attention= None,
cross_attn= False cross_attn= False
): ):
"""
q: [B, Lq, Nq, C1].
k: [B, Lk, Nk, C1].
v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
q_lens: [B].
k_lens: [B].
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
causal: bool. Whether to apply causal attention mask.
window_size: (left right). If not (-1, -1), apply sliding window local attention.
deterministic: bool. If True, slightly slower and uses more memory.
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
"""
attn = offload.shared_state["_attention"] if force_attention== None else force_attention attn = offload.shared_state["_attention"] if force_attention== None else force_attention
q,k,v = qkv_list q,k,v = qkv_list
qkv_list.clear() qkv_list.clear()
half_dtypes = (torch.float16, torch.bfloat16)
assert dtype in half_dtypes
assert q.device.type == 'cuda' and q.size(-1) <= 256
# params # params
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
assert b==1
q = q.squeeze(0)
k = k.squeeze(0)
v = v.squeeze(0)
def half(x):
return x if x.dtype in half_dtypes else x.to(dtype)
# preprocess query
if q_lens is None:
q = half(q.flatten(0, 1))
q_lens = torch.tensor(
[lq] * b, dtype=torch.int32).to(
device=q.device, non_blocking=True)
else:
q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
# preprocess key, value
if k_lens is None:
k = half(k.flatten(0, 1))
v = half(v.flatten(0, 1))
k_lens = torch.tensor(
[lk] * b, dtype=torch.int32).to(
device=k.device, non_blocking=True)
else:
k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
q = q.to(v.dtype) q = q.to(v.dtype)
k = k.to(v.dtype) k = k.to(v.dtype)
if q_scale is not None: # if q_scale is not None:
q = q * q_scale # q = q * q_scale
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE: if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
warnings.warn( warnings.warn(
'Flash attention 3 is not available, use flash attention 2 instead.' 'Flash attention 3 is not available, use flash attention 2 instead.'
) )
if attn=="sage" or attn=="flash":
cu_seqlens_q = torch.tensor([0, lq], dtype=torch.int32, device="cuda")
cu_seqlens_k = torch.tensor([0, lk], dtype=torch.int32, device="cuda")
# apply attention # apply attention
if attn=="sage": if attn=="sage":
x = sageattn_varlen_wrapper( x = sageattn_varlen_wrapper(
q=q, q=q,
k=k, k=k,
v=v, v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( cu_seqlens_q= cu_seqlens_q,
0, dtype=torch.int32).to(q.device, non_blocking=True), cu_seqlens_kv= cu_seqlens_k,
cu_seqlens_kv=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
max_seqlen_q=lq, max_seqlen_q=lq,
max_seqlen_kv=lk, max_seqlen_kv=lk,
).unflatten(0, (b, lq)) ).unflatten(0, (b, lq))
@ -314,10 +275,8 @@ def pay_attention(
q=q, q=q,
k=k, k=k,
v=v, v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( cu_seqlens_q= cu_seqlens_q,
0, dtype=torch.int32).to(q.device, non_blocking=True), cu_seqlens_kv= cu_seqlens_k,
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
seqused_q=None, seqused_q=None,
seqused_k=None, seqused_k=None,
max_seqlen_q=lq, max_seqlen_q=lq,
@ -330,10 +289,8 @@ def pay_attention(
q=q, q=q,
k=k, k=k,
v=v, v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( cu_seqlens_q= [0, lq],
0, dtype=torch.int32).to(q.device, non_blocking=True), cu_seqlens_kv=[0, lk],
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
max_seqlen_q=lq, max_seqlen_q=lq,
max_seqlen_k=lk, max_seqlen_k=lk,
dropout_p=dropout_p, dropout_p=dropout_p,

View File

@ -1,6 +1,6 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import math import math
from einops import rearrange
import torch import torch
import torch.cuda.amp as amp import torch.cuda.amp as amp
import torch.nn as nn import torch.nn as nn
@ -167,11 +167,10 @@ class WanSelfAttention(nn.Module):
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
def forward(self, xlist, seq_lens, grid_sizes, freqs): def forward(self, xlist, grid_sizes, freqs):
r""" r"""
Args: Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads] x(Tensor): Shape [B, L, num_heads, C / num_heads]
seq_lens(Tensor): Shape [B]
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
""" """
@ -196,10 +195,6 @@ class WanSelfAttention(nn.Module):
del q,k,v del q,k,v
x = pay_attention( x = pay_attention(
qkv_list, qkv_list,
# q=q,
# k=k,
# v=v,
# k_lens=seq_lens,
window_size=self.window_size) window_size=self.window_size)
# output # output
x = x.flatten(2) x = x.flatten(2)
@ -209,12 +204,11 @@ class WanSelfAttention(nn.Module):
class WanT2VCrossAttention(WanSelfAttention): class WanT2VCrossAttention(WanSelfAttention):
def forward(self, xlist, context, context_lens): def forward(self, xlist, context):
r""" r"""
Args: Args:
x(Tensor): Shape [B, L1, C] x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C] context(Tensor): Shape [B, L2, C]
context_lens(Tensor): Shape [B]
""" """
x = xlist[0] x = xlist[0]
xlist.clear() xlist.clear()
@ -233,7 +227,7 @@ class WanT2VCrossAttention(WanSelfAttention):
# compute attention # compute attention
qvl_list=[q, k, v] qvl_list=[q, k, v]
del q, k, v del q, k, v
x = pay_attention(qvl_list, k_lens=context_lens, cross_attn= True) x = pay_attention(qvl_list, cross_attn= True)
# output # output
x = x.flatten(2) x = x.flatten(2)
@ -256,12 +250,11 @@ class WanI2VCrossAttention(WanSelfAttention):
# self.alpha = nn.Parameter(torch.zeros((1, ))) # self.alpha = nn.Parameter(torch.zeros((1, )))
self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
def forward(self, xlist, context, context_lens): def forward(self, xlist, context):
r""" r"""
Args: Args:
x(Tensor): Shape [B, L1, C] x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C] context(Tensor): Shape [B, L2, C]
context_lens(Tensor): Shape [B]
""" """
##### Enjoy this spagheti VRAM optimizations done by DeepBeepMeep ! ##### Enjoy this spagheti VRAM optimizations done by DeepBeepMeep !
@ -287,7 +280,7 @@ class WanI2VCrossAttention(WanSelfAttention):
qkv_list = [q, k, v] qkv_list = [q, k, v]
del k,v del k,v
x = pay_attention(qkv_list, k_lens=context_lens) x = pay_attention(qkv_list)
k_img = self.k_img(context_img) k_img = self.k_img(context_img)
self.norm_k_img(k_img) self.norm_k_img(k_img)
@ -295,7 +288,7 @@ class WanI2VCrossAttention(WanSelfAttention):
v_img = self.v_img(context_img).view(b, -1, n, d) v_img = self.v_img(context_img).view(b, -1, n, d)
qkv_list = [q, k_img, v_img] qkv_list = [q, k_img, v_img]
del q, k_img, v_img del q, k_img, v_img
img_x = pay_attention(qkv_list, k_lens=None) img_x = pay_attention(qkv_list)
# compute attention # compute attention
@ -362,30 +355,26 @@ class WanAttentionBlock(nn.Module):
self, self,
x, x,
e, e,
seq_lens,
grid_sizes, grid_sizes,
freqs, freqs,
context, context,
context_lens,
hints= None, hints= None,
context_scale=1.0, context_scale=1.0,
cam_emb= None
): ):
r""" r"""
Args: Args:
x(Tensor): Shape [B, L, C] x(Tensor): Shape [B, L, C]
e(Tensor): Shape [B, 6, C] e(Tensor): Shape [B, 6, C]
seq_lens(Tensor): Shape [B], length of each sequence in batch
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
""" """
hint = None hint = None
if self.block_id is not None and hints is not None: if self.block_id is not None and hints is not None:
kwargs = { kwargs = {
"seq_lens" : seq_lens,
"grid_sizes" : grid_sizes, "grid_sizes" : grid_sizes,
"freqs" :freqs, "freqs" :freqs,
"context" : context, "context" : context,
"context_lens" : context_lens,
"e" : e, "e" : e,
} }
if self.block_id == 0: if self.block_id == 0:
@ -399,15 +388,26 @@ class WanAttentionBlock(nn.Module):
x_mod = self.norm1(x) x_mod = self.norm1(x)
x_mod *= 1 + e[1] x_mod *= 1 + e[1]
x_mod += e[0] x_mod += e[0]
if cam_emb != None:
cam_emb = self.cam_encoder(cam_emb)
cam_emb = cam_emb.repeat(1, 2, 1)
cam_emb = cam_emb.unsqueeze(2).unsqueeze(3).repeat(1, 1, grid_sizes[0][1], grid_sizes[0][2], 1)
cam_emb = rearrange(cam_emb, 'b f h w d -> b (f h w) d')
x_mod += cam_emb
xlist = [x_mod] xlist = [x_mod]
del x_mod del x_mod
y = self.self_attn( xlist, seq_lens, grid_sizes,freqs) y = self.self_attn( xlist, grid_sizes, freqs)
if cam_emb != None:
y = self.projector(y)
# x = x + gate_msa * self.projector(self.self_attn(input_x, freqs))
x.addcmul_(y, e[2]) x.addcmul_(y, e[2])
del y del y
y = self.norm3(x) y = self.norm3(x)
ylist= [y] ylist= [y]
del y del y
x += self.cross_attn(ylist, context, context_lens) x += self.cross_attn(ylist, context)
y = self.norm2(x) y = self.norm2(x)
y *= 1 + e[4] y *= 1 + e[4]
@ -552,6 +552,7 @@ class WanModel(ModelMixin, ConfigMixin):
qk_norm=True, qk_norm=True,
cross_attn_norm=True, cross_attn_norm=True,
eps=1e-6, eps=1e-6,
recammaster = False
): ):
r""" r"""
Initialize the diffusion model backbone. Initialize the diffusion model backbone.
@ -666,6 +667,15 @@ class WanModel(ModelMixin, ConfigMixin):
self.vace_patch_embedding = nn.Conv3d( self.vace_patch_embedding = nn.Conv3d(
self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size
) )
if recammaster :
dim=self.blocks[0].self_attn.q.weight.shape[0]
for block in self.blocks:
block.cam_encoder = nn.Linear(12, dim)
block.projector = nn.Linear(dim, dim)
block.cam_encoder.weight.data.zero_()
block.cam_encoder.bias.data.zero_()
block.projector.weight = nn.Parameter(torch.eye(dim))
block.projector.bias = nn.Parameter(torch.zeros(dim))
def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0): def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
@ -716,7 +726,6 @@ class WanModel(ModelMixin, ConfigMixin):
x, x,
t, t,
context, context,
seq_len,
vace_context = None, vace_context = None,
vace_context_scale=1.0, vace_context_scale=1.0,
clip_fea=None, clip_fea=None,
@ -729,28 +738,9 @@ class WanModel(ModelMixin, ConfigMixin):
max_steps = 0, max_steps = 0,
slg_layers=None, slg_layers=None,
callback = None, callback = None,
cam_emb: torch.Tensor = None,
): ):
r"""
Forward pass through the diffusion model
Args:
x (List[Tensor]):
List of input video tensors, each with shape [C_in, F, H, W]
t (Tensor):
Diffusion timesteps tensor of shape [B]
context (List[Tensor]):
List of text embeddings each with shape [L, C]
seq_len (`int`):
Maximum sequence length for positional encoding
clip_fea (Tensor, *optional*):
CLIP image features for image-to-video mode
y (List[Tensor], *optional*):
Conditional video inputs for image-to-video mode, same shape as x
Returns:
List[Tensor]:
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
"""
if self.model_type == 'i2v': if self.model_type == 'i2v':
assert clip_fea is not None and y is not None assert clip_fea is not None and y is not None
# params # params
@ -775,15 +765,7 @@ class WanModel(ModelMixin, ConfigMixin):
x = [u.flatten(2).transpose(1, 2) for u in x] x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) x = x[0]
assert seq_lens.max() <= seq_len
if len(x)==1 and seq_len == x[0].size(1):
x = x[0]
else:
x = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in x
])
# time embeddings # time embeddings
e = self.time_embedding( e = self.time_embedding(
@ -791,7 +773,6 @@ class WanModel(ModelMixin, ConfigMixin):
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype) e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype)
# context # context
context_lens = None
context = self.text_embedding( context = self.text_embedding(
torch.stack([ torch.stack([
torch.cat( torch.cat(
@ -825,10 +806,9 @@ class WanModel(ModelMixin, ConfigMixin):
# arguments # arguments
kwargs = dict( kwargs = dict(
seq_lens=seq_lens,
grid_sizes=grid_sizes, grid_sizes=grid_sizes,
freqs=freqs, freqs=freqs,
context_lens=context_lens, cam_emb = cam_emb
) )
if vace_context == None: if vace_context == None:
@ -837,13 +817,7 @@ class WanModel(ModelMixin, ConfigMixin):
# embeddings # embeddings
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c] c = [u.flatten(2).transpose(1, 2) for u in c]
if (len(c) == 1 and seq_len == c[0].size(1)): c = c[0]
c = c[0]
else:
c = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in c
])
kwargs['context_scale'] = vace_context_scale kwargs['context_scale'] = vace_context_scale
hints_list = [ [c] for _ in range(len(x_list)) ] hints_list = [ [c] for _ in range(len(x_list)) ]

View File

@ -10,6 +10,7 @@ from contextlib import contextmanager
from functools import partial from functools import partial
from mmgp import offload from mmgp import offload
import torch import torch
import torch.nn as nn
import torch.cuda.amp as amp import torch.cuda.amp as amp
import torch.distributed as dist import torch.distributed as dist
from tqdm import tqdm from tqdm import tqdm
@ -106,6 +107,9 @@ class WanT2V:
from mmgp import offload from mmgp import offload
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False)
# offload.load_model_data(self.model, "recam.ckpt")
# self.model.cpu()
# offload.save_model(self.model, "recam.safetensors")
if self.dtype == torch.float16 and not "fp16" in model_filename: if self.dtype == torch.float16 and not "fp16" in model_filename:
self.model.to(self.dtype) self.model.to(self.dtype)
# offload.save_model(self.model, "t2v_fp16.safetensors",do_quantize=True) # offload.save_model(self.model, "t2v_fp16.safetensors",do_quantize=True)
@ -279,6 +283,8 @@ class WanT2V:
input_frames= None, input_frames= None,
input_masks = None, input_masks = None,
input_ref_images = None, input_ref_images = None,
source_video=None,
target_camera=None,
context_scale=1.0, context_scale=1.0,
size=(1280, 720), size=(1280, 720),
frame_num=81, frame_num=81,
@ -340,17 +346,18 @@ class WanT2V:
seed_g = torch.Generator(device=self.device) seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed) seed_g.manual_seed(seed)
if not self.t5_cpu: context = self.text_encoder([input_prompt], self.device)
# self.text_encoder.model.to(self.device) context_null = self.text_encoder([n_prompt], self.device)
context = self.text_encoder([input_prompt], self.device) if target_camera != None:
context_null = self.text_encoder([n_prompt], self.device) size = (source_video.shape[2], source_video.shape[1])
if offload_model: source_video = source_video.to(dtype=self.dtype , device=self.device)
self.text_encoder.model.cpu() source_video = source_video.permute(3, 0, 1, 2).div_(127.5).sub_(1.)
else: source_latents = self.vae.encode([source_video]) #.to(dtype=self.dtype, device=self.device)
context = self.text_encoder([input_prompt], torch.device('cpu')) del source_video
context_null = self.text_encoder([n_prompt], torch.device('cpu')) # Process target camera (recammaster)
context = [t.to(self.device) for t in context] from wan.utils.cammmaster_tools import get_camera_embedding
context_null = [t.to(self.device) for t in context_null] cam_emb = get_camera_embedding(target_camera)
cam_emb = cam_emb.to(dtype=self.dtype, device=self.device)
if input_frames != None: if input_frames != None:
# vace context encode # vace context encode
@ -377,22 +384,7 @@ class WanT2V:
context = [u.to(self.dtype) for u in context] context = [u.to(self.dtype) for u in context]
context_null = [u.to(self.dtype) for u in context_null] context_null = [u.to(self.dtype) for u in context_null]
noise = [ noise = [ torch.randn( *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) ]
torch.randn(
target_shape[0],
target_shape[1],
target_shape[2],
target_shape[3],
dtype=torch.float32,
device=self.device,
generator=seed_g)
]
@contextmanager
def noop_no_sync():
yield
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
# evaluation mode # evaluation mode
@ -419,11 +411,24 @@ class WanT2V:
# sample videos # sample videos
latents = noise latents = noise
del noise
batch_size =len(latents) batch_size =len(latents)
freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx) if target_camera != None:
arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self} shape = list(latents[0].shape[1:])
arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self} shape[0] *= 2
arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self} freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
else:
freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
arg_c = {'context': context, 'freqs': freqs, 'pipeline': self}
arg_null = {'context': context_null, 'freqs': freqs, 'pipeline': self}
arg_both = {'context': context, 'context2': context_null, 'freqs': freqs, 'pipeline': self}
if target_camera != None:
recam_dict = {'cam_emb': cam_emb}
arg_c.update(recam_dict)
arg_null.update(recam_dict)
arg_both.update(recam_dict)
if input_frames != None: if input_frames != None:
vace_dict = {'vace_context' : z, 'vace_context_scale' : context_scale} vace_dict = {'vace_context' : z, 'vace_context_scale' : context_scale}
arg_c.update(vace_dict) arg_c.update(vace_dict)
@ -435,7 +440,10 @@ class WanT2V:
if callback != None: if callback != None:
callback(-1, True) callback(-1, True)
for i, t in enumerate(tqdm(timesteps)): for i, t in enumerate(tqdm(timesteps)):
latent_model_input = latents if target_camera != None:
latent_model_input = [torch.cat([u,v], dim=1) for u,v in zip(latents,source_latents )]
else:
latent_model_input = latents
slg_layers_local = None slg_layers_local = None
if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps): if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps):
slg_layers_local = slg_layers slg_layers_local = slg_layers
@ -443,7 +451,6 @@ class WanT2V:
offload.set_step_no_for_lora(self.model, i) offload.set_step_no_for_lora(self.model, i)
timestep = torch.stack(timestep) timestep = torch.stack(timestep)
# self.model.to(self.device)
if joint_pass: if joint_pass:
noise_pred_cond, noise_pred_uncond = self.model( noise_pred_cond, noise_pred_uncond = self.model(
latent_model_input, t=timestep, current_step=i, slg_layers=slg_layers_local, **arg_both) latent_model_input, t=timestep, current_step=i, slg_layers=slg_layers_local, **arg_both)
@ -459,7 +466,7 @@ class WanT2V:
if self._interrupt: if self._interrupt:
return None return None
del latent_model_input # del latent_model_input
# CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/ # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
noise_pred_text = noise_pred_cond noise_pred_text = noise_pred_cond
@ -478,7 +485,7 @@ class WanT2V:
del noise_pred_uncond del noise_pred_uncond
temp_x0 = sample_scheduler.step( temp_x0 = sample_scheduler.step(
noise_pred.unsqueeze(0), noise_pred[:, :target_shape[1]].unsqueeze(0),
t, t,
latents[0].unsqueeze(0), latents[0].unsqueeze(0),
return_dict=False, return_dict=False,
@ -490,24 +497,14 @@ class WanT2V:
callback(i, False) callback(i, False)
x0 = latents x0 = latents
if offload_model:
self.model.cpu()
torch.cuda.empty_cache()
if self.rank == 0:
if input_frames == None: if input_frames == None:
videos = self.vae.decode(x0, VAE_tile_size) videos = self.vae.decode(x0, VAE_tile_size)
else: else:
videos = self.decode_latent(x0, input_ref_images, VAE_tile_size) videos = self.decode_latent(x0, input_ref_images, VAE_tile_size)
del latents
del noise, latents
del sample_scheduler del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
if dist.is_initialized():
dist.barrier()
return videos[0] if self.rank == 0 else None return videos[0] if self.rank == 0 else None

View File

@ -69,25 +69,21 @@ def remove_background(img, session=None):
def resize_and_remove_background(img_list, canvas_width, canvas_height, rm_background ): def resize_and_remove_background(img_list, budget_width, budget_height, rm_background ):
if rm_background: if rm_background:
session = new_session() session = new_session()
output_list =[] output_list =[]
for img in img_list: for img in img_list:
width, height = img.size width, height = img.size
white_canvas = np.full( (canvas_height, canvas_width, 3), 255, dtype= np.uint8 ) scale = (budget_height * budget_width / (height * width))**(1/2)
scale = min(canvas_height / height, canvas_width / width) new_height = int( round(height * scale / 16) * 16)
new_height = int(height * scale) new_width = int( round(width * scale / 16) * 16)
new_width = int(width * scale)
resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS) resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS)
if rm_background: if rm_background:
resized_image = remove(resized_image, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') resized_image = remove(resized_image, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
top = (canvas_height - new_height) // 2 output_list.append(resized_image)
left = (canvas_width - new_width) // 2
white_canvas[top:top + new_height, left:left + new_width, :] = np.array(resized_image)
img = Image.fromarray(white_canvas)
output_list.append(img)
return output_list return output_list

177
wgp.py
View File

@ -174,7 +174,25 @@ def process_prompt_and_add_tasks(state, model_choice):
sliding_window_repeat = inputs["sliding_window_repeat"] sliding_window_repeat = inputs["sliding_window_repeat"]
sliding_window = sliding_window_repeat > 0 sliding_window = sliding_window_repeat > 0
if "Vace" in model_filename:
if "recam" in model_filename:
video_source = inputs["video_source"]
if video_source == None:
gr.Info("You must provide a Source Video")
return
frames = get_resampled_video(video_source, 0, 81)
if len(frames)<81:
gr.Info("Recammaster source video should be at least 81 frames one the resampling at 16 fps has been done")
return
for single_prompt in prompts:
extra_inputs = {
"prompt" : single_prompt,
"video_source" : video_source,
}
inputs.update(extra_inputs)
add_video_task(**inputs)
elif "Vace" in model_filename:
video_prompt_type = inputs["video_prompt_type"] video_prompt_type = inputs["video_prompt_type"]
image_refs = inputs["image_refs"] image_refs = inputs["image_refs"]
video_guide = inputs["video_guide"] video_guide = inputs["video_guide"]
@ -334,14 +352,8 @@ def process_prompt_and_add_tasks(state, model_choice):
queue= gen.get("queue", []) queue= gen.get("queue", [])
return update_queue_data(queue) return update_queue_data(queue)
def add_video_task(**inputs): def get_preview_images(inputs):
global task_id inputs_to_query = ["image_start", "image_end", "video_guide", "image_refs","video_mask", "video_source"]
state = inputs["state"]
gen = get_gen_info(state)
queue = gen["queue"]
task_id += 1
current_task_id = task_id
inputs_to_query = ["image_start", "image_end", "video_guide", "image_refs","video_mask"]
start_image_data = None start_image_data = None
end_image_data = None end_image_data = None
for name in inputs_to_query: for name in inputs_to_query:
@ -353,6 +365,17 @@ def add_video_task(**inputs):
else: else:
end_image_data = image end_image_data = image
break break
return start_image_data, end_image_data
def add_video_task(**inputs):
global task_id
state = inputs["state"]
gen = get_gen_info(state)
queue = gen["queue"]
task_id += 1
current_task_id = task_id
start_image_data, end_image_data = get_preview_images(inputs)
queue.append({ queue.append({
"id": current_task_id, "id": current_task_id,
@ -434,7 +457,7 @@ def save_queue_action(state):
task_id_s = task.get('id', f"task_{task_index}") task_id_s = task.get('id', f"task_{task_index}")
image_keys = ["image_start", "image_end", "image_refs"] image_keys = ["image_start", "image_end", "image_refs"]
video_keys = ["video_guide", "video_mask"] video_keys = ["video_guide", "video_mask", "video_source"]
for key in image_keys: for key in image_keys:
images_pil = params_copy.get(key) images_pil = params_copy.get(key)
@ -595,7 +618,7 @@ def load_queue_action(filepath, state):
params['state'] = state params['state'] = state
image_keys = ["image_start", "image_end", "image_refs"] image_keys = ["image_start", "image_end", "image_refs"]
video_keys = ["video_guide", "video_mask"] video_keys = ["video_guide", "video_mask", "video_source"]
loaded_pil_images = {} loaded_pil_images = {}
loaded_video_paths = {} loaded_video_paths = {}
@ -652,20 +675,10 @@ def load_queue_action(filepath, state):
print(f"[load_queue_action] Error copying video {video_filename_in_zip} to cache: {vid_e}") print(f"[load_queue_action] Error copying video {video_filename_in_zip} to cache: {vid_e}")
params.pop(key, None) params.pop(key, None)
primary_preview_pil_list, secondary_preview_pil_list = get_preview_images(params)
primary_preview_pil_list = loaded_pil_images.get("image_start") or loaded_pil_images.get("image_refs") start_b64 = [pil_to_base64_uri(primary_preview_pil_list[0], format="jpeg", quality=70)] if primary_preview_pil_list[0] else None
secondary_preview_pil_list = loaded_pil_images.get("image_end") end_b64 = [pil_to_base64_uri(secondary_preview_pil_list[0], format="jpeg", quality=70)] if secondary_preview_pil_list[0] else None
primary_preview_pil = None
if primary_preview_pil_list:
primary_preview_pil = primary_preview_pil_list[0] if isinstance(primary_preview_pil_list, list) else primary_preview_pil_list
secondary_preview_pil = None
if secondary_preview_pil_list:
secondary_preview_pil = secondary_preview_pil_list[0] if isinstance(secondary_preview_pil_list, list) else secondary_preview_pil_list
start_b64 = [pil_to_base64_uri(primary_preview_pil, format="jpeg", quality=70)] if primary_preview_pil else None
end_b64 = [pil_to_base64_uri(secondary_preview_pil, format="jpeg", quality=70)] if secondary_preview_pil else None
top_level_start_image = params.get("image_start") or params.get("image_refs") top_level_start_image = params.get("image_start") or params.get("image_refs")
top_level_end_image = params.get("image_end") top_level_end_image = params.get("image_end")
@ -818,7 +831,7 @@ def autosave_queue():
task_id_s = task.get('id', f"task_{task_index}") task_id_s = task.get('id', f"task_{task_index}")
image_keys = ["image_start", "image_end", "image_refs"] image_keys = ["image_start", "image_end", "image_refs"]
video_keys = ["video_guide", "video_mask"] video_keys = ["video_guide", "video_mask", "video_source"]
for key in image_keys: for key in image_keys:
images_pil = params_copy.get(key) images_pil = params_copy.get(key)
@ -1352,7 +1365,7 @@ quantizeTransformer = args.quantize_transformer
check_loras = args.check_loras ==1 check_loras = args.check_loras ==1
advanced = args.advanced advanced = args.advanced
transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", "ckpts/wan2.1_Vace_1.3B_preview_bf16.safetensors"] transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", "ckpts/wan2.1_Vace_1.3B_preview_bf16.safetensors", "ckpts/wan2.1_recammaster_1.3B_bf16.safetensors"]
transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", ] transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", ]
transformer_choices = transformer_choices_t2v + transformer_choices_i2v transformer_choices = transformer_choices_t2v + transformer_choices_i2v
text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"] text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"]
@ -1391,30 +1404,16 @@ else:
server_config = json.loads(text) server_config = json.loads(text)
model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "fun_inp"] model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "fun_inp", "recam_1.3B"]
model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B", model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B",
"i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B" } "i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B", "recam_1.3B": "recammaster_1.3B" }
def get_model_type(model_filename): def get_model_type(model_filename):
if "text2video" in model_filename and "14B" in model_filename: for model_type, signature in model_signatures.items():
return "t2v" if signature in model_filename:
elif "text2video" in model_filename and "1.3B" in model_filename: return model_type
return "t2v_1.3B" raise Exception("Unknown model:" + model_filename)
elif "Fun_InP" in model_filename and "1.3B" in model_filename:
return "fun_inp_1.3B"
elif "Fun_InP" in model_filename and "14B" in model_filename:
return "fun_inp"
elif "image2video_480p" in model_filename :
return "i2v"
elif "image2video_720p" in model_filename :
return "i2v_720p"
elif "Vace" in model_filename and "1.3B" in model_filename:
return "vace_1.3B"
elif "Vace" in model_filename and "14B" in model_filename:
return "vace"
else:
raise Exception("Unknown model:" + model_filename)
def test_class_i2v(model_filename): def test_class_i2v(model_filename):
return "image2video" in model_filename or "Fun_InP" in model_filename return "image2video" in model_filename or "Fun_InP" in model_filename
@ -1862,6 +1861,9 @@ def get_model_name(model_filename):
elif "image" in model_filename: elif "image" in model_filename:
model_name = "Wan2.1 image2video" model_name = "Wan2.1 image2video"
model_name += " 720p" if "720p" in model_filename else " 480p" model_name += " 720p" if "720p" in model_filename else " 480p"
elif "recam" in model_filename:
model_name = "ReCamMaster"
model_name += " 14B" if "14B" in model_filename else " 1.3B"
else: else:
model_name = "Wan2.1 text2video" model_name = "Wan2.1 text2video"
model_name += " 14B" if "14B" in model_filename else " 1.3B" model_name += " 14B" if "14B" in model_filename else " 1.3B"
@ -2145,9 +2147,7 @@ def convert_image(image):
return cast(Image, ImageOps.exif_transpose(image)) return cast(Image, ImageOps.exif_transpose(image))
def get_resampled_video(video_in, start_frame, max_frames):
def preprocess_video(process_type, height, width, video_in, max_frames, start_frame=0):
from wan.utils.utils import resample from wan.utils.utils import resample
import decord import decord
@ -2158,15 +2158,26 @@ def preprocess_video(process_type, height, width, video_in, max_frames, start_fr
frame_nos = resample(fps, len(reader), max_target_frames_count= max_frames, target_fps=16, start_target_frame= start_frame) frame_nos = resample(fps, len(reader), max_target_frames_count= max_frames, target_fps=16, start_target_frame= start_frame)
frames_list = reader.get_batch(frame_nos) frames_list = reader.get_batch(frame_nos)
return frames_list
def preprocess_video(process_type, height, width, video_in, max_frames, start_frame=0, fit_canvas = False):
frames_list = get_resampled_video(video_in, start_frame, max_frames)
if len(frames_list) == 0: if len(frames_list) == 0:
return None return None
frame_height, frame_width, _ = frames_list[0].shape frame_height, frame_width, _ = frames_list[0].shape
scale = ((height * width ) / (frame_height * frame_width))**(1/2) if fit_canvas :
# scale = min(height / frame_height, width / frame_width) scale = min(height / frame_height, width / frame_width)
else:
scale = ((height * width ) / (frame_height * frame_width))**(1/2)
new_height = (int(frame_height * scale) // 16) * 16 new_height = (int(frame_height * scale) // 16) * 16
new_width = (int(frame_width * scale) // 16) * 16 new_width = (int(frame_width * scale) // 16) * 16
# if fit_canvas :
# new_height = height
# new_width = width
processed_frames_list = [] processed_frames_list = []
for frame in frames_list: for frame in frames_list:
@ -2188,12 +2199,17 @@ def preprocess_video(process_type, height, width, video_in, max_frames, start_fr
"PRETRAINED_MODEL": "ckpts/depth/dpt_hybrid-midas-501f0c75.pt" "PRETRAINED_MODEL": "ckpts/depth/dpt_hybrid-midas-501f0c75.pt"
} }
anno_ins = DepthVideoAnnotator(cfg_dict) anno_ins = DepthVideoAnnotator(cfg_dict)
else: elif process_type=="gray":
from preprocessing.gray import GrayVideoAnnotator from preprocessing.gray import GrayVideoAnnotator
cfg_dict = {} cfg_dict = {}
anno_ins = GrayVideoAnnotator(cfg_dict) anno_ins = GrayVideoAnnotator(cfg_dict)
else:
anno_ins = None
np_frames = anno_ins.forward(processed_frames_list) if anno_ins == None:
np_frames = [np.array(frame) for frame in processed_frames_list]
else:
np_frames = anno_ins.forward(processed_frames_list)
# from preprocessing.dwpose.pose import save_one_video # from preprocessing.dwpose.pose import save_one_video
# save_one_video("test.mp4", np_frames, fps=8, quality=8, macro_block_size=None) # save_one_video("test.mp4", np_frames, fps=8, quality=8, macro_block_size=None)
@ -2281,6 +2297,8 @@ def generate_video(
image_refs, image_refs,
video_guide, video_guide,
video_mask, video_mask,
camera_type,
video_source,
keep_frames, keep_frames,
sliding_window_repeat, sliding_window_repeat,
sliding_window_overlap, sliding_window_overlap,
@ -2439,8 +2457,11 @@ def generate_video(
trans.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404] trans.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
else: else:
raise gr.Error("Teacache not supported for this model") raise gr.Error("Teacache not supported for this model")
source_video = None
target_camera = None
if "recam" in model_filename:
source_video = preprocess_video("", width=width, height=height,video_in=video_source, max_frames= video_length, start_frame = 0, fit_canvas= True)
target_camera = camera_type
import random import random
if seed == None or seed <0: if seed == None or seed <0:
seed = random.randint(0, 999999999) seed = random.randint(0, 999999999)
@ -2532,7 +2553,6 @@ def generate_video(
prompts_max = gen["prompts_max"] prompts_max = gen["prompts_max"]
status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation, sliding_window) status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation, sliding_window)
yield status
gen["progress_status"] = status gen["progress_status"] = status
gen["progress_phase"] = (" - Encoding Prompt", -1 ) gen["progress_phase"] = (" - Encoding Prompt", -1 )
@ -2582,6 +2602,8 @@ def generate_video(
input_frames = src_video, input_frames = src_video,
input_ref_images= src_ref_images, input_ref_images= src_ref_images,
input_masks = src_mask, input_masks = src_mask,
source_video= source_video,
target_camera= target_camera,
frame_num=(video_length // 4)* 4 + 1, frame_num=(video_length // 4)* 4 + 1,
size=(width, height), size=(width, height),
shift=flow_shift, shift=flow_shift,
@ -2755,6 +2777,7 @@ def generate_video(
state['update_gallery'] = True state['update_gallery'] = True
if not sliding_window: if not sliding_window:
seed += 1 seed += 1
yield status
if temp_filename!= None and os.path.isfile(temp_filename): if temp_filename!= None and os.path.isfile(temp_filename):
os.remove(temp_filename) os.remove(temp_filename)
@ -2858,7 +2881,7 @@ def update_status(state):
gen = get_gen_info(state) gen = get_gen_info(state)
prompt_no = gen["prompt_no"] prompt_no = gen["prompt_no"]
prompts_max = gen.get("prompts_max",0) prompts_max = gen.get("prompts_max",0)
total_generation = gen["total_generation"] total_generation = gen.get("total_generation", 1)
repeat_no = gen["repeat_no"] repeat_no = gen["repeat_no"]
sliding_window = gen["sliding_window"] sliding_window = gen["sliding_window"]
status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation, sliding_window) status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation, sliding_window)
@ -3186,7 +3209,7 @@ def prepare_inputs_dict(target, inputs ):
if target == "state": if target == "state":
return inputs return inputs
unsaved_params = ["image_start", "image_end", "image_refs", "video_guide", "video_mask"] unsaved_params = ["image_start", "image_end", "image_refs", "video_guide", "video_source", "video_mask"]
for k in unsaved_params: for k in unsaved_params:
inputs.pop(k) inputs.pop(k)
@ -3200,6 +3223,9 @@ def prepare_inputs_dict(target, inputs ):
inputs.pop("image_prompt_type") inputs.pop("image_prompt_type")
if not "recam" in model_filename:
inputs.pop("camera_type")
if not "Vace" in model_filename: if not "Vace" in model_filename:
unsaved_params = ["video_prompt_type", "keep_frames", "remove_background_image_ref", "sliding_window_repeat", "sliding_window_overlap", "sliding_window_discard_last_frames"] unsaved_params = ["video_prompt_type", "keep_frames", "remove_background_image_ref", "sliding_window_repeat", "sliding_window_overlap", "sliding_window_discard_last_frames"]
for k in unsaved_params: for k in unsaved_params:
@ -3243,6 +3269,8 @@ def save_inputs(
image_refs, image_refs,
video_guide, video_guide,
video_mask, video_mask,
camera_type,
video_source,
keep_frames, keep_frames,
sliding_window_repeat, sliding_window_repeat,
sliding_window_overlap, sliding_window_overlap,
@ -3551,6 +3579,26 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
else: else:
image_end = gr.Image(label= "Last Image for a new video", type ="pil", visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None)) image_end = gr.Image(label= "Last Image for a new video", type ="pil", visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None))
with gr.Column(visible= "recam" in model_filename ) as recam_column:
camera_type = gr.Dropdown(
choices=[
("Pan Right", 1),
("Pan Left", 2),
("Tilt Up", 3),
("Tilt Down", 4),
("Zoom In", 5),
("Zoom Out", 6),
("Translate Up (with rotation)", 7),
("Translate Down (with rotation)", 8),
("Arc Left (with rotation)", 9),
("Arc Right (with rotation)", 10),
],
value=ui_defaults.get("camera_type", 1),
label="Camera Movement Type", scale = 3
)
video_source = gr.Video(label= "Video Source", value= ui_defaults.get("video_source", None),)
with gr.Column(visible= "Vace" in model_filename ) as video_prompt_column: with gr.Column(visible= "Vace" in model_filename ) as video_prompt_column:
video_prompt_type_value= ui_defaults.get("video_prompt_type","") video_prompt_type_value= ui_defaults.get("video_prompt_type","")
video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False) video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False)
@ -3657,7 +3705,11 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
label="Resolution" label="Resolution"
) )
with gr.Row(): with gr.Row():
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)") if "recam" in model_filename:
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s), locked", interactive= False)
else:
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True)
num_inference_steps = gr.Slider(1, 100, value=ui_defaults.get("num_inference_steps",30), step=1, label="Number of Inference Steps") num_inference_steps = gr.Slider(1, 100, value=ui_defaults.get("num_inference_steps",30), step=1, label="Number of Inference Steps")
@ -3794,7 +3846,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
sliding_window_overlap = gr.Slider(1, 32, value=ui_defaults.get("sliding_window_overlap",16), step=1, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)") sliding_window_overlap = gr.Slider(1, 32, value=ui_defaults.get("sliding_window_overlap",16), step=1, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
sliding_window_discard_last_frames = gr.Slider(1, 10, value=ui_defaults.get("sliding_window_discard_last_frames", 4), step=1, label="Discard Last Frames of a Window (that may have bad quality)") sliding_window_discard_last_frames = gr.Slider(1, 10, value=ui_defaults.get("sliding_window_discard_last_frames", 4), step=1, label="Discard Last Frames of a Window (that may have bad quality)")
with gr.Tab("Miscellaneous"): with gr.Tab("Miscellaneous", visible= not "recam" in model_filename):
gr.Markdown("<B>With Riflex you can generate videos longer than 5s which is the default duration of videos used to train the model</B>") gr.Markdown("<B>With Riflex you can generate videos longer than 5s which is the default duration of videos used to train the model</B>")
RIFLEx_setting = gr.Dropdown( RIFLEx_setting = gr.Dropdown(
choices=[ choices=[
@ -4007,7 +4059,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
) )
extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column, extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column,
prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, advanced_row, sliding_window_tab] # show_advanced presets_column, prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, advanced_row, sliding_window_tab,
video_prompt_type_video_guide, video_prompt_type_image_refs, recam_column] # show_advanced presets_column,
if update_form: if update_form:
locals_dict = locals() locals_dict = locals()
gen_inputs = [state_dict if k=="state" else locals_dict[k] for k in inputs_names] + [state_dict] + extra_inputs gen_inputs = [state_dict if k=="state" else locals_dict[k] for k in inputs_names] + [state_dict] + extra_inputs
@ -4615,7 +4668,7 @@ def create_demo():
theme = gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md") theme = gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md")
with gr.Blocks(css=css, theme=theme, title= "Wan2GP") as demo: with gr.Blocks(css=css, theme=theme, title= "Wan2GP") as demo:
gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.0 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>") gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.1 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
global model_list global model_list
tab_state = gr.State({ "tab_no":0 }) tab_state = gr.State({ "tab_no":0 })