mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-12-16 12:13:27 +00:00
Added Recammaster support
This commit is contained in:
parent
19de2b1989
commit
30f2889a35
10
README.md
10
README.md
@ -1,9 +1,4 @@
|
||||
# Wan2.1 GP
|
||||
|
||||
|
||||
<p align="center">
|
||||
💜 <a href=""><b>Wan</b></a>    |    🖥️ <a href="https://github.com/Wan-Video/Wan2.1">GitHub</a>    |   🤗 <a href="https://huggingface.co/Wan-AI/">Hugging Face</a>   |   🤖 <a href="https://modelscope.cn/organization/Wan-AI">ModelScope</a>   |    📑 <a href="">Paper (Coming soon)</a>    |    📑 <a href="https://wanxai.com">Blog</a>    |   💬 <a href="https://gw.alicdn.com/imgextra/i2/O1CN01tqjWFi1ByuyehkTSB_!!6000000000015-0-tps-611-1279.jpg">WeChat Group</a>   |    📖 <a href="https://discord.gg/g7efUW9jGV">Discord</a>  
|
||||
<br>
|
||||
# WanGP
|
||||
|
||||
-----
|
||||
<p align="center">
|
||||
@ -15,6 +10,7 @@
|
||||
|
||||
|
||||
## 🔥 Latest News!!
|
||||
* April 17 2025: 👋 Wan 2.1GP v4.1: Recam Master model support, view a video from a different angle. The video to process must be at least 81 frames long and you should set at least 15 steps denoising to get good results.
|
||||
* April 13 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you !
|
||||
- A new UI, tabs were replaced by a Dropdown box to easily switch models
|
||||
- A new queuing system that lets you stack in a queue as many text2video, imag2video tasks, ... as you want. Each task can rely on complete different generation parameters (different number of frames, steps, loras, ...). Many thanks to **Tophness** for being a big contributor on this new feature
|
||||
@ -254,7 +250,7 @@ Each preset, is a file with ".lset" extension stored in the loras directory and
|
||||
|
||||
Last but not least you can pre activate Loras corresponding and prefill a prompt (comments only or full prompt) by specifying a preset when launching the gradio server:
|
||||
```bash
|
||||
python wgp.py.py --lora-preset mylorapreset.lset # where 'mylorapreset.lset' is a preset stored in the 'loras' folder
|
||||
python wgp.py --lora-preset mylorapreset.lset # where 'mylorapreset.lset' is a preset stored in the 'loras' folder
|
||||
```
|
||||
|
||||
You will find prebuilt Loras on https://civitai.com/ or you will be able to build them with tools such as kohya or onetrainer.
|
||||
|
||||
@ -319,7 +319,6 @@ class WanI2V:
|
||||
arg_c = {
|
||||
'context': [context[0]],
|
||||
'clip_fea': clip_context,
|
||||
'seq_len': max_seq_len,
|
||||
'y': [y],
|
||||
'freqs' : freqs,
|
||||
'pipeline' : self,
|
||||
@ -329,7 +328,6 @@ class WanI2V:
|
||||
arg_null = {
|
||||
'context': context_null,
|
||||
'clip_fea': clip_context,
|
||||
'seq_len': max_seq_len,
|
||||
'y': [y],
|
||||
'freqs' : freqs,
|
||||
'pipeline' : self,
|
||||
@ -340,7 +338,6 @@ class WanI2V:
|
||||
'context': [context[0]],
|
||||
'context2': context_null,
|
||||
'clip_fea': clip_context,
|
||||
'seq_len': max_seq_len,
|
||||
'y': [y],
|
||||
'freqs' : freqs,
|
||||
'pipeline' : self,
|
||||
|
||||
@ -152,91 +152,52 @@ __all__ = [
|
||||
|
||||
def pay_attention(
|
||||
qkv_list,
|
||||
# q,
|
||||
# k,
|
||||
# v,
|
||||
q_lens=None,
|
||||
k_lens=None,
|
||||
dropout_p=0.,
|
||||
softmax_scale=None,
|
||||
q_scale=None,
|
||||
causal=False,
|
||||
window_size=(-1, -1),
|
||||
deterministic=False,
|
||||
dtype=torch.bfloat16,
|
||||
version=None,
|
||||
force_attention= None,
|
||||
cross_attn= False
|
||||
):
|
||||
"""
|
||||
q: [B, Lq, Nq, C1].
|
||||
k: [B, Lk, Nk, C1].
|
||||
v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
|
||||
q_lens: [B].
|
||||
k_lens: [B].
|
||||
dropout_p: float. Dropout probability.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
causal: bool. Whether to apply causal attention mask.
|
||||
window_size: (left right). If not (-1, -1), apply sliding window local attention.
|
||||
deterministic: bool. If True, slightly slower and uses more memory.
|
||||
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
|
||||
"""
|
||||
|
||||
attn = offload.shared_state["_attention"] if force_attention== None else force_attention
|
||||
q,k,v = qkv_list
|
||||
qkv_list.clear()
|
||||
|
||||
half_dtypes = (torch.float16, torch.bfloat16)
|
||||
assert dtype in half_dtypes
|
||||
assert q.device.type == 'cuda' and q.size(-1) <= 256
|
||||
|
||||
# params
|
||||
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
|
||||
assert b==1
|
||||
q = q.squeeze(0)
|
||||
k = k.squeeze(0)
|
||||
v = v.squeeze(0)
|
||||
|
||||
def half(x):
|
||||
return x if x.dtype in half_dtypes else x.to(dtype)
|
||||
|
||||
# preprocess query
|
||||
if q_lens is None:
|
||||
q = half(q.flatten(0, 1))
|
||||
q_lens = torch.tensor(
|
||||
[lq] * b, dtype=torch.int32).to(
|
||||
device=q.device, non_blocking=True)
|
||||
else:
|
||||
q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
|
||||
|
||||
# preprocess key, value
|
||||
if k_lens is None:
|
||||
k = half(k.flatten(0, 1))
|
||||
v = half(v.flatten(0, 1))
|
||||
k_lens = torch.tensor(
|
||||
[lk] * b, dtype=torch.int32).to(
|
||||
device=k.device, non_blocking=True)
|
||||
else:
|
||||
k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
|
||||
v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
|
||||
|
||||
q = q.to(v.dtype)
|
||||
k = k.to(v.dtype)
|
||||
|
||||
if q_scale is not None:
|
||||
q = q * q_scale
|
||||
# if q_scale is not None:
|
||||
# q = q * q_scale
|
||||
|
||||
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
|
||||
warnings.warn(
|
||||
'Flash attention 3 is not available, use flash attention 2 instead.'
|
||||
)
|
||||
|
||||
if attn=="sage" or attn=="flash":
|
||||
cu_seqlens_q = torch.tensor([0, lq], dtype=torch.int32, device="cuda")
|
||||
cu_seqlens_k = torch.tensor([0, lk], dtype=torch.int32, device="cuda")
|
||||
|
||||
# apply attention
|
||||
if attn=="sage":
|
||||
x = sageattn_varlen_wrapper(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
|
||||
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
||||
cu_seqlens_kv=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
|
||||
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
||||
cu_seqlens_q= cu_seqlens_q,
|
||||
cu_seqlens_kv= cu_seqlens_k,
|
||||
max_seqlen_q=lq,
|
||||
max_seqlen_kv=lk,
|
||||
).unflatten(0, (b, lq))
|
||||
@ -314,10 +275,8 @@ def pay_attention(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
|
||||
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
||||
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
|
||||
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
||||
cu_seqlens_q= cu_seqlens_q,
|
||||
cu_seqlens_kv= cu_seqlens_k,
|
||||
seqused_q=None,
|
||||
seqused_k=None,
|
||||
max_seqlen_q=lq,
|
||||
@ -330,10 +289,8 @@ def pay_attention(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
|
||||
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
||||
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
|
||||
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
||||
cu_seqlens_q= [0, lq],
|
||||
cu_seqlens_kv=[0, lk],
|
||||
max_seqlen_q=lq,
|
||||
max_seqlen_k=lk,
|
||||
dropout_p=dropout_p,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import math
|
||||
|
||||
from einops import rearrange
|
||||
import torch
|
||||
import torch.cuda.amp as amp
|
||||
import torch.nn as nn
|
||||
@ -167,11 +167,10 @@ class WanSelfAttention(nn.Module):
|
||||
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
||||
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, xlist, seq_lens, grid_sizes, freqs):
|
||||
def forward(self, xlist, grid_sizes, freqs):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
||||
seq_lens(Tensor): Shape [B]
|
||||
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
||||
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
||||
"""
|
||||
@ -196,10 +195,6 @@ class WanSelfAttention(nn.Module):
|
||||
del q,k,v
|
||||
x = pay_attention(
|
||||
qkv_list,
|
||||
# q=q,
|
||||
# k=k,
|
||||
# v=v,
|
||||
# k_lens=seq_lens,
|
||||
window_size=self.window_size)
|
||||
# output
|
||||
x = x.flatten(2)
|
||||
@ -209,12 +204,11 @@ class WanSelfAttention(nn.Module):
|
||||
|
||||
class WanT2VCrossAttention(WanSelfAttention):
|
||||
|
||||
def forward(self, xlist, context, context_lens):
|
||||
def forward(self, xlist, context):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L1, C]
|
||||
context(Tensor): Shape [B, L2, C]
|
||||
context_lens(Tensor): Shape [B]
|
||||
"""
|
||||
x = xlist[0]
|
||||
xlist.clear()
|
||||
@ -233,7 +227,7 @@ class WanT2VCrossAttention(WanSelfAttention):
|
||||
# compute attention
|
||||
qvl_list=[q, k, v]
|
||||
del q, k, v
|
||||
x = pay_attention(qvl_list, k_lens=context_lens, cross_attn= True)
|
||||
x = pay_attention(qvl_list, cross_attn= True)
|
||||
|
||||
# output
|
||||
x = x.flatten(2)
|
||||
@ -256,12 +250,11 @@ class WanI2VCrossAttention(WanSelfAttention):
|
||||
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
||||
self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, xlist, context, context_lens):
|
||||
def forward(self, xlist, context):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L1, C]
|
||||
context(Tensor): Shape [B, L2, C]
|
||||
context_lens(Tensor): Shape [B]
|
||||
"""
|
||||
|
||||
##### Enjoy this spagheti VRAM optimizations done by DeepBeepMeep !
|
||||
@ -287,7 +280,7 @@ class WanI2VCrossAttention(WanSelfAttention):
|
||||
|
||||
qkv_list = [q, k, v]
|
||||
del k,v
|
||||
x = pay_attention(qkv_list, k_lens=context_lens)
|
||||
x = pay_attention(qkv_list)
|
||||
|
||||
k_img = self.k_img(context_img)
|
||||
self.norm_k_img(k_img)
|
||||
@ -295,7 +288,7 @@ class WanI2VCrossAttention(WanSelfAttention):
|
||||
v_img = self.v_img(context_img).view(b, -1, n, d)
|
||||
qkv_list = [q, k_img, v_img]
|
||||
del q, k_img, v_img
|
||||
img_x = pay_attention(qkv_list, k_lens=None)
|
||||
img_x = pay_attention(qkv_list)
|
||||
# compute attention
|
||||
|
||||
|
||||
@ -362,30 +355,26 @@ class WanAttentionBlock(nn.Module):
|
||||
self,
|
||||
x,
|
||||
e,
|
||||
seq_lens,
|
||||
grid_sizes,
|
||||
freqs,
|
||||
context,
|
||||
context_lens,
|
||||
hints= None,
|
||||
context_scale=1.0,
|
||||
cam_emb= None
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L, C]
|
||||
e(Tensor): Shape [B, 6, C]
|
||||
seq_lens(Tensor): Shape [B], length of each sequence in batch
|
||||
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
||||
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
||||
"""
|
||||
hint = None
|
||||
if self.block_id is not None and hints is not None:
|
||||
kwargs = {
|
||||
"seq_lens" : seq_lens,
|
||||
"grid_sizes" : grid_sizes,
|
||||
"freqs" :freqs,
|
||||
"context" : context,
|
||||
"context_lens" : context_lens,
|
||||
"e" : e,
|
||||
}
|
||||
if self.block_id == 0:
|
||||
@ -394,20 +383,31 @@ class WanAttentionBlock(nn.Module):
|
||||
hint = self.vace(hints, None, **kwargs)
|
||||
|
||||
e = (self.modulation + e).chunk(6, dim=1)
|
||||
|
||||
|
||||
# self-attention
|
||||
x_mod = self.norm1(x)
|
||||
x_mod *= 1 + e[1]
|
||||
x_mod += e[0]
|
||||
if cam_emb != None:
|
||||
cam_emb = self.cam_encoder(cam_emb)
|
||||
cam_emb = cam_emb.repeat(1, 2, 1)
|
||||
cam_emb = cam_emb.unsqueeze(2).unsqueeze(3).repeat(1, 1, grid_sizes[0][1], grid_sizes[0][2], 1)
|
||||
cam_emb = rearrange(cam_emb, 'b f h w d -> b (f h w) d')
|
||||
x_mod += cam_emb
|
||||
|
||||
xlist = [x_mod]
|
||||
del x_mod
|
||||
y = self.self_attn( xlist, seq_lens, grid_sizes,freqs)
|
||||
y = self.self_attn( xlist, grid_sizes, freqs)
|
||||
if cam_emb != None:
|
||||
y = self.projector(y)
|
||||
# x = x + gate_msa * self.projector(self.self_attn(input_x, freqs))
|
||||
|
||||
x.addcmul_(y, e[2])
|
||||
del y
|
||||
y = self.norm3(x)
|
||||
ylist= [y]
|
||||
del y
|
||||
x += self.cross_attn(ylist, context, context_lens)
|
||||
x += self.cross_attn(ylist, context)
|
||||
y = self.norm2(x)
|
||||
|
||||
y *= 1 + e[4]
|
||||
@ -552,6 +552,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
qk_norm=True,
|
||||
cross_attn_norm=True,
|
||||
eps=1e-6,
|
||||
recammaster = False
|
||||
):
|
||||
r"""
|
||||
Initialize the diffusion model backbone.
|
||||
@ -666,6 +667,15 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
self.vace_patch_embedding = nn.Conv3d(
|
||||
self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size
|
||||
)
|
||||
if recammaster :
|
||||
dim=self.blocks[0].self_attn.q.weight.shape[0]
|
||||
for block in self.blocks:
|
||||
block.cam_encoder = nn.Linear(12, dim)
|
||||
block.projector = nn.Linear(dim, dim)
|
||||
block.cam_encoder.weight.data.zero_()
|
||||
block.cam_encoder.bias.data.zero_()
|
||||
block.projector.weight = nn.Parameter(torch.eye(dim))
|
||||
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
||||
|
||||
|
||||
def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
|
||||
@ -716,7 +726,6 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
x,
|
||||
t,
|
||||
context,
|
||||
seq_len,
|
||||
vace_context = None,
|
||||
vace_context_scale=1.0,
|
||||
clip_fea=None,
|
||||
@ -729,28 +738,9 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
max_steps = 0,
|
||||
slg_layers=None,
|
||||
callback = None,
|
||||
cam_emb: torch.Tensor = None,
|
||||
):
|
||||
r"""
|
||||
Forward pass through the diffusion model
|
||||
|
||||
Args:
|
||||
x (List[Tensor]):
|
||||
List of input video tensors, each with shape [C_in, F, H, W]
|
||||
t (Tensor):
|
||||
Diffusion timesteps tensor of shape [B]
|
||||
context (List[Tensor]):
|
||||
List of text embeddings each with shape [L, C]
|
||||
seq_len (`int`):
|
||||
Maximum sequence length for positional encoding
|
||||
clip_fea (Tensor, *optional*):
|
||||
CLIP image features for image-to-video mode
|
||||
y (List[Tensor], *optional*):
|
||||
Conditional video inputs for image-to-video mode, same shape as x
|
||||
|
||||
Returns:
|
||||
List[Tensor]:
|
||||
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
|
||||
"""
|
||||
if self.model_type == 'i2v':
|
||||
assert clip_fea is not None and y is not None
|
||||
# params
|
||||
@ -775,15 +765,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
|
||||
|
||||
x = [u.flatten(2).transpose(1, 2) for u in x]
|
||||
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
||||
assert seq_lens.max() <= seq_len
|
||||
if len(x)==1 and seq_len == x[0].size(1):
|
||||
x = x[0]
|
||||
else:
|
||||
x = torch.cat([
|
||||
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
||||
dim=1) for u in x
|
||||
])
|
||||
x = x[0]
|
||||
|
||||
# time embeddings
|
||||
e = self.time_embedding(
|
||||
@ -791,7 +773,6 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype)
|
||||
|
||||
# context
|
||||
context_lens = None
|
||||
context = self.text_embedding(
|
||||
torch.stack([
|
||||
torch.cat(
|
||||
@ -825,10 +806,9 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
# arguments
|
||||
|
||||
kwargs = dict(
|
||||
seq_lens=seq_lens,
|
||||
grid_sizes=grid_sizes,
|
||||
freqs=freqs,
|
||||
context_lens=context_lens,
|
||||
cam_emb = cam_emb
|
||||
)
|
||||
|
||||
if vace_context == None:
|
||||
@ -837,13 +817,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
# embeddings
|
||||
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
||||
c = [u.flatten(2).transpose(1, 2) for u in c]
|
||||
if (len(c) == 1 and seq_len == c[0].size(1)):
|
||||
c = c[0]
|
||||
else:
|
||||
c = torch.cat([
|
||||
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
||||
dim=1) for u in c
|
||||
])
|
||||
c = c[0]
|
||||
|
||||
kwargs['context_scale'] = vace_context_scale
|
||||
hints_list = [ [c] for _ in range(len(x_list)) ]
|
||||
|
||||
@ -10,6 +10,7 @@ from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from mmgp import offload
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.cuda.amp as amp
|
||||
import torch.distributed as dist
|
||||
from tqdm import tqdm
|
||||
@ -106,6 +107,9 @@ class WanT2V:
|
||||
from mmgp import offload
|
||||
|
||||
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False)
|
||||
# offload.load_model_data(self.model, "recam.ckpt")
|
||||
# self.model.cpu()
|
||||
# offload.save_model(self.model, "recam.safetensors")
|
||||
if self.dtype == torch.float16 and not "fp16" in model_filename:
|
||||
self.model.to(self.dtype)
|
||||
# offload.save_model(self.model, "t2v_fp16.safetensors",do_quantize=True)
|
||||
@ -278,7 +282,9 @@ class WanT2V:
|
||||
input_prompt,
|
||||
input_frames= None,
|
||||
input_masks = None,
|
||||
input_ref_images = None,
|
||||
input_ref_images = None,
|
||||
source_video=None,
|
||||
target_camera=None,
|
||||
context_scale=1.0,
|
||||
size=(1280, 720),
|
||||
frame_num=81,
|
||||
@ -340,18 +346,19 @@ class WanT2V:
|
||||
seed_g = torch.Generator(device=self.device)
|
||||
seed_g.manual_seed(seed)
|
||||
|
||||
if not self.t5_cpu:
|
||||
# self.text_encoder.model.to(self.device)
|
||||
context = self.text_encoder([input_prompt], self.device)
|
||||
context_null = self.text_encoder([n_prompt], self.device)
|
||||
if offload_model:
|
||||
self.text_encoder.model.cpu()
|
||||
else:
|
||||
context = self.text_encoder([input_prompt], torch.device('cpu'))
|
||||
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
|
||||
context = [t.to(self.device) for t in context]
|
||||
context_null = [t.to(self.device) for t in context_null]
|
||||
|
||||
context = self.text_encoder([input_prompt], self.device)
|
||||
context_null = self.text_encoder([n_prompt], self.device)
|
||||
if target_camera != None:
|
||||
size = (source_video.shape[2], source_video.shape[1])
|
||||
source_video = source_video.to(dtype=self.dtype , device=self.device)
|
||||
source_video = source_video.permute(3, 0, 1, 2).div_(127.5).sub_(1.)
|
||||
source_latents = self.vae.encode([source_video]) #.to(dtype=self.dtype, device=self.device)
|
||||
del source_video
|
||||
# Process target camera (recammaster)
|
||||
from wan.utils.cammmaster_tools import get_camera_embedding
|
||||
cam_emb = get_camera_embedding(target_camera)
|
||||
cam_emb = cam_emb.to(dtype=self.dtype, device=self.device)
|
||||
|
||||
if input_frames != None:
|
||||
# vace context encode
|
||||
input_frames = [u.to(self.device) for u in input_frames]
|
||||
@ -377,22 +384,7 @@ class WanT2V:
|
||||
context = [u.to(self.dtype) for u in context]
|
||||
context_null = [u.to(self.dtype) for u in context_null]
|
||||
|
||||
noise = [
|
||||
torch.randn(
|
||||
target_shape[0],
|
||||
target_shape[1],
|
||||
target_shape[2],
|
||||
target_shape[3],
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
generator=seed_g)
|
||||
]
|
||||
|
||||
@contextmanager
|
||||
def noop_no_sync():
|
||||
yield
|
||||
|
||||
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
|
||||
noise = [ torch.randn( *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) ]
|
||||
|
||||
# evaluation mode
|
||||
|
||||
@ -419,11 +411,24 @@ class WanT2V:
|
||||
|
||||
# sample videos
|
||||
latents = noise
|
||||
del noise
|
||||
batch_size =len(latents)
|
||||
freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
|
||||
arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
||||
arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
||||
arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
||||
if target_camera != None:
|
||||
shape = list(latents[0].shape[1:])
|
||||
shape[0] *= 2
|
||||
freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
|
||||
else:
|
||||
freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
|
||||
arg_c = {'context': context, 'freqs': freqs, 'pipeline': self}
|
||||
arg_null = {'context': context_null, 'freqs': freqs, 'pipeline': self}
|
||||
arg_both = {'context': context, 'context2': context_null, 'freqs': freqs, 'pipeline': self}
|
||||
|
||||
if target_camera != None:
|
||||
recam_dict = {'cam_emb': cam_emb}
|
||||
arg_c.update(recam_dict)
|
||||
arg_null.update(recam_dict)
|
||||
arg_both.update(recam_dict)
|
||||
|
||||
if input_frames != None:
|
||||
vace_dict = {'vace_context' : z, 'vace_context_scale' : context_scale}
|
||||
arg_c.update(vace_dict)
|
||||
@ -435,7 +440,10 @@ class WanT2V:
|
||||
if callback != None:
|
||||
callback(-1, True)
|
||||
for i, t in enumerate(tqdm(timesteps)):
|
||||
latent_model_input = latents
|
||||
if target_camera != None:
|
||||
latent_model_input = [torch.cat([u,v], dim=1) for u,v in zip(latents,source_latents )]
|
||||
else:
|
||||
latent_model_input = latents
|
||||
slg_layers_local = None
|
||||
if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps):
|
||||
slg_layers_local = slg_layers
|
||||
@ -443,7 +451,6 @@ class WanT2V:
|
||||
offload.set_step_no_for_lora(self.model, i)
|
||||
timestep = torch.stack(timestep)
|
||||
|
||||
# self.model.to(self.device)
|
||||
if joint_pass:
|
||||
noise_pred_cond, noise_pred_uncond = self.model(
|
||||
latent_model_input, t=timestep, current_step=i, slg_layers=slg_layers_local, **arg_both)
|
||||
@ -459,7 +466,7 @@ class WanT2V:
|
||||
if self._interrupt:
|
||||
return None
|
||||
|
||||
del latent_model_input
|
||||
# del latent_model_input
|
||||
|
||||
# CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
|
||||
noise_pred_text = noise_pred_cond
|
||||
@ -478,7 +485,7 @@ class WanT2V:
|
||||
del noise_pred_uncond
|
||||
|
||||
temp_x0 = sample_scheduler.step(
|
||||
noise_pred.unsqueeze(0),
|
||||
noise_pred[:, :target_shape[1]].unsqueeze(0),
|
||||
t,
|
||||
latents[0].unsqueeze(0),
|
||||
return_dict=False,
|
||||
@ -490,24 +497,14 @@ class WanT2V:
|
||||
callback(i, False)
|
||||
|
||||
x0 = latents
|
||||
if offload_model:
|
||||
self.model.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
if self.rank == 0:
|
||||
|
||||
if input_frames == None:
|
||||
videos = self.vae.decode(x0, VAE_tile_size)
|
||||
else:
|
||||
videos = self.decode_latent(x0, input_ref_images, VAE_tile_size)
|
||||
if input_frames == None:
|
||||
videos = self.vae.decode(x0, VAE_tile_size)
|
||||
else:
|
||||
videos = self.decode_latent(x0, input_ref_images, VAE_tile_size)
|
||||
|
||||
|
||||
del noise, latents
|
||||
del latents
|
||||
del sample_scheduler
|
||||
if offload_model:
|
||||
gc.collect()
|
||||
torch.cuda.synchronize()
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
return videos[0] if self.rank == 0 else None
|
||||
|
||||
|
||||
@ -69,25 +69,21 @@ def remove_background(img, session=None):
|
||||
|
||||
|
||||
|
||||
def resize_and_remove_background(img_list, canvas_width, canvas_height, rm_background ):
|
||||
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background ):
|
||||
if rm_background:
|
||||
session = new_session()
|
||||
|
||||
output_list =[]
|
||||
for img in img_list:
|
||||
width, height = img.size
|
||||
white_canvas = np.full( (canvas_height, canvas_width, 3), 255, dtype= np.uint8 )
|
||||
scale = min(canvas_height / height, canvas_width / width)
|
||||
new_height = int(height * scale)
|
||||
new_width = int(width * scale)
|
||||
scale = (budget_height * budget_width / (height * width))**(1/2)
|
||||
new_height = int( round(height * scale / 16) * 16)
|
||||
new_width = int( round(width * scale / 16) * 16)
|
||||
|
||||
resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS)
|
||||
if rm_background:
|
||||
resized_image = remove(resized_image, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
|
||||
top = (canvas_height - new_height) // 2
|
||||
left = (canvas_width - new_width) // 2
|
||||
white_canvas[top:top + new_height, left:left + new_width, :] = np.array(resized_image)
|
||||
img = Image.fromarray(white_canvas)
|
||||
output_list.append(img)
|
||||
output_list.append(resized_image)
|
||||
return output_list
|
||||
|
||||
|
||||
|
||||
183
wgp.py
183
wgp.py
@ -174,7 +174,25 @@ def process_prompt_and_add_tasks(state, model_choice):
|
||||
|
||||
sliding_window_repeat = inputs["sliding_window_repeat"]
|
||||
sliding_window = sliding_window_repeat > 0
|
||||
if "Vace" in model_filename:
|
||||
|
||||
if "recam" in model_filename:
|
||||
video_source = inputs["video_source"]
|
||||
if video_source == None:
|
||||
gr.Info("You must provide a Source Video")
|
||||
return
|
||||
frames = get_resampled_video(video_source, 0, 81)
|
||||
if len(frames)<81:
|
||||
gr.Info("Recammaster source video should be at least 81 frames one the resampling at 16 fps has been done")
|
||||
return
|
||||
for single_prompt in prompts:
|
||||
extra_inputs = {
|
||||
"prompt" : single_prompt,
|
||||
"video_source" : video_source,
|
||||
}
|
||||
inputs.update(extra_inputs)
|
||||
add_video_task(**inputs)
|
||||
|
||||
elif "Vace" in model_filename:
|
||||
video_prompt_type = inputs["video_prompt_type"]
|
||||
image_refs = inputs["image_refs"]
|
||||
video_guide = inputs["video_guide"]
|
||||
@ -334,14 +352,8 @@ def process_prompt_and_add_tasks(state, model_choice):
|
||||
queue= gen.get("queue", [])
|
||||
return update_queue_data(queue)
|
||||
|
||||
def add_video_task(**inputs):
|
||||
global task_id
|
||||
state = inputs["state"]
|
||||
gen = get_gen_info(state)
|
||||
queue = gen["queue"]
|
||||
task_id += 1
|
||||
current_task_id = task_id
|
||||
inputs_to_query = ["image_start", "image_end", "video_guide", "image_refs","video_mask"]
|
||||
def get_preview_images(inputs):
|
||||
inputs_to_query = ["image_start", "image_end", "video_guide", "image_refs","video_mask", "video_source"]
|
||||
start_image_data = None
|
||||
end_image_data = None
|
||||
for name in inputs_to_query:
|
||||
@ -353,6 +365,17 @@ def add_video_task(**inputs):
|
||||
else:
|
||||
end_image_data = image
|
||||
break
|
||||
return start_image_data, end_image_data
|
||||
|
||||
def add_video_task(**inputs):
|
||||
global task_id
|
||||
state = inputs["state"]
|
||||
gen = get_gen_info(state)
|
||||
queue = gen["queue"]
|
||||
task_id += 1
|
||||
current_task_id = task_id
|
||||
|
||||
start_image_data, end_image_data = get_preview_images(inputs)
|
||||
|
||||
queue.append({
|
||||
"id": current_task_id,
|
||||
@ -434,7 +457,7 @@ def save_queue_action(state):
|
||||
task_id_s = task.get('id', f"task_{task_index}")
|
||||
|
||||
image_keys = ["image_start", "image_end", "image_refs"]
|
||||
video_keys = ["video_guide", "video_mask"]
|
||||
video_keys = ["video_guide", "video_mask", "video_source"]
|
||||
|
||||
for key in image_keys:
|
||||
images_pil = params_copy.get(key)
|
||||
@ -595,7 +618,7 @@ def load_queue_action(filepath, state):
|
||||
params['state'] = state
|
||||
|
||||
image_keys = ["image_start", "image_end", "image_refs"]
|
||||
video_keys = ["video_guide", "video_mask"]
|
||||
video_keys = ["video_guide", "video_mask", "video_source"]
|
||||
|
||||
loaded_pil_images = {}
|
||||
loaded_video_paths = {}
|
||||
@ -652,20 +675,10 @@ def load_queue_action(filepath, state):
|
||||
print(f"[load_queue_action] Error copying video {video_filename_in_zip} to cache: {vid_e}")
|
||||
params.pop(key, None)
|
||||
|
||||
primary_preview_pil_list, secondary_preview_pil_list = get_preview_images(params)
|
||||
|
||||
primary_preview_pil_list = loaded_pil_images.get("image_start") or loaded_pil_images.get("image_refs")
|
||||
secondary_preview_pil_list = loaded_pil_images.get("image_end")
|
||||
|
||||
primary_preview_pil = None
|
||||
if primary_preview_pil_list:
|
||||
primary_preview_pil = primary_preview_pil_list[0] if isinstance(primary_preview_pil_list, list) else primary_preview_pil_list
|
||||
|
||||
secondary_preview_pil = None
|
||||
if secondary_preview_pil_list:
|
||||
secondary_preview_pil = secondary_preview_pil_list[0] if isinstance(secondary_preview_pil_list, list) else secondary_preview_pil_list
|
||||
|
||||
start_b64 = [pil_to_base64_uri(primary_preview_pil, format="jpeg", quality=70)] if primary_preview_pil else None
|
||||
end_b64 = [pil_to_base64_uri(secondary_preview_pil, format="jpeg", quality=70)] if secondary_preview_pil else None
|
||||
start_b64 = [pil_to_base64_uri(primary_preview_pil_list[0], format="jpeg", quality=70)] if primary_preview_pil_list[0] else None
|
||||
end_b64 = [pil_to_base64_uri(secondary_preview_pil_list[0], format="jpeg", quality=70)] if secondary_preview_pil_list[0] else None
|
||||
|
||||
top_level_start_image = params.get("image_start") or params.get("image_refs")
|
||||
top_level_end_image = params.get("image_end")
|
||||
@ -818,7 +831,7 @@ def autosave_queue():
|
||||
task_id_s = task.get('id', f"task_{task_index}")
|
||||
|
||||
image_keys = ["image_start", "image_end", "image_refs"]
|
||||
video_keys = ["video_guide", "video_mask"]
|
||||
video_keys = ["video_guide", "video_mask", "video_source"]
|
||||
|
||||
for key in image_keys:
|
||||
images_pil = params_copy.get(key)
|
||||
@ -1352,13 +1365,13 @@ quantizeTransformer = args.quantize_transformer
|
||||
check_loras = args.check_loras ==1
|
||||
advanced = args.advanced
|
||||
|
||||
transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", "ckpts/wan2.1_Vace_1.3B_preview_bf16.safetensors"]
|
||||
transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", "ckpts/wan2.1_Vace_1.3B_preview_bf16.safetensors", "ckpts/wan2.1_recammaster_1.3B_bf16.safetensors"]
|
||||
transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", ]
|
||||
transformer_choices = transformer_choices_t2v + transformer_choices_i2v
|
||||
text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"]
|
||||
server_config_filename = "wgp_config.json"
|
||||
if not os.path.isdir("settings"):
|
||||
os.mkdir("settings")
|
||||
os.mkdir("settings")
|
||||
if os.path.isfile("t2v_settings.json"):
|
||||
for f in glob.glob(os.path.join(".", "*_settings.json*")):
|
||||
target_file = os.path.join("settings", Path(f).parts[-1] )
|
||||
@ -1391,30 +1404,16 @@ else:
|
||||
server_config = json.loads(text)
|
||||
|
||||
|
||||
model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "fun_inp"]
|
||||
model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "fun_inp", "recam_1.3B"]
|
||||
model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B",
|
||||
"i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B" }
|
||||
"i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B", "recam_1.3B": "recammaster_1.3B" }
|
||||
|
||||
|
||||
def get_model_type(model_filename):
|
||||
if "text2video" in model_filename and "14B" in model_filename:
|
||||
return "t2v"
|
||||
elif "text2video" in model_filename and "1.3B" in model_filename:
|
||||
return "t2v_1.3B"
|
||||
elif "Fun_InP" in model_filename and "1.3B" in model_filename:
|
||||
return "fun_inp_1.3B"
|
||||
elif "Fun_InP" in model_filename and "14B" in model_filename:
|
||||
return "fun_inp"
|
||||
elif "image2video_480p" in model_filename :
|
||||
return "i2v"
|
||||
elif "image2video_720p" in model_filename :
|
||||
return "i2v_720p"
|
||||
elif "Vace" in model_filename and "1.3B" in model_filename:
|
||||
return "vace_1.3B"
|
||||
elif "Vace" in model_filename and "14B" in model_filename:
|
||||
return "vace"
|
||||
else:
|
||||
raise Exception("Unknown model:" + model_filename)
|
||||
for model_type, signature in model_signatures.items():
|
||||
if signature in model_filename:
|
||||
return model_type
|
||||
raise Exception("Unknown model:" + model_filename)
|
||||
|
||||
def test_class_i2v(model_filename):
|
||||
return "image2video" in model_filename or "Fun_InP" in model_filename
|
||||
@ -1862,6 +1861,9 @@ def get_model_name(model_filename):
|
||||
elif "image" in model_filename:
|
||||
model_name = "Wan2.1 image2video"
|
||||
model_name += " 720p" if "720p" in model_filename else " 480p"
|
||||
elif "recam" in model_filename:
|
||||
model_name = "ReCamMaster"
|
||||
model_name += " 14B" if "14B" in model_filename else " 1.3B"
|
||||
else:
|
||||
model_name = "Wan2.1 text2video"
|
||||
model_name += " 14B" if "14B" in model_filename else " 1.3B"
|
||||
@ -2145,9 +2147,7 @@ def convert_image(image):
|
||||
|
||||
return cast(Image, ImageOps.exif_transpose(image))
|
||||
|
||||
|
||||
def preprocess_video(process_type, height, width, video_in, max_frames, start_frame=0):
|
||||
|
||||
def get_resampled_video(video_in, start_frame, max_frames):
|
||||
from wan.utils.utils import resample
|
||||
|
||||
import decord
|
||||
@ -2158,15 +2158,26 @@ def preprocess_video(process_type, height, width, video_in, max_frames, start_fr
|
||||
|
||||
frame_nos = resample(fps, len(reader), max_target_frames_count= max_frames, target_fps=16, start_target_frame= start_frame)
|
||||
frames_list = reader.get_batch(frame_nos)
|
||||
return frames_list
|
||||
|
||||
def preprocess_video(process_type, height, width, video_in, max_frames, start_frame=0, fit_canvas = False):
|
||||
|
||||
frames_list = get_resampled_video(video_in, start_frame, max_frames)
|
||||
|
||||
if len(frames_list) == 0:
|
||||
return None
|
||||
frame_height, frame_width, _ = frames_list[0].shape
|
||||
|
||||
scale = ((height * width ) / (frame_height * frame_width))**(1/2)
|
||||
# scale = min(height / frame_height, width / frame_width)
|
||||
if fit_canvas :
|
||||
scale = min(height / frame_height, width / frame_width)
|
||||
else:
|
||||
scale = ((height * width ) / (frame_height * frame_width))**(1/2)
|
||||
|
||||
new_height = (int(frame_height * scale) // 16) * 16
|
||||
new_width = (int(frame_width * scale) // 16) * 16
|
||||
# if fit_canvas :
|
||||
# new_height = height
|
||||
# new_width = width
|
||||
|
||||
processed_frames_list = []
|
||||
for frame in frames_list:
|
||||
@ -2188,12 +2199,17 @@ def preprocess_video(process_type, height, width, video_in, max_frames, start_fr
|
||||
"PRETRAINED_MODEL": "ckpts/depth/dpt_hybrid-midas-501f0c75.pt"
|
||||
}
|
||||
anno_ins = DepthVideoAnnotator(cfg_dict)
|
||||
else:
|
||||
elif process_type=="gray":
|
||||
from preprocessing.gray import GrayVideoAnnotator
|
||||
cfg_dict = {}
|
||||
anno_ins = GrayVideoAnnotator(cfg_dict)
|
||||
|
||||
np_frames = anno_ins.forward(processed_frames_list)
|
||||
else:
|
||||
anno_ins = None
|
||||
|
||||
if anno_ins == None:
|
||||
np_frames = [np.array(frame) for frame in processed_frames_list]
|
||||
else:
|
||||
np_frames = anno_ins.forward(processed_frames_list)
|
||||
|
||||
# from preprocessing.dwpose.pose import save_one_video
|
||||
# save_one_video("test.mp4", np_frames, fps=8, quality=8, macro_block_size=None)
|
||||
@ -2281,6 +2297,8 @@ def generate_video(
|
||||
image_refs,
|
||||
video_guide,
|
||||
video_mask,
|
||||
camera_type,
|
||||
video_source,
|
||||
keep_frames,
|
||||
sliding_window_repeat,
|
||||
sliding_window_overlap,
|
||||
@ -2439,8 +2457,11 @@ def generate_video(
|
||||
trans.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
|
||||
else:
|
||||
raise gr.Error("Teacache not supported for this model")
|
||||
|
||||
|
||||
source_video = None
|
||||
target_camera = None
|
||||
if "recam" in model_filename:
|
||||
source_video = preprocess_video("", width=width, height=height,video_in=video_source, max_frames= video_length, start_frame = 0, fit_canvas= True)
|
||||
target_camera = camera_type
|
||||
import random
|
||||
if seed == None or seed <0:
|
||||
seed = random.randint(0, 999999999)
|
||||
@ -2532,7 +2553,6 @@ def generate_video(
|
||||
prompts_max = gen["prompts_max"]
|
||||
status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation, sliding_window)
|
||||
|
||||
yield status
|
||||
|
||||
gen["progress_status"] = status
|
||||
gen["progress_phase"] = (" - Encoding Prompt", -1 )
|
||||
@ -2582,6 +2602,8 @@ def generate_video(
|
||||
input_frames = src_video,
|
||||
input_ref_images= src_ref_images,
|
||||
input_masks = src_mask,
|
||||
source_video= source_video,
|
||||
target_camera= target_camera,
|
||||
frame_num=(video_length // 4)* 4 + 1,
|
||||
size=(width, height),
|
||||
shift=flow_shift,
|
||||
@ -2755,7 +2777,8 @@ def generate_video(
|
||||
state['update_gallery'] = True
|
||||
if not sliding_window:
|
||||
seed += 1
|
||||
|
||||
yield status
|
||||
|
||||
if temp_filename!= None and os.path.isfile(temp_filename):
|
||||
os.remove(temp_filename)
|
||||
offload.unload_loras_from_model(trans)
|
||||
@ -2858,7 +2881,7 @@ def update_status(state):
|
||||
gen = get_gen_info(state)
|
||||
prompt_no = gen["prompt_no"]
|
||||
prompts_max = gen.get("prompts_max",0)
|
||||
total_generation = gen["total_generation"]
|
||||
total_generation = gen.get("total_generation", 1)
|
||||
repeat_no = gen["repeat_no"]
|
||||
sliding_window = gen["sliding_window"]
|
||||
status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation, sliding_window)
|
||||
@ -3186,7 +3209,7 @@ def prepare_inputs_dict(target, inputs ):
|
||||
|
||||
if target == "state":
|
||||
return inputs
|
||||
unsaved_params = ["image_start", "image_end", "image_refs", "video_guide", "video_mask"]
|
||||
unsaved_params = ["image_start", "image_end", "image_refs", "video_guide", "video_source", "video_mask"]
|
||||
for k in unsaved_params:
|
||||
inputs.pop(k)
|
||||
|
||||
@ -3200,6 +3223,9 @@ def prepare_inputs_dict(target, inputs ):
|
||||
inputs.pop("image_prompt_type")
|
||||
|
||||
|
||||
if not "recam" in model_filename:
|
||||
inputs.pop("camera_type")
|
||||
|
||||
if not "Vace" in model_filename:
|
||||
unsaved_params = ["video_prompt_type", "keep_frames", "remove_background_image_ref", "sliding_window_repeat", "sliding_window_overlap", "sliding_window_discard_last_frames"]
|
||||
for k in unsaved_params:
|
||||
@ -3243,6 +3269,8 @@ def save_inputs(
|
||||
image_refs,
|
||||
video_guide,
|
||||
video_mask,
|
||||
camera_type,
|
||||
video_source,
|
||||
keep_frames,
|
||||
sliding_window_repeat,
|
||||
sliding_window_overlap,
|
||||
@ -3551,6 +3579,26 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
else:
|
||||
image_end = gr.Image(label= "Last Image for a new video", type ="pil", visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None))
|
||||
|
||||
with gr.Column(visible= "recam" in model_filename ) as recam_column:
|
||||
camera_type = gr.Dropdown(
|
||||
choices=[
|
||||
("Pan Right", 1),
|
||||
("Pan Left", 2),
|
||||
("Tilt Up", 3),
|
||||
("Tilt Down", 4),
|
||||
("Zoom In", 5),
|
||||
("Zoom Out", 6),
|
||||
("Translate Up (with rotation)", 7),
|
||||
("Translate Down (with rotation)", 8),
|
||||
("Arc Left (with rotation)", 9),
|
||||
("Arc Right (with rotation)", 10),
|
||||
],
|
||||
value=ui_defaults.get("camera_type", 1),
|
||||
label="Camera Movement Type", scale = 3
|
||||
)
|
||||
video_source = gr.Video(label= "Video Source", value= ui_defaults.get("video_source", None),)
|
||||
|
||||
|
||||
with gr.Column(visible= "Vace" in model_filename ) as video_prompt_column:
|
||||
video_prompt_type_value= ui_defaults.get("video_prompt_type","")
|
||||
video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False)
|
||||
@ -3657,7 +3705,11 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
label="Resolution"
|
||||
)
|
||||
with gr.Row():
|
||||
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)")
|
||||
if "recam" in model_filename:
|
||||
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s), locked", interactive= False)
|
||||
else:
|
||||
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True)
|
||||
|
||||
num_inference_steps = gr.Slider(1, 100, value=ui_defaults.get("num_inference_steps",30), step=1, label="Number of Inference Steps")
|
||||
|
||||
|
||||
@ -3794,7 +3846,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
sliding_window_overlap = gr.Slider(1, 32, value=ui_defaults.get("sliding_window_overlap",16), step=1, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
|
||||
sliding_window_discard_last_frames = gr.Slider(1, 10, value=ui_defaults.get("sliding_window_discard_last_frames", 4), step=1, label="Discard Last Frames of a Window (that may have bad quality)")
|
||||
|
||||
with gr.Tab("Miscellaneous"):
|
||||
with gr.Tab("Miscellaneous", visible= not "recam" in model_filename):
|
||||
gr.Markdown("<B>With Riflex you can generate videos longer than 5s which is the default duration of videos used to train the model</B>")
|
||||
RIFLEx_setting = gr.Dropdown(
|
||||
choices=[
|
||||
@ -4007,7 +4059,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
)
|
||||
|
||||
extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column,
|
||||
prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, advanced_row, sliding_window_tab] # show_advanced presets_column,
|
||||
prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, advanced_row, sliding_window_tab,
|
||||
video_prompt_type_video_guide, video_prompt_type_image_refs, recam_column] # show_advanced presets_column,
|
||||
if update_form:
|
||||
locals_dict = locals()
|
||||
gen_inputs = [state_dict if k=="state" else locals_dict[k] for k in inputs_names] + [state_dict] + extra_inputs
|
||||
@ -4615,7 +4668,7 @@ def create_demo():
|
||||
theme = gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md")
|
||||
|
||||
with gr.Blocks(css=css, theme=theme, title= "Wan2GP") as demo:
|
||||
gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.0 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
|
||||
gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.1 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
|
||||
global model_list
|
||||
|
||||
tab_state = gr.State({ "tab_no":0 })
|
||||
|
||||
Loading…
Reference in New Issue
Block a user