This commit is contained in:
Houchen Li 2025-07-28 17:23:28 +08:00 committed by GitHub
commit 618d94c564
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 399 additions and 181 deletions

View File

@ -4,6 +4,7 @@ import logging
import os
import sys
import warnings
from time import perf_counter
from datetime import datetime
warnings.filterwarnings('ignore')
@ -12,12 +13,20 @@ import random
import torch
import torch.distributed as dist
from torch.cuda import set_device
from PIL import Image
try:
import torch_musa
from torch_musa.core.device import set_device
except ModuleNotFoundError:
pass
import wan
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_image, cache_video, str2bool
from wan.utils.platform import get_torch_distributed_backend
EXAMPLE_PROMPT = {
@ -275,9 +284,9 @@ def generate(args):
logging.info(
f"offload_model is not specified, set to {args.offload_model}.")
if world_size > 1:
torch.cuda.set_device(local_rank)
set_device(local_rank)
dist.init_process_group(
backend="nccl",
backend=get_torch_distributed_backend(),
init_method="env://",
rank=rank,
world_size=world_size)
@ -357,6 +366,7 @@ def generate(args):
logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating WanT2V pipeline.")
start_time = perf_counter()
wan_t2v = wan.WanT2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
@ -367,6 +377,8 @@ def generate(args):
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
)
end_time = perf_counter()
logging.info(f"Creating WanT2V pipeline took {end_time - start_time:.2f} seconds.")
logging.info(
f"Generating {'image' if 't2i' in args.task else 'video'} ...")
@ -380,7 +392,6 @@ def generate(args):
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
elif "i2v" in args.task:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
@ -414,6 +425,7 @@ def generate(args):
logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating WanI2V pipeline.")
start_time = perf_counter()
wan_i2v = wan.WanI2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
@ -424,6 +436,8 @@ def generate(args):
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
)
end_time = perf_counter()
logging.info(f"Creating WanI2V pipeline took {end_time - start_time:.2f} seconds.")
logging.info("Generating video ...")
video = wan_i2v.generate(
@ -572,6 +586,7 @@ def generate(args):
value_range=(-1, 1))
else:
logging.info(f"Saving generated video to {args.save_file}")
start_time = perf_counter()
cache_video(
tensor=video[None],
save_file=args.save_file,
@ -579,6 +594,8 @@ def generate(args):
nrow=1,
normalize=True,
value_range=(-1, 1))
end_time = perf_counter()
logging.info(f"Saving Video took {end_time - start_time:.2f} seconds")
logging.info("Finished.")

View File

@ -3,11 +3,17 @@ import gc
from functools import partial
import torch
from torch.cuda import empty_cache
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
from torch.distributed.utils import _free_storage
try:
import torch_musa
from torch_musa.core.memory import empty_cache
except ModuleNotFoundError:
pass
def shard_model(
model,
@ -40,4 +46,4 @@ def free_model(model):
_free_storage(m._handle.flat_param.data)
del model
gc.collect()
torch.cuda.empty_cache()
empty_cache()

View File

@ -6,20 +6,28 @@ from xfuser.core.distributed import (
get_sequence_parallel_world_size,
get_sp_group,
)
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from xfuser.core.long_ctx_attention import xFuserLongContextAttention, AttnType
attn_type:AttnType = AttnType.FA
try:
import torch_musa
import torch_musa.core.amp as amp
attn_type = AttnType.TORCH
except ImportError:
pass
from ..modules.model import sinusoidal_embedding_1d
def pad_freqs(original_tensor, target_len):
def pad_tensor(original_tensor, target_len, pad_value=0.0):
seq_len, s1, s2 = original_tensor.shape
pad_size = target_len - seq_len
padding_tensor = torch.ones(
pad_size,
s1,
s2,
padding_tensor = torch.full(
(pad_size, s1, s2),
pad_value,
dtype=original_tensor.dtype,
device=original_tensor.device)
device=original_tensor.device,
)
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
return padded_tensor
@ -32,8 +40,13 @@ def rope_apply(x, grid_sizes, freqs):
freqs: [M, C // 2].
"""
s, n, c = x.size(1), x.size(2), x.size(3) // 2
c0 = c - 2 * (c // 3)
c1 = c // 3
c2 = c // 3
# split freqs
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
freqs_real = freqs[0].split([c0, c1, c2], dim=1)
freqs_imag = freqs[-1].split([c0, c1, c2], dim=1)
# loop over samples
output = []
@ -41,28 +54,45 @@ def rope_apply(x, grid_sizes, freqs):
seq_len = f * h * w
# precompute multipliers
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
s, n, -1, 2))
freqs_i = torch.cat([
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
x_i = x[i, :seq_len].reshape(s, n, -1, 2)
x_real = x_i[..., 0]
x_imag = x_i[..., 1]
freqs_real = torch.cat(
[
freqs_real[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs_real[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs_real[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1).reshape(seq_len, 1, -1)
dim=-1,
).reshape(seq_len, 1, -1)
freqs_imag = torch.cat(
[
freqs_imag[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs_imag[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs_imag[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
).reshape(seq_len, 1, -1)
# apply rotary embedding
sp_size = get_sequence_parallel_world_size()
sp_rank = get_sequence_parallel_rank()
freqs_i = pad_freqs(freqs_i, s * sp_size)
s_per_rank = s
freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
s_per_rank), :, :]
x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
x_i = torch.cat([x_i, x[i, s:]])
freqs_real = pad_tensor(freqs_real, s * sp_size, 1.0)
freqs_imag = pad_tensor(freqs_imag, s * sp_size, 0.0)
freqs_real_rank = freqs_real[(sp_rank * s) : ((sp_rank + 1) * s), :, :]
freqs_imag_rank = freqs_imag[(sp_rank * s) : ((sp_rank + 1) * s), :, :]
out_real = x_real * freqs_real_rank - x_imag * freqs_imag_rank
out_imag = x_real * freqs_imag_rank + x_imag * freqs_real_rank
x_out = torch.stack([out_real, out_imag], dim=-1).flatten(2)
x_out = torch.cat([x_out, x[i, seq_len:]], dim=0)
# append to collection
output.append(x_i)
return torch.stack(output).float()
output.append(x_out)
return torch.stack(output)
def usp_dit_forward_vace(self, x, vace_context, seq_len, kwargs):
@ -109,9 +139,13 @@ def usp_dit_forward(
if self.model_type == 'i2v':
assert clip_fea is not None and y is not None
# params
dtype = self.patch_embedding.weight.dtype
device = self.patch_embedding.weight.device
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
if self.freqs[0].dtype != dtype or self.freqs[0].device != device:
self.freqs = (
self.freqs[0].to(dtype=dtype, device=device),
self.freqs[-1].to(dtype=dtype, device=device)
)
if self.model_type != 'vace' and y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
@ -129,11 +163,9 @@ def usp_dit_forward(
])
# time embeddings
with amp.autocast(dtype=torch.float32):
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).float())
sinusoidal_embedding_1d(self.freq_dim, t))
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32
# context
context_lens = None
@ -177,7 +209,7 @@ def usp_dit_forward(
# unpatchify
x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x]
return x
def usp_attn_forward(self,
@ -210,7 +242,7 @@ def usp_attn_forward(self,
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
x = xFuserLongContextAttention()(
x = xFuserLongContextAttention(attn_type=attn_type)(
None,
query=half(q),
key=half(k),

View File

@ -12,10 +12,19 @@ from functools import partial
import numpy as np
import torch
import torch.cuda.amp as amp
from torch.cuda import empty_cache, synchronize
import torch.distributed as dist
import torchvision.transforms.functional as TF
from tqdm import tqdm
try:
import torch_musa
import torch_musa.core.amp as amp
from torch_musa.core.memory import empty_cache
from torch_musa.core.device import synchronize
except ModuleNotFoundError:
pass
from .distributed.fsdp import shard_model
from .modules.clip import CLIPModel
from .modules.model import WanModel
@ -27,6 +36,7 @@ from .utils.fm_solvers import (
retrieve_timesteps,
)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from .utils.platform import get_device
class WanFLF2V:
@ -66,7 +76,7 @@ class WanFLF2V:
init_on_cpu (`bool`, *optional*, defaults to True):
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
"""
self.device = torch.device(f"cuda:{device_id}")
self.device = get_device(device_id)
self.config = config
self.rank = rank
self.use_usp = use_usp
@ -323,7 +333,7 @@ class WanFLF2V:
}
if offload_model:
torch.cuda.empty_cache()
empty_cache()
self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)):
@ -336,12 +346,12 @@ class WanFLF2V:
latent_model_input, t=timestep, **arg_c)[0].to(
torch.device('cpu') if offload_model else self.device)
if offload_model:
torch.cuda.empty_cache()
empty_cache()
noise_pred_uncond = self.model(
latent_model_input, t=timestep, **arg_null)[0].to(
torch.device('cpu') if offload_model else self.device)
if offload_model:
torch.cuda.empty_cache()
empty_cache()
noise_pred = noise_pred_uncond + guide_scale * (
noise_pred_cond - noise_pred_uncond)
@ -361,7 +371,7 @@ class WanFLF2V:
if offload_model:
self.model.cpu()
torch.cuda.empty_cache()
empty_cache()
if self.rank == 0:
videos = self.vae.decode(x0)
@ -370,7 +380,7 @@ class WanFLF2V:
del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
synchronize()
if dist.is_initialized():
dist.barrier()

View File

@ -6,16 +6,26 @@ import os
import random
import sys
import types
from time import perf_counter
from contextlib import contextmanager
from functools import partial
import numpy as np
import torch
import torch.cuda.amp as amp
from torch.cuda import empty_cache, synchronize
import torch.distributed as dist
import torchvision.transforms.functional as TF
from tqdm import tqdm
try:
import torch_musa
import torch_musa.core.amp as amp
from torch_musa.core.memory import empty_cache
from torch_musa.core.device import synchronize
except ModuleNotFoundError:
pass
from .distributed.fsdp import shard_model
from .modules.clip import CLIPModel
from .modules.model import WanModel
@ -27,6 +37,7 @@ from .utils.fm_solvers import (
retrieve_timesteps,
)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from .utils.platform import get_device
class WanI2V:
@ -66,7 +77,7 @@ class WanI2V:
init_on_cpu (`bool`, *optional*, defaults to True):
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
"""
self.device = torch.device(f"cuda:{device_id}")
self.device = get_device(device_id)
self.config = config
self.rank = rank
self.use_usp = use_usp
@ -220,6 +231,7 @@ class WanI2V:
n_prompt = self.sample_neg_prompt
# preprocess
start_time = perf_counter()
if not self.t5_cpu:
self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device)
@ -231,12 +243,18 @@ class WanI2V:
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null]
end_time = perf_counter()
logging.info(f"T5 Encoding took {end_time - start_time:.2f} seconds.")
start_time = perf_counter()
self.clip.model.to(self.device)
clip_context = self.clip.visual([img[:, None, :, :]])
if offload_model:
self.clip.model.cpu()
end_time = perf_counter()
logging.info(f"CLIP took {end_time - start_time:.2f} seconds.")
start_time = perf_counter()
y = self.vae.encode([
torch.concat([
torch.nn.functional.interpolate(
@ -246,6 +264,9 @@ class WanI2V:
],
dim=1).to(self.device)
])[0]
end_time = perf_counter()
logging.info(f"VAE Encoding took {end_time - start_time:.2f} seconds.")
y = torch.concat([msk, y])
@contextmanager
@ -296,8 +317,9 @@ class WanI2V:
}
if offload_model:
torch.cuda.empty_cache()
empty_cache()
start_time = perf_counter()
self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)):
latent_model_input = [latent.to(self.device)]
@ -309,12 +331,12 @@ class WanI2V:
latent_model_input, t=timestep, **arg_c)[0].to(
torch.device('cpu') if offload_model else self.device)
if offload_model:
torch.cuda.empty_cache()
empty_cache()
noise_pred_uncond = self.model(
latent_model_input, t=timestep, **arg_null)[0].to(
torch.device('cpu') if offload_model else self.device)
if offload_model:
torch.cuda.empty_cache()
empty_cache()
noise_pred = noise_pred_uncond + guide_scale * (
noise_pred_cond - noise_pred_uncond)
@ -331,19 +353,24 @@ class WanI2V:
x0 = [latent.to(self.device)]
del latent_model_input, timestep
end_time = perf_counter()
logging.info(f"Sampling took {end_time - start_time:.2f} seconds.")
if offload_model:
self.model.cpu()
torch.cuda.empty_cache()
empty_cache()
if self.rank == 0:
start_time = perf_counter()
videos = self.vae.decode(x0)
end_time = perf_counter()
logging.info(f"VAE Decoding took {end_time - start_time:.2f} seconds.")
del noise, latent
del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
synchronize()
if dist.is_initialized():
dist.barrier()

View File

@ -1,4 +1,6 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import warnings
import torch
try:
@ -13,7 +15,13 @@ try:
except ModuleNotFoundError:
FLASH_ATTN_2_AVAILABLE = False
import warnings
try:
import torch_musa
FLASH_ATTN_3_AVAILABLE = False
FLASH_ATTN_2_AVAILABLE = False
except ModuleNotFoundError:
pass
__all__ = [
'flash_attention',
@ -51,7 +59,7 @@ def flash_attention(
"""
half_dtypes = (torch.float16, torch.bfloat16)
assert dtype in half_dtypes
assert q.device.type == 'cuda' and q.size(-1) <= 256
assert (q.device.type == "cuda" or q.device.type == "musa") and q.size(-1) <= 256
# params
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
@ -172,8 +180,9 @@ def attention(
k = k.transpose(1, 2).to(dtype)
v = v.transpose(1, 2).to(dtype)
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False):
out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=causal, scale=softmax_scale)
out = out.transpose(1, 2).contiguous()
return out

View File

@ -6,12 +6,20 @@ import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.cuda.amp as amp
import torchvision.transforms as T
from .attention import flash_attention
from .tokenizers import HuggingfaceTokenizer
from .xlm_roberta import XLMRoberta
try:
import torch_musa
import torch_musa.core.amp as amp
from .attention import attention as flash_attention
except ModuleNotFoundError:
pass
__all__ = [
'XLMRobertaCLIP',
'clip_xlm_roberta_vit_h_14',
@ -29,7 +37,7 @@ def pos_interpolate(pos, seq_len):
return torch.cat([
pos[:, :n],
F.interpolate(
pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
pos[:, n:].reshape(1, src_grid, src_grid, -1).permute(
0, 3, 1, 2),
size=(tar_grid, tar_grid),
mode='bicubic',
@ -44,12 +52,6 @@ class QuickGELU(nn.Module):
return x * torch.sigmoid(1.702 * x)
class LayerNorm(nn.LayerNorm):
def forward(self, x):
return super().forward(x.float()).type_as(x)
class SelfAttention(nn.Module):
def __init__(self,
@ -82,7 +84,7 @@ class SelfAttention(nn.Module):
# compute attention
p = self.attn_dropout if self.training else 0.0
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal)
x = x.reshape(b, s, c)
# output
@ -131,10 +133,10 @@ class AttentionBlock(nn.Module):
self.norm_eps = norm_eps
# layers
self.norm1 = LayerNorm(dim, eps=norm_eps)
self.norm1 = nn.LayerNorm(dim, eps=norm_eps)
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
proj_dropout)
self.norm2 = LayerNorm(dim, eps=norm_eps)
self.norm2 = nn.LayerNorm(dim, eps=norm_eps)
if activation == 'swi_glu':
self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
else:
@ -177,7 +179,7 @@ class AttentionPool(nn.Module):
self.to_q = nn.Linear(dim, dim)
self.to_kv = nn.Linear(dim, dim * 2)
self.proj = nn.Linear(dim, dim)
self.norm = LayerNorm(dim, eps=norm_eps)
self.norm = nn.LayerNorm(dim, eps=norm_eps)
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
@ -259,13 +261,13 @@ class VisionTransformer(nn.Module):
self.dropout = nn.Dropout(embedding_dropout)
# transformer
self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
self.pre_norm = nn.LayerNorm(dim, eps=norm_eps) if pre_norm else None
self.transformer = nn.Sequential(*[
AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
activation, attn_dropout, proj_dropout, norm_eps)
for _ in range(num_layers)
])
self.post_norm = LayerNorm(dim, eps=norm_eps)
self.post_norm = nn.LayerNorm(dim, eps=norm_eps)
# head
if pool_type == 'token':
@ -537,6 +539,6 @@ class CLIPModel:
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
# forward
with torch.cuda.amp.autocast(dtype=self.dtype):
with amp.autocast(dtype=self.dtype):
out = self.model.visual(videos, use_31_block=True)
return out

View File

@ -7,7 +7,15 @@ import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from .attention import flash_attention
from wan.utils.platform import get_device
from wan.modules.attention import flash_attention
try:
import torch_musa
import torch_musa.core.amp as amp
from wan.modules.attention import attention as flash_attention
except ModuleNotFoundError:
pass
__all__ = ['WanModel']
@ -19,7 +27,7 @@ def sinusoidal_embedding_1d(dim, position):
# preprocess
assert dim % 2 == 0
half = dim // 2
position = position.type(torch.float64)
position = position.type(torch.bfloat16)
# calculation
sinusoid = torch.outer(
@ -29,22 +37,45 @@ def sinusoidal_embedding_1d(dim, position):
@amp.autocast(enabled=False)
def rope_params(max_seq_len, dim, theta=10000):
def rope_params_real(
max_seq_len, dim, theta=10000, dtype=torch.float32, device=torch.device("cpu")
):
assert dim % 2 == 0
freqs = torch.outer(
torch.arange(max_seq_len),
1.0 / torch.pow(theta,
torch.arange(0, dim, 2).to(torch.float64).div(dim)))
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
freqs_real = torch.outer(
torch.arange(max_seq_len, dtype=dtype, device=device),
1.0
/ torch.pow(
theta, torch.arange(0, dim, 2, dtype=dtype, device=device).div(dim)
),
)
return torch.cos(freqs_real)
@amp.autocast(enabled=False)
def rope_params_imag(
max_seq_len, dim, theta=10000, dtype=torch.float32, device=torch.device("cpu")
):
assert dim % 2 == 0
freqs_imag = torch.outer(
torch.arange(max_seq_len, dtype=dtype, device=device),
1.0
/ torch.pow(
theta, torch.arange(0, dim, 2, dtype=dtype, device=device).div(dim)
),
)
return torch.sin(freqs_imag)
@amp.autocast(enabled=False)
def rope_apply(x, grid_sizes, freqs):
n, c = x.size(2), x.size(3) // 2
c0 = c - 2 * (c // 3)
c1 = c // 3
c2 = c // 3
# split freqs
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
freqs_real = freqs[0].split([c0, c1, c2], dim=1)
freqs_imag = freqs[-1].split([c0, c1, c2], dim=1)
# loop over samples
output = []
@ -52,22 +83,36 @@ def rope_apply(x, grid_sizes, freqs):
seq_len = f * h * w
# precompute multipliers
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
seq_len, n, -1, 2))
freqs_i = torch.cat([
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
x_i = x[i, :seq_len].reshape(seq_len, n, c, 2)
x_real = x_i[..., 0]
x_imag = x_i[..., 1]
freqs_real = torch.cat(
[
freqs_real[0][:f].view(f, 1, 1, c0).expand(f, h, w, c0),
freqs_real[1][:h].view(1, h, 1, c1).expand(f, h, w, c1),
freqs_real[2][:w].view(1, 1, w, c2).expand(f, h, w, c2),
],
dim=-1).reshape(seq_len, 1, -1)
dim=-1,
).reshape(seq_len, 1, c)
freqs_imag = torch.cat(
[
freqs_imag[0][:f].view(f, 1, 1, c0).expand(f, h, w, c0),
freqs_imag[1][:h].view(1, h, 1, c1).expand(f, h, w, c1),
freqs_imag[2][:w].view(1, 1, w, c2).expand(f, h, w, c2),
],
dim=-1,
).reshape(seq_len, 1, c)
out_real = x_real * freqs_real - x_imag * freqs_imag
out_imag = x_real * freqs_imag + x_imag * freqs_real
# apply rotary embedding
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
x_i = torch.cat([x_i, x[i, seq_len:]])
x_out = torch.stack([out_real, out_imag], dim=-1).flatten(2)
x_out = torch.cat([x_out, x[i, seq_len:]], dim=0)
# append to collection
output.append(x_i)
return torch.stack(output).float()
output.append(x_out)
return torch.stack(output)
class WanRMSNorm(nn.Module):
@ -89,19 +134,6 @@ class WanRMSNorm(nn.Module):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
class WanLayerNorm(nn.LayerNorm):
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
def forward(self, x):
r"""
Args:
x(Tensor): Shape [B, L, C]
"""
return super().forward(x.float()).type_as(x)
class WanSelfAttention(nn.Module):
def __init__(self,
@ -256,10 +288,10 @@ class WanAttentionBlock(nn.Module):
self.eps = eps
# layers
self.norm1 = WanLayerNorm(dim, eps)
self.norm1 = nn.LayerNorm(dim, eps, elementwise_affine=False)
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
eps)
self.norm3 = WanLayerNorm(
self.norm3 = nn.LayerNorm(
dim, eps,
elementwise_affine=True) if cross_attn_norm else nn.Identity()
self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
@ -267,7 +299,7 @@ class WanAttentionBlock(nn.Module):
(-1, -1),
qk_norm,
eps)
self.norm2 = WanLayerNorm(dim, eps)
self.norm2 = nn.LayerNorm(dim, eps, elementwise_affine=False)
self.ffn = nn.Sequential(
nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
nn.Linear(ffn_dim, dim))
@ -293,23 +325,18 @@ class WanAttentionBlock(nn.Module):
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
assert e.dtype == torch.float32
with amp.autocast(dtype=torch.float32):
e = (self.modulation + e).chunk(6, dim=1)
assert e[0].dtype == torch.float32
# self-attention
y = self.self_attn(
self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes,
self.norm1(x) * (1 + e[1]) + e[0], seq_lens, grid_sizes,
freqs)
with amp.autocast(dtype=torch.float32):
x = x + y * e[2]
# cross-attention & ffn function
def cross_attn_ffn(x, context, context_lens, e):
x = x + self.cross_attn(self.norm3(x), context, context_lens)
y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
with amp.autocast(dtype=torch.float32):
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
x = x + y * e[5]
return x
@ -328,7 +355,7 @@ class Head(nn.Module):
# layers
out_dim = math.prod(patch_size) * out_dim
self.norm = WanLayerNorm(dim, eps)
self.norm = nn.LayerNorm(dim, eps, elementwise_affine=False)
self.head = nn.Linear(dim, out_dim)
# modulation
@ -340,10 +367,8 @@ class Head(nn.Module):
x(Tensor): Shape [B, L1, C]
e(Tensor): Shape [B, C]
"""
assert e.dtype == torch.float32
with amp.autocast(dtype=torch.float32):
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
x = self.head(self.norm(x) * (1 + e[1]) + e[0])
return x
@ -477,12 +502,23 @@ class WanModel(ModelMixin, ConfigMixin):
# buffers (don't use register_buffer otherwise dtype will be changed in to())
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
d = dim // num_heads
self.freqs = torch.cat([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6))
freqs_real = torch.cat(
[
rope_params_real(1024, d - 4 * (d // 6)),
rope_params_real(1024, 2 * (d // 6)),
rope_params_real(1024, 2 * (d // 6)),
],
dim=1)
dim=1,
)
freqs_imag = torch.cat(
[
rope_params_imag(1024, d - 4 * (d // 6)),
rope_params_imag(1024, 2 * (d // 6)),
rope_params_imag(1024, 2 * (d // 6)),
],
dim=1,
)
self.freqs = (freqs_real, freqs_imag)
if model_type == 'i2v' or model_type == 'flf2v':
self.img_emb = MLPProj(1280, dim, flf_pos_emb=model_type == 'flf2v')
@ -523,9 +559,13 @@ class WanModel(ModelMixin, ConfigMixin):
if self.model_type == 'i2v' or self.model_type == 'flf2v':
assert clip_fea is not None and y is not None
# params
dtype = self.patch_embedding.weight.dtype
device = self.patch_embedding.weight.device
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
if self.freqs[0].dtype != dtype or self.freqs[0].device != device:
self.freqs = (
self.freqs[0].to(dtype=dtype, device=device),
self.freqs[-1].to(dtype=dtype, device=device)
)
if y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
@ -543,11 +583,9 @@ class WanModel(ModelMixin, ConfigMixin):
])
# time embeddings
with amp.autocast(dtype=torch.float32):
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).float())
sinusoidal_embedding_1d(self.freq_dim, t))
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32
# context
context_lens = None
@ -579,7 +617,7 @@ class WanModel(ModelMixin, ConfigMixin):
# unpatchify
x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x]
return x
def unpatchify(self, x, grid_sizes):
r"""

View File

@ -7,7 +7,14 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from .tokenizers import HuggingfaceTokenizer
try:
import torch_musa
from torch_musa.core.device import current_device
except ModuleNotFoundError:
pass
from wan.modules.tokenizers import HuggingfaceTokenizer
from wan.utils.platform import get_device
__all__ = [
'T5Model',
@ -59,10 +66,8 @@ class T5LayerNorm(nn.Module):
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) +
self.eps)
if self.weight.dtype in [torch.float16, torch.bfloat16]:
x = x.type_as(self.weight)
return self.weight * x
@ -110,7 +115,7 @@ class T5Attention(nn.Module):
# compute attention (T5 does not use scaling)
attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
attn = F.softmax(attn, dim=-1)
x = torch.einsum('bnij,bjnc->binc', attn, v)
# output
@ -255,7 +260,7 @@ class T5RelativeEmbedding(nn.Module):
# embeddings for small and large positions
max_exact = num_buckets // 2
rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
rel_pos_large = max_exact + (torch.log(rel_pos / max_exact) /
math.log(self.max_dist / max_exact) *
(num_buckets - max_exact)).long()
rel_pos_large = torch.min(
@ -475,7 +480,7 @@ class T5EncoderModel:
self,
text_len,
dtype=torch.bfloat16,
device=torch.cuda.current_device(),
device=get_device(),
checkpoint_path=None,
tokenizer_path=None,
shard_fn=None,

View File

@ -4,6 +4,12 @@ import torch.cuda.amp as amp
import torch.nn as nn
from diffusers.configuration_utils import register_to_config
try:
import torch_musa
import torch_musa.core.amp as amp
except ModuleNotFoundError:
pass
from .model import WanAttentionBlock, WanModel, sinusoidal_embedding_1d

View File

@ -1,12 +1,20 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import logging
from math import sqrt
import torch
import torch.cuda.amp as amp
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Upsample
from einops import rearrange
try:
import torch_musa
except ModuleNotFoundError:
pass
from wan.utils.platform import get_device
__all__ = [
'WanVAE',
]
@ -44,23 +52,17 @@ class RMS_norm(nn.Module):
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.scale = sqrt(dim)
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
def forward(self, x):
return F.normalize(
x, dim=(1 if self.channel_first else
-1)) * self.scale * self.gamma + self.bias
class Upsample(nn.Upsample):
def forward(self, x):
"""
Fix bfloat16 support for nearest neighbor interpolation.
"""
return super().forward(x.float()).type_as(x)
return (
F.normalize(x.float(), dim=(1 if self.channel_first else -1)).type_as(x)
* self.scale
* self.gamma
+ self.bias
)
class Resample(nn.Module):
@ -253,6 +255,10 @@ class AttentionBlock(nn.Module):
q,
k,
v,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale=None,
)
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
@ -621,8 +627,8 @@ class WanVAE:
def __init__(self,
z_dim=16,
vae_pth='cache/vae_step_411000.pth',
dtype=torch.float,
device="cuda"):
dtype=torch.bfloat16,
device=get_device()):
self.dtype = dtype
self.device = device
@ -648,16 +654,12 @@ class WanVAE:
"""
videos: A list of videos each with shape [C, T, H, W].
"""
with amp.autocast(dtype=self.dtype):
return [
self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
for u in videos
self.model.encode(u.unsqueeze(0), self.scale).squeeze(0) for u in videos
]
def decode(self, zs):
with amp.autocast(dtype=self.dtype):
return [
self.model.decode(u.unsqueeze(0),
self.scale).float().clamp_(-1, 1).squeeze(0)
self.model.decode(u.unsqueeze(0), self.scale).clamp_(-1, 1).squeeze(0)
for u in zs
]

View File

@ -6,14 +6,24 @@ import os
import random
import sys
import types
from time import perf_counter
from contextlib import contextmanager
from functools import partial
import torch
import torch.cuda.amp as amp
from torch.cuda import empty_cache, synchronize
import torch.distributed as dist
from tqdm import tqdm
try:
import torch_musa
import torch_musa.core.amp as amp
from torch_musa.core.memory import empty_cache
from torch_musa.core.device import synchronize
except ModuleNotFoundError:
pass
from .distributed.fsdp import shard_model
from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
@ -24,7 +34,7 @@ from .utils.fm_solvers import (
retrieve_timesteps,
)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from .utils.platform import get_device
class WanT2V:
@ -60,7 +70,7 @@ class WanT2V:
t5_cpu (`bool`, *optional*, defaults to False):
Whether to place T5 model on CPU. Only works without t5_fsdp.
"""
self.device = torch.device(f"cuda:{device_id}")
self.device = get_device(device_id)
self.config = config
self.rank = rank
self.t5_cpu = t5_cpu
@ -171,6 +181,7 @@ class WanT2V:
seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed)
start_time = perf_counter()
if not self.t5_cpu:
self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device)
@ -182,6 +193,8 @@ class WanT2V:
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null]
end_time = perf_counter()
logging.info(f"T5 Encoding Context took {end_time - start_time:.2f} seconds.")
noise = [
torch.randn(
@ -230,13 +243,14 @@ class WanT2V:
arg_c = {'context': context, 'seq_len': seq_len}
arg_null = {'context': context_null, 'seq_len': seq_len}
start_time = perf_counter()
self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)):
latent_model_input = latents
timestep = [t]
timestep = torch.stack(timestep)
self.model.to(self.device)
noise_pred_cond = self.model(
latent_model_input, t=timestep, **arg_c)[0]
noise_pred_uncond = self.model(
@ -252,19 +266,24 @@ class WanT2V:
return_dict=False,
generator=seed_g)[0]
latents = [temp_x0.squeeze(0)]
end_time = perf_counter()
logging.info(f"Sampling took {end_time - start_time:.2f} seconds.")
x0 = latents
if offload_model:
self.model.cpu()
torch.cuda.empty_cache()
empty_cache()
if self.rank == 0:
start_time = perf_counter()
videos = self.vae.decode(x0)
end_time = perf_counter()
logging.info(f"VAE Decoding took {end_time - start_time:.2f} seconds.")
del noise, latents
del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
synchronize()
if dist.is_initialized():
dist.barrier()

View File

@ -5,9 +5,10 @@ from .fm_solvers import (
)
from .fm_solvers_unipc import FlowUniPCMultistepScheduler
from .vace_processor import VaceVideoProcessor
from .platform import get_device, get_torch_distributed_backend
__all__ = [
'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps',
'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler',
'VaceVideoProcessor'
'VaceVideoProcessor', 'get_device', 'get_torch_distributed_backend'
]

34
wan/utils/platform.py Normal file
View File

@ -0,0 +1,34 @@
from typing import Optional
import torch
try:
import torch_musa
except ModuleNotFoundError:
pass
def _is_musa():
try:
if torch.musa.is_available():
return True
except ModuleNotFoundError:
return False
def get_device(local_rank:Optional[int]=None) -> torch.device:
if torch.cuda.is_available():
return torch.cuda.current_device() if local_rank is None else torch.device("cuda", local_rank)
elif _is_musa():
return torch.musa.current_device() if local_rank is None else torch.device("musa", local_rank)
else:
return torch.device("cpu")
def get_torch_distributed_backend() -> str:
if torch.cuda.is_available():
return "nccl"
elif _is_musa():
return "mccl"
else:
raise NotImplementedError("No Accelerators(AMD/NV/MTT GPU, AMD MI instinct accelerators) available")

View File

@ -13,6 +13,7 @@ from functools import partial
import torch
import torch.cuda.amp as amp
from torch.cuda import empty_cache, synchronize
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F
@ -20,6 +21,14 @@ import torchvision.transforms.functional as TF
from PIL import Image
from tqdm import tqdm
try:
import torch_musa
import torch_musa.core.amp as amp
from torch_musa.core.memory import empty_cache
from torch_musa.core.device import synchronize
except ModuleNotFoundError:
pass
from .modules.vace_model import VaceWanModel
from .text2video import (
FlowDPMSolverMultistepScheduler,
@ -32,6 +41,7 @@ from .text2video import (
shard_model,
)
from .utils.vace_processor import VaceVideoProcessor
from .utils.platform import get_device, get_torch_distributed_backend
class WanVace(WanT2V):
@ -68,7 +78,7 @@ class WanVace(WanT2V):
t5_cpu (`bool`, *optional*, defaults to False):
Whether to place T5 model on CPU. Only works without t5_fsdp.
"""
self.device = torch.device(f"cuda:{device_id}")
self.device = get_device(device_id)
self.config = config
self.rank = rank
self.t5_cpu = t5_cpu
@ -460,7 +470,7 @@ class WanVace(WanT2V):
x0 = latents
if offload_model:
self.model.cpu()
torch.cuda.empty_cache()
empty_cache()
if self.rank == 0:
videos = self.decode_latent(x0, input_ref_images)
@ -468,7 +478,7 @@ class WanVace(WanT2V):
del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
synchronize()
if dist.is_initialized():
dist.barrier()
@ -568,7 +578,7 @@ class WanVaceMP(WanVace):
torch.cuda.set_device(gpu)
dist.init_process_group(
backend='nccl',
backend=get_torch_distributed_backend(),
init_method='env://',
rank=rank,
world_size=world_size)
@ -633,7 +643,7 @@ class WanVaceMP(WanVace):
model = shard_fn(model)
sample_neg_prompt = self.config.sample_neg_prompt
torch.cuda.empty_cache()
empty_cache()
event = initialized_events[gpu]
in_q = in_q_list[gpu]
event.set()
@ -748,7 +758,7 @@ class WanVaceMP(WanVace):
generator=seed_g)[0]
latents = [temp_x0.squeeze(0)]
torch.cuda.empty_cache()
empty_cache()
x0 = latents
if rank == 0:
videos = self.decode_latent(
@ -758,7 +768,7 @@ class WanVaceMP(WanVace):
del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
synchronize()
if dist.is_initialized():
dist.barrier()