mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
Added Preview mode and support Sky Reels v2 Diffusion Forcing
This commit is contained in:
parent
a63aff0377
commit
7c447ea36c
21
README.md
21
README.md
@ -10,6 +10,7 @@
|
|||||||
|
|
||||||
|
|
||||||
## 🔥 Latest News!!
|
## 🔥 Latest News!!
|
||||||
|
* April 25 2025: 👋 Wan 2.1GP v4.3: Added preview mode and support for Sky Reels v2 Diffusion Forcing for high quality "infinite length videos" (see Windo siding section below)
|
||||||
* April 18 2025: 👋 Wan 2.1GP v4.2: FLF2V model support, official support from Wan for image2video start and end frames specialized for 720p.
|
* April 18 2025: 👋 Wan 2.1GP v4.2: FLF2V model support, official support from Wan for image2video start and end frames specialized for 720p.
|
||||||
* April 17 2025: 👋 Wan 2.1GP v4.1: Recam Master model support, view a video from a different angle. The video to process must be at least 81 frames long and you should set at least 15 steps denoising to get good results.
|
* April 17 2025: 👋 Wan 2.1GP v4.1: Recam Master model support, view a video from a different angle. The video to process must be at least 81 frames long and you should set at least 15 steps denoising to get good results.
|
||||||
* April 13 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you !
|
* April 13 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you !
|
||||||
@ -302,18 +303,22 @@ There is also a guide that describes the various combination of hints (https://g
|
|||||||
|
|
||||||
It seems you will get better results if you turn on "Skip Layer Guidance" with its default configuration
|
It seems you will get better results if you turn on "Skip Layer Guidance" with its default configuration
|
||||||
|
|
||||||
### VACE Slidig Window
|
### VACE and Sky Reels v2 Diffusion Forcing Slidig Window
|
||||||
With this mode (that works for the moment only with Vace) you can merge mutiple Videos to form a very long video (up to 1 min). What is this very nice a about this feature is that the resulting video can be driven by the same control video. For instance the first 0-4s of the control video will be used to generate the first window then the next 4-8s of the control video will be used to generate the second window, and so on. So if your control video contains a person walking, your generate video could contain up to one minute of this person walking.
|
With this mode (that works for the moment only with Vace and Sky Reels v2) you can merge mutiple Videos to form a very long video (up to 1 min).
|
||||||
|
|
||||||
To turn on sliding window, you need to go in the Advanced Settings Tab *Sliding Window* and set the iteration number to a number greater than 1. This number corresponds to the default number of windows. You can still increase the number during the genreation by clicking the "One More Sample, Please !" button.
|
When combined with Vace this feature can use the same control video to generate the full Video that results from concatenining the different windows. For instance the first 0-4s of the control video will be used to generate the first window then the next 4-8s of the control video will be used to generate the second window, and so on. So if your control video contains a person walking, your generate video could contain up to one minute of this person walking.
|
||||||
|
|
||||||
Each window duration will be set by the *Number of frames (16 = 1s)* form field. However the actual number of frames generated by each iteration will be less, because the *overlap frames* and *discard last frames*:
|
When combined with Sky Reels V2, you can extend an existing video indefinetely.
|
||||||
- *overlap frames* : the first frames ofa new window are filled with last frames of the previous window in order to ensure continuity between the two windows
|
|
||||||
- *discard last frames* : quite often the last frames of a window have a worse quality. You decide here how many ending frames of a new window should be dropped.
|
|
||||||
|
|
||||||
Number of Generated = [Number of iterations] * ([Number of frames] - [Overlap Frames] - [Discard Last Frames]) + [Overlap Frames]
|
Sliding Windows are turned on by default and are triggered as soon as you try to generate a Video longer than the Window Size. You can go in the Advanced Settings Tab *Sliding Window* to set this Window Size. You can make the Video even longer during the generation process by adding one more Window to generate each time you click "Extend the Video Sample, Please !" button.
|
||||||
|
|
||||||
Experimental: if your prompt is broken into multiple lines (each line separated by a carriage return), then each line of the prompt will be used for a new window. If there are more windows to generate than prompt lines, the last prompt line will be repeated.
|
Although the window duration is set by the *Sliding Window Size* form field, the actual number of frames generated by each iteration will be less, because of the *overlap frames* and *discard last frames*:
|
||||||
|
- *overlap frames* : the first frames of a new window are filled with last frames of the previous window in order to ensure continuity between the two windows
|
||||||
|
- *discard last frames* : quite often (Vace model Only) the last frames of a window have a worse quality. You can decide here how many ending frames of a new window should be dropped.
|
||||||
|
s
|
||||||
|
Number of Generated Frames = [Number of Windows - 1] * ([Window Size] - [Overlap Frames] - [Discard Last Frames]) + [Window Size]
|
||||||
|
|
||||||
|
Experimental: if your prompt is broken into multiple lines (each line separated by a carriage return), then each line of the prompt will be used for a new window. If there are more windows to generate than prompt lines, the last prompt line will be repeated.
|
||||||
|
|
||||||
### Command line parameters for Gradio Server
|
### Command line parameters for Gradio Server
|
||||||
--i2v : launch the image to video generator\
|
--i2v : launch the image to video generator\
|
||||||
|
|||||||
@ -12,7 +12,7 @@ ftfy
|
|||||||
dashscope
|
dashscope
|
||||||
imageio-ffmpeg
|
imageio-ffmpeg
|
||||||
# flash_attn
|
# flash_attn
|
||||||
gradio>=5.0.0
|
gradio==5.23.0
|
||||||
numpy>=1.23.5,<2
|
numpy>=1.23.5,<2
|
||||||
einops
|
einops
|
||||||
moviepy==1.0.3
|
moviepy==1.0.3
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
from . import configs, distributed, modules
|
from . import configs, distributed, modules
|
||||||
from .image2video import WanI2V
|
from .image2video import WanI2V
|
||||||
from .text2video import WanT2V
|
from .text2video import WanT2V
|
||||||
|
from .diffusion_forcing import DTT2V
|
||||||
@ -352,7 +352,7 @@ class WanI2V:
|
|||||||
|
|
||||||
# self.model.to(self.device)
|
# self.model.to(self.device)
|
||||||
if callback != None:
|
if callback != None:
|
||||||
callback(-1, True)
|
callback(-1, None, True)
|
||||||
|
|
||||||
for i, t in enumerate(tqdm(timesteps)):
|
for i, t in enumerate(tqdm(timesteps)):
|
||||||
offload.set_step_no_for_lora(self.model, i)
|
offload.set_step_no_for_lora(self.model, i)
|
||||||
@ -426,7 +426,7 @@ class WanI2V:
|
|||||||
del timestep
|
del timestep
|
||||||
|
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback(i, False)
|
callback(i, latent, False)
|
||||||
|
|
||||||
|
|
||||||
x0 = [latent.to(self.device, dtype=self.dtype)]
|
x0 = [latent.to(self.device, dtype=self.dtype)]
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import numpy as np
|
|||||||
from typing import Union,Optional
|
from typing import Union,Optional
|
||||||
from mmgp import offload
|
from mmgp import offload
|
||||||
from .attention import pay_attention
|
from .attention import pay_attention
|
||||||
|
from torch.backends.cuda import sdp_kernel
|
||||||
|
|
||||||
__all__ = ['WanModel']
|
__all__ = ['WanModel']
|
||||||
|
|
||||||
@ -27,6 +28,10 @@ def sinusoidal_embedding_1d(dim, position):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def reshape_latent(latent, latent_frames):
|
||||||
|
if latent_frames == latent.shape[0]:
|
||||||
|
return latent
|
||||||
|
return latent.reshape(latent_frames, -1, latent.shape[-1] )
|
||||||
|
|
||||||
|
|
||||||
def identify_k( b: float, d: int, N: int):
|
def identify_k( b: float, d: int, N: int):
|
||||||
@ -167,7 +172,7 @@ class WanSelfAttention(nn.Module):
|
|||||||
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
||||||
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
||||||
|
|
||||||
def forward(self, xlist, grid_sizes, freqs):
|
def forward(self, xlist, grid_sizes, freqs, block_mask = None):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
||||||
@ -190,12 +195,44 @@ class WanSelfAttention(nn.Module):
|
|||||||
del x
|
del x
|
||||||
qklist = [q,k]
|
qklist = [q,k]
|
||||||
del q,k
|
del q,k
|
||||||
|
|
||||||
q,k = apply_rotary_emb(qklist, freqs, head_first=False)
|
q,k = apply_rotary_emb(qklist, freqs, head_first=False)
|
||||||
qkv_list = [q,k,v]
|
qkv_list = [q,k,v]
|
||||||
del q,k,v
|
del q,k,v
|
||||||
x = pay_attention(
|
if block_mask == None:
|
||||||
qkv_list,
|
x = pay_attention(
|
||||||
window_size=self.window_size)
|
qkv_list,
|
||||||
|
window_size=self.window_size)
|
||||||
|
else:
|
||||||
|
with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
||||||
|
x = (
|
||||||
|
torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=block_mask
|
||||||
|
)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
|
||||||
|
# if not self._flag_ar_attention:
|
||||||
|
# q = rope_apply(q, grid_sizes, freqs)
|
||||||
|
# k = rope_apply(k, grid_sizes, freqs)
|
||||||
|
# x = flash_attention(q=q, k=k, v=v, window_size=self.window_size)
|
||||||
|
# else:
|
||||||
|
# q = rope_apply(q, grid_sizes, freqs)
|
||||||
|
# k = rope_apply(k, grid_sizes, freqs)
|
||||||
|
# q = q.to(torch.bfloat16)
|
||||||
|
# k = k.to(torch.bfloat16)
|
||||||
|
# v = v.to(torch.bfloat16)
|
||||||
|
|
||||||
|
# with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
||||||
|
# x = (
|
||||||
|
# torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
# q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=block_mask
|
||||||
|
# )
|
||||||
|
# .transpose(1, 2)
|
||||||
|
# .contiguous()
|
||||||
|
# )
|
||||||
|
|
||||||
# output
|
# output
|
||||||
x = x.flatten(2)
|
x = x.flatten(2)
|
||||||
x = self.o(x)
|
x = self.o(x)
|
||||||
@ -360,7 +397,8 @@ class WanAttentionBlock(nn.Module):
|
|||||||
context,
|
context,
|
||||||
hints= None,
|
hints= None,
|
||||||
context_scale=1.0,
|
context_scale=1.0,
|
||||||
cam_emb= None
|
cam_emb= None,
|
||||||
|
block_mask = None
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@ -381,13 +419,14 @@ class WanAttentionBlock(nn.Module):
|
|||||||
hint = self.vace(hints, x, **kwargs)
|
hint = self.vace(hints, x, **kwargs)
|
||||||
else:
|
else:
|
||||||
hint = self.vace(hints, None, **kwargs)
|
hint = self.vace(hints, None, **kwargs)
|
||||||
|
latent_frames = e.shape[0]
|
||||||
e = (self.modulation + e).chunk(6, dim=1)
|
e = (self.modulation + e).chunk(6, dim=1)
|
||||||
|
|
||||||
# self-attention
|
# self-attention
|
||||||
x_mod = self.norm1(x)
|
x_mod = self.norm1(x)
|
||||||
|
x_mod = reshape_latent(x_mod , latent_frames)
|
||||||
x_mod *= 1 + e[1]
|
x_mod *= 1 + e[1]
|
||||||
x_mod += e[0]
|
x_mod += e[0]
|
||||||
|
x_mod = reshape_latent(x_mod , 1)
|
||||||
if cam_emb != None:
|
if cam_emb != None:
|
||||||
cam_emb = self.cam_encoder(cam_emb)
|
cam_emb = self.cam_encoder(cam_emb)
|
||||||
cam_emb = cam_emb.repeat(1, 2, 1)
|
cam_emb = cam_emb.repeat(1, 2, 1)
|
||||||
@ -397,12 +436,13 @@ class WanAttentionBlock(nn.Module):
|
|||||||
|
|
||||||
xlist = [x_mod]
|
xlist = [x_mod]
|
||||||
del x_mod
|
del x_mod
|
||||||
y = self.self_attn( xlist, grid_sizes, freqs)
|
y = self.self_attn( xlist, grid_sizes, freqs, block_mask)
|
||||||
if cam_emb != None:
|
if cam_emb != None:
|
||||||
y = self.projector(y)
|
y = self.projector(y)
|
||||||
# x = x + gate_msa * self.projector(self.self_attn(input_x, freqs))
|
|
||||||
|
|
||||||
|
x, y = reshape_latent(x , latent_frames), reshape_latent(y , latent_frames)
|
||||||
x.addcmul_(y, e[2])
|
x.addcmul_(y, e[2])
|
||||||
|
x, y = reshape_latent(x , 1), reshape_latent(y , 1)
|
||||||
del y
|
del y
|
||||||
y = self.norm3(x)
|
y = self.norm3(x)
|
||||||
ylist= [y]
|
ylist= [y]
|
||||||
@ -410,8 +450,10 @@ class WanAttentionBlock(nn.Module):
|
|||||||
x += self.cross_attn(ylist, context)
|
x += self.cross_attn(ylist, context)
|
||||||
y = self.norm2(x)
|
y = self.norm2(x)
|
||||||
|
|
||||||
|
y = reshape_latent(y , latent_frames)
|
||||||
y *= 1 + e[4]
|
y *= 1 + e[4]
|
||||||
y += e[3]
|
y += e[3]
|
||||||
|
y = reshape_latent(y , 1)
|
||||||
|
|
||||||
ffn = self.ffn[0]
|
ffn = self.ffn[0]
|
||||||
gelu = self.ffn[1]
|
gelu = self.ffn[1]
|
||||||
@ -428,7 +470,9 @@ class WanAttentionBlock(nn.Module):
|
|||||||
del mlp_chunk
|
del mlp_chunk
|
||||||
y = y.view(y_shape)
|
y = y.view(y_shape)
|
||||||
|
|
||||||
|
x, y = reshape_latent(x , latent_frames), reshape_latent(y , latent_frames)
|
||||||
x.addcmul_(y, e[5])
|
x.addcmul_(y, e[5])
|
||||||
|
x, y = reshape_latent(x , 1), reshape_latent(y , 1)
|
||||||
|
|
||||||
if hint is not None:
|
if hint is not None:
|
||||||
if context_scale == 1:
|
if context_scale == 1:
|
||||||
@ -500,10 +544,14 @@ class Head(nn.Module):
|
|||||||
"""
|
"""
|
||||||
# assert e.dtype == torch.float32
|
# assert e.dtype == torch.float32
|
||||||
dtype = x.dtype
|
dtype = x.dtype
|
||||||
|
|
||||||
|
latent_frames = e.shape[0]
|
||||||
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
|
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
|
||||||
x = self.norm(x).to(dtype)
|
x = self.norm(x).to(dtype)
|
||||||
|
x = reshape_latent(x , latent_frames)
|
||||||
x *= (1 + e[1])
|
x *= (1 + e[1])
|
||||||
x += e[0]
|
x += e[0]
|
||||||
|
x = reshape_latent(x , 1)
|
||||||
x = self.head(x)
|
x = self.head(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -552,7 +600,8 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
qk_norm=True,
|
qk_norm=True,
|
||||||
cross_attn_norm=True,
|
cross_attn_norm=True,
|
||||||
eps=1e-6,
|
eps=1e-6,
|
||||||
recammaster = False
|
recammaster = False,
|
||||||
|
inject_sample_info = False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Initialize the diffusion model backbone.
|
Initialize the diffusion model backbone.
|
||||||
@ -609,6 +658,10 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
self.qk_norm = qk_norm
|
self.qk_norm = qk_norm
|
||||||
self.cross_attn_norm = cross_attn_norm
|
self.cross_attn_norm = cross_attn_norm
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
self.num_frame_per_block = 1
|
||||||
|
self.flag_causal_attention = False
|
||||||
|
self.block_mask = None
|
||||||
|
self.inject_sample_info = inject_sample_info
|
||||||
|
|
||||||
# embeddings
|
# embeddings
|
||||||
self.patch_embedding = nn.Conv3d(
|
self.patch_embedding = nn.Conv3d(
|
||||||
@ -617,6 +670,10 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
|
nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
|
||||||
nn.Linear(dim, dim))
|
nn.Linear(dim, dim))
|
||||||
|
|
||||||
|
if inject_sample_info:
|
||||||
|
self.fps_embedding = nn.Embedding(2, dim)
|
||||||
|
self.fps_projection = nn.Sequential(nn.Linear(dim, dim), nn.SiLU(), nn.Linear(dim, dim * 6))
|
||||||
|
|
||||||
self.time_embedding = nn.Sequential(
|
self.time_embedding = nn.Sequential(
|
||||||
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
||||||
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
|
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
|
||||||
@ -678,12 +735,13 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
||||||
|
|
||||||
|
|
||||||
def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
|
def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
|
||||||
rescale_func = np.poly1d(self.coefficients)
|
rescale_func = np.poly1d(self.coefficients)
|
||||||
e_list = []
|
e_list = []
|
||||||
for t in timesteps:
|
for t in timesteps:
|
||||||
t = torch.stack([t])
|
t = torch.stack([t])
|
||||||
e_list.append(self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t)))
|
time_emb = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype) ) # b, dim
|
||||||
|
e_list.append(time_emb)
|
||||||
|
|
||||||
best_threshold = 0.01
|
best_threshold = 0.01
|
||||||
best_diff = 1000
|
best_diff = 1000
|
||||||
@ -695,16 +753,13 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
nb_steps = 0
|
nb_steps = 0
|
||||||
diff = 1000
|
diff = 1000
|
||||||
for i, t in enumerate(timesteps):
|
for i, t in enumerate(timesteps):
|
||||||
skip = False
|
skip = False
|
||||||
if not (i<=start_step or i== len(timesteps)):
|
if not (i<=start_step or i== len(timesteps)):
|
||||||
accumulated_rel_l1_distance += rescale_func(((e_list[i]-previous_modulated_input).abs().mean() / previous_modulated_input.abs().mean()).cpu().item())
|
accumulated_rel_l1_distance += abs(rescale_func(((e_list[i]-e_list[i-1]).abs().mean() / e_list[i-1].abs().mean()).cpu().item()))
|
||||||
# self.accumulated_rel_l1_distance_even += rescale_func(((e_list[i]-self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()).cpu().item())
|
|
||||||
|
|
||||||
if accumulated_rel_l1_distance < threshold:
|
if accumulated_rel_l1_distance < threshold:
|
||||||
skip = True
|
skip = True
|
||||||
else:
|
else:
|
||||||
accumulated_rel_l1_distance = 0
|
accumulated_rel_l1_distance = 0
|
||||||
previous_modulated_input = e_list[i]
|
|
||||||
if not skip:
|
if not skip:
|
||||||
nb_steps += 1
|
nb_steps += 1
|
||||||
signed_diff = target_nb_steps - nb_steps
|
signed_diff = target_nb_steps - nb_steps
|
||||||
@ -739,6 +794,9 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
slg_layers=None,
|
slg_layers=None,
|
||||||
callback = None,
|
callback = None,
|
||||||
cam_emb: torch.Tensor = None,
|
cam_emb: torch.Tensor = None,
|
||||||
|
fps = None,
|
||||||
|
causal_block_size = 1,
|
||||||
|
causal_attention = False,
|
||||||
):
|
):
|
||||||
|
|
||||||
if self.model_type == 'i2v':
|
if self.model_type == 'i2v':
|
||||||
@ -752,26 +810,53 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
||||||
|
|
||||||
# embeddings
|
# embeddings
|
||||||
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
||||||
# grid_sizes = torch.stack(
|
# grid_sizes = torch.stack(
|
||||||
# [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
# [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
||||||
|
|
||||||
grid_sizes = [ list(u.shape[2:]) for u in x]
|
grid_sizes = [ list(u.shape[2:]) for u in x]
|
||||||
embed_sizes = grid_sizes[0]
|
embed_sizes = grid_sizes[0]
|
||||||
|
if causal_attention : #causal_block_size > 0:
|
||||||
|
frame_num = embed_sizes[0]
|
||||||
|
height = embed_sizes[1]
|
||||||
|
width = embed_sizes[2]
|
||||||
|
block_num = frame_num // causal_block_size
|
||||||
|
range_tensor = torch.arange(block_num).view(-1, 1)
|
||||||
|
range_tensor = range_tensor.repeat(1, causal_block_size).flatten()
|
||||||
|
causal_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f
|
||||||
|
causal_mask = causal_mask.view(frame_num, 1, 1, frame_num, 1, 1).to(x[0].device)
|
||||||
|
causal_mask = causal_mask.repeat(1, height, width, 1, height, width)
|
||||||
|
causal_mask = causal_mask.reshape(frame_num * height * width, frame_num * height * width)
|
||||||
|
block_mask = causal_mask.unsqueeze(0).unsqueeze(0)
|
||||||
|
del causal_mask
|
||||||
|
|
||||||
offload.shared_state["embed_sizes"] = embed_sizes
|
offload.shared_state["embed_sizes"] = embed_sizes
|
||||||
offload.shared_state["step_no"] = current_step
|
offload.shared_state["step_no"] = current_step
|
||||||
offload.shared_state["max_steps"] = max_steps
|
offload.shared_state["max_steps"] = max_steps
|
||||||
|
|
||||||
|
|
||||||
x = [u.flatten(2).transpose(1, 2) for u in x]
|
x = [u.flatten(2).transpose(1, 2) for u in x]
|
||||||
x = x[0]
|
x = x[0]
|
||||||
|
|
||||||
# time embeddings
|
if t.dim() == 2:
|
||||||
|
b, f = t.shape
|
||||||
|
_flag_df = True
|
||||||
|
else:
|
||||||
|
_flag_df = False
|
||||||
|
|
||||||
e = self.time_embedding(
|
e = self.time_embedding(
|
||||||
sinusoidal_embedding_1d(self.freq_dim, t))
|
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype)
|
||||||
|
) # b, dim
|
||||||
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype)
|
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype)
|
||||||
|
|
||||||
|
if self.inject_sample_info:
|
||||||
|
fps = torch.tensor(fps, dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
fps_emb = self.fps_embedding(fps).float()
|
||||||
|
if _flag_df:
|
||||||
|
e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1)
|
||||||
|
else:
|
||||||
|
e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim))
|
||||||
|
|
||||||
# context
|
# context
|
||||||
context = self.text_embedding(
|
context = self.text_embedding(
|
||||||
torch.stack([
|
torch.stack([
|
||||||
@ -833,7 +918,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
self.accumulated_rel_l1_distance = 0
|
self.accumulated_rel_l1_distance = 0
|
||||||
else:
|
else:
|
||||||
rescale_func = np.poly1d(self.coefficients)
|
rescale_func = np.poly1d(self.coefficients)
|
||||||
self.accumulated_rel_l1_distance += rescale_func(((e-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
self.accumulated_rel_l1_distance += abs(rescale_func(((e-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()))
|
||||||
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
||||||
should_calc = False
|
should_calc = False
|
||||||
self.teacache_skipped_steps += 1
|
self.teacache_skipped_steps += 1
|
||||||
@ -858,7 +943,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
for block_idx, block in enumerate(self.blocks):
|
for block_idx, block in enumerate(self.blocks):
|
||||||
offload.shared_state["layer"] = block_idx
|
offload.shared_state["layer"] = block_idx
|
||||||
if callback != None:
|
if callback != None:
|
||||||
callback(-1, False, True)
|
callback(-1, None, False, True)
|
||||||
if pipeline._interrupt:
|
if pipeline._interrupt:
|
||||||
if joint_pass:
|
if joint_pass:
|
||||||
return None, None
|
return None, None
|
||||||
|
|||||||
@ -1075,13 +1075,13 @@ def sageattn_qk_int8_pv_fp8_cuda_sm90(
|
|||||||
q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128, WARPK=128)
|
q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128, WARPK=128)
|
||||||
|
|
||||||
q_size = q.size()
|
q_size = q.size()
|
||||||
|
kv_len = k.size(seq_dim)
|
||||||
q_device = q.device
|
q_device = q.device
|
||||||
del q,k
|
del q,k
|
||||||
|
|
||||||
|
|
||||||
# pad v to multiple of 128
|
# pad v to multiple of 128
|
||||||
# TODO: modify per_channel_fp8 kernel to handle this
|
# TODO: modify per_channel_fp8 kernel to handle this
|
||||||
kv_len = k.size(seq_dim)
|
|
||||||
v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0
|
v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0
|
||||||
if v_pad_len > 0:
|
if v_pad_len > 0:
|
||||||
if tensor_layout == "HND":
|
if tensor_layout == "HND":
|
||||||
|
|||||||
@ -49,40 +49,14 @@ class WanT2V:
|
|||||||
config,
|
config,
|
||||||
checkpoint_dir,
|
checkpoint_dir,
|
||||||
rank=0,
|
rank=0,
|
||||||
t5_fsdp=False,
|
|
||||||
dit_fsdp=False,
|
|
||||||
use_usp=False,
|
|
||||||
t5_cpu=False,
|
|
||||||
model_filename = None,
|
model_filename = None,
|
||||||
text_encoder_filename = None,
|
text_encoder_filename = None,
|
||||||
quantizeTransformer = False,
|
quantizeTransformer = False,
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
):
|
):
|
||||||
r"""
|
|
||||||
Initializes the Wan text-to-video generation model components.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (EasyDict):
|
|
||||||
Object containing model parameters initialized from config.py
|
|
||||||
checkpoint_dir (`str`):
|
|
||||||
Path to directory containing model checkpoints
|
|
||||||
device_id (`int`, *optional*, defaults to 0):
|
|
||||||
Id of target GPU device
|
|
||||||
rank (`int`, *optional*, defaults to 0):
|
|
||||||
Process rank for distributed training
|
|
||||||
t5_fsdp (`bool`, *optional*, defaults to False):
|
|
||||||
Enable FSDP sharding for T5 model
|
|
||||||
dit_fsdp (`bool`, *optional*, defaults to False):
|
|
||||||
Enable FSDP sharding for DiT model
|
|
||||||
use_usp (`bool`, *optional*, defaults to False):
|
|
||||||
Enable distribution strategy of USP.
|
|
||||||
t5_cpu (`bool`, *optional*, defaults to False):
|
|
||||||
Whether to place T5 model on CPU. Only works without t5_fsdp.
|
|
||||||
"""
|
|
||||||
self.device = torch.device(f"cuda")
|
self.device = torch.device(f"cuda")
|
||||||
self.config = config
|
self.config = config
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.t5_cpu = t5_cpu
|
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.num_train_timesteps = config.num_train_timesteps
|
self.num_train_timesteps = config.num_train_timesteps
|
||||||
self.param_dtype = config.param_dtype
|
self.param_dtype = config.param_dtype
|
||||||
@ -419,9 +393,9 @@ class WanT2V:
|
|||||||
freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
|
freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
|
||||||
else:
|
else:
|
||||||
freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
|
freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
|
||||||
arg_c = {'context': context, 'freqs': freqs, 'pipeline': self}
|
arg_c = {'context': context, 'freqs': freqs, 'pipeline': self, 'callback': callback}
|
||||||
arg_null = {'context': context_null, 'freqs': freqs, 'pipeline': self}
|
arg_null = {'context': context_null, 'freqs': freqs, 'pipeline': self, 'callback': callback}
|
||||||
arg_both = {'context': context, 'context2': context_null, 'freqs': freqs, 'pipeline': self}
|
arg_both = {'context': context, 'context2': context_null, 'freqs': freqs, 'pipeline': self, 'callback': callback}
|
||||||
|
|
||||||
if target_camera != None:
|
if target_camera != None:
|
||||||
recam_dict = {'cam_emb': cam_emb}
|
recam_dict = {'cam_emb': cam_emb}
|
||||||
@ -438,7 +412,7 @@ class WanT2V:
|
|||||||
if self.model.enable_teacache:
|
if self.model.enable_teacache:
|
||||||
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
||||||
if callback != None:
|
if callback != None:
|
||||||
callback(-1, True)
|
callback(-1, None, True)
|
||||||
for i, t in enumerate(tqdm(timesteps)):
|
for i, t in enumerate(tqdm(timesteps)):
|
||||||
if target_camera != None:
|
if target_camera != None:
|
||||||
latent_model_input = [torch.cat([u,v], dim=1) for u,v in zip(latents,source_latents )]
|
latent_model_input = [torch.cat([u,v], dim=1) for u,v in zip(latents,source_latents )]
|
||||||
@ -494,7 +468,7 @@ class WanT2V:
|
|||||||
del temp_x0
|
del temp_x0
|
||||||
|
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback(i, False)
|
callback(i, latents[0], False)
|
||||||
|
|
||||||
x0 = latents
|
x0 = latents
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user