Compare commits

..

2 Commits

Author SHA1 Message Date
Houchen Li
43ac073411
Merge 447aa08620 into 7c81b2f27d 2025-07-28 19:15:11 +08:00
Houchen Li
447aa08620 [feature] adapt for Moore Threads GPU family 2025-07-28 19:14:58 +08:00
14 changed files with 282 additions and 146 deletions

View File

@ -4,7 +4,6 @@ import logging
import os import os
import sys import sys
import warnings import warnings
from time import perf_counter
from datetime import datetime from datetime import datetime
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
@ -20,7 +19,7 @@ try:
import torch_musa import torch_musa
from torch_musa.core.device import set_device from torch_musa.core.device import set_device
except ModuleNotFoundError: except ModuleNotFoundError:
pass torch_musa = None
import wan import wan
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
@ -366,7 +365,6 @@ def generate(args):
logging.info(f"Extended prompt: {args.prompt}") logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating WanT2V pipeline.") logging.info("Creating WanT2V pipeline.")
start_time = perf_counter()
wan_t2v = wan.WanT2V( wan_t2v = wan.WanT2V(
config=cfg, config=cfg,
checkpoint_dir=args.ckpt_dir, checkpoint_dir=args.ckpt_dir,
@ -377,8 +375,6 @@ def generate(args):
use_usp=(args.ulysses_size > 1 or args.ring_size > 1), use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu, t5_cpu=args.t5_cpu,
) )
end_time = perf_counter()
logging.info(f"Creating WanT2V pipeline took {end_time - start_time:.2f} seconds.")
logging.info( logging.info(
f"Generating {'image' if 't2i' in args.task else 'video'} ...") f"Generating {'image' if 't2i' in args.task else 'video'} ...")
@ -392,6 +388,7 @@ def generate(args):
guide_scale=args.sample_guide_scale, guide_scale=args.sample_guide_scale,
seed=args.base_seed, seed=args.base_seed,
offload_model=args.offload_model) offload_model=args.offload_model)
elif "i2v" in args.task: elif "i2v" in args.task:
if args.prompt is None: if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
@ -425,7 +422,6 @@ def generate(args):
logging.info(f"Extended prompt: {args.prompt}") logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating WanI2V pipeline.") logging.info("Creating WanI2V pipeline.")
start_time = perf_counter()
wan_i2v = wan.WanI2V( wan_i2v = wan.WanI2V(
config=cfg, config=cfg,
checkpoint_dir=args.ckpt_dir, checkpoint_dir=args.ckpt_dir,
@ -436,8 +432,6 @@ def generate(args):
use_usp=(args.ulysses_size > 1 or args.ring_size > 1), use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu, t5_cpu=args.t5_cpu,
) )
end_time = perf_counter()
logging.info(f"Creating WanI2V pipeline took {end_time - start_time:.2f} seconds.")
logging.info("Generating video ...") logging.info("Generating video ...")
video = wan_i2v.generate( video = wan_i2v.generate(
@ -586,7 +580,6 @@ def generate(args):
value_range=(-1, 1)) value_range=(-1, 1))
else: else:
logging.info(f"Saving generated video to {args.save_file}") logging.info(f"Saving generated video to {args.save_file}")
start_time = perf_counter()
cache_video( cache_video(
tensor=video[None], tensor=video[None],
save_file=args.save_file, save_file=args.save_file,
@ -594,8 +587,6 @@ def generate(args):
nrow=1, nrow=1,
normalize=True, normalize=True,
value_range=(-1, 1)) value_range=(-1, 1))
end_time = perf_counter()
logging.info(f"Saving Video took {end_time - start_time:.2f} seconds")
logging.info("Finished.") logging.info("Finished.")

View File

@ -13,7 +13,7 @@ try:
import torch_musa import torch_musa
from torch_musa.core.memory import empty_cache from torch_musa.core.memory import empty_cache
except ModuleNotFoundError: except ModuleNotFoundError:
pass torch_musa = None
def shard_model( def shard_model(
model, model,

View File

@ -14,11 +14,24 @@ try:
import torch_musa.core.amp as amp import torch_musa.core.amp as amp
attn_type = AttnType.TORCH attn_type = AttnType.TORCH
except ImportError: except ImportError:
pass torch_musa = None
from ..modules.model import sinusoidal_embedding_1d from ..modules.model import sinusoidal_embedding_1d
def pad_freqs(original_tensor, target_len):
seq_len, s1, s2 = original_tensor.shape
pad_size = target_len - seq_len
padding_tensor = torch.ones(
pad_size,
s1,
s2,
dtype=original_tensor.dtype,
device=original_tensor.device)
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
return padded_tensor
def pad_tensor(original_tensor, target_len, pad_value=0.0): def pad_tensor(original_tensor, target_len, pad_value=0.0):
seq_len, s1, s2 = original_tensor.shape seq_len, s1, s2 = original_tensor.shape
pad_size = target_len - seq_len pad_size = target_len - seq_len
@ -34,6 +47,47 @@ def pad_tensor(original_tensor, target_len, pad_value=0.0):
@amp.autocast(enabled=False) @amp.autocast(enabled=False)
def rope_apply(x, grid_sizes, freqs): def rope_apply(x, grid_sizes, freqs):
"""
x: [B, L, N, C].
grid_sizes: [B, 3].
freqs: [M, C // 2].
"""
s, n, c = x.size(1), x.size(2), x.size(3) // 2
# split freqs
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
# loop over samples
output = []
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
seq_len = f * h * w
# precompute multipliers
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
s, n, -1, 2))
freqs_i = torch.cat([
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
],
dim=-1).reshape(seq_len, 1, -1)
# apply rotary embedding
sp_size = get_sequence_parallel_world_size()
sp_rank = get_sequence_parallel_rank()
freqs_i = pad_freqs(freqs_i, s * sp_size)
s_per_rank = s
freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
s_per_rank), :, :]
x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
x_i = torch.cat([x_i, x[i, s:]])
# append to collection
output.append(x_i)
return torch.stack(output).float()
@amp.autocast(enabled=False)
def rope_apply_musa(x, grid_sizes, freqs):
""" """
x: [B, L, N, C]. x: [B, L, N, C].
grid_sizes: [B, 3]. grid_sizes: [B, 3].
@ -141,11 +195,15 @@ def usp_dit_forward(
# params # params
dtype = self.patch_embedding.weight.dtype dtype = self.patch_embedding.weight.dtype
device = self.patch_embedding.weight.device device = self.patch_embedding.weight.device
if self.freqs[0].dtype != dtype or self.freqs[0].device != device: if torch_musa is not None:
self.freqs = ( if self.freqs[0].dtype != dtype or self.freqs[0].device != device:
self.freqs[0].to(dtype=dtype, device=device), self.freqs = (
self.freqs[-1].to(dtype=dtype, device=device) self.freqs[0].to(dtype=dtype, device=device),
) self.freqs[-1].to(dtype=dtype, device=device)
)
else:
if self.freqs.dtype != dtype or self.freqs.device != device:
self.freqs = self.freqs.to(dtype=dtype, device=device)
if self.model_type != 'vace' and y is not None: if self.model_type != 'vace' and y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
@ -163,9 +221,11 @@ def usp_dit_forward(
]) ])
# time embeddings # time embeddings
e = self.time_embedding( with amp.autocast(dtype=torch.float32):
sinusoidal_embedding_1d(self.freq_dim, t)) e = self.time_embedding(
e0 = self.time_projection(e).unflatten(1, (6, self.dim)) sinusoidal_embedding_1d(self.freq_dim, t).float())
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32
# context # context
context_lens = None context_lens = None
@ -209,7 +269,7 @@ def usp_dit_forward(
# unpatchify # unpatchify
x = self.unpatchify(x, grid_sizes) x = self.unpatchify(x, grid_sizes)
return x return [u.float() for u in x]
def usp_attn_forward(self, def usp_attn_forward(self,
@ -232,8 +292,13 @@ def usp_attn_forward(self,
return q, k, v return q, k, v
q, k, v = qkv_fn(x) q, k, v = qkv_fn(x)
q = rope_apply(q, grid_sizes, freqs)
k = rope_apply(k, grid_sizes, freqs) if torch_musa is not None:
q = rope_apply_musa(q, grid_sizes, freqs)
k = rope_apply_musa(k, grid_sizes, freqs)
else:
q = rope_apply(q, grid_sizes, freqs)
k = rope_apply(k, grid_sizes, freqs)
# TODO: We should use unpaded q,k,v for attention. # TODO: We should use unpaded q,k,v for attention.
# k_lens = seq_lens // get_sequence_parallel_world_size() # k_lens = seq_lens // get_sequence_parallel_world_size()

View File

@ -23,7 +23,7 @@ try:
from torch_musa.core.memory import empty_cache from torch_musa.core.memory import empty_cache
from torch_musa.core.device import synchronize from torch_musa.core.device import synchronize
except ModuleNotFoundError: except ModuleNotFoundError:
pass torch_musa = None
from .distributed.fsdp import shard_model from .distributed.fsdp import shard_model
from .modules.clip import CLIPModel from .modules.clip import CLIPModel

View File

@ -6,7 +6,6 @@ import os
import random import random
import sys import sys
import types import types
from time import perf_counter
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
@ -24,7 +23,7 @@ try:
from torch_musa.core.memory import empty_cache from torch_musa.core.memory import empty_cache
from torch_musa.core.device import synchronize from torch_musa.core.device import synchronize
except ModuleNotFoundError: except ModuleNotFoundError:
pass torch_musa = None
from .distributed.fsdp import shard_model from .distributed.fsdp import shard_model
from .modules.clip import CLIPModel from .modules.clip import CLIPModel
@ -231,7 +230,6 @@ class WanI2V:
n_prompt = self.sample_neg_prompt n_prompt = self.sample_neg_prompt
# preprocess # preprocess
start_time = perf_counter()
if not self.t5_cpu: if not self.t5_cpu:
self.text_encoder.model.to(self.device) self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device) context = self.text_encoder([input_prompt], self.device)
@ -243,18 +241,12 @@ class WanI2V:
context_null = self.text_encoder([n_prompt], torch.device('cpu')) context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context] context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null] context_null = [t.to(self.device) for t in context_null]
end_time = perf_counter()
logging.info(f"T5 Encoding took {end_time - start_time:.2f} seconds.")
start_time = perf_counter()
self.clip.model.to(self.device) self.clip.model.to(self.device)
clip_context = self.clip.visual([img[:, None, :, :]]) clip_context = self.clip.visual([img[:, None, :, :]])
if offload_model: if offload_model:
self.clip.model.cpu() self.clip.model.cpu()
end_time = perf_counter()
logging.info(f"CLIP took {end_time - start_time:.2f} seconds.")
start_time = perf_counter()
y = self.vae.encode([ y = self.vae.encode([
torch.concat([ torch.concat([
torch.nn.functional.interpolate( torch.nn.functional.interpolate(
@ -264,9 +256,6 @@ class WanI2V:
], ],
dim=1).to(self.device) dim=1).to(self.device)
])[0] ])[0]
end_time = perf_counter()
logging.info(f"VAE Encoding took {end_time - start_time:.2f} seconds.")
y = torch.concat([msk, y]) y = torch.concat([msk, y])
@contextmanager @contextmanager
@ -319,7 +308,6 @@ class WanI2V:
if offload_model: if offload_model:
empty_cache() empty_cache()
start_time = perf_counter()
self.model.to(self.device) self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)): for _, t in enumerate(tqdm(timesteps)):
latent_model_input = [latent.to(self.device)] latent_model_input = [latent.to(self.device)]
@ -353,18 +341,13 @@ class WanI2V:
x0 = [latent.to(self.device)] x0 = [latent.to(self.device)]
del latent_model_input, timestep del latent_model_input, timestep
end_time = perf_counter()
logging.info(f"Sampling took {end_time - start_time:.2f} seconds.")
if offload_model: if offload_model:
self.model.cpu() self.model.cpu()
empty_cache() empty_cache()
if self.rank == 0: if self.rank == 0:
start_time = perf_counter()
videos = self.vae.decode(x0) videos = self.vae.decode(x0)
end_time = perf_counter()
logging.info(f"VAE Decoding took {end_time - start_time:.2f} seconds.")
del noise, latent del noise, latent
del sample_scheduler del sample_scheduler

View File

@ -20,7 +20,7 @@ try:
FLASH_ATTN_3_AVAILABLE = False FLASH_ATTN_3_AVAILABLE = False
FLASH_ATTN_2_AVAILABLE = False FLASH_ATTN_2_AVAILABLE = False
except ModuleNotFoundError: except ModuleNotFoundError:
pass torch_musa = None
__all__ = [ __all__ = [
@ -180,9 +180,8 @@ def attention(
k = k.transpose(1, 2).to(dtype) k = k.transpose(1, 2).to(dtype)
v = v.transpose(1, 2).to(dtype) v = v.transpose(1, 2).to(dtype)
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False): out = torch.nn.functional.scaled_dot_product_attention(
out = torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=causal, scale=softmax_scale)
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=causal, scale=softmax_scale)
out = out.transpose(1, 2).contiguous() out = out.transpose(1, 2).contiguous()
return out return out

View File

@ -18,7 +18,7 @@ try:
import torch_musa.core.amp as amp import torch_musa.core.amp as amp
from .attention import attention as flash_attention from .attention import attention as flash_attention
except ModuleNotFoundError: except ModuleNotFoundError:
pass torch_musa = None
__all__ = [ __all__ = [
'XLMRobertaCLIP', 'XLMRobertaCLIP',
@ -37,7 +37,7 @@ def pos_interpolate(pos, seq_len):
return torch.cat([ return torch.cat([
pos[:, :n], pos[:, :n],
F.interpolate( F.interpolate(
pos[:, n:].reshape(1, src_grid, src_grid, -1).permute( pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
0, 3, 1, 2), 0, 3, 1, 2),
size=(tar_grid, tar_grid), size=(tar_grid, tar_grid),
mode='bicubic', mode='bicubic',
@ -52,6 +52,12 @@ class QuickGELU(nn.Module):
return x * torch.sigmoid(1.702 * x) return x * torch.sigmoid(1.702 * x)
class LayerNorm(nn.LayerNorm):
def forward(self, x):
return super().forward(x.float()).type_as(x)
class SelfAttention(nn.Module): class SelfAttention(nn.Module):
def __init__(self, def __init__(self,
@ -84,7 +90,10 @@ class SelfAttention(nn.Module):
# compute attention # compute attention
p = self.attn_dropout if self.training else 0.0 p = self.attn_dropout if self.training else 0.0
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal) if torch_musa is not None:
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal)
else:
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
x = x.reshape(b, s, c) x = x.reshape(b, s, c)
# output # output
@ -133,10 +142,10 @@ class AttentionBlock(nn.Module):
self.norm_eps = norm_eps self.norm_eps = norm_eps
# layers # layers
self.norm1 = nn.LayerNorm(dim, eps=norm_eps) self.norm1 = LayerNorm(dim, eps=norm_eps)
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
proj_dropout) proj_dropout)
self.norm2 = nn.LayerNorm(dim, eps=norm_eps) self.norm2 = LayerNorm(dim, eps=norm_eps)
if activation == 'swi_glu': if activation == 'swi_glu':
self.mlp = SwiGLU(dim, int(dim * mlp_ratio)) self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
else: else:
@ -179,7 +188,7 @@ class AttentionPool(nn.Module):
self.to_q = nn.Linear(dim, dim) self.to_q = nn.Linear(dim, dim)
self.to_kv = nn.Linear(dim, dim * 2) self.to_kv = nn.Linear(dim, dim * 2)
self.proj = nn.Linear(dim, dim) self.proj = nn.Linear(dim, dim)
self.norm = nn.LayerNorm(dim, eps=norm_eps) self.norm = LayerNorm(dim, eps=norm_eps)
self.mlp = nn.Sequential( self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)), nn.Linear(dim, int(dim * mlp_ratio)),
QuickGELU() if activation == 'quick_gelu' else nn.GELU(), QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
@ -196,7 +205,10 @@ class AttentionPool(nn.Module):
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2) k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
# compute attention # compute attention
x = flash_attention(q, k, v, version=2) if torch_musa is not None:
x = flash_attention(q, k, v)
else:
x = flash_attention(q, k, v, version=2)
x = x.reshape(b, 1, c) x = x.reshape(b, 1, c)
# output # output
@ -261,13 +273,13 @@ class VisionTransformer(nn.Module):
self.dropout = nn.Dropout(embedding_dropout) self.dropout = nn.Dropout(embedding_dropout)
# transformer # transformer
self.pre_norm = nn.LayerNorm(dim, eps=norm_eps) if pre_norm else None self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
self.transformer = nn.Sequential(*[ self.transformer = nn.Sequential(*[
AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False, AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
activation, attn_dropout, proj_dropout, norm_eps) activation, attn_dropout, proj_dropout, norm_eps)
for _ in range(num_layers) for _ in range(num_layers)
]) ])
self.post_norm = nn.LayerNorm(dim, eps=norm_eps) self.post_norm = LayerNorm(dim, eps=norm_eps)
# head # head
if pool_type == 'token': if pool_type == 'token':

View File

@ -7,7 +7,6 @@ import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin from diffusers.models.modeling_utils import ModelMixin
from wan.utils.platform import get_device
from wan.modules.attention import flash_attention from wan.modules.attention import flash_attention
try: try:
@ -15,7 +14,7 @@ try:
import torch_musa.core.amp as amp import torch_musa.core.amp as amp
from wan.modules.attention import attention as flash_attention from wan.modules.attention import attention as flash_attention
except ModuleNotFoundError: except ModuleNotFoundError:
pass torch_musa = None
__all__ = ['WanModel'] __all__ = ['WanModel']
@ -27,7 +26,7 @@ def sinusoidal_embedding_1d(dim, position):
# preprocess # preprocess
assert dim % 2 == 0 assert dim % 2 == 0
half = dim // 2 half = dim // 2
position = position.type(torch.bfloat16) position = position.type(torch.float32)
# calculation # calculation
sinusoid = torch.outer( sinusoid = torch.outer(
@ -36,6 +35,17 @@ def sinusoidal_embedding_1d(dim, position):
return x return x
@amp.autocast(enabled=False)
def rope_params(max_seq_len, dim, theta=10000):
assert dim % 2 == 0
freqs = torch.outer(
torch.arange(max_seq_len),
1.0 / torch.pow(theta,
torch.arange(0, dim, 2).to(torch.float64).div(dim)))
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
@amp.autocast(enabled=False) @amp.autocast(enabled=False)
def rope_params_real( def rope_params_real(
max_seq_len, dim, theta=10000, dtype=torch.float32, device=torch.device("cpu") max_seq_len, dim, theta=10000, dtype=torch.float32, device=torch.device("cpu")
@ -69,6 +79,37 @@ def rope_params_imag(
@amp.autocast(enabled=False) @amp.autocast(enabled=False)
def rope_apply(x, grid_sizes, freqs): def rope_apply(x, grid_sizes, freqs):
n, c = x.size(2), x.size(3) // 2 n, c = x.size(2), x.size(3) // 2
# split freqs
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
# loop over samples
output = []
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
seq_len = f * h * w
# precompute multipliers
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
seq_len, n, -1, 2))
freqs_i = torch.cat([
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
],
dim=-1).reshape(seq_len, 1, -1)
# apply rotary embedding
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
x_i = torch.cat([x_i, x[i, seq_len:]])
# append to collection
output.append(x_i)
return torch.stack(output).float()
@amp.autocast(enabled=False)
def rope_apply_musa(x, grid_sizes, freqs):
n, c = x.size(2), x.size(3) // 2
c0 = c - 2 * (c // 3) c0 = c - 2 * (c // 3)
c1 = c // 3 c1 = c // 3
c2 = c // 3 c2 = c // 3
@ -134,6 +175,19 @@ class WanRMSNorm(nn.Module):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
class WanLayerNorm(nn.LayerNorm):
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
def forward(self, x):
r"""
Args:
x(Tensor): Shape [B, L, C]
"""
return super().forward(x.float()).type_as(x)
class WanSelfAttention(nn.Module): class WanSelfAttention(nn.Module):
def __init__(self, def __init__(self,
@ -178,12 +232,22 @@ class WanSelfAttention(nn.Module):
q, k, v = qkv_fn(x) q, k, v = qkv_fn(x)
x = flash_attention( if torch_musa is not None:
q=rope_apply(q, grid_sizes, freqs), x = flash_attention(
k=rope_apply(k, grid_sizes, freqs), q=rope_apply_musa(q, grid_sizes, freqs),
v=v, k=rope_apply_musa(k, grid_sizes, freqs),
k_lens=seq_lens, v=v,
window_size=self.window_size) k_lens=seq_lens,
window_size=self.window_size,
)
else:
x = flash_attention(
q=rope_apply(q, grid_sizes, freqs),
k=rope_apply(k, grid_sizes, freqs),
v=v,
k_lens=seq_lens,
window_size=self.window_size,
)
# output # output
x = x.flatten(2) x = x.flatten(2)
@ -288,10 +352,10 @@ class WanAttentionBlock(nn.Module):
self.eps = eps self.eps = eps
# layers # layers
self.norm1 = nn.LayerNorm(dim, eps, elementwise_affine=False) self.norm1 = WanLayerNorm(dim, eps)
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
eps) eps)
self.norm3 = nn.LayerNorm( self.norm3 = WanLayerNorm(
dim, eps, dim, eps,
elementwise_affine=True) if cross_attn_norm else nn.Identity() elementwise_affine=True) if cross_attn_norm else nn.Identity()
self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim, self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
@ -299,7 +363,7 @@ class WanAttentionBlock(nn.Module):
(-1, -1), (-1, -1),
qk_norm, qk_norm,
eps) eps)
self.norm2 = nn.LayerNorm(dim, eps, elementwise_affine=False) self.norm2 = WanLayerNorm(dim, eps)
self.ffn = nn.Sequential( self.ffn = nn.Sequential(
nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'), nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
nn.Linear(ffn_dim, dim)) nn.Linear(ffn_dim, dim))
@ -325,19 +389,24 @@ class WanAttentionBlock(nn.Module):
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
""" """
e = (self.modulation + e).chunk(6, dim=1) assert e.dtype == torch.float32
with amp.autocast(dtype=torch.float32):
e = (self.modulation + e).chunk(6, dim=1)
assert e[0].dtype == torch.float32
# self-attention # self-attention
y = self.self_attn( y = self.self_attn(
self.norm1(x) * (1 + e[1]) + e[0], seq_lens, grid_sizes, self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes,
freqs) freqs)
x = x + y * e[2] with amp.autocast(dtype=torch.float32):
x = x + y * e[2]
# cross-attention & ffn function # cross-attention & ffn function
def cross_attn_ffn(x, context, context_lens, e): def cross_attn_ffn(x, context, context_lens, e):
x = x + self.cross_attn(self.norm3(x), context, context_lens) x = x + self.cross_attn(self.norm3(x), context, context_lens)
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3]) y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
x = x + y * e[5] with amp.autocast(dtype=torch.float32):
x = x + y * e[5]
return x return x
x = cross_attn_ffn(x, context, context_lens, e) x = cross_attn_ffn(x, context, context_lens, e)
@ -355,7 +424,7 @@ class Head(nn.Module):
# layers # layers
out_dim = math.prod(patch_size) * out_dim out_dim = math.prod(patch_size) * out_dim
self.norm = nn.LayerNorm(dim, eps, elementwise_affine=False) self.norm = WanLayerNorm(dim, eps)
self.head = nn.Linear(dim, out_dim) self.head = nn.Linear(dim, out_dim)
# modulation # modulation
@ -367,8 +436,10 @@ class Head(nn.Module):
x(Tensor): Shape [B, L1, C] x(Tensor): Shape [B, L1, C]
e(Tensor): Shape [B, C] e(Tensor): Shape [B, C]
""" """
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) assert e.dtype == torch.float32
x = self.head(self.norm(x) * (1 + e[1]) + e[0]) with amp.autocast(dtype=torch.float32):
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
return x return x
@ -502,23 +573,33 @@ class WanModel(ModelMixin, ConfigMixin):
# buffers (don't use register_buffer otherwise dtype will be changed in to()) # buffers (don't use register_buffer otherwise dtype will be changed in to())
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
d = dim // num_heads d = dim // num_heads
freqs_real = torch.cat( if torch_musa is not None:
[ freqs_real = torch.cat(
rope_params_real(1024, d - 4 * (d // 6)), [
rope_params_real(1024, 2 * (d // 6)), rope_params_real(1024, d - 4 * (d // 6)),
rope_params_real(1024, 2 * (d // 6)), rope_params_real(1024, 2 * (d // 6)),
], rope_params_real(1024, 2 * (d // 6)),
dim=1, ],
) dim=1,
freqs_imag = torch.cat( )
[ freqs_imag = torch.cat(
rope_params_imag(1024, d - 4 * (d // 6)), [
rope_params_imag(1024, 2 * (d // 6)), rope_params_imag(1024, d - 4 * (d // 6)),
rope_params_imag(1024, 2 * (d // 6)), rope_params_imag(1024, 2 * (d // 6)),
], rope_params_imag(1024, 2 * (d // 6)),
dim=1, ],
) dim=1,
self.freqs = (freqs_real, freqs_imag) )
self.freqs = (freqs_real, freqs_imag)
else:
self.freqs = torch.cat(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
dim=1,
)
if model_type == 'i2v' or model_type == 'flf2v': if model_type == 'i2v' or model_type == 'flf2v':
self.img_emb = MLPProj(1280, dim, flf_pos_emb=model_type == 'flf2v') self.img_emb = MLPProj(1280, dim, flf_pos_emb=model_type == 'flf2v')
@ -561,11 +642,15 @@ class WanModel(ModelMixin, ConfigMixin):
# params # params
dtype = self.patch_embedding.weight.dtype dtype = self.patch_embedding.weight.dtype
device = self.patch_embedding.weight.device device = self.patch_embedding.weight.device
if self.freqs[0].dtype != dtype or self.freqs[0].device != device: if torch_musa is not None:
self.freqs = ( if self.freqs[0].dtype != dtype or self.freqs[0].device != device:
self.freqs[0].to(dtype=dtype, device=device), self.freqs = (
self.freqs[-1].to(dtype=dtype, device=device) self.freqs[0].to(dtype=dtype, device=device),
) self.freqs[-1].to(dtype=dtype, device=device)
)
else:
if self.freqs.dtype != dtype or self.freqs.device != device:
self.freqs = self.freqs.to(dtype=dtype, device=device)
if y is not None: if y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
@ -583,9 +668,11 @@ class WanModel(ModelMixin, ConfigMixin):
]) ])
# time embeddings # time embeddings
e = self.time_embedding( with amp.autocast(dtype=torch.float32):
sinusoidal_embedding_1d(self.freq_dim, t)) e = self.time_embedding(
e0 = self.time_projection(e).unflatten(1, (6, self.dim)) sinusoidal_embedding_1d(self.freq_dim, t).float())
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32
# context # context
context_lens = None context_lens = None
@ -617,7 +704,7 @@ class WanModel(ModelMixin, ConfigMixin):
# unpatchify # unpatchify
x = self.unpatchify(x, grid_sizes) x = self.unpatchify(x, grid_sizes)
return x return [u.float() for u in x]
def unpatchify(self, x, grid_sizes): def unpatchify(self, x, grid_sizes):
r""" r"""

View File

@ -6,15 +6,15 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.cuda import current_device
try: try:
import torch_musa import torch_musa
from torch_musa.core.device import current_device from torch_musa.core.device import current_device
except ModuleNotFoundError: except ModuleNotFoundError:
pass torch_musa = None
from wan.modules.tokenizers import HuggingfaceTokenizer from .tokenizers import HuggingfaceTokenizer
from wan.utils.platform import get_device
__all__ = [ __all__ = [
'T5Model', 'T5Model',
@ -66,8 +66,10 @@ class T5LayerNorm(nn.Module):
self.weight = nn.Parameter(torch.ones(dim)) self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x): def forward(self, x):
x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
self.eps) self.eps)
if self.weight.dtype in [torch.float16, torch.bfloat16]:
x = x.type_as(self.weight)
return self.weight * x return self.weight * x
@ -115,7 +117,7 @@ class T5Attention(nn.Module):
# compute attention (T5 does not use scaling) # compute attention (T5 does not use scaling)
attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
attn = F.softmax(attn, dim=-1) attn = F.softmax(attn.float(), dim=-1).type_as(attn)
x = torch.einsum('bnij,bjnc->binc', attn, v) x = torch.einsum('bnij,bjnc->binc', attn, v)
# output # output
@ -260,7 +262,7 @@ class T5RelativeEmbedding(nn.Module):
# embeddings for small and large positions # embeddings for small and large positions
max_exact = num_buckets // 2 max_exact = num_buckets // 2
rel_pos_large = max_exact + (torch.log(rel_pos / max_exact) / rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
math.log(self.max_dist / max_exact) * math.log(self.max_dist / max_exact) *
(num_buckets - max_exact)).long() (num_buckets - max_exact)).long()
rel_pos_large = torch.min( rel_pos_large = torch.min(
@ -480,7 +482,7 @@ class T5EncoderModel:
self, self,
text_len, text_len,
dtype=torch.bfloat16, dtype=torch.bfloat16,
device=get_device(), device=current_device(),
checkpoint_path=None, checkpoint_path=None,
tokenizer_path=None, tokenizer_path=None,
shard_fn=None, shard_fn=None,

View File

@ -8,7 +8,7 @@ try:
import torch_musa import torch_musa
import torch_musa.core.amp as amp import torch_musa.core.amp as amp
except ModuleNotFoundError: except ModuleNotFoundError:
pass torch_musa = None
from .model import WanAttentionBlock, WanModel, sinusoidal_embedding_1d from .model import WanAttentionBlock, WanModel, sinusoidal_embedding_1d

View File

@ -1,8 +1,8 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import logging import logging
from math import sqrt
import torch import torch
import torch.cuda.amp as amp
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import Upsample from torch.nn import Upsample
@ -10,8 +10,9 @@ from einops import rearrange
try: try:
import torch_musa import torch_musa
import torch_musa.core.amp as amp
except ModuleNotFoundError: except ModuleNotFoundError:
pass torch_musa = None
from wan.utils.platform import get_device from wan.utils.platform import get_device
@ -52,17 +53,23 @@ class RMS_norm(nn.Module):
shape = (dim, *broadcastable_dims) if channel_first else (dim,) shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first self.channel_first = channel_first
self.scale = sqrt(dim) self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape)) self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0. self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
def forward(self, x): def forward(self, x):
return ( return F.normalize(
F.normalize(x.float(), dim=(1 if self.channel_first else -1)).type_as(x) x, dim=(1 if self.channel_first else
* self.scale -1)) * self.scale * self.gamma + self.bias
* self.gamma
+ self.bias
) class Upsample(nn.Upsample):
def forward(self, x):
"""
Fix bfloat16 support for nearest neighbor interpolation.
"""
return super().forward(x.float()).type_as(x)
class Resample(nn.Module): class Resample(nn.Module):
@ -255,10 +262,6 @@ class AttentionBlock(nn.Module):
q, q,
k, k,
v, v,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale=None,
) )
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
@ -627,7 +630,7 @@ class WanVAE:
def __init__(self, def __init__(self,
z_dim=16, z_dim=16,
vae_pth='cache/vae_step_411000.pth', vae_pth='cache/vae_step_411000.pth',
dtype=torch.bfloat16, dtype=torch.float,
device=get_device()): device=get_device()):
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
@ -654,12 +657,16 @@ class WanVAE:
""" """
videos: A list of videos each with shape [C, T, H, W]. videos: A list of videos each with shape [C, T, H, W].
""" """
return [ with amp.autocast(dtype=self.dtype):
self.model.encode(u.unsqueeze(0), self.scale).squeeze(0) for u in videos return [
] self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
for u in videos
]
def decode(self, zs): def decode(self, zs):
return [ with amp.autocast(dtype=self.dtype):
self.model.decode(u.unsqueeze(0), self.scale).clamp_(-1, 1).squeeze(0) return [
for u in zs self.model.decode(u.unsqueeze(0),
] self.scale).float().clamp_(-1, 1).squeeze(0)
for u in zs
]

View File

@ -6,7 +6,6 @@ import os
import random import random
import sys import sys
import types import types
from time import perf_counter
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
@ -22,7 +21,7 @@ try:
from torch_musa.core.memory import empty_cache from torch_musa.core.memory import empty_cache
from torch_musa.core.device import synchronize from torch_musa.core.device import synchronize
except ModuleNotFoundError: except ModuleNotFoundError:
pass torch_musa = None
from .distributed.fsdp import shard_model from .distributed.fsdp import shard_model
from .modules.model import WanModel from .modules.model import WanModel
@ -181,7 +180,6 @@ class WanT2V:
seed_g = torch.Generator(device=self.device) seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed) seed_g.manual_seed(seed)
start_time = perf_counter()
if not self.t5_cpu: if not self.t5_cpu:
self.text_encoder.model.to(self.device) self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device) context = self.text_encoder([input_prompt], self.device)
@ -193,8 +191,6 @@ class WanT2V:
context_null = self.text_encoder([n_prompt], torch.device('cpu')) context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context] context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null] context_null = [t.to(self.device) for t in context_null]
end_time = perf_counter()
logging.info(f"T5 Encoding Context took {end_time - start_time:.2f} seconds.")
noise = [ noise = [
torch.randn( torch.randn(
@ -243,14 +239,13 @@ class WanT2V:
arg_c = {'context': context, 'seq_len': seq_len} arg_c = {'context': context, 'seq_len': seq_len}
arg_null = {'context': context_null, 'seq_len': seq_len} arg_null = {'context': context_null, 'seq_len': seq_len}
start_time = perf_counter()
self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)): for _, t in enumerate(tqdm(timesteps)):
latent_model_input = latents latent_model_input = latents
timestep = [t] timestep = [t]
timestep = torch.stack(timestep) timestep = torch.stack(timestep)
self.model.to(self.device)
noise_pred_cond = self.model( noise_pred_cond = self.model(
latent_model_input, t=timestep, **arg_c)[0] latent_model_input, t=timestep, **arg_c)[0]
noise_pred_uncond = self.model( noise_pred_uncond = self.model(
@ -266,18 +261,13 @@ class WanT2V:
return_dict=False, return_dict=False,
generator=seed_g)[0] generator=seed_g)[0]
latents = [temp_x0.squeeze(0)] latents = [temp_x0.squeeze(0)]
end_time = perf_counter()
logging.info(f"Sampling took {end_time - start_time:.2f} seconds.")
x0 = latents x0 = latents
if offload_model: if offload_model:
self.model.cpu() self.model.cpu()
empty_cache() empty_cache()
if self.rank == 0: if self.rank == 0:
start_time = perf_counter()
videos = self.vae.decode(x0) videos = self.vae.decode(x0)
end_time = perf_counter()
logging.info(f"VAE Decoding took {end_time - start_time:.2f} seconds.")
del noise, latents del noise, latents
del sample_scheduler del sample_scheduler

View File

@ -5,7 +5,7 @@ import torch
try: try:
import torch_musa import torch_musa
except ModuleNotFoundError: except ModuleNotFoundError:
pass torch_musa = None
def _is_musa(): def _is_musa():
@ -31,4 +31,4 @@ def get_torch_distributed_backend() -> str:
elif _is_musa(): elif _is_musa():
return "mccl" return "mccl"
else: else:
raise NotImplementedError("No Accelerators(AMD/NV/MTT GPU, AMD MI instinct accelerators) available") raise NotImplementedError("No Accelerators(NV/MTT GPU) available")

View File

@ -27,7 +27,7 @@ try:
from torch_musa.core.memory import empty_cache from torch_musa.core.memory import empty_cache
from torch_musa.core.device import synchronize from torch_musa.core.device import synchronize
except ModuleNotFoundError: except ModuleNotFoundError:
pass torch_musa = None
from .modules.vace_model import VaceWanModel from .modules.vace_model import VaceWanModel
from .text2video import ( from .text2video import (