Wan2.1/wan/modules/rope.py
2025-08-11 14:21:31 +08:00

318 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Optional, Tuple
import triton
import triton.language as tl
import torch
def pad_tensor(
original_tensor: torch.tensor, target_len: int, pad_value: float = 0.0
) -> torch.tensor:
seq_len, s1, s2 = original_tensor.shape
pad_size = target_len - seq_len
padding_tensor = torch.full(
(pad_size, s1, s2),
pad_value,
dtype=original_tensor.dtype,
device=original_tensor.device,
)
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
return padded_tensor
def rope_apply_pytorch(
x: torch.tensor,
grid_sizes: torch.tensor,
freqs: Tuple[torch.tensor],
sp_size: Optional[int] = None,
sp_rank: Optional[int] = None,
) -> torch.tensor:
"""
x: [B, L, N, C].
grid_sizes: [B, 3].
freqs: [M, C // 2].
"""
s, n, c = x.size(1), x.size(2), x.size(3) // 2
c0 = c - 2 * (c // 3)
c1 = c // 3
c2 = c // 3
# split freqs
freqs_real = freqs[0].split([c0, c1, c2], dim=1)
freqs_imag = freqs[-1].split([c0, c1, c2], dim=1)
# loop over samples
output = []
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
seq_len = f * h * w
# precompute multipliers
x_i = x[i, :seq_len].reshape(s, n, -1, 2)
x_real = x_i[..., 0]
x_imag = x_i[..., 1]
freqs_real = torch.cat(
[
freqs_real[0][:f].view(f, 1, 1, c0).expand(f, h, w, c0),
freqs_real[1][:h].view(1, h, 1, c1).expand(f, h, w, c1),
freqs_real[2][:w].view(1, 1, w, c2).expand(f, h, w, c2),
],
dim=-1,
).reshape(seq_len, 1, -1)
freqs_imag = torch.cat(
[
freqs_imag[0][:f].view(f, 1, 1, c0).expand(f, h, w, c0),
freqs_imag[1][:h].view(1, h, 1, c1).expand(f, h, w, c1),
freqs_imag[2][:w].view(1, 1, w, c2).expand(f, h, w, c2),
],
dim=-1,
).reshape(seq_len, 1, -1)
if sp_rank is None:
freqs_real_rank = freqs_real
freqs_imag_rank = freqs_imag
else:
freqs_real = pad_tensor(freqs_real, s * sp_size, 1.0)
freqs_imag = pad_tensor(freqs_imag, s * sp_size, 0.0)
freqs_real_rank = freqs_real[(sp_rank * s) : ((sp_rank + 1) * s), :, :]
freqs_imag_rank = freqs_imag[(sp_rank * s) : ((sp_rank + 1) * s), :, :]
out_real = x_real * freqs_real_rank - x_imag * freqs_imag_rank
out_imag = x_real * freqs_imag_rank + x_imag * freqs_real_rank
x_out = torch.stack([out_real, out_imag], dim=-1).flatten(2)
x_out = torch.cat([x_out, x[i, seq_len:]], dim=0)
# append to collection
output.append(x_out)
return torch.stack(output)
@triton.jit
def rope_kernel(
x_ptr, # [B, S, N, 2C]
grid_sizes_ptr, # [B, 3]
freqs_real_ptr, # [M, C]
freqs_imag_ptr, # [M, C]
output_ptr, # [B, S, N, 2C]
sp_size, # SP world size
sp_rank, # SP rank
B,
S,
N: tl.constexpr,
C: tl.constexpr,
M: tl.constexpr,
CfM: tl.constexpr,
ChM: tl.constexpr,
CwM: tl.constexpr,
SEQ_BLOCK: tl.constexpr,
HEADS_BLOCK: tl.constexpr,
):
Cf = C - 2 * (C // 3)
Ch = C // 3
Cw = C // 3
batch_idx = tl.program_id(0)
seqlen_group_idx = tl.program_id(1)
head_group_idx = tl.program_id(2)
base = batch_idx * 3
F = tl.load(grid_sizes_ptr + base + 0)
H = tl.load(grid_sizes_ptr + base + 1)
W = tl.load(grid_sizes_ptr + base + 2)
seq_len = F * H * W
global_offset = sp_rank * S + seqlen_group_idx * SEQ_BLOCK
seq_indices = global_offset + tl.arange(0, SEQ_BLOCK)
limit = tl.minimum(seq_len, S * sp_size)
seq_mask = seq_indices < limit
seq_indices = tl.where(seq_mask, seq_indices, 0)
HW = H * W
f_idx = seq_indices // HW
rem = seq_indices - f_idx * HW
h_idx = rem // W
w_idx = rem - h_idx * W
freq_offset_cf = tl.arange(0, CfM) # 第1段列偏移 [0, Cf)
freq_offset_ch = Cf + tl.arange(0, ChM) # 第2段列偏移 [Cf, Cf+Ch)
freq_offset_cw = Cf + Ch + tl.arange(0, CwM) # 第3段列偏移 [Cf+Ch, C)
# 按照每个序列位置取对应频率值 (利用广播计算每个位置不同行的值)
# 频率表取值地址 = idx * C + col_offset
freq_addr_cf = f_idx[:, None] * C + freq_offset_cf[None, :]
freq_addr_ch = h_idx[:, None] * C + freq_offset_ch[None, :]
freq_addr_cw = w_idx[:, None] * C + freq_offset_cw[None, :]
freqs_real_cf = tl.load(
freqs_real_ptr + freq_addr_cf,
mask=(seq_mask[:, None] & (freq_offset_cf[None, :] < Cf)),
other=1.0,
).to(tl.float32)
freqs_imag_cf = tl.load(
freqs_imag_ptr + freq_addr_cf,
mask=(seq_mask[:, None] & (freq_offset_cf[None, :] < Cf)),
other=1.0,
).to(tl.float32)
freqs_real_ch = tl.load(
freqs_real_ptr + freq_addr_ch,
mask=(seq_mask[:, None] & (freq_offset_ch[None, :] < Cf + Ch)),
other=1.0,
).to(tl.float32)
freqs_imag_ch = tl.load(
freqs_imag_ptr + freq_addr_ch,
mask=(seq_mask[:, None] & (freq_offset_ch[None, :] < Cf + Ch)),
other=1.0,
).to(tl.float32)
freqs_real_cw = tl.load(
freqs_real_ptr + freq_addr_cw,
mask=(seq_mask[:, None] & (freq_offset_cw[None, :] < C)),
other=1.0,
).to(tl.float32)
freqs_imag_cw = tl.load(
freqs_imag_ptr + freq_addr_cw,
mask=(seq_mask[:, None] & (freq_offset_cw[None, :] < C)),
other=1.0,
).to(tl.float32)
# 将频率值扩展维度以便与x相乘 (在head维度上广播)
freqs_real_cf = freqs_real_cf[:, None, :] # [SEQ_BLOCK, 1, Cf]
freqs_imag_cf = freqs_imag_cf[:, None, :]
freqs_real_ch = freqs_real_ch[:, None, :]
freqs_imag_ch = freqs_imag_ch[:, None, :]
freqs_real_cw = freqs_real_cw[:, None, :]
freqs_imag_cw = freqs_imag_cw[:, None, :]
# 加载输入x对应块的实部和虚部 (形状: [SEQ_BLOCK, HEADS_BLOCK, C])
seq_offset = seqlen_group_idx * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK)
head_offset = head_group_idx * HEADS_BLOCK + tl.arange(0, HEADS_BLOCK)
# 计算x_ptr偏移地址
base_offset = batch_idx * S * N * 2 * C
seq_head_offset = (
base_offset
+ seq_offset[:, None, None] * (N * 2 * C)
+ head_offset[None, :, None] * (2 * C)
)
x_mask = (seq_offset < S)[:, None, None] & (head_offset < N)[None, :, None]
# 加载输入 x 的对应通道段数据超出实际长度部分掩码为0
# 段1通道 [0, Cf-1]
chan_cf = tl.arange(0, CfM * 2)
mask_2cf_chan = chan_cf < Cf * 2
x_cf = tl.load(
x_ptr + seq_head_offset + chan_cf[None, None, :],
mask=(x_mask & mask_2cf_chan[None, None, :]),
other=0.0,
).to(tl.float32)
x_cf = x_cf.reshape(
SEQ_BLOCK, HEADS_BLOCK, CfM, 2
) # [SEQ_BLOCK, HEADS_BLOCK, CfM, 2]
x_real_cf, x_imag_cf = x_cf.split()
# 计算 RoPE 旋转段1
out_real_cf = x_real_cf * freqs_real_cf - x_imag_cf * freqs_imag_cf
out_imag_cf = x_real_cf * freqs_imag_cf + x_imag_cf * freqs_real_cf
out_cf = tl.interleave(out_real_cf, out_imag_cf) # [SEQ_BLOCK, HEADS_BLOCK, CfM, 2]
tl.store(
output_ptr + seq_head_offset + chan_cf[None, None, :],
out_cf,
mask=(x_mask & mask_2cf_chan[None, None, :]),
)
# 段2通道 [Cf, Cf+Ch-1]
chan_ch = tl.arange(0, ChM * 2) + Cf * 2
mask_2ch_chan = chan_ch < 2 * (Cf + Ch)
x_ch = tl.load(
x_ptr + seq_head_offset + chan_ch[None, None, :],
mask=(x_mask & mask_2ch_chan[None, None, :]),
other=0.0,
).to(tl.float32)
x_ch = x_ch.reshape(SEQ_BLOCK, HEADS_BLOCK, ChM, 2)
x_real_ch, x_imag_ch = x_ch.split()
out_real_ch = x_real_ch * freqs_real_ch - x_imag_ch * freqs_imag_ch
out_imag_ch = x_real_ch * freqs_imag_ch + x_imag_ch * freqs_real_ch
out_ch = tl.interleave(out_real_ch, out_imag_ch) # [SEQ_BLOCK, HEADS_BLOCK, ChM, 2]
tl.store(
output_ptr + seq_head_offset + chan_ch[None, None, :],
out_ch,
mask=(x_mask & mask_2ch_chan[None, None, :]),
)
# 段3通道 [Cf+Ch, C-1]
chan_cw = tl.arange(0, CwM * 2) + (Cf + Ch) * 2
mask_2cw_chan = chan_cw < 2 * C
x_cw = tl.load(
x_ptr + seq_head_offset + chan_cw[None, None, :],
mask=(x_mask & mask_2cw_chan[None, None, :]),
other=0.0,
).to(tl.float32)
x_cw = x_cw.reshape(SEQ_BLOCK, HEADS_BLOCK, CwM, 2)
x_real_cw, x_imag_cw = x_cw.split()
out_real_cw = x_real_cw * freqs_real_cw - x_imag_cw * freqs_imag_cw
out_imag_cw = x_real_cw * freqs_imag_cw + x_imag_cw * freqs_real_cw
out_cw = tl.interleave(out_real_cw, out_imag_cw)
tl.store(
output_ptr + seq_head_offset + chan_cw[None, None, :],
out_cw,
mask=(x_mask & mask_2cw_chan[None, None, :]),
)
@torch._dynamo.disable
def rope_apply_triton(
x: torch.tensor,
grid_sizes: torch.tensor,
freqs: Tuple[torch.tensor],
sp_size: Optional[int] = None,
sp_rank: Optional[int] = None,
) -> torch.tensor:
"""
x: [1, 9450, 40, 128]
grid_sizes: [[21, 45, 80]]
freqs_real: [1024, 64]
freqs_imag: [1024, 64]
"""
B, S, N, C = x.shape
C = C // 2
Cf = C - 2 * (C // 3) # 第一维度频率长度
Ch = C // 3 # 第二维度频率长度
Cw = C // 3 # 第三维度频率长度
M = freqs[0].shape[0]
SEQ_BLOCK = 64 # 每个线程块处理的序列长度
HEADS_BLOCK = 8 # 每个线程块处理的头数
if sp_rank is None:
sp_size = 1
sp_rank = 0
grid_sizes = grid_sizes.to(device=x.device)
output = torch.empty_like(x)
rope_kernel[(B, triton.cdiv(S, SEQ_BLOCK), triton.cdiv(N, HEADS_BLOCK))](
x,
grid_sizes,
freqs[0],
freqs[-1],
output,
sp_size,
sp_rank,
B,
S,
N=N,
C=C,
M=M,
CfM=triton.next_power_of_2(Cf),
ChM=triton.next_power_of_2(Ch),
CwM=triton.next_power_of_2(Cw),
SEQ_BLOCK=SEQ_BLOCK,
HEADS_BLOCK=HEADS_BLOCK,
num_warps=32,
num_stages=3,
)
return output.float()