Fixed pytorch compilation

This commit is contained in:
DeepBeepMeep 2025-03-08 16:37:21 +01:00
parent 28f19586a5
commit f9ce97a1ba
6 changed files with 502 additions and 31 deletions

View File

@ -19,6 +19,7 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
## 🔥 Latest News!! ## 🔥 Latest News!!
* Mar 03, 2025: 👋 Wan2.1GP v1.4: Fix Pytorch compilation, now it is really 20% faster when activated
* Mar 03, 2025: 👋 Wan2.1GP v1.3: Support for Image to Video with multiples images for different images / prompts combinations (requires *--multiple-images* switch), and added command line *--preload x* to preload in VRAM x MB of the main diffusion model if you find there is too much unused VRAM and you want to (slightly) accelerate the generation process. * Mar 03, 2025: 👋 Wan2.1GP v1.3: Support for Image to Video with multiples images for different images / prompts combinations (requires *--multiple-images* switch), and added command line *--preload x* to preload in VRAM x MB of the main diffusion model if you find there is too much unused VRAM and you want to (slightly) accelerate the generation process.
If you upgrade you will need to do a 'pip install -r requirements.txt' again. If you upgrade you will need to do a 'pip install -r requirements.txt' again.
* Mar 03, 2025: 👋 Wan2.1GP v1.2: Implemented tiling on VAE encoding and decoding. No more VRAM peaks at the beginning and at the end * Mar 03, 2025: 👋 Wan2.1GP v1.2: Implemented tiling on VAE encoding and decoding. No more VRAM peaks at the beginning and at the end

View File

@ -853,7 +853,7 @@ def generate_video(
if use_image2video: if use_image2video:
samples = wan_model.generate( samples = wan_model.generate(
prompt, prompt,
image_to_continue[video_no-1], image_to_continue[ (video_no-1) % len(image_to_continue)],
frame_num=(video_length // 4)* 4 + 1, frame_num=(video_length // 4)* 4 + 1,
max_area=MAX_AREA_CONFIGS[resolution], max_area=MAX_AREA_CONFIGS[resolution],
shift=flow_shift, shift=flow_shift,

View File

@ -24,7 +24,7 @@ from .modules.vae import WanVAE
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps) get_sampling_sigmas, retrieve_timesteps)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from wan.modules.posemb_layers import get_rotary_pos_embed
class WanI2V: class WanI2V:
@ -290,7 +290,7 @@ class WanI2V:
# sample videos # sample videos
latent = noise latent = noise
freqs = self.model.get_rope_freqs(nb_latent_frames = int((frame_num - 1)/4 + 1), RIFLEx_k = 6 if enable_RIFLEx else None ) freqs = get_rotary_pos_embed(frame_num, h, w, enable_RIFLEx= enable_RIFLEx )
arg_c = { arg_c = {
'context': [context[0]], 'context': [context[0]],
@ -318,6 +318,7 @@ class WanI2V:
callback(-1, None) callback(-1, None)
for i, t in enumerate(tqdm(timesteps)): for i, t in enumerate(tqdm(timesteps)):
offload.set_step_no_for_lora(i)
latent_model_input = [latent.to(self.device)] latent_model_input = [latent.to(self.device)]
timestep = [t] timestep = [t]

View File

@ -67,17 +67,15 @@ def rope_params_riflex(max_seq_len, dim, theta=10000, L_test=30, k=6):
inv_theta_pow[k-1] = 0.9 * 2 * torch.pi / L_test inv_theta_pow[k-1] = 0.9 * 2 * torch.pi / L_test
freqs = torch.outer(torch.arange(max_seq_len), inv_theta_pow) freqs = torch.outer(torch.arange(max_seq_len), inv_theta_pow)
freqs = torch.polar(torch.ones_like(freqs), freqs) if True:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
return (freqs_cos, freqs_sin)
else:
freqs = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs return freqs
def rope_params(max_seq_len, dim, theta=10000):
assert dim % 2 == 0
freqs = torch.outer(
torch.arange(max_seq_len),
1.0 / torch.pow(theta,
torch.arange(0, dim, 2).to(torch.float32).div(dim)))
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
def rope_apply_(x, grid_sizes, freqs): def rope_apply_(x, grid_sizes, freqs):
@ -209,6 +207,7 @@ class WanLayerNorm(nn.LayerNorm):
return x return x
# return super().forward(x).type_as(x) # return super().forward(x).type_as(x)
from wan.modules.posemb_layers import apply_rotary_emb
class WanSelfAttention(nn.Module): class WanSelfAttention(nn.Module):
@ -257,8 +256,11 @@ class WanSelfAttention(nn.Module):
k = k.view(b, s, n, d) k = k.view(b, s, n, d)
v = self.v(x).view(b, s, n, d) v = self.v(x).view(b, s, n, d)
del x del x
rope_apply_(q, grid_sizes, freqs) # rope_apply_(q, grid_sizes, freqs)
rope_apply_(k, grid_sizes, freqs) # rope_apply_(k, grid_sizes, freqs)
qklist = [q,k]
del q,k
q,k = apply_rotary_emb(qklist, freqs, head_first=False)
qkv_list = [q,k,v] qkv_list = [q,k,v]
del q,k,v del q,k,v
x = pay_attention( x = pay_attention(
@ -652,20 +654,18 @@ class WanModel(ModelMixin, ConfigMixin):
# ],dim=1) # ],dim=1)
def get_rope_freqs(self, nb_latent_frames, RIFLEx_k = None): def get_rope_freqs(self, nb_latent_frames, RIFLEx_k = None, device = "cuda"):
dim = self.dim dim = self.dim
num_heads = self.num_heads num_heads = self.num_heads
d = dim // num_heads d = dim // num_heads
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
freqs = torch.cat([ c1, s1 = rope_params_riflex(1024, dim= d - 4 * (d // 6), L_test=nb_latent_frames, k = RIFLEx_k ) if RIFLEx_k != None else rope_params(1024, dim= d - 4 * (d // 6)) #44
rope_params_riflex(1024, dim= d - 4 * (d // 6), L_test=nb_latent_frames, k = RIFLEx_k ) if RIFLEx_k != None else rope_params(1024, dim= d - 4 * (d // 6)), #44 c2, s2 = rope_params(1024, 2 * (d // 6)) #42
rope_params(1024, 2 * (d // 6)), #42 c3, s3 = rope_params(1024, 2 * (d // 6)) #42
rope_params(1024, 2 * (d // 6)) #42
],dim=1)
return freqs return (torch.cat([c1,c2,c3],dim=1).to(device) , torch.cat([s1,s2,s3],dim=1).to(device))
def forward( def forward(
@ -706,7 +706,7 @@ class WanModel(ModelMixin, ConfigMixin):
assert clip_fea is not None and y is not None assert clip_fea is not None and y is not None
# params # params
device = self.patch_embedding.weight.device device = self.patch_embedding.weight.device
if freqs.device != device: if torch.is_tensor(freqs) and freqs.device != device:
freqs = freqs.to(device) freqs = freqs.to(device)
if y is not None: if y is not None:

View File

@ -0,0 +1,474 @@
import torch
from typing import Union, Tuple, List, Optional
import numpy as np
###### Thanks to the RifleX project (https://github.com/thu-ml/RIFLEx/) for this alternative pos embed for long videos
#
def get_1d_rotary_pos_embed_riflex(
dim: int,
pos: Union[np.ndarray, int],
theta: float = 10000.0,
use_real=False,
k: Optional[int] = None,
L_test: Optional[int] = None,
):
"""
RIFLEx: Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
data type.
Args:
dim (`int`): Dimension of the frequency tensor.
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
theta (`float`, *optional*, defaults to 10000.0):
Scaling factor for frequency computation. Defaults to 10000.0.
use_real (`bool`, *optional*):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
k (`int`, *optional*, defaults to None): the index for the intrinsic frequency in RoPE
L_test (`int`, *optional*, defaults to None): the number of frames for inference
Returns:
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
"""
assert dim % 2 == 0
if isinstance(pos, int):
pos = torch.arange(pos)
if isinstance(pos, np.ndarray):
pos = torch.from_numpy(pos) # type: ignore # [S]
freqs = 1.0 / (
theta ** (torch.arange(0, dim, 2, device=pos.device)[: (dim // 2)].float() / dim)
) # [D/2]
# === Riflex modification start ===
# Reduce the intrinsic frequency to stay within a single period after extrapolation (see Eq. (8)).
# Empirical observations show that a few videos may exhibit repetition in the tail frames.
# To be conservative, we multiply by 0.9 to keep the extrapolated length below 90% of a single period.
if k is not None:
freqs[k-1] = 0.9 * 2 * torch.pi / L_test
# === Riflex modification end ===
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
return freqs_cos, freqs_sin
else:
# lumina
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis
def identify_k( b: float, d: int, N: int):
"""
This function identifies the index of the intrinsic frequency component in a RoPE-based pre-trained diffusion transformer.
Args:
b (`float`): The base frequency for RoPE.
d (`int`): Dimension of the frequency tensor
N (`int`): the first observed repetition frame in latent space
Returns:
k (`int`): the index of intrinsic frequency component
N_k (`int`): the period of intrinsic frequency component in latent space
Example:
In HunyuanVideo, b=256 and d=16, the repetition occurs approximately 8s (N=48 in latent space).
k, N_k = identify_k(b=256, d=16, N=48)
In this case, the intrinsic frequency index k is 4, and the period N_k is 50.
"""
# Compute the period of each frequency in RoPE according to Eq.(4)
periods = []
for j in range(1, d // 2 + 1):
theta_j = 1.0 / (b ** (2 * (j - 1) / d))
N_j = round(2 * torch.pi / theta_j)
periods.append(N_j)
# Identify the intrinsic frequency whose period is closed to Nsee Eq.(7)
diffs = [abs(N_j - N) for N_j in periods]
k = diffs.index(min(diffs)) + 1
N_k = periods[k-1]
return k, N_k
def _to_tuple(x, dim=2):
if isinstance(x, int):
return (x,) * dim
elif len(x) == dim:
return x
else:
raise ValueError(f"Expected length {dim} or int, but got {x}")
def get_meshgrid_nd(start, *args, dim=2):
"""
Get n-D meshgrid with start, stop and num.
Args:
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
n-tuples.
*args: See above.
dim (int): Dimension of the meshgrid. Defaults to 2.
Returns:
grid (np.ndarray): [dim, ...]
"""
if len(args) == 0:
# start is grid_size
num = _to_tuple(start, dim=dim)
start = (0,) * dim
stop = num
elif len(args) == 1:
# start is start, args[0] is stop, step is 1
start = _to_tuple(start, dim=dim)
stop = _to_tuple(args[0], dim=dim)
num = [stop[i] - start[i] for i in range(dim)]
elif len(args) == 2:
# start is start, args[0] is stop, args[1] is num
start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
else:
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
# PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
axis_grid = []
for i in range(dim):
a, b, n = start[i], stop[i], num[i]
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
axis_grid.append(g)
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
return grid
#################################################################################
# Rotary Positional Embedding Functions #
#################################################################################
# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
def reshape_for_broadcast(
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
x: torch.Tensor,
head_first=False,
):
"""
Reshape frequency tensor for broadcasting it with another tensor.
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
for the purpose of broadcasting the frequency tensor during element-wise operations.
Notes:
When using FlashMHAModified, head_first should be False.
When using Attention, head_first should be True.
Args:
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
x (torch.Tensor): Target tensor for broadcasting compatibility.
head_first (bool): head dimension first (except batch dim) or not.
Returns:
torch.Tensor: Reshaped frequency tensor.
Raises:
AssertionError: If the frequency tensor doesn't match the expected shape.
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
"""
ndim = x.ndim
assert 0 <= 1 < ndim
if isinstance(freqs_cis, tuple):
# freqs_cis: (cos, sin) in real space
if head_first:
assert freqs_cis[0].shape == (
x.shape[-2],
x.shape[-1],
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
shape = [
d if i == ndim - 2 or i == ndim - 1 else 1
for i, d in enumerate(x.shape)
]
else:
assert freqs_cis[0].shape == (
x.shape[1],
x.shape[-1],
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
else:
# freqs_cis: values in complex space
if head_first:
assert freqs_cis.shape == (
x.shape[-2],
x.shape[-1],
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
shape = [
d if i == ndim - 2 or i == ndim - 1 else 1
for i, d in enumerate(x.shape)
]
else:
assert freqs_cis.shape == (
x.shape[1],
x.shape[-1],
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def rotate_half(x):
x_real, x_imag = (
x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
) # [B, S, H, D//2]
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
def apply_rotary_emb( qklist,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
head_first: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor.
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
returned as real tensors.
Args:
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
head_first (bool): head dimension first (except batch dim) or not.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
xq, xk = qklist
qklist.clear()
xk_out = None
if isinstance(freqs_cis, tuple):
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
cos, sin = cos.to(xq.device), sin.to(xq.device)
# real * cos - imag * sin
# imag * cos + real * sin
xq_dtype = xq.dtype
xq_out = xq.to(torch.float)
xq = None
xq_rot = rotate_half(xq_out)
xq_out *= cos
xq_rot *= sin
xq_out += xq_rot
del xq_rot
xq_out = xq_out.to(xq_dtype)
xk_out = xk.to(torch.float)
xk = None
xk_rot = rotate_half(xk_out)
xk_out *= cos
xk_rot *= sin
xk_out += xk_rot
del xk_rot
xk_out = xk_out.to(xq_dtype)
else:
# view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
xq_ = torch.view_as_complex(
xq.float().reshape(*xq.shape[:-1], -1, 2)
) # [B, S, H, D//2]
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
xq.device
) # [S, D//2] --> [1, S, 1, D//2]
# (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
# view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
xk_ = torch.view_as_complex(
xk.float().reshape(*xk.shape[:-1], -1, 2)
) # [B, S, H, D//2]
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
return xq_out, xk_out
return xq_out, xk_out
def get_nd_rotary_pos_embed(
rope_dim_list,
start,
*args,
theta=10000.0,
use_real=False,
theta_rescale_factor: Union[float, List[float]] = 1.0,
interpolation_factor: Union[float, List[float]] = 1.0,
k = 6,
L_test = 66,
enable_riflex = True
):
"""
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
Args:
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
sum(rope_dim_list) should equal to head_dim of attention layer.
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
*args: See above.
theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
part and an imaginary part separately.
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
Returns:
pos_embed (torch.Tensor): [HW, D/2]
"""
grid = get_meshgrid_nd(
start, *args, dim=len(rope_dim_list)
) # [3, W, H, D] / [2, W, H]
if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
assert len(theta_rescale_factor) == len(
rope_dim_list
), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
assert len(interpolation_factor) == len(
rope_dim_list
), "len(interpolation_factor) should equal to len(rope_dim_list)"
# use 1/ndim of dimensions to encode grid_axis
embs = []
for i in range(len(rope_dim_list)):
# emb = get_1d_rotary_pos_embed(
# rope_dim_list[i],
# grid[i].reshape(-1),
# theta,
# use_real=use_real,
# theta_rescale_factor=theta_rescale_factor[i],
# interpolation_factor=interpolation_factor[i],
# ) # 2 x [WHD, rope_dim_list[i]]
# === RIFLEx modification start ===
# apply RIFLEx for time dimension
if i == 0 and enable_riflex:
emb = get_1d_rotary_pos_embed_riflex(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=True, k=k, L_test=L_test)
# === RIFLEx modification end ===
else:
emb = get_1d_rotary_pos_embed(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=True, theta_rescale_factor=theta_rescale_factor[i],interpolation_factor=interpolation_factor[i],)
embs.append(emb)
if use_real:
cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
return cos, sin
else:
emb = torch.cat(embs, dim=1) # (WHD, D/2)
return emb
def get_1d_rotary_pos_embed(
dim: int,
pos: Union[torch.FloatTensor, int],
theta: float = 10000.0,
use_real: bool = False,
theta_rescale_factor: float = 1.0,
interpolation_factor: float = 1.0,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.
Args:
dim (int): Dimension of the frequency tensor.
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool, optional): If True, return real part and imaginary part separately.
Otherwise, return complex numbers.
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
Returns:
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
"""
if isinstance(pos, int):
pos = torch.arange(pos).float()
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
if theta_rescale_factor != 1.0:
theta *= theta_rescale_factor ** (dim / (dim - 2))
freqs = 1.0 / (
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
) # [D/2]
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
return freqs_cos, freqs_sin
else:
freqs_cis = torch.polar(
torch.ones_like(freqs), freqs
) # complex64 # [S, D/2]
return freqs_cis
def get_rotary_pos_embed(video_length, height, width, enable_RIFLEx = False):
target_ndim = 3
ndim = 5 - 2
latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
patch_size = [1, 2, 2]
if isinstance(patch_size, int):
assert all(s % patch_size == 0 for s in latents_size), (
f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), "
f"but got {latents_size}."
)
rope_sizes = [s // patch_size for s in latents_size]
elif isinstance(patch_size, list):
assert all(
s % patch_size[idx] == 0
for idx, s in enumerate(latents_size)
), (
f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), "
f"but got {latents_size}."
)
rope_sizes = [
s // patch_size[idx] for idx, s in enumerate(latents_size)
]
if len(rope_sizes) != target_ndim:
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
head_dim = 128
rope_dim_list = [44, 42, 42]
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert (
sum(rope_dim_list) == head_dim
), "sum(rope_dim_list) should equal to head_dim of attention layer"
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
rope_dim_list,
rope_sizes,
theta=10000,
use_real=True,
theta_rescale_factor=1,
L_test = (video_length - 1) // 4 + 1,
enable_riflex = enable_RIFLEx
)
return (freqs_cos, freqs_sin)

View File

@ -8,7 +8,7 @@ import sys
import types import types
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
from mmgp import offload
import torch import torch
import torch.cuda.amp as amp import torch.cuda.amp as amp
import torch.distributed as dist import torch.distributed as dist
@ -21,6 +21,7 @@ from .modules.vae import WanVAE
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps) get_sampling_sigmas, retrieve_timesteps)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from wan.modules.posemb_layers import get_rotary_pos_embed
class WanT2V: class WanT2V:
@ -236,13 +237,7 @@ class WanT2V:
# sample videos # sample videos
latents = noise latents = noise
# from .modules.model import identify_k freqs = get_rotary_pos_embed(frame_num, size[1], size[0], enable_RIFLEx= enable_RIFLEx)
# for nf in range(20, 50):
# k, N_k = identify_k(10000, 44, 26)
# print(f"value nb latent frames={nf}, k={k}, n_k={N_k}")
freqs = self.model.get_rope_freqs(nb_latent_frames = int((frame_num - 1)/4 + 1), RIFLEx_k = 6 if enable_RIFLEx else None )
arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self} arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self} arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
@ -252,7 +247,7 @@ class WanT2V:
for i, t in enumerate(tqdm(timesteps)): for i, t in enumerate(tqdm(timesteps)):
latent_model_input = latents latent_model_input = latents
timestep = [t] timestep = [t]
offload.set_step_no_for_lora(i)
timestep = torch.stack(timestep) timestep = torch.stack(timestep)
# self.model.to(self.device) # self.model.to(self.device)