Added Preview mode and support Sky Reels v2 Diffusion Forcing

This commit is contained in:
deepbeepmeep 2025-04-25 22:07:13 +02:00
parent a63aff0377
commit 7c447ea36c
8 changed files with 1155 additions and 832 deletions

View File

@ -10,6 +10,7 @@
## 🔥 Latest News!!
* April 25 2025: 👋 Wan 2.1GP v4.3: Added preview mode and support for Sky Reels v2 Diffusion Forcing for high quality "infinite length videos" (see Windo siding section below)
* April 18 2025: 👋 Wan 2.1GP v4.2: FLF2V model support, official support from Wan for image2video start and end frames specialized for 720p.
* April 17 2025: 👋 Wan 2.1GP v4.1: Recam Master model support, view a video from a different angle. The video to process must be at least 81 frames long and you should set at least 15 steps denoising to get good results.
* April 13 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you !
@ -302,16 +303,20 @@ There is also a guide that describes the various combination of hints (https://g
It seems you will get better results if you turn on "Skip Layer Guidance" with its default configuration
### VACE Slidig Window
With this mode (that works for the moment only with Vace) you can merge mutiple Videos to form a very long video (up to 1 min). What is this very nice a about this feature is that the resulting video can be driven by the same control video. For instance the first 0-4s of the control video will be used to generate the first window then the next 4-8s of the control video will be used to generate the second window, and so on. So if your control video contains a person walking, your generate video could contain up to one minute of this person walking.
### VACE and Sky Reels v2 Diffusion Forcing Slidig Window
With this mode (that works for the moment only with Vace and Sky Reels v2) you can merge mutiple Videos to form a very long video (up to 1 min).
To turn on sliding window, you need to go in the Advanced Settings Tab *Sliding Window* and set the iteration number to a number greater than 1. This number corresponds to the default number of windows. You can still increase the number during the genreation by clicking the "One More Sample, Please !" button.
When combined with Vace this feature can use the same control video to generate the full Video that results from concatenining the different windows. For instance the first 0-4s of the control video will be used to generate the first window then the next 4-8s of the control video will be used to generate the second window, and so on. So if your control video contains a person walking, your generate video could contain up to one minute of this person walking.
Each window duration will be set by the *Number of frames (16 = 1s)* form field. However the actual number of frames generated by each iteration will be less, because the *overlap frames* and *discard last frames*:
When combined with Sky Reels V2, you can extend an existing video indefinetely.
Sliding Windows are turned on by default and are triggered as soon as you try to generate a Video longer than the Window Size. You can go in the Advanced Settings Tab *Sliding Window* to set this Window Size. You can make the Video even longer during the generation process by adding one more Window to generate each time you click "Extend the Video Sample, Please !" button.
Although the window duration is set by the *Sliding Window Size* form field, the actual number of frames generated by each iteration will be less, because of the *overlap frames* and *discard last frames*:
- *overlap frames* : the first frames of a new window are filled with last frames of the previous window in order to ensure continuity between the two windows
- *discard last frames* : quite often the last frames of a window have a worse quality. You decide here how many ending frames of a new window should be dropped.
Number of Generated = [Number of iterations] * ([Number of frames] - [Overlap Frames] - [Discard Last Frames]) + [Overlap Frames]
- *discard last frames* : quite often (Vace model Only) the last frames of a window have a worse quality. You can decide here how many ending frames of a new window should be dropped.
s
Number of Generated Frames = [Number of Windows - 1] * ([Window Size] - [Overlap Frames] - [Discard Last Frames]) + [Window Size]
Experimental: if your prompt is broken into multiple lines (each line separated by a carriage return), then each line of the prompt will be used for a new window. If there are more windows to generate than prompt lines, the last prompt line will be repeated.

View File

@ -12,7 +12,7 @@ ftfy
dashscope
imageio-ffmpeg
# flash_attn
gradio>=5.0.0
gradio==5.23.0
numpy>=1.23.5,<2
einops
moviepy==1.0.3

View File

@ -1,3 +1,4 @@
from . import configs, distributed, modules
from .image2video import WanI2V
from .text2video import WanT2V
from .diffusion_forcing import DTT2V

View File

@ -352,7 +352,7 @@ class WanI2V:
# self.model.to(self.device)
if callback != None:
callback(-1, True)
callback(-1, None, True)
for i, t in enumerate(tqdm(timesteps)):
offload.set_step_no_for_lora(self.model, i)
@ -426,7 +426,7 @@ class WanI2V:
del timestep
if callback is not None:
callback(i, False)
callback(i, latent, False)
x0 = [latent.to(self.device, dtype=self.dtype)]

View File

@ -10,6 +10,7 @@ import numpy as np
from typing import Union,Optional
from mmgp import offload
from .attention import pay_attention
from torch.backends.cuda import sdp_kernel
__all__ = ['WanModel']
@ -27,6 +28,10 @@ def sinusoidal_embedding_1d(dim, position):
return x
def reshape_latent(latent, latent_frames):
if latent_frames == latent.shape[0]:
return latent
return latent.reshape(latent_frames, -1, latent.shape[-1] )
def identify_k( b: float, d: int, N: int):
@ -167,7 +172,7 @@ class WanSelfAttention(nn.Module):
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
def forward(self, xlist, grid_sizes, freqs):
def forward(self, xlist, grid_sizes, freqs, block_mask = None):
r"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
@ -190,12 +195,44 @@ class WanSelfAttention(nn.Module):
del x
qklist = [q,k]
del q,k
q,k = apply_rotary_emb(qklist, freqs, head_first=False)
qkv_list = [q,k,v]
del q,k,v
if block_mask == None:
x = pay_attention(
qkv_list,
window_size=self.window_size)
else:
with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
x = (
torch.nn.functional.scaled_dot_product_attention(
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=block_mask
)
.transpose(1, 2)
.contiguous()
)
# if not self._flag_ar_attention:
# q = rope_apply(q, grid_sizes, freqs)
# k = rope_apply(k, grid_sizes, freqs)
# x = flash_attention(q=q, k=k, v=v, window_size=self.window_size)
# else:
# q = rope_apply(q, grid_sizes, freqs)
# k = rope_apply(k, grid_sizes, freqs)
# q = q.to(torch.bfloat16)
# k = k.to(torch.bfloat16)
# v = v.to(torch.bfloat16)
# with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
# x = (
# torch.nn.functional.scaled_dot_product_attention(
# q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=block_mask
# )
# .transpose(1, 2)
# .contiguous()
# )
# output
x = x.flatten(2)
x = self.o(x)
@ -360,7 +397,8 @@ class WanAttentionBlock(nn.Module):
context,
hints= None,
context_scale=1.0,
cam_emb= None
cam_emb= None,
block_mask = None
):
r"""
Args:
@ -381,13 +419,14 @@ class WanAttentionBlock(nn.Module):
hint = self.vace(hints, x, **kwargs)
else:
hint = self.vace(hints, None, **kwargs)
latent_frames = e.shape[0]
e = (self.modulation + e).chunk(6, dim=1)
# self-attention
x_mod = self.norm1(x)
x_mod = reshape_latent(x_mod , latent_frames)
x_mod *= 1 + e[1]
x_mod += e[0]
x_mod = reshape_latent(x_mod , 1)
if cam_emb != None:
cam_emb = self.cam_encoder(cam_emb)
cam_emb = cam_emb.repeat(1, 2, 1)
@ -397,12 +436,13 @@ class WanAttentionBlock(nn.Module):
xlist = [x_mod]
del x_mod
y = self.self_attn( xlist, grid_sizes, freqs)
y = self.self_attn( xlist, grid_sizes, freqs, block_mask)
if cam_emb != None:
y = self.projector(y)
# x = x + gate_msa * self.projector(self.self_attn(input_x, freqs))
x, y = reshape_latent(x , latent_frames), reshape_latent(y , latent_frames)
x.addcmul_(y, e[2])
x, y = reshape_latent(x , 1), reshape_latent(y , 1)
del y
y = self.norm3(x)
ylist= [y]
@ -410,8 +450,10 @@ class WanAttentionBlock(nn.Module):
x += self.cross_attn(ylist, context)
y = self.norm2(x)
y = reshape_latent(y , latent_frames)
y *= 1 + e[4]
y += e[3]
y = reshape_latent(y , 1)
ffn = self.ffn[0]
gelu = self.ffn[1]
@ -428,7 +470,9 @@ class WanAttentionBlock(nn.Module):
del mlp_chunk
y = y.view(y_shape)
x, y = reshape_latent(x , latent_frames), reshape_latent(y , latent_frames)
x.addcmul_(y, e[5])
x, y = reshape_latent(x , 1), reshape_latent(y , 1)
if hint is not None:
if context_scale == 1:
@ -500,10 +544,14 @@ class Head(nn.Module):
"""
# assert e.dtype == torch.float32
dtype = x.dtype
latent_frames = e.shape[0]
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
x = self.norm(x).to(dtype)
x = reshape_latent(x , latent_frames)
x *= (1 + e[1])
x += e[0]
x = reshape_latent(x , 1)
x = self.head(x)
return x
@ -552,7 +600,8 @@ class WanModel(ModelMixin, ConfigMixin):
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
recammaster = False
recammaster = False,
inject_sample_info = False,
):
r"""
Initialize the diffusion model backbone.
@ -609,6 +658,10 @@ class WanModel(ModelMixin, ConfigMixin):
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
self.num_frame_per_block = 1
self.flag_causal_attention = False
self.block_mask = None
self.inject_sample_info = inject_sample_info
# embeddings
self.patch_embedding = nn.Conv3d(
@ -617,6 +670,10 @@ class WanModel(ModelMixin, ConfigMixin):
nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
nn.Linear(dim, dim))
if inject_sample_info:
self.fps_embedding = nn.Embedding(2, dim)
self.fps_projection = nn.Sequential(nn.Linear(dim, dim), nn.SiLU(), nn.Linear(dim, dim * 6))
self.time_embedding = nn.Sequential(
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
@ -683,7 +740,8 @@ class WanModel(ModelMixin, ConfigMixin):
e_list = []
for t in timesteps:
t = torch.stack([t])
e_list.append(self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t)))
time_emb = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype) ) # b, dim
e_list.append(time_emb)
best_threshold = 0.01
best_diff = 1000
@ -697,14 +755,11 @@ class WanModel(ModelMixin, ConfigMixin):
for i, t in enumerate(timesteps):
skip = False
if not (i<=start_step or i== len(timesteps)):
accumulated_rel_l1_distance += rescale_func(((e_list[i]-previous_modulated_input).abs().mean() / previous_modulated_input.abs().mean()).cpu().item())
# self.accumulated_rel_l1_distance_even += rescale_func(((e_list[i]-self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()).cpu().item())
accumulated_rel_l1_distance += abs(rescale_func(((e_list[i]-e_list[i-1]).abs().mean() / e_list[i-1].abs().mean()).cpu().item()))
if accumulated_rel_l1_distance < threshold:
skip = True
else:
accumulated_rel_l1_distance = 0
previous_modulated_input = e_list[i]
if not skip:
nb_steps += 1
signed_diff = target_nb_steps - nb_steps
@ -739,6 +794,9 @@ class WanModel(ModelMixin, ConfigMixin):
slg_layers=None,
callback = None,
cam_emb: torch.Tensor = None,
fps = None,
causal_block_size = 1,
causal_attention = False,
):
if self.model_type == 'i2v':
@ -758,20 +816,47 @@ class WanModel(ModelMixin, ConfigMixin):
grid_sizes = [ list(u.shape[2:]) for u in x]
embed_sizes = grid_sizes[0]
if causal_attention : #causal_block_size > 0:
frame_num = embed_sizes[0]
height = embed_sizes[1]
width = embed_sizes[2]
block_num = frame_num // causal_block_size
range_tensor = torch.arange(block_num).view(-1, 1)
range_tensor = range_tensor.repeat(1, causal_block_size).flatten()
causal_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f
causal_mask = causal_mask.view(frame_num, 1, 1, frame_num, 1, 1).to(x[0].device)
causal_mask = causal_mask.repeat(1, height, width, 1, height, width)
causal_mask = causal_mask.reshape(frame_num * height * width, frame_num * height * width)
block_mask = causal_mask.unsqueeze(0).unsqueeze(0)
del causal_mask
offload.shared_state["embed_sizes"] = embed_sizes
offload.shared_state["step_no"] = current_step
offload.shared_state["max_steps"] = max_steps
x = [u.flatten(2).transpose(1, 2) for u in x]
x = x[0]
# time embeddings
if t.dim() == 2:
b, f = t.shape
_flag_df = True
else:
_flag_df = False
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t))
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype)
) # b, dim
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype)
if self.inject_sample_info:
fps = torch.tensor(fps, dtype=torch.long, device=device)
fps_emb = self.fps_embedding(fps).float()
if _flag_df:
e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1)
else:
e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim))
# context
context = self.text_embedding(
torch.stack([
@ -833,7 +918,7 @@ class WanModel(ModelMixin, ConfigMixin):
self.accumulated_rel_l1_distance = 0
else:
rescale_func = np.poly1d(self.coefficients)
self.accumulated_rel_l1_distance += rescale_func(((e-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
self.accumulated_rel_l1_distance += abs(rescale_func(((e-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()))
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
self.teacache_skipped_steps += 1
@ -858,7 +943,7 @@ class WanModel(ModelMixin, ConfigMixin):
for block_idx, block in enumerate(self.blocks):
offload.shared_state["layer"] = block_idx
if callback != None:
callback(-1, False, True)
callback(-1, None, False, True)
if pipeline._interrupt:
if joint_pass:
return None, None

View File

@ -1075,13 +1075,13 @@ def sageattn_qk_int8_pv_fp8_cuda_sm90(
q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128, WARPK=128)
q_size = q.size()
kv_len = k.size(seq_dim)
q_device = q.device
del q,k
# pad v to multiple of 128
# TODO: modify per_channel_fp8 kernel to handle this
kv_len = k.size(seq_dim)
v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0
if v_pad_len > 0:
if tensor_layout == "HND":

View File

@ -49,40 +49,14 @@ class WanT2V:
config,
checkpoint_dir,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
t5_cpu=False,
model_filename = None,
text_encoder_filename = None,
quantizeTransformer = False,
dtype = torch.bfloat16
):
r"""
Initializes the Wan text-to-video generation model components.
Args:
config (EasyDict):
Object containing model parameters initialized from config.py
checkpoint_dir (`str`):
Path to directory containing model checkpoints
device_id (`int`, *optional*, defaults to 0):
Id of target GPU device
rank (`int`, *optional*, defaults to 0):
Process rank for distributed training
t5_fsdp (`bool`, *optional*, defaults to False):
Enable FSDP sharding for T5 model
dit_fsdp (`bool`, *optional*, defaults to False):
Enable FSDP sharding for DiT model
use_usp (`bool`, *optional*, defaults to False):
Enable distribution strategy of USP.
t5_cpu (`bool`, *optional*, defaults to False):
Whether to place T5 model on CPU. Only works without t5_fsdp.
"""
self.device = torch.device(f"cuda")
self.config = config
self.rank = rank
self.t5_cpu = t5_cpu
self.dtype = dtype
self.num_train_timesteps = config.num_train_timesteps
self.param_dtype = config.param_dtype
@ -419,9 +393,9 @@ class WanT2V:
freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
else:
freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
arg_c = {'context': context, 'freqs': freqs, 'pipeline': self}
arg_null = {'context': context_null, 'freqs': freqs, 'pipeline': self}
arg_both = {'context': context, 'context2': context_null, 'freqs': freqs, 'pipeline': self}
arg_c = {'context': context, 'freqs': freqs, 'pipeline': self, 'callback': callback}
arg_null = {'context': context_null, 'freqs': freqs, 'pipeline': self, 'callback': callback}
arg_both = {'context': context, 'context2': context_null, 'freqs': freqs, 'pipeline': self, 'callback': callback}
if target_camera != None:
recam_dict = {'cam_emb': cam_emb}
@ -438,7 +412,7 @@ class WanT2V:
if self.model.enable_teacache:
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
if callback != None:
callback(-1, True)
callback(-1, None, True)
for i, t in enumerate(tqdm(timesteps)):
if target_camera != None:
latent_model_input = [torch.cat([u,v], dim=1) for u,v in zip(latents,source_latents )]
@ -494,7 +468,7 @@ class WanT2V:
del temp_x0
if callback is not None:
callback(i, False)
callback(i, latents[0], False)
x0 = latents

1298
wgp.py

File diff suppressed because it is too large Load Diff