mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-05 06:29:14 +00:00
Fixed pytorch compilation
This commit is contained in:
parent
28f19586a5
commit
f9ce97a1ba
@ -19,6 +19,7 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
|
|||||||
|
|
||||||
|
|
||||||
## 🔥 Latest News!!
|
## 🔥 Latest News!!
|
||||||
|
* Mar 03, 2025: 👋 Wan2.1GP v1.4: Fix Pytorch compilation, now it is really 20% faster when activated
|
||||||
* Mar 03, 2025: 👋 Wan2.1GP v1.3: Support for Image to Video with multiples images for different images / prompts combinations (requires *--multiple-images* switch), and added command line *--preload x* to preload in VRAM x MB of the main diffusion model if you find there is too much unused VRAM and you want to (slightly) accelerate the generation process.
|
* Mar 03, 2025: 👋 Wan2.1GP v1.3: Support for Image to Video with multiples images for different images / prompts combinations (requires *--multiple-images* switch), and added command line *--preload x* to preload in VRAM x MB of the main diffusion model if you find there is too much unused VRAM and you want to (slightly) accelerate the generation process.
|
||||||
If you upgrade you will need to do a 'pip install -r requirements.txt' again.
|
If you upgrade you will need to do a 'pip install -r requirements.txt' again.
|
||||||
* Mar 03, 2025: 👋 Wan2.1GP v1.2: Implemented tiling on VAE encoding and decoding. No more VRAM peaks at the beginning and at the end
|
* Mar 03, 2025: 👋 Wan2.1GP v1.2: Implemented tiling on VAE encoding and decoding. No more VRAM peaks at the beginning and at the end
|
||||||
|
|||||||
@ -853,7 +853,7 @@ def generate_video(
|
|||||||
if use_image2video:
|
if use_image2video:
|
||||||
samples = wan_model.generate(
|
samples = wan_model.generate(
|
||||||
prompt,
|
prompt,
|
||||||
image_to_continue[video_no-1],
|
image_to_continue[ (video_no-1) % len(image_to_continue)],
|
||||||
frame_num=(video_length // 4)* 4 + 1,
|
frame_num=(video_length // 4)* 4 + 1,
|
||||||
max_area=MAX_AREA_CONFIGS[resolution],
|
max_area=MAX_AREA_CONFIGS[resolution],
|
||||||
shift=flow_shift,
|
shift=flow_shift,
|
||||||
|
|||||||
@ -24,7 +24,7 @@ from .modules.vae import WanVAE
|
|||||||
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
||||||
get_sampling_sigmas, retrieve_timesteps)
|
get_sampling_sigmas, retrieve_timesteps)
|
||||||
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||||
|
from wan.modules.posemb_layers import get_rotary_pos_embed
|
||||||
|
|
||||||
class WanI2V:
|
class WanI2V:
|
||||||
|
|
||||||
@ -290,7 +290,7 @@ class WanI2V:
|
|||||||
# sample videos
|
# sample videos
|
||||||
latent = noise
|
latent = noise
|
||||||
|
|
||||||
freqs = self.model.get_rope_freqs(nb_latent_frames = int((frame_num - 1)/4 + 1), RIFLEx_k = 6 if enable_RIFLEx else None )
|
freqs = get_rotary_pos_embed(frame_num, h, w, enable_RIFLEx= enable_RIFLEx )
|
||||||
|
|
||||||
arg_c = {
|
arg_c = {
|
||||||
'context': [context[0]],
|
'context': [context[0]],
|
||||||
@ -318,6 +318,7 @@ class WanI2V:
|
|||||||
callback(-1, None)
|
callback(-1, None)
|
||||||
|
|
||||||
for i, t in enumerate(tqdm(timesteps)):
|
for i, t in enumerate(tqdm(timesteps)):
|
||||||
|
offload.set_step_no_for_lora(i)
|
||||||
latent_model_input = [latent.to(self.device)]
|
latent_model_input = [latent.to(self.device)]
|
||||||
timestep = [t]
|
timestep = [t]
|
||||||
|
|
||||||
|
|||||||
@ -67,17 +67,15 @@ def rope_params_riflex(max_seq_len, dim, theta=10000, L_test=30, k=6):
|
|||||||
inv_theta_pow[k-1] = 0.9 * 2 * torch.pi / L_test
|
inv_theta_pow[k-1] = 0.9 * 2 * torch.pi / L_test
|
||||||
|
|
||||||
freqs = torch.outer(torch.arange(max_seq_len), inv_theta_pow)
|
freqs = torch.outer(torch.arange(max_seq_len), inv_theta_pow)
|
||||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
if True:
|
||||||
|
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
|
||||||
|
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
|
||||||
|
return (freqs_cos, freqs_sin)
|
||||||
|
else:
|
||||||
|
freqs = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
||||||
return freqs
|
return freqs
|
||||||
|
|
||||||
def rope_params(max_seq_len, dim, theta=10000):
|
|
||||||
assert dim % 2 == 0
|
|
||||||
freqs = torch.outer(
|
|
||||||
torch.arange(max_seq_len),
|
|
||||||
1.0 / torch.pow(theta,
|
|
||||||
torch.arange(0, dim, 2).to(torch.float32).div(dim)))
|
|
||||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
|
||||||
return freqs
|
|
||||||
|
|
||||||
|
|
||||||
def rope_apply_(x, grid_sizes, freqs):
|
def rope_apply_(x, grid_sizes, freqs):
|
||||||
@ -209,6 +207,7 @@ class WanLayerNorm(nn.LayerNorm):
|
|||||||
return x
|
return x
|
||||||
# return super().forward(x).type_as(x)
|
# return super().forward(x).type_as(x)
|
||||||
|
|
||||||
|
from wan.modules.posemb_layers import apply_rotary_emb
|
||||||
|
|
||||||
class WanSelfAttention(nn.Module):
|
class WanSelfAttention(nn.Module):
|
||||||
|
|
||||||
@ -257,8 +256,11 @@ class WanSelfAttention(nn.Module):
|
|||||||
k = k.view(b, s, n, d)
|
k = k.view(b, s, n, d)
|
||||||
v = self.v(x).view(b, s, n, d)
|
v = self.v(x).view(b, s, n, d)
|
||||||
del x
|
del x
|
||||||
rope_apply_(q, grid_sizes, freqs)
|
# rope_apply_(q, grid_sizes, freqs)
|
||||||
rope_apply_(k, grid_sizes, freqs)
|
# rope_apply_(k, grid_sizes, freqs)
|
||||||
|
qklist = [q,k]
|
||||||
|
del q,k
|
||||||
|
q,k = apply_rotary_emb(qklist, freqs, head_first=False)
|
||||||
qkv_list = [q,k,v]
|
qkv_list = [q,k,v]
|
||||||
del q,k,v
|
del q,k,v
|
||||||
x = pay_attention(
|
x = pay_attention(
|
||||||
@ -652,20 +654,18 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
# ],dim=1)
|
# ],dim=1)
|
||||||
|
|
||||||
|
|
||||||
def get_rope_freqs(self, nb_latent_frames, RIFLEx_k = None):
|
def get_rope_freqs(self, nb_latent_frames, RIFLEx_k = None, device = "cuda"):
|
||||||
dim = self.dim
|
dim = self.dim
|
||||||
num_heads = self.num_heads
|
num_heads = self.num_heads
|
||||||
d = dim // num_heads
|
d = dim // num_heads
|
||||||
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
||||||
|
|
||||||
|
|
||||||
freqs = torch.cat([
|
c1, s1 = rope_params_riflex(1024, dim= d - 4 * (d // 6), L_test=nb_latent_frames, k = RIFLEx_k ) if RIFLEx_k != None else rope_params(1024, dim= d - 4 * (d // 6)) #44
|
||||||
rope_params_riflex(1024, dim= d - 4 * (d // 6), L_test=nb_latent_frames, k = RIFLEx_k ) if RIFLEx_k != None else rope_params(1024, dim= d - 4 * (d // 6)), #44
|
c2, s2 = rope_params(1024, 2 * (d // 6)) #42
|
||||||
rope_params(1024, 2 * (d // 6)), #42
|
c3, s3 = rope_params(1024, 2 * (d // 6)) #42
|
||||||
rope_params(1024, 2 * (d // 6)) #42
|
|
||||||
],dim=1)
|
|
||||||
|
|
||||||
return freqs
|
return (torch.cat([c1,c2,c3],dim=1).to(device) , torch.cat([s1,s2,s3],dim=1).to(device))
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -706,7 +706,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
assert clip_fea is not None and y is not None
|
assert clip_fea is not None and y is not None
|
||||||
# params
|
# params
|
||||||
device = self.patch_embedding.weight.device
|
device = self.patch_embedding.weight.device
|
||||||
if freqs.device != device:
|
if torch.is_tensor(freqs) and freqs.device != device:
|
||||||
freqs = freqs.to(device)
|
freqs = freqs.to(device)
|
||||||
|
|
||||||
if y is not None:
|
if y is not None:
|
||||||
|
|||||||
474
wan/modules/posemb_layers.py
Normal file
474
wan/modules/posemb_layers.py
Normal file
@ -0,0 +1,474 @@
|
|||||||
|
import torch
|
||||||
|
from typing import Union, Tuple, List, Optional
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
###### Thanks to the RifleX project (https://github.com/thu-ml/RIFLEx/) for this alternative pos embed for long videos
|
||||||
|
#
|
||||||
|
def get_1d_rotary_pos_embed_riflex(
|
||||||
|
dim: int,
|
||||||
|
pos: Union[np.ndarray, int],
|
||||||
|
theta: float = 10000.0,
|
||||||
|
use_real=False,
|
||||||
|
k: Optional[int] = None,
|
||||||
|
L_test: Optional[int] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
RIFLEx: Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
||||||
|
|
||||||
|
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
|
||||||
|
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
|
||||||
|
data type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim (`int`): Dimension of the frequency tensor.
|
||||||
|
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
|
||||||
|
theta (`float`, *optional*, defaults to 10000.0):
|
||||||
|
Scaling factor for frequency computation. Defaults to 10000.0.
|
||||||
|
use_real (`bool`, *optional*):
|
||||||
|
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||||
|
k (`int`, *optional*, defaults to None): the index for the intrinsic frequency in RoPE
|
||||||
|
L_test (`int`, *optional*, defaults to None): the number of frames for inference
|
||||||
|
Returns:
|
||||||
|
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
||||||
|
"""
|
||||||
|
assert dim % 2 == 0
|
||||||
|
|
||||||
|
if isinstance(pos, int):
|
||||||
|
pos = torch.arange(pos)
|
||||||
|
if isinstance(pos, np.ndarray):
|
||||||
|
pos = torch.from_numpy(pos) # type: ignore # [S]
|
||||||
|
|
||||||
|
freqs = 1.0 / (
|
||||||
|
theta ** (torch.arange(0, dim, 2, device=pos.device)[: (dim // 2)].float() / dim)
|
||||||
|
) # [D/2]
|
||||||
|
|
||||||
|
# === Riflex modification start ===
|
||||||
|
# Reduce the intrinsic frequency to stay within a single period after extrapolation (see Eq. (8)).
|
||||||
|
# Empirical observations show that a few videos may exhibit repetition in the tail frames.
|
||||||
|
# To be conservative, we multiply by 0.9 to keep the extrapolated length below 90% of a single period.
|
||||||
|
if k is not None:
|
||||||
|
freqs[k-1] = 0.9 * 2 * torch.pi / L_test
|
||||||
|
# === Riflex modification end ===
|
||||||
|
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
||||||
|
|
||||||
|
if use_real:
|
||||||
|
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
|
||||||
|
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
|
||||||
|
return freqs_cos, freqs_sin
|
||||||
|
else:
|
||||||
|
# lumina
|
||||||
|
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
||||||
|
return freqs_cis
|
||||||
|
|
||||||
|
def identify_k( b: float, d: int, N: int):
|
||||||
|
"""
|
||||||
|
This function identifies the index of the intrinsic frequency component in a RoPE-based pre-trained diffusion transformer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
b (`float`): The base frequency for RoPE.
|
||||||
|
d (`int`): Dimension of the frequency tensor
|
||||||
|
N (`int`): the first observed repetition frame in latent space
|
||||||
|
Returns:
|
||||||
|
k (`int`): the index of intrinsic frequency component
|
||||||
|
N_k (`int`): the period of intrinsic frequency component in latent space
|
||||||
|
Example:
|
||||||
|
In HunyuanVideo, b=256 and d=16, the repetition occurs approximately 8s (N=48 in latent space).
|
||||||
|
k, N_k = identify_k(b=256, d=16, N=48)
|
||||||
|
In this case, the intrinsic frequency index k is 4, and the period N_k is 50.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Compute the period of each frequency in RoPE according to Eq.(4)
|
||||||
|
periods = []
|
||||||
|
for j in range(1, d // 2 + 1):
|
||||||
|
theta_j = 1.0 / (b ** (2 * (j - 1) / d))
|
||||||
|
N_j = round(2 * torch.pi / theta_j)
|
||||||
|
periods.append(N_j)
|
||||||
|
|
||||||
|
# Identify the intrinsic frequency whose period is closed to N(see Eq.(7))
|
||||||
|
diffs = [abs(N_j - N) for N_j in periods]
|
||||||
|
k = diffs.index(min(diffs)) + 1
|
||||||
|
N_k = periods[k-1]
|
||||||
|
return k, N_k
|
||||||
|
|
||||||
|
def _to_tuple(x, dim=2):
|
||||||
|
if isinstance(x, int):
|
||||||
|
return (x,) * dim
|
||||||
|
elif len(x) == dim:
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Expected length {dim} or int, but got {x}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_meshgrid_nd(start, *args, dim=2):
|
||||||
|
"""
|
||||||
|
Get n-D meshgrid with start, stop and num.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
|
||||||
|
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
|
||||||
|
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
|
||||||
|
n-tuples.
|
||||||
|
*args: See above.
|
||||||
|
dim (int): Dimension of the meshgrid. Defaults to 2.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
grid (np.ndarray): [dim, ...]
|
||||||
|
"""
|
||||||
|
if len(args) == 0:
|
||||||
|
# start is grid_size
|
||||||
|
num = _to_tuple(start, dim=dim)
|
||||||
|
start = (0,) * dim
|
||||||
|
stop = num
|
||||||
|
elif len(args) == 1:
|
||||||
|
# start is start, args[0] is stop, step is 1
|
||||||
|
start = _to_tuple(start, dim=dim)
|
||||||
|
stop = _to_tuple(args[0], dim=dim)
|
||||||
|
num = [stop[i] - start[i] for i in range(dim)]
|
||||||
|
elif len(args) == 2:
|
||||||
|
# start is start, args[0] is stop, args[1] is num
|
||||||
|
start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
|
||||||
|
stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
|
||||||
|
num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
|
||||||
|
else:
|
||||||
|
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
|
||||||
|
|
||||||
|
# PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
|
||||||
|
axis_grid = []
|
||||||
|
for i in range(dim):
|
||||||
|
a, b, n = start[i], stop[i], num[i]
|
||||||
|
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
|
||||||
|
axis_grid.append(g)
|
||||||
|
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
|
||||||
|
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
|
||||||
|
|
||||||
|
return grid
|
||||||
|
|
||||||
|
|
||||||
|
#################################################################################
|
||||||
|
# Rotary Positional Embedding Functions #
|
||||||
|
#################################################################################
|
||||||
|
# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
|
||||||
|
|
||||||
|
|
||||||
|
def reshape_for_broadcast(
|
||||||
|
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||||
|
x: torch.Tensor,
|
||||||
|
head_first=False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Reshape frequency tensor for broadcasting it with another tensor.
|
||||||
|
|
||||||
|
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
|
||||||
|
for the purpose of broadcasting the frequency tensor during element-wise operations.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
When using FlashMHAModified, head_first should be False.
|
||||||
|
When using Attention, head_first should be True.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
|
||||||
|
x (torch.Tensor): Target tensor for broadcasting compatibility.
|
||||||
|
head_first (bool): head dimension first (except batch dim) or not.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Reshaped frequency tensor.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If the frequency tensor doesn't match the expected shape.
|
||||||
|
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
|
||||||
|
"""
|
||||||
|
ndim = x.ndim
|
||||||
|
assert 0 <= 1 < ndim
|
||||||
|
|
||||||
|
if isinstance(freqs_cis, tuple):
|
||||||
|
# freqs_cis: (cos, sin) in real space
|
||||||
|
if head_first:
|
||||||
|
assert freqs_cis[0].shape == (
|
||||||
|
x.shape[-2],
|
||||||
|
x.shape[-1],
|
||||||
|
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
|
||||||
|
shape = [
|
||||||
|
d if i == ndim - 2 or i == ndim - 1 else 1
|
||||||
|
for i, d in enumerate(x.shape)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
assert freqs_cis[0].shape == (
|
||||||
|
x.shape[1],
|
||||||
|
x.shape[-1],
|
||||||
|
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
|
||||||
|
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||||
|
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
||||||
|
else:
|
||||||
|
# freqs_cis: values in complex space
|
||||||
|
if head_first:
|
||||||
|
assert freqs_cis.shape == (
|
||||||
|
x.shape[-2],
|
||||||
|
x.shape[-1],
|
||||||
|
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
|
||||||
|
shape = [
|
||||||
|
d if i == ndim - 2 or i == ndim - 1 else 1
|
||||||
|
for i, d in enumerate(x.shape)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
assert freqs_cis.shape == (
|
||||||
|
x.shape[1],
|
||||||
|
x.shape[-1],
|
||||||
|
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
|
||||||
|
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||||
|
return freqs_cis.view(*shape)
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
x_real, x_imag = (
|
||||||
|
x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
||||||
|
) # [B, S, H, D//2]
|
||||||
|
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_emb( qklist,
|
||||||
|
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
head_first: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Apply rotary embeddings to input tensors using the given frequency tensor.
|
||||||
|
|
||||||
|
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
|
||||||
|
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
|
||||||
|
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
|
||||||
|
returned as real tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
|
||||||
|
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
|
||||||
|
freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
|
||||||
|
head_first (bool): head dimension first (except batch dim) or not.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
||||||
|
|
||||||
|
"""
|
||||||
|
xq, xk = qklist
|
||||||
|
qklist.clear()
|
||||||
|
xk_out = None
|
||||||
|
if isinstance(freqs_cis, tuple):
|
||||||
|
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
|
||||||
|
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
||||||
|
# real * cos - imag * sin
|
||||||
|
# imag * cos + real * sin
|
||||||
|
xq_dtype = xq.dtype
|
||||||
|
xq_out = xq.to(torch.float)
|
||||||
|
xq = None
|
||||||
|
xq_rot = rotate_half(xq_out)
|
||||||
|
xq_out *= cos
|
||||||
|
xq_rot *= sin
|
||||||
|
xq_out += xq_rot
|
||||||
|
del xq_rot
|
||||||
|
xq_out = xq_out.to(xq_dtype)
|
||||||
|
|
||||||
|
xk_out = xk.to(torch.float)
|
||||||
|
xk = None
|
||||||
|
xk_rot = rotate_half(xk_out)
|
||||||
|
xk_out *= cos
|
||||||
|
xk_rot *= sin
|
||||||
|
xk_out += xk_rot
|
||||||
|
del xk_rot
|
||||||
|
xk_out = xk_out.to(xq_dtype)
|
||||||
|
else:
|
||||||
|
# view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
|
||||||
|
xq_ = torch.view_as_complex(
|
||||||
|
xq.float().reshape(*xq.shape[:-1], -1, 2)
|
||||||
|
) # [B, S, H, D//2]
|
||||||
|
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
|
||||||
|
xq.device
|
||||||
|
) # [S, D//2] --> [1, S, 1, D//2]
|
||||||
|
# (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
|
||||||
|
# view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
|
||||||
|
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
|
||||||
|
xk_ = torch.view_as_complex(
|
||||||
|
xk.float().reshape(*xk.shape[:-1], -1, 2)
|
||||||
|
) # [B, S, H, D//2]
|
||||||
|
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
|
||||||
|
|
||||||
|
return xq_out, xk_out
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
return xq_out, xk_out
|
||||||
|
def get_nd_rotary_pos_embed(
|
||||||
|
rope_dim_list,
|
||||||
|
start,
|
||||||
|
*args,
|
||||||
|
theta=10000.0,
|
||||||
|
use_real=False,
|
||||||
|
theta_rescale_factor: Union[float, List[float]] = 1.0,
|
||||||
|
interpolation_factor: Union[float, List[float]] = 1.0,
|
||||||
|
k = 6,
|
||||||
|
L_test = 66,
|
||||||
|
enable_riflex = True
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
|
||||||
|
sum(rope_dim_list) should equal to head_dim of attention layer.
|
||||||
|
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
|
||||||
|
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
|
||||||
|
*args: See above.
|
||||||
|
theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
|
||||||
|
use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||||
|
Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
|
||||||
|
part and an imaginary part separately.
|
||||||
|
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pos_embed (torch.Tensor): [HW, D/2]
|
||||||
|
"""
|
||||||
|
|
||||||
|
grid = get_meshgrid_nd(
|
||||||
|
start, *args, dim=len(rope_dim_list)
|
||||||
|
) # [3, W, H, D] / [2, W, H]
|
||||||
|
|
||||||
|
if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
|
||||||
|
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
|
||||||
|
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
|
||||||
|
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
|
||||||
|
assert len(theta_rescale_factor) == len(
|
||||||
|
rope_dim_list
|
||||||
|
), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
|
||||||
|
|
||||||
|
if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
|
||||||
|
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
|
||||||
|
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
|
||||||
|
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
|
||||||
|
assert len(interpolation_factor) == len(
|
||||||
|
rope_dim_list
|
||||||
|
), "len(interpolation_factor) should equal to len(rope_dim_list)"
|
||||||
|
|
||||||
|
# use 1/ndim of dimensions to encode grid_axis
|
||||||
|
embs = []
|
||||||
|
for i in range(len(rope_dim_list)):
|
||||||
|
# emb = get_1d_rotary_pos_embed(
|
||||||
|
# rope_dim_list[i],
|
||||||
|
# grid[i].reshape(-1),
|
||||||
|
# theta,
|
||||||
|
# use_real=use_real,
|
||||||
|
# theta_rescale_factor=theta_rescale_factor[i],
|
||||||
|
# interpolation_factor=interpolation_factor[i],
|
||||||
|
# ) # 2 x [WHD, rope_dim_list[i]]
|
||||||
|
|
||||||
|
|
||||||
|
# === RIFLEx modification start ===
|
||||||
|
# apply RIFLEx for time dimension
|
||||||
|
if i == 0 and enable_riflex:
|
||||||
|
emb = get_1d_rotary_pos_embed_riflex(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=True, k=k, L_test=L_test)
|
||||||
|
# === RIFLEx modification end ===
|
||||||
|
else:
|
||||||
|
emb = get_1d_rotary_pos_embed(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=True, theta_rescale_factor=theta_rescale_factor[i],interpolation_factor=interpolation_factor[i],)
|
||||||
|
embs.append(emb)
|
||||||
|
|
||||||
|
if use_real:
|
||||||
|
cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
|
||||||
|
sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
|
||||||
|
return cos, sin
|
||||||
|
else:
|
||||||
|
emb = torch.cat(embs, dim=1) # (WHD, D/2)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
def get_1d_rotary_pos_embed(
|
||||||
|
dim: int,
|
||||||
|
pos: Union[torch.FloatTensor, int],
|
||||||
|
theta: float = 10000.0,
|
||||||
|
use_real: bool = False,
|
||||||
|
theta_rescale_factor: float = 1.0,
|
||||||
|
interpolation_factor: float = 1.0,
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
|
||||||
|
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
|
||||||
|
|
||||||
|
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
|
||||||
|
and the end index 'end'. The 'theta' parameter scales the frequencies.
|
||||||
|
The returned tensor contains complex values in complex64 data type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim (int): Dimension of the frequency tensor.
|
||||||
|
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
|
||||||
|
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
||||||
|
use_real (bool, optional): If True, return real part and imaginary part separately.
|
||||||
|
Otherwise, return complex numbers.
|
||||||
|
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
|
||||||
|
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
|
||||||
|
"""
|
||||||
|
if isinstance(pos, int):
|
||||||
|
pos = torch.arange(pos).float()
|
||||||
|
|
||||||
|
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
||||||
|
# has some connection to NTK literature
|
||||||
|
if theta_rescale_factor != 1.0:
|
||||||
|
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
||||||
|
|
||||||
|
freqs = 1.0 / (
|
||||||
|
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
|
||||||
|
) # [D/2]
|
||||||
|
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
|
||||||
|
freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
|
||||||
|
if use_real:
|
||||||
|
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
||||||
|
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
||||||
|
return freqs_cos, freqs_sin
|
||||||
|
else:
|
||||||
|
freqs_cis = torch.polar(
|
||||||
|
torch.ones_like(freqs), freqs
|
||||||
|
) # complex64 # [S, D/2]
|
||||||
|
return freqs_cis
|
||||||
|
|
||||||
|
def get_rotary_pos_embed(video_length, height, width, enable_RIFLEx = False):
|
||||||
|
target_ndim = 3
|
||||||
|
ndim = 5 - 2
|
||||||
|
|
||||||
|
latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
|
||||||
|
patch_size = [1, 2, 2]
|
||||||
|
if isinstance(patch_size, int):
|
||||||
|
assert all(s % patch_size == 0 for s in latents_size), (
|
||||||
|
f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), "
|
||||||
|
f"but got {latents_size}."
|
||||||
|
)
|
||||||
|
rope_sizes = [s // patch_size for s in latents_size]
|
||||||
|
elif isinstance(patch_size, list):
|
||||||
|
assert all(
|
||||||
|
s % patch_size[idx] == 0
|
||||||
|
for idx, s in enumerate(latents_size)
|
||||||
|
), (
|
||||||
|
f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), "
|
||||||
|
f"but got {latents_size}."
|
||||||
|
)
|
||||||
|
rope_sizes = [
|
||||||
|
s // patch_size[idx] for idx, s in enumerate(latents_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
if len(rope_sizes) != target_ndim:
|
||||||
|
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
|
||||||
|
head_dim = 128
|
||||||
|
rope_dim_list = [44, 42, 42]
|
||||||
|
if rope_dim_list is None:
|
||||||
|
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
||||||
|
assert (
|
||||||
|
sum(rope_dim_list) == head_dim
|
||||||
|
), "sum(rope_dim_list) should equal to head_dim of attention layer"
|
||||||
|
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
|
||||||
|
rope_dim_list,
|
||||||
|
rope_sizes,
|
||||||
|
theta=10000,
|
||||||
|
use_real=True,
|
||||||
|
theta_rescale_factor=1,
|
||||||
|
L_test = (video_length - 1) // 4 + 1,
|
||||||
|
enable_riflex = enable_RIFLEx
|
||||||
|
)
|
||||||
|
return (freqs_cos, freqs_sin)
|
||||||
@ -8,7 +8,7 @@ import sys
|
|||||||
import types
|
import types
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from mmgp import offload
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda.amp as amp
|
import torch.cuda.amp as amp
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -21,6 +21,7 @@ from .modules.vae import WanVAE
|
|||||||
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
||||||
get_sampling_sigmas, retrieve_timesteps)
|
get_sampling_sigmas, retrieve_timesteps)
|
||||||
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||||
|
from wan.modules.posemb_layers import get_rotary_pos_embed
|
||||||
|
|
||||||
|
|
||||||
class WanT2V:
|
class WanT2V:
|
||||||
@ -236,13 +237,7 @@ class WanT2V:
|
|||||||
# sample videos
|
# sample videos
|
||||||
latents = noise
|
latents = noise
|
||||||
|
|
||||||
# from .modules.model import identify_k
|
freqs = get_rotary_pos_embed(frame_num, size[1], size[0], enable_RIFLEx= enable_RIFLEx)
|
||||||
# for nf in range(20, 50):
|
|
||||||
# k, N_k = identify_k(10000, 44, 26)
|
|
||||||
# print(f"value nb latent frames={nf}, k={k}, n_k={N_k}")
|
|
||||||
|
|
||||||
freqs = self.model.get_rope_freqs(nb_latent_frames = int((frame_num - 1)/4 + 1), RIFLEx_k = 6 if enable_RIFLEx else None )
|
|
||||||
|
|
||||||
arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
||||||
arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
||||||
|
|
||||||
@ -252,7 +247,7 @@ class WanT2V:
|
|||||||
for i, t in enumerate(tqdm(timesteps)):
|
for i, t in enumerate(tqdm(timesteps)):
|
||||||
latent_model_input = latents
|
latent_model_input = latents
|
||||||
timestep = [t]
|
timestep = [t]
|
||||||
|
offload.set_step_no_for_lora(i)
|
||||||
timestep = torch.stack(timestep)
|
timestep = torch.stack(timestep)
|
||||||
|
|
||||||
# self.model.to(self.device)
|
# self.model.to(self.device)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user