Added Preview mode and support Sky Reels v2 Diffusion Forcing

This commit is contained in:
deepbeepmeep 2025-04-25 22:07:13 +02:00
parent a63aff0377
commit 7c447ea36c
8 changed files with 1155 additions and 832 deletions

View File

@ -10,6 +10,7 @@
## 🔥 Latest News!! ## 🔥 Latest News!!
* April 25 2025: 👋 Wan 2.1GP v4.3: Added preview mode and support for Sky Reels v2 Diffusion Forcing for high quality "infinite length videos" (see Windo siding section below)
* April 18 2025: 👋 Wan 2.1GP v4.2: FLF2V model support, official support from Wan for image2video start and end frames specialized for 720p. * April 18 2025: 👋 Wan 2.1GP v4.2: FLF2V model support, official support from Wan for image2video start and end frames specialized for 720p.
* April 17 2025: 👋 Wan 2.1GP v4.1: Recam Master model support, view a video from a different angle. The video to process must be at least 81 frames long and you should set at least 15 steps denoising to get good results. * April 17 2025: 👋 Wan 2.1GP v4.1: Recam Master model support, view a video from a different angle. The video to process must be at least 81 frames long and you should set at least 15 steps denoising to get good results.
* April 13 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you ! * April 13 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you !
@ -302,18 +303,22 @@ There is also a guide that describes the various combination of hints (https://g
It seems you will get better results if you turn on "Skip Layer Guidance" with its default configuration It seems you will get better results if you turn on "Skip Layer Guidance" with its default configuration
### VACE Slidig Window ### VACE and Sky Reels v2 Diffusion Forcing Slidig Window
With this mode (that works for the moment only with Vace) you can merge mutiple Videos to form a very long video (up to 1 min). What is this very nice a about this feature is that the resulting video can be driven by the same control video. For instance the first 0-4s of the control video will be used to generate the first window then the next 4-8s of the control video will be used to generate the second window, and so on. So if your control video contains a person walking, your generate video could contain up to one minute of this person walking. With this mode (that works for the moment only with Vace and Sky Reels v2) you can merge mutiple Videos to form a very long video (up to 1 min).
To turn on sliding window, you need to go in the Advanced Settings Tab *Sliding Window* and set the iteration number to a number greater than 1. This number corresponds to the default number of windows. You can still increase the number during the genreation by clicking the "One More Sample, Please !" button. When combined with Vace this feature can use the same control video to generate the full Video that results from concatenining the different windows. For instance the first 0-4s of the control video will be used to generate the first window then the next 4-8s of the control video will be used to generate the second window, and so on. So if your control video contains a person walking, your generate video could contain up to one minute of this person walking.
Each window duration will be set by the *Number of frames (16 = 1s)* form field. However the actual number of frames generated by each iteration will be less, because the *overlap frames* and *discard last frames*: When combined with Sky Reels V2, you can extend an existing video indefinetely.
- *overlap frames* : the first frames ofa new window are filled with last frames of the previous window in order to ensure continuity between the two windows
- *discard last frames* : quite often the last frames of a window have a worse quality. You decide here how many ending frames of a new window should be dropped.
Number of Generated = [Number of iterations] * ([Number of frames] - [Overlap Frames] - [Discard Last Frames]) + [Overlap Frames] Sliding Windows are turned on by default and are triggered as soon as you try to generate a Video longer than the Window Size. You can go in the Advanced Settings Tab *Sliding Window* to set this Window Size. You can make the Video even longer during the generation process by adding one more Window to generate each time you click "Extend the Video Sample, Please !" button.
Experimental: if your prompt is broken into multiple lines (each line separated by a carriage return), then each line of the prompt will be used for a new window. If there are more windows to generate than prompt lines, the last prompt line will be repeated. Although the window duration is set by the *Sliding Window Size* form field, the actual number of frames generated by each iteration will be less, because of the *overlap frames* and *discard last frames*:
- *overlap frames* : the first frames of a new window are filled with last frames of the previous window in order to ensure continuity between the two windows
- *discard last frames* : quite often (Vace model Only) the last frames of a window have a worse quality. You can decide here how many ending frames of a new window should be dropped.
s
Number of Generated Frames = [Number of Windows - 1] * ([Window Size] - [Overlap Frames] - [Discard Last Frames]) + [Window Size]
Experimental: if your prompt is broken into multiple lines (each line separated by a carriage return), then each line of the prompt will be used for a new window. If there are more windows to generate than prompt lines, the last prompt line will be repeated.
### Command line parameters for Gradio Server ### Command line parameters for Gradio Server
--i2v : launch the image to video generator\ --i2v : launch the image to video generator\

View File

@ -12,7 +12,7 @@ ftfy
dashscope dashscope
imageio-ffmpeg imageio-ffmpeg
# flash_attn # flash_attn
gradio>=5.0.0 gradio==5.23.0
numpy>=1.23.5,<2 numpy>=1.23.5,<2
einops einops
moviepy==1.0.3 moviepy==1.0.3

View File

@ -1,3 +1,4 @@
from . import configs, distributed, modules from . import configs, distributed, modules
from .image2video import WanI2V from .image2video import WanI2V
from .text2video import WanT2V from .text2video import WanT2V
from .diffusion_forcing import DTT2V

View File

@ -352,7 +352,7 @@ class WanI2V:
# self.model.to(self.device) # self.model.to(self.device)
if callback != None: if callback != None:
callback(-1, True) callback(-1, None, True)
for i, t in enumerate(tqdm(timesteps)): for i, t in enumerate(tqdm(timesteps)):
offload.set_step_no_for_lora(self.model, i) offload.set_step_no_for_lora(self.model, i)
@ -426,7 +426,7 @@ class WanI2V:
del timestep del timestep
if callback is not None: if callback is not None:
callback(i, False) callback(i, latent, False)
x0 = [latent.to(self.device, dtype=self.dtype)] x0 = [latent.to(self.device, dtype=self.dtype)]

View File

@ -10,6 +10,7 @@ import numpy as np
from typing import Union,Optional from typing import Union,Optional
from mmgp import offload from mmgp import offload
from .attention import pay_attention from .attention import pay_attention
from torch.backends.cuda import sdp_kernel
__all__ = ['WanModel'] __all__ = ['WanModel']
@ -27,6 +28,10 @@ def sinusoidal_embedding_1d(dim, position):
return x return x
def reshape_latent(latent, latent_frames):
if latent_frames == latent.shape[0]:
return latent
return latent.reshape(latent_frames, -1, latent.shape[-1] )
def identify_k( b: float, d: int, N: int): def identify_k( b: float, d: int, N: int):
@ -167,7 +172,7 @@ class WanSelfAttention(nn.Module):
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
def forward(self, xlist, grid_sizes, freqs): def forward(self, xlist, grid_sizes, freqs, block_mask = None):
r""" r"""
Args: Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads] x(Tensor): Shape [B, L, num_heads, C / num_heads]
@ -190,12 +195,44 @@ class WanSelfAttention(nn.Module):
del x del x
qklist = [q,k] qklist = [q,k]
del q,k del q,k
q,k = apply_rotary_emb(qklist, freqs, head_first=False) q,k = apply_rotary_emb(qklist, freqs, head_first=False)
qkv_list = [q,k,v] qkv_list = [q,k,v]
del q,k,v del q,k,v
x = pay_attention( if block_mask == None:
qkv_list, x = pay_attention(
window_size=self.window_size) qkv_list,
window_size=self.window_size)
else:
with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
x = (
torch.nn.functional.scaled_dot_product_attention(
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=block_mask
)
.transpose(1, 2)
.contiguous()
)
# if not self._flag_ar_attention:
# q = rope_apply(q, grid_sizes, freqs)
# k = rope_apply(k, grid_sizes, freqs)
# x = flash_attention(q=q, k=k, v=v, window_size=self.window_size)
# else:
# q = rope_apply(q, grid_sizes, freqs)
# k = rope_apply(k, grid_sizes, freqs)
# q = q.to(torch.bfloat16)
# k = k.to(torch.bfloat16)
# v = v.to(torch.bfloat16)
# with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
# x = (
# torch.nn.functional.scaled_dot_product_attention(
# q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=block_mask
# )
# .transpose(1, 2)
# .contiguous()
# )
# output # output
x = x.flatten(2) x = x.flatten(2)
x = self.o(x) x = self.o(x)
@ -360,7 +397,8 @@ class WanAttentionBlock(nn.Module):
context, context,
hints= None, hints= None,
context_scale=1.0, context_scale=1.0,
cam_emb= None cam_emb= None,
block_mask = None
): ):
r""" r"""
Args: Args:
@ -381,13 +419,14 @@ class WanAttentionBlock(nn.Module):
hint = self.vace(hints, x, **kwargs) hint = self.vace(hints, x, **kwargs)
else: else:
hint = self.vace(hints, None, **kwargs) hint = self.vace(hints, None, **kwargs)
latent_frames = e.shape[0]
e = (self.modulation + e).chunk(6, dim=1) e = (self.modulation + e).chunk(6, dim=1)
# self-attention # self-attention
x_mod = self.norm1(x) x_mod = self.norm1(x)
x_mod = reshape_latent(x_mod , latent_frames)
x_mod *= 1 + e[1] x_mod *= 1 + e[1]
x_mod += e[0] x_mod += e[0]
x_mod = reshape_latent(x_mod , 1)
if cam_emb != None: if cam_emb != None:
cam_emb = self.cam_encoder(cam_emb) cam_emb = self.cam_encoder(cam_emb)
cam_emb = cam_emb.repeat(1, 2, 1) cam_emb = cam_emb.repeat(1, 2, 1)
@ -397,12 +436,13 @@ class WanAttentionBlock(nn.Module):
xlist = [x_mod] xlist = [x_mod]
del x_mod del x_mod
y = self.self_attn( xlist, grid_sizes, freqs) y = self.self_attn( xlist, grid_sizes, freqs, block_mask)
if cam_emb != None: if cam_emb != None:
y = self.projector(y) y = self.projector(y)
# x = x + gate_msa * self.projector(self.self_attn(input_x, freqs))
x, y = reshape_latent(x , latent_frames), reshape_latent(y , latent_frames)
x.addcmul_(y, e[2]) x.addcmul_(y, e[2])
x, y = reshape_latent(x , 1), reshape_latent(y , 1)
del y del y
y = self.norm3(x) y = self.norm3(x)
ylist= [y] ylist= [y]
@ -410,8 +450,10 @@ class WanAttentionBlock(nn.Module):
x += self.cross_attn(ylist, context) x += self.cross_attn(ylist, context)
y = self.norm2(x) y = self.norm2(x)
y = reshape_latent(y , latent_frames)
y *= 1 + e[4] y *= 1 + e[4]
y += e[3] y += e[3]
y = reshape_latent(y , 1)
ffn = self.ffn[0] ffn = self.ffn[0]
gelu = self.ffn[1] gelu = self.ffn[1]
@ -428,7 +470,9 @@ class WanAttentionBlock(nn.Module):
del mlp_chunk del mlp_chunk
y = y.view(y_shape) y = y.view(y_shape)
x, y = reshape_latent(x , latent_frames), reshape_latent(y , latent_frames)
x.addcmul_(y, e[5]) x.addcmul_(y, e[5])
x, y = reshape_latent(x , 1), reshape_latent(y , 1)
if hint is not None: if hint is not None:
if context_scale == 1: if context_scale == 1:
@ -500,10 +544,14 @@ class Head(nn.Module):
""" """
# assert e.dtype == torch.float32 # assert e.dtype == torch.float32
dtype = x.dtype dtype = x.dtype
latent_frames = e.shape[0]
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
x = self.norm(x).to(dtype) x = self.norm(x).to(dtype)
x = reshape_latent(x , latent_frames)
x *= (1 + e[1]) x *= (1 + e[1])
x += e[0] x += e[0]
x = reshape_latent(x , 1)
x = self.head(x) x = self.head(x)
return x return x
@ -552,7 +600,8 @@ class WanModel(ModelMixin, ConfigMixin):
qk_norm=True, qk_norm=True,
cross_attn_norm=True, cross_attn_norm=True,
eps=1e-6, eps=1e-6,
recammaster = False recammaster = False,
inject_sample_info = False,
): ):
r""" r"""
Initialize the diffusion model backbone. Initialize the diffusion model backbone.
@ -609,6 +658,10 @@ class WanModel(ModelMixin, ConfigMixin):
self.qk_norm = qk_norm self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm self.cross_attn_norm = cross_attn_norm
self.eps = eps self.eps = eps
self.num_frame_per_block = 1
self.flag_causal_attention = False
self.block_mask = None
self.inject_sample_info = inject_sample_info
# embeddings # embeddings
self.patch_embedding = nn.Conv3d( self.patch_embedding = nn.Conv3d(
@ -617,6 +670,10 @@ class WanModel(ModelMixin, ConfigMixin):
nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'), nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
nn.Linear(dim, dim)) nn.Linear(dim, dim))
if inject_sample_info:
self.fps_embedding = nn.Embedding(2, dim)
self.fps_projection = nn.Sequential(nn.Linear(dim, dim), nn.SiLU(), nn.Linear(dim, dim * 6))
self.time_embedding = nn.Sequential( self.time_embedding = nn.Sequential(
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
@ -678,12 +735,13 @@ class WanModel(ModelMixin, ConfigMixin):
block.projector.bias = nn.Parameter(torch.zeros(dim)) block.projector.bias = nn.Parameter(torch.zeros(dim))
def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0): def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
rescale_func = np.poly1d(self.coefficients) rescale_func = np.poly1d(self.coefficients)
e_list = [] e_list = []
for t in timesteps: for t in timesteps:
t = torch.stack([t]) t = torch.stack([t])
e_list.append(self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t))) time_emb = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype) ) # b, dim
e_list.append(time_emb)
best_threshold = 0.01 best_threshold = 0.01
best_diff = 1000 best_diff = 1000
@ -695,16 +753,13 @@ class WanModel(ModelMixin, ConfigMixin):
nb_steps = 0 nb_steps = 0
diff = 1000 diff = 1000
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
skip = False skip = False
if not (i<=start_step or i== len(timesteps)): if not (i<=start_step or i== len(timesteps)):
accumulated_rel_l1_distance += rescale_func(((e_list[i]-previous_modulated_input).abs().mean() / previous_modulated_input.abs().mean()).cpu().item()) accumulated_rel_l1_distance += abs(rescale_func(((e_list[i]-e_list[i-1]).abs().mean() / e_list[i-1].abs().mean()).cpu().item()))
# self.accumulated_rel_l1_distance_even += rescale_func(((e_list[i]-self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()).cpu().item())
if accumulated_rel_l1_distance < threshold: if accumulated_rel_l1_distance < threshold:
skip = True skip = True
else: else:
accumulated_rel_l1_distance = 0 accumulated_rel_l1_distance = 0
previous_modulated_input = e_list[i]
if not skip: if not skip:
nb_steps += 1 nb_steps += 1
signed_diff = target_nb_steps - nb_steps signed_diff = target_nb_steps - nb_steps
@ -739,6 +794,9 @@ class WanModel(ModelMixin, ConfigMixin):
slg_layers=None, slg_layers=None,
callback = None, callback = None,
cam_emb: torch.Tensor = None, cam_emb: torch.Tensor = None,
fps = None,
causal_block_size = 1,
causal_attention = False,
): ):
if self.model_type == 'i2v': if self.model_type == 'i2v':
@ -752,26 +810,53 @@ class WanModel(ModelMixin, ConfigMixin):
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings # embeddings
x = [self.patch_embedding(u.unsqueeze(0)) for u in x] x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
# grid_sizes = torch.stack( # grid_sizes = torch.stack(
# [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) # [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
grid_sizes = [ list(u.shape[2:]) for u in x] grid_sizes = [ list(u.shape[2:]) for u in x]
embed_sizes = grid_sizes[0] embed_sizes = grid_sizes[0]
if causal_attention : #causal_block_size > 0:
frame_num = embed_sizes[0]
height = embed_sizes[1]
width = embed_sizes[2]
block_num = frame_num // causal_block_size
range_tensor = torch.arange(block_num).view(-1, 1)
range_tensor = range_tensor.repeat(1, causal_block_size).flatten()
causal_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f
causal_mask = causal_mask.view(frame_num, 1, 1, frame_num, 1, 1).to(x[0].device)
causal_mask = causal_mask.repeat(1, height, width, 1, height, width)
causal_mask = causal_mask.reshape(frame_num * height * width, frame_num * height * width)
block_mask = causal_mask.unsqueeze(0).unsqueeze(0)
del causal_mask
offload.shared_state["embed_sizes"] = embed_sizes offload.shared_state["embed_sizes"] = embed_sizes
offload.shared_state["step_no"] = current_step offload.shared_state["step_no"] = current_step
offload.shared_state["max_steps"] = max_steps offload.shared_state["max_steps"] = max_steps
x = [u.flatten(2).transpose(1, 2) for u in x] x = [u.flatten(2).transpose(1, 2) for u in x]
x = x[0] x = x[0]
# time embeddings if t.dim() == 2:
b, f = t.shape
_flag_df = True
else:
_flag_df = False
e = self.time_embedding( e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t)) sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype)
) # b, dim
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype) e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype)
if self.inject_sample_info:
fps = torch.tensor(fps, dtype=torch.long, device=device)
fps_emb = self.fps_embedding(fps).float()
if _flag_df:
e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1)
else:
e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim))
# context # context
context = self.text_embedding( context = self.text_embedding(
torch.stack([ torch.stack([
@ -833,7 +918,7 @@ class WanModel(ModelMixin, ConfigMixin):
self.accumulated_rel_l1_distance = 0 self.accumulated_rel_l1_distance = 0
else: else:
rescale_func = np.poly1d(self.coefficients) rescale_func = np.poly1d(self.coefficients)
self.accumulated_rel_l1_distance += rescale_func(((e-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) self.accumulated_rel_l1_distance += abs(rescale_func(((e-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()))
if self.accumulated_rel_l1_distance < self.rel_l1_thresh: if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False should_calc = False
self.teacache_skipped_steps += 1 self.teacache_skipped_steps += 1
@ -858,7 +943,7 @@ class WanModel(ModelMixin, ConfigMixin):
for block_idx, block in enumerate(self.blocks): for block_idx, block in enumerate(self.blocks):
offload.shared_state["layer"] = block_idx offload.shared_state["layer"] = block_idx
if callback != None: if callback != None:
callback(-1, False, True) callback(-1, None, False, True)
if pipeline._interrupt: if pipeline._interrupt:
if joint_pass: if joint_pass:
return None, None return None, None

View File

@ -1075,13 +1075,13 @@ def sageattn_qk_int8_pv_fp8_cuda_sm90(
q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128, WARPK=128) q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128, WARPK=128)
q_size = q.size() q_size = q.size()
kv_len = k.size(seq_dim)
q_device = q.device q_device = q.device
del q,k del q,k
# pad v to multiple of 128 # pad v to multiple of 128
# TODO: modify per_channel_fp8 kernel to handle this # TODO: modify per_channel_fp8 kernel to handle this
kv_len = k.size(seq_dim)
v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0 v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0
if v_pad_len > 0: if v_pad_len > 0:
if tensor_layout == "HND": if tensor_layout == "HND":

View File

@ -49,40 +49,14 @@ class WanT2V:
config, config,
checkpoint_dir, checkpoint_dir,
rank=0, rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
t5_cpu=False,
model_filename = None, model_filename = None,
text_encoder_filename = None, text_encoder_filename = None,
quantizeTransformer = False, quantizeTransformer = False,
dtype = torch.bfloat16 dtype = torch.bfloat16
): ):
r"""
Initializes the Wan text-to-video generation model components.
Args:
config (EasyDict):
Object containing model parameters initialized from config.py
checkpoint_dir (`str`):
Path to directory containing model checkpoints
device_id (`int`, *optional*, defaults to 0):
Id of target GPU device
rank (`int`, *optional*, defaults to 0):
Process rank for distributed training
t5_fsdp (`bool`, *optional*, defaults to False):
Enable FSDP sharding for T5 model
dit_fsdp (`bool`, *optional*, defaults to False):
Enable FSDP sharding for DiT model
use_usp (`bool`, *optional*, defaults to False):
Enable distribution strategy of USP.
t5_cpu (`bool`, *optional*, defaults to False):
Whether to place T5 model on CPU. Only works without t5_fsdp.
"""
self.device = torch.device(f"cuda") self.device = torch.device(f"cuda")
self.config = config self.config = config
self.rank = rank self.rank = rank
self.t5_cpu = t5_cpu
self.dtype = dtype self.dtype = dtype
self.num_train_timesteps = config.num_train_timesteps self.num_train_timesteps = config.num_train_timesteps
self.param_dtype = config.param_dtype self.param_dtype = config.param_dtype
@ -419,9 +393,9 @@ class WanT2V:
freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False) freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
else: else:
freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx) freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
arg_c = {'context': context, 'freqs': freqs, 'pipeline': self} arg_c = {'context': context, 'freqs': freqs, 'pipeline': self, 'callback': callback}
arg_null = {'context': context_null, 'freqs': freqs, 'pipeline': self} arg_null = {'context': context_null, 'freqs': freqs, 'pipeline': self, 'callback': callback}
arg_both = {'context': context, 'context2': context_null, 'freqs': freqs, 'pipeline': self} arg_both = {'context': context, 'context2': context_null, 'freqs': freqs, 'pipeline': self, 'callback': callback}
if target_camera != None: if target_camera != None:
recam_dict = {'cam_emb': cam_emb} recam_dict = {'cam_emb': cam_emb}
@ -438,7 +412,7 @@ class WanT2V:
if self.model.enable_teacache: if self.model.enable_teacache:
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier) self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
if callback != None: if callback != None:
callback(-1, True) callback(-1, None, True)
for i, t in enumerate(tqdm(timesteps)): for i, t in enumerate(tqdm(timesteps)):
if target_camera != None: if target_camera != None:
latent_model_input = [torch.cat([u,v], dim=1) for u,v in zip(latents,source_latents )] latent_model_input = [torch.cat([u,v], dim=1) for u,v in zip(latents,source_latents )]
@ -494,7 +468,7 @@ class WanT2V:
del temp_x0 del temp_x0
if callback is not None: if callback is not None:
callback(i, False) callback(i, latents[0], False)
x0 = latents x0 = latents

1790
wgp.py

File diff suppressed because it is too large Load Diff