[feature] adapt for Moore Threads GPU family

This commit is contained in:
Houchen Li 2025-08-06 20:05:57 +08:00
parent 7c81b2f27d
commit 6d7fc288d8
19 changed files with 1026 additions and 192 deletions

4
.gitignore vendored
View File

@ -21,7 +21,7 @@
*.html
*.pdf
*.whl
cache
*cache/
__pycache__/
storage/
samples/
@ -29,9 +29,11 @@ samples/
!requirements.txt
.DS_Store
*DS_Store
.vscode
google/
Wan2.1-T2V-14B/
Wan2.1-T2V-1.3B/
Wan2.1-I2V-14B-480P/
Wan2.1-I2V-14B-720P/
poetry.lock
logs/

View File

@ -12,12 +12,25 @@ import random
import torch
import torch.distributed as dist
from torch.cuda import set_device
from PIL import Image
try:
import torch_musa
from torch_musa.core.device import set_device
except ModuleNotFoundError:
torch_musa = None
import wan
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_image, cache_video, str2bool
from wan.utils.platform import (
get_device_type,
get_torch_distributed_backend,
get_torch_profiler_activities,
)
EXAMPLE_PROMPT = {
@ -243,6 +256,11 @@ def _parse_args():
type=float,
default=5.0,
help="Classifier free guidance scale.")
parser.add_argument(
"--profile",
action="store_true",
default=False,
help="profile the generating procedure.")
args = parser.parse_args()
@ -263,6 +281,30 @@ def _init_logging(rank):
logging.basicConfig(level=logging.ERROR)
def _init_profiler():
profiler = torch.profiler.profile(
activities=get_torch_profiler_activities(),
on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs'),
record_shapes=True,
profile_memory=True,
with_stack=True,
)
profiler.start()
return profiler
def _finalize_profiler(profiler):
profiler.stop()
table = profiler.key_averages().table(
sort_by=f"{get_device_type()}_time_total",
row_limit=20,
)
file_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
with open(f"logs/profiling-{file_name}.txt", "w") as f:
f.write(table)
del file_name
def generate(args):
rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
@ -275,9 +317,9 @@ def generate(args):
logging.info(
f"offload_model is not specified, set to {args.offload_model}.")
if world_size > 1:
torch.cuda.set_device(local_rank)
set_device(local_rank)
dist.init_process_group(
backend="nccl",
backend=get_torch_distributed_backend(),
init_method="env://",
rank=rank,
world_size=world_size)
@ -330,6 +372,10 @@ def generate(args):
dist.broadcast_object_list(base_seed, src=0)
args.base_seed = base_seed[0]
profiler = None
if args.profile and rank == 0:
profiler = _init_profiler()
if "t2v" in args.task or "t2i" in args.task:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
@ -366,10 +412,23 @@ def generate(args):
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
profiler=profiler,
)
logging.info(
f"Generating {'image' if 't2i' in args.task else 'video'} ...")
logging.info("Warming up WanT2V pipeline ...")
with torch.no_grad():
_ = wan_t2v.generate(
args.prompt,
size=SIZE_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=3,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
logging.info(f"Generating {'image' if 't2i' in args.task else 'video'} ...")
video = wan_t2v.generate(
args.prompt,
size=SIZE_CONFIGS[args.size],
@ -423,8 +482,23 @@ def generate(args):
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
profiler=profiler,
)
logging.info("Warming up WanI2V pipeline ...")
with torch.no_grad():
_ = wan_i2v.generate(
args.prompt,
img,
max_area=MAX_AREA_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=3,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
logging.info("Generating video ...")
video = wan_i2v.generate(
args.prompt,
@ -481,8 +555,24 @@ def generate(args):
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
profiler=profiler
)
logging.info("Warming up WanFLF2V pipeline ...")
with torch.no_grad():
_ = wan_flf2v.generate(
args.prompt,
first_frame,
last_frame,
max_area=MAX_AREA_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=3,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
logging.info("Generating video ...")
video = wan_flf2v.generate(
args.prompt,
@ -529,6 +619,7 @@ def generate(args):
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
profiler=profiler
)
src_video, src_mask, src_ref_images = wan_vace.prepare_source(
@ -537,6 +628,22 @@ def generate(args):
args.src_ref_images.split(',')
], args.frame_num, SIZE_CONFIGS[args.size], device)
logging.info("Warming up VACE pipeline ...")
with torch.no_grad():
_ = wan_vace.generate(
args.prompt,
src_video,
src_mask,
src_ref_images,
size=SIZE_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=3,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
logging.info(f"Generating video...")
video = wan_vace.generate(
args.prompt,
@ -554,6 +661,9 @@ def generate(args):
else:
raise ValueError(f"Unkown task type: {args.task}")
if args.profile and rank == 0:
_finalize_profiler(profiler)
if rank == 0:
if args.save_file is None:
formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")

View File

@ -3,11 +3,17 @@ import gc
from functools import partial
import torch
from torch.cuda import empty_cache
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
from torch.distributed.utils import _free_storage
try:
import torch_musa
from torch_musa.core.memory import empty_cache
except ModuleNotFoundError:
torch_musa = None
def shard_model(
model,
@ -40,4 +46,4 @@ def free_model(model):
_free_storage(m._handle.flat_param.data)
del model
gc.collect()
torch.cuda.empty_cache()
empty_cache()

View File

@ -6,7 +6,18 @@ from xfuser.core.distributed import (
get_sequence_parallel_world_size,
get_sp_group,
)
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from xfuser.core.long_ctx_attention import xFuserLongContextAttention, AttnType
attn_type:AttnType = AttnType.FA
from wan.modules.rope import rope_apply_pytorch, rope_apply_triton
try:
import torch_musa
import torch_musa.core.amp as amp
attn_type = AttnType.TORCH
torch.backends.mudnn.allow_tf32 = True
except ImportError:
torch_musa = None
from ..modules.model import sinusoidal_embedding_1d
@ -25,7 +36,7 @@ def pad_freqs(original_tensor, target_len):
@amp.autocast(enabled=False)
def rope_apply(x, grid_sizes, freqs):
def rope_apply(x, grid_sizes, freqs, sp_size, sp_rank):
"""
x: [B, L, N, C].
grid_sizes: [B, 3].
@ -51,8 +62,6 @@ def rope_apply(x, grid_sizes, freqs):
dim=-1).reshape(seq_len, 1, -1)
# apply rotary embedding
sp_size = get_sequence_parallel_world_size()
sp_rank = get_sequence_parallel_rank()
freqs_i = pad_freqs(freqs_i, s * sp_size)
s_per_rank = s
freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
@ -109,9 +118,13 @@ def usp_dit_forward(
if self.model_type == 'i2v':
assert clip_fea is not None and y is not None
# params
dtype = self.patch_embedding.weight.dtype
device = self.patch_embedding.weight.device
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
if self.freqs[0].dtype != dtype or self.freqs[0].device != device:
self.freqs = (
self.freqs[0].to(dtype=dtype, device=device),
self.freqs[-1].to(dtype=dtype, device=device)
)
if self.model_type != 'vace' and y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
@ -129,11 +142,9 @@ def usp_dit_forward(
])
# time embeddings
with amp.autocast(dtype=torch.float32):
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).float())
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t))
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
# context
context_lens = None
@ -177,7 +188,7 @@ def usp_dit_forward(
# unpatchify
x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x]
return x
def usp_attn_forward(self,
@ -200,8 +211,12 @@ def usp_attn_forward(self,
return q, k, v
q, k, v = qkv_fn(x)
q = rope_apply(q, grid_sizes, freqs)
k = rope_apply(k, grid_sizes, freqs)
if torch_musa is None:
q = rope_apply(q, grid_sizes, freqs, get_sequence_parallel_world_size(), get_sequence_parallel_rank())
k = rope_apply(k, grid_sizes, freqs, get_sequence_parallel_world_size(), get_sequence_parallel_rank())
else:
q = rope_apply_pytorch(q, grid_sizes, freqs, get_sequence_parallel_world_size(), get_sequence_parallel_rank())
k = rope_apply_pytorch(k, grid_sizes, freqs, get_sequence_parallel_world_size(), get_sequence_parallel_rank())
# TODO: We should use unpaded q,k,v for attention.
# k_lens = seq_lens // get_sequence_parallel_world_size()
@ -210,7 +225,7 @@ def usp_attn_forward(self,
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
x = xFuserLongContextAttention()(
x = xFuserLongContextAttention(attn_type=attn_type)(
None,
query=half(q),
key=half(k),

View File

@ -6,27 +6,40 @@ import os
import random
import sys
import types
from time import perf_counter
from contextlib import contextmanager
from functools import partial
import numpy as np
import torch
import torch.cuda.amp as amp
from torch.cuda import empty_cache, synchronize
import torch.distributed as dist
import torchvision.transforms.functional as TF
from tqdm import tqdm
from .distributed.fsdp import shard_model
from .modules.clip import CLIPModel
from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE
from .utils.fm_solvers import (
try:
import torch_musa
import torch_musa.core.amp as amp
from torch_musa.core.memory import empty_cache
from torch_musa.core.device import synchronize
torch.backends.mudnn.allow_tf32 = True
except ModuleNotFoundError:
torch_musa = None
from wan.distributed.fsdp import shard_model
from wan.modules.clip import CLIPModel
from wan.modules.model import WanModel
from wan.modules.t5 import T5EncoderModel
from wan.modules.vae import WanVAE
from wan.utils.fm_solvers import (
FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from wan.utils.platform import get_device
from wan.utils.memory_format import convert_conv3d_weight_memory_format
class WanFLF2V:
@ -42,6 +55,7 @@ class WanFLF2V:
use_usp=False,
t5_cpu=False,
init_on_cpu=True,
profiler=None,
):
r"""
Initializes the image-to-video generation model components.
@ -66,7 +80,7 @@ class WanFLF2V:
init_on_cpu (`bool`, *optional*, defaults to True):
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
"""
self.device = torch.device(f"cuda:{device_id}")
self.device = get_device(device_id)
self.config = config
self.rank = rank
self.use_usp = use_usp
@ -90,6 +104,7 @@ class WanFLF2V:
self.vae = WanVAE(
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
device=self.device)
convert_conv3d_weight_memory_format(self.vae.model, memory_format=torch.channels_last_3d)
self.clip = CLIPModel(
dtype=config.clip_dtype,
@ -121,7 +136,8 @@ class WanFLF2V:
self.sp_size = 1
if dist.is_initialized():
dist.barrier()
# dist.barrier()
pass
if dit_fsdp:
self.model = shard_fn(self.model)
else:
@ -129,6 +145,7 @@ class WanFLF2V:
self.model.to(self.device)
self.sample_neg_prompt = config.sample_neg_prompt
self.profiler = profiler
def generate(self,
input_prompt,
@ -183,6 +200,11 @@ class WanFLF2V:
- H: Frame height (from max_area)
- W: Frame width from max_area)
"""
start_time = 0.0
end_time = 0.0
if self.rank == 0:
start_time = perf_counter()
first_frame_size = first_frame.size
last_frame_size = last_frame.size
first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).to(
@ -275,6 +297,10 @@ class WanFLF2V:
])[0]
y = torch.concat([msk, y])
if self.rank == 0:
end_time = perf_counter()
logging.info(f"[preprocess and VAE encode] Elapsed time: {end_time - start_time:.2f} seconds")
@contextmanager
def noop_no_sync():
yield
@ -323,10 +349,16 @@ class WanFLF2V:
}
if offload_model:
torch.cuda.empty_cache()
empty_cache()
if self.rank == 0:
start_time = perf_counter()
self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)):
if self.profiler and self.rank == 0:
self.profiler.step()
latent_model_input = [latent.to(self.device)]
timestep = [t]
@ -336,12 +368,12 @@ class WanFLF2V:
latent_model_input, t=timestep, **arg_c)[0].to(
torch.device('cpu') if offload_model else self.device)
if offload_model:
torch.cuda.empty_cache()
empty_cache()
noise_pred_uncond = self.model(
latent_model_input, t=timestep, **arg_null)[0].to(
torch.device('cpu') if offload_model else self.device)
if offload_model:
torch.cuda.empty_cache()
empty_cache()
noise_pred = noise_pred_uncond + guide_scale * (
noise_pred_cond - noise_pred_uncond)
@ -356,22 +388,30 @@ class WanFLF2V:
generator=seed_g)[0]
latent = temp_x0.squeeze(0)
x0 = [latent.to(self.device)]
del latent_model_input, timestep
if self.rank == 0:
end_time = perf_counter()
logging.info(f"[sampling time steps] Elapsed time: {end_time - start_time:.2f} seconds")
x0 = [latent.to(self.device)]
del latent_model_input, timestep
if offload_model:
self.model.cpu()
torch.cuda.empty_cache()
empty_cache()
if self.rank == 0:
start_time = perf_counter()
videos = self.vae.decode(x0)
end_time = perf_counter()
logging.info(f"[VAE decoding] Elapsed time: {end_time - start_time:.2f} seconds")
del noise, latent
del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
synchronize()
if dist.is_initialized():
dist.barrier()
# dist.barrier()
pass
return videos[0] if self.rank == 0 else None

View File

@ -6,27 +6,40 @@ import os
import random
import sys
import types
from time import perf_counter
from contextlib import contextmanager
from functools import partial
import numpy as np
import torch
import torch.cuda.amp as amp
from torch.cuda import empty_cache, synchronize
import torch.distributed as dist
import torchvision.transforms.functional as TF
from tqdm import tqdm
from .distributed.fsdp import shard_model
from .modules.clip import CLIPModel
from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE
from .utils.fm_solvers import (
try:
import torch_musa
import torch_musa.core.amp as amp
from torch_musa.core.memory import empty_cache
from torch_musa.core.device import synchronize
torch.backends.mudnn.allow_tf32 = True
except ModuleNotFoundError:
torch_musa = None
from wan.distributed.fsdp import shard_model
from wan.modules.clip import CLIPModel
from wan.modules.model import WanModel
from wan.modules.t5 import T5EncoderModel
from wan.modules.vae import WanVAE
from wan.utils.fm_solvers import (
FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from wan.utils.platform import get_device
from wan.utils.memory_format import convert_conv3d_weight_memory_format
class WanI2V:
@ -42,6 +55,7 @@ class WanI2V:
use_usp=False,
t5_cpu=False,
init_on_cpu=True,
profiler=None,
):
r"""
Initializes the image-to-video generation model components.
@ -66,7 +80,7 @@ class WanI2V:
init_on_cpu (`bool`, *optional*, defaults to True):
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
"""
self.device = torch.device(f"cuda:{device_id}")
self.device = get_device(device_id)
self.config = config
self.rank = rank
self.use_usp = use_usp
@ -90,6 +104,7 @@ class WanI2V:
self.vae = WanVAE(
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
device=self.device)
convert_conv3d_weight_memory_format(self.vae.model, memory_format=torch.channels_last_3d)
self.clip = CLIPModel(
dtype=config.clip_dtype,
@ -121,7 +136,8 @@ class WanI2V:
self.sp_size = 1
if dist.is_initialized():
dist.barrier()
# dist.barrier()
pass
if dit_fsdp:
self.model = shard_fn(self.model)
else:
@ -129,6 +145,7 @@ class WanI2V:
self.model.to(self.device)
self.sample_neg_prompt = config.sample_neg_prompt
self.profiler = profiler
def generate(self,
input_prompt,
@ -178,6 +195,11 @@ class WanI2V:
- H: Frame height (from max_area)
- W: Frame width from max_area)
"""
start_time = 0.0
end_time = 0.0
if self.rank == 0:
start_time = perf_counter()
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
F = frame_num
@ -248,6 +270,10 @@ class WanI2V:
])[0]
y = torch.concat([msk, y])
if self.rank == 0:
end_time = perf_counter()
logging.info(f"[preprocess and VAE encode] Elapsed time: {end_time - start_time:.2f} seconds")
@contextmanager
def noop_no_sync():
yield
@ -296,10 +322,16 @@ class WanI2V:
}
if offload_model:
torch.cuda.empty_cache()
empty_cache()
if self.rank == 0:
start_time = perf_counter()
self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)):
if self.profiler and self.rank == 0:
self.profiler.step()
latent_model_input = [latent.to(self.device)]
timestep = [t]
@ -309,12 +341,12 @@ class WanI2V:
latent_model_input, t=timestep, **arg_c)[0].to(
torch.device('cpu') if offload_model else self.device)
if offload_model:
torch.cuda.empty_cache()
empty_cache()
noise_pred_uncond = self.model(
latent_model_input, t=timestep, **arg_null)[0].to(
torch.device('cpu') if offload_model else self.device)
if offload_model:
torch.cuda.empty_cache()
empty_cache()
noise_pred = noise_pred_uncond + guide_scale * (
noise_pred_cond - noise_pred_uncond)
@ -329,22 +361,30 @@ class WanI2V:
generator=seed_g)[0]
latent = temp_x0.squeeze(0)
x0 = [latent.to(self.device)]
del latent_model_input, timestep
if self.rank == 0:
end_time = perf_counter()
logging.info(f"[sampling time steps] Elapsed time: {end_time - start_time:.2f} seconds")
x0 = [latent.to(self.device)]
del latent_model_input, timestep
if offload_model:
self.model.cpu()
torch.cuda.empty_cache()
empty_cache()
if self.rank == 0:
start_time = perf_counter()
videos = self.vae.decode(x0)
end_time = perf_counter()
logging.info(f"[VAE decoding] Elapsed time: {end_time - start_time:.2f} seconds")
del noise, latent
del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
synchronize()
if dist.is_initialized():
dist.barrier()
# dist.barrier()
pass
return videos[0] if self.rank == 0 else None

View File

@ -1,4 +1,6 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import warnings
import torch
try:
@ -13,7 +15,14 @@ try:
except ModuleNotFoundError:
FLASH_ATTN_2_AVAILABLE = False
import warnings
try:
import torch_musa
FLASH_ATTN_3_AVAILABLE = False
FLASH_ATTN_2_AVAILABLE = False
torch.backends.mudnn.allow_tf32 = True
except ModuleNotFoundError:
torch_musa = None
__all__ = [
'flash_attention',
@ -51,7 +60,7 @@ def flash_attention(
"""
half_dtypes = (torch.float16, torch.bfloat16)
assert dtype in half_dtypes
assert q.device.type == 'cuda' and q.size(-1) <= 256
assert q.device.type in ("cuda", "musa") and q.size(-1) <= 256
# params
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
@ -172,8 +181,9 @@ def attention(
k = k.transpose(1, 2).to(dtype)
v = v.transpose(1, 2).to(dtype)
out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False):
out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=causal, scale=softmax_scale)
out = out.transpose(1, 2).contiguous()
return out

View File

@ -6,12 +6,20 @@ import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.cuda.amp as amp
import torchvision.transforms as T
from .attention import flash_attention
from .tokenizers import HuggingfaceTokenizer
from .xlm_roberta import XLMRoberta
try:
import torch_musa
import torch_musa.core.amp as amp
from .attention import attention as flash_attention
except ModuleNotFoundError:
torch_musa = None
__all__ = [
'XLMRobertaCLIP',
'clip_xlm_roberta_vit_h_14',
@ -29,7 +37,7 @@ def pos_interpolate(pos, seq_len):
return torch.cat([
pos[:, :n],
F.interpolate(
pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
pos[:, n:].reshape(1, src_grid, src_grid, -1).permute(
0, 3, 1, 2),
size=(tar_grid, tar_grid),
mode='bicubic',
@ -44,12 +52,6 @@ class QuickGELU(nn.Module):
return x * torch.sigmoid(1.702 * x)
class LayerNorm(nn.LayerNorm):
def forward(self, x):
return super().forward(x.float()).type_as(x)
class SelfAttention(nn.Module):
def __init__(self,
@ -82,7 +84,7 @@ class SelfAttention(nn.Module):
# compute attention
p = self.attn_dropout if self.training else 0.0
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal)
x = x.reshape(b, s, c)
# output
@ -131,10 +133,10 @@ class AttentionBlock(nn.Module):
self.norm_eps = norm_eps
# layers
self.norm1 = LayerNorm(dim, eps=norm_eps)
self.norm1 = nn.LayerNorm(dim, eps=norm_eps)
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
proj_dropout)
self.norm2 = LayerNorm(dim, eps=norm_eps)
self.norm2 = nn.LayerNorm(dim, eps=norm_eps)
if activation == 'swi_glu':
self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
else:
@ -177,7 +179,7 @@ class AttentionPool(nn.Module):
self.to_q = nn.Linear(dim, dim)
self.to_kv = nn.Linear(dim, dim * 2)
self.proj = nn.Linear(dim, dim)
self.norm = LayerNorm(dim, eps=norm_eps)
self.norm = nn.LayerNorm(dim, eps=norm_eps)
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
@ -259,13 +261,13 @@ class VisionTransformer(nn.Module):
self.dropout = nn.Dropout(embedding_dropout)
# transformer
self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
self.pre_norm = nn.LayerNorm(dim, eps=norm_eps) if pre_norm else None
self.transformer = nn.Sequential(*[
AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
activation, attn_dropout, proj_dropout, norm_eps)
for _ in range(num_layers)
])
self.post_norm = LayerNorm(dim, eps=norm_eps)
self.post_norm = nn.LayerNorm(dim, eps=norm_eps)
# head
if pool_type == 'token':
@ -537,6 +539,6 @@ class CLIPModel:
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
# forward
with torch.cuda.amp.autocast(dtype=self.dtype):
with amp.autocast(dtype=self.dtype):
out = self.model.visual(videos, use_31_block=True)
return out

View File

@ -7,7 +7,16 @@ import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from .attention import flash_attention
from wan.modules.attention import flash_attention
from wan.modules.rope import rope_apply_pytorch
try:
import torch_musa
import torch_musa.core.amp as amp
from wan.modules.attention import attention as flash_attention
torch.backends.mudnn.allow_tf32 = True
except ModuleNotFoundError:
torch_musa = None
__all__ = ['WanModel']
@ -19,7 +28,7 @@ def sinusoidal_embedding_1d(dim, position):
# preprocess
assert dim % 2 == 0
half = dim // 2
position = position.type(torch.float64)
position = position.type(torch.bfloat16)
# calculation
sinusoid = torch.outer(
@ -29,14 +38,33 @@ def sinusoidal_embedding_1d(dim, position):
@amp.autocast(enabled=False)
def rope_params(max_seq_len, dim, theta=10000):
def rope_params_real(
max_seq_len, dim, theta=10000, dtype=torch.float32, device=torch.device("cpu")
):
assert dim % 2 == 0
freqs = torch.outer(
torch.arange(max_seq_len),
1.0 / torch.pow(theta,
torch.arange(0, dim, 2).to(torch.float64).div(dim)))
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
freqs_real = torch.outer(
torch.arange(max_seq_len, dtype=dtype, device=device),
1.0
/ torch.pow(
theta, torch.arange(0, dim, 2, dtype=dtype, device=device).div(dim)
),
)
return torch.cos(freqs_real)
@amp.autocast(enabled=False)
def rope_params_imag(
max_seq_len, dim, theta=10000, dtype=torch.float32, device=torch.device("cpu")
):
assert dim % 2 == 0
freqs_imag = torch.outer(
torch.arange(max_seq_len, dtype=dtype, device=device),
1.0
/ torch.pow(
theta, torch.arange(0, dim, 2, dtype=dtype, device=device).div(dim)
),
)
return torch.sin(freqs_imag)
@amp.autocast(enabled=False)
@ -89,19 +117,6 @@ class WanRMSNorm(nn.Module):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
class WanLayerNorm(nn.LayerNorm):
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
def forward(self, x):
r"""
Args:
x(Tensor): Shape [B, L, C]
"""
return super().forward(x.float()).type_as(x)
class WanSelfAttention(nn.Module):
def __init__(self,
@ -145,10 +160,16 @@ class WanSelfAttention(nn.Module):
return q, k, v
q, k, v = qkv_fn(x)
if torch_musa is None:
q = rope_apply(q, grid_sizes, freqs)
k = rope_apply(k, grid_sizes, freqs)
else:
q = rope_apply_pytorch(q, grid_sizes, freqs)
k = rope_apply_pytorch(k, grid_sizes, freqs)
x = flash_attention(
q=rope_apply(q, grid_sizes, freqs),
k=rope_apply(k, grid_sizes, freqs),
q=q,
k=k,
v=v,
k_lens=seq_lens,
window_size=self.window_size)
@ -256,10 +277,10 @@ class WanAttentionBlock(nn.Module):
self.eps = eps
# layers
self.norm1 = WanLayerNorm(dim, eps)
self.norm1 = nn.LayerNorm(dim, eps, elementwise_affine=False)
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
eps)
self.norm3 = WanLayerNorm(
self.norm3 = nn.LayerNorm(
dim, eps,
elementwise_affine=True) if cross_attn_norm else nn.Identity()
self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
@ -267,7 +288,7 @@ class WanAttentionBlock(nn.Module):
(-1, -1),
qk_norm,
eps)
self.norm2 = WanLayerNorm(dim, eps)
self.norm2 = nn.LayerNorm(dim, eps, elementwise_affine=False)
self.ffn = nn.Sequential(
nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
nn.Linear(ffn_dim, dim))
@ -293,24 +314,19 @@ class WanAttentionBlock(nn.Module):
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
assert e.dtype == torch.float32
with amp.autocast(dtype=torch.float32):
e = (self.modulation + e).chunk(6, dim=1)
assert e[0].dtype == torch.float32
e = (self.modulation + e).chunk(6, dim=1)
# self-attention
y = self.self_attn(
self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes,
self.norm1(x) * (1 + e[1]) + e[0], seq_lens, grid_sizes,
freqs)
with amp.autocast(dtype=torch.float32):
x = x + y * e[2]
x = x + y * e[2]
# cross-attention & ffn function
def cross_attn_ffn(x, context, context_lens, e):
x = x + self.cross_attn(self.norm3(x), context, context_lens)
y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
with amp.autocast(dtype=torch.float32):
x = x + y * e[5]
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
x = x + y * e[5]
return x
x = cross_attn_ffn(x, context, context_lens, e)
@ -328,7 +344,7 @@ class Head(nn.Module):
# layers
out_dim = math.prod(patch_size) * out_dim
self.norm = WanLayerNorm(dim, eps)
self.norm = nn.LayerNorm(dim, eps, elementwise_affine=False)
self.head = nn.Linear(dim, out_dim)
# modulation
@ -340,10 +356,8 @@ class Head(nn.Module):
x(Tensor): Shape [B, L1, C]
e(Tensor): Shape [B, C]
"""
assert e.dtype == torch.float32
with amp.autocast(dtype=torch.float32):
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
x = self.head(self.norm(x) * (1 + e[1]) + e[0])
return x
@ -477,12 +491,33 @@ class WanModel(ModelMixin, ConfigMixin):
# buffers (don't use register_buffer otherwise dtype will be changed in to())
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
d = dim // num_heads
self.freqs = torch.cat([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6))
],
dim=1)
if torch_musa is None:
self.freqs = torch.cat(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
dim=1,
)
else:
freqs_real = torch.cat(
[
rope_params_real(1024, d - 4 * (d // 6)),
rope_params_real(1024, 2 * (d // 6)),
rope_params_real(1024, 2 * (d // 6)),
],
dim=1,
)
freqs_imag = torch.cat(
[
rope_params_imag(1024, d - 4 * (d // 6)),
rope_params_imag(1024, 2 * (d // 6)),
rope_params_imag(1024, 2 * (d // 6)),
],
dim=1,
)
self.freqs = (freqs_real, freqs_imag)
if model_type == 'i2v' or model_type == 'flf2v':
self.img_emb = MLPProj(1280, dim, flf_pos_emb=model_type == 'flf2v')
@ -523,9 +558,17 @@ class WanModel(ModelMixin, ConfigMixin):
if self.model_type == 'i2v' or self.model_type == 'flf2v':
assert clip_fea is not None and y is not None
# params
dtype = self.patch_embedding.weight.dtype
device = self.patch_embedding.weight.device
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
if torch_musa is None:
if self.freqs.dtype != dtype or self.freqs.device != device:
self.freqs = self.freqs.to(dtype=dtype, device=device)
else:
if self.freqs[0].dtype != dtype or self.freqs[0].device != device:
self.freqs = (
self.freqs[0].to(dtype=dtype, device=device),
self.freqs[-1].to(dtype=dtype, device=device),
)
if y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
@ -543,11 +586,9 @@ class WanModel(ModelMixin, ConfigMixin):
])
# time embeddings
with amp.autocast(dtype=torch.float32):
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).float())
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t))
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
# context
context_lens = None
@ -579,7 +620,7 @@ class WanModel(ModelMixin, ConfigMixin):
# unpatchify
x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x]
return x
def unpatchify(self, x, grid_sizes):
r"""

317
wan/modules/rope.py Normal file
View File

@ -0,0 +1,317 @@
from typing import Optional, Tuple
import triton
import triton.language as tl
import torch
def pad_tensor(
original_tensor: torch.tensor, target_len: int, pad_value: float = 0.0
) -> torch.tensor:
seq_len, s1, s2 = original_tensor.shape
pad_size = target_len - seq_len
padding_tensor = torch.full(
(pad_size, s1, s2),
pad_value,
dtype=original_tensor.dtype,
device=original_tensor.device,
)
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
return padded_tensor
def rope_apply_pytorch(
x: torch.tensor,
grid_sizes: torch.tensor,
freqs: Tuple[torch.tensor],
sp_size: Optional[int] = None,
sp_rank: Optional[int] = None,
) -> torch.tensor:
"""
x: [B, L, N, C].
grid_sizes: [B, 3].
freqs: [M, C // 2].
"""
s, n, c = x.size(1), x.size(2), x.size(3) // 2
c0 = c - 2 * (c // 3)
c1 = c // 3
c2 = c // 3
# split freqs
freqs_real = freqs[0].split([c0, c1, c2], dim=1)
freqs_imag = freqs[-1].split([c0, c1, c2], dim=1)
# loop over samples
output = []
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
seq_len = f * h * w
# precompute multipliers
x_i = x[i, :seq_len].reshape(s, n, -1, 2)
x_real = x_i[..., 0]
x_imag = x_i[..., 1]
freqs_real = torch.cat(
[
freqs_real[0][:f].view(f, 1, 1, c0).expand(f, h, w, c0),
freqs_real[1][:h].view(1, h, 1, c1).expand(f, h, w, c1),
freqs_real[2][:w].view(1, 1, w, c2).expand(f, h, w, c2),
],
dim=-1,
).reshape(seq_len, 1, -1)
freqs_imag = torch.cat(
[
freqs_imag[0][:f].view(f, 1, 1, c0).expand(f, h, w, c0),
freqs_imag[1][:h].view(1, h, 1, c1).expand(f, h, w, c1),
freqs_imag[2][:w].view(1, 1, w, c2).expand(f, h, w, c2),
],
dim=-1,
).reshape(seq_len, 1, -1)
if sp_rank is None:
freqs_real_rank = freqs_real
freqs_imag_rank = freqs_imag
else:
freqs_real = pad_tensor(freqs_real, s * sp_size, 1.0)
freqs_imag = pad_tensor(freqs_imag, s * sp_size, 0.0)
freqs_real_rank = freqs_real[(sp_rank * s) : ((sp_rank + 1) * s), :, :]
freqs_imag_rank = freqs_imag[(sp_rank * s) : ((sp_rank + 1) * s), :, :]
out_real = x_real * freqs_real_rank - x_imag * freqs_imag_rank
out_imag = x_real * freqs_imag_rank + x_imag * freqs_real_rank
x_out = torch.stack([out_real, out_imag], dim=-1).flatten(2)
x_out = torch.cat([x_out, x[i, seq_len:]], dim=0)
# append to collection
output.append(x_out)
return torch.stack(output)
@triton.jit
def rope_kernel(
x_ptr, # [B, S, N, 2C]
grid_sizes_ptr, # [B, 3]
freqs_real_ptr, # [M, C]
freqs_imag_ptr, # [M, C]
output_ptr, # [B, S, N, 2C]
sp_size, # SP world size
sp_rank, # SP rank
B,
S,
N: tl.constexpr,
C: tl.constexpr,
M: tl.constexpr,
CfM: tl.constexpr,
ChM: tl.constexpr,
CwM: tl.constexpr,
SEQ_BLOCK: tl.constexpr,
HEADS_BLOCK: tl.constexpr,
):
Cf = C - 2 * (C // 3)
Ch = C // 3
Cw = C // 3
batch_idx = tl.program_id(0)
seqlen_group_idx = tl.program_id(1)
head_group_idx = tl.program_id(2)
base = batch_idx * 3
F = tl.load(grid_sizes_ptr + base + 0)
H = tl.load(grid_sizes_ptr + base + 1)
W = tl.load(grid_sizes_ptr + base + 2)
seq_len = F * H * W
global_offset = sp_rank * S + seqlen_group_idx * SEQ_BLOCK
seq_indices = global_offset + tl.arange(0, SEQ_BLOCK)
limit = tl.minimum(seq_len, S * sp_size)
seq_mask = seq_indices < limit
seq_indices = tl.where(seq_mask, seq_indices, 0)
HW = H * W
f_idx = seq_indices // HW
rem = seq_indices - f_idx * HW
h_idx = rem // W
w_idx = rem - h_idx * W
freq_offset_cf = tl.arange(0, CfM) # 第1段列偏移 [0, Cf)
freq_offset_ch = Cf + tl.arange(0, ChM) # 第2段列偏移 [Cf, Cf+Ch)
freq_offset_cw = Cf + Ch + tl.arange(0, CwM) # 第3段列偏移 [Cf+Ch, C)
# 按照每个序列位置取对应频率值 (利用广播计算每个位置不同行的值)
# 频率表取值地址 = idx * C + col_offset
freq_addr_cf = f_idx[:, None] * C + freq_offset_cf[None, :]
freq_addr_ch = h_idx[:, None] * C + freq_offset_ch[None, :]
freq_addr_cw = w_idx[:, None] * C + freq_offset_cw[None, :]
freqs_real_cf = tl.load(
freqs_real_ptr + freq_addr_cf,
mask=(seq_mask[:, None] & (freq_offset_cf[None, :] < Cf)),
other=1.0,
).to(tl.float32)
freqs_imag_cf = tl.load(
freqs_imag_ptr + freq_addr_cf,
mask=(seq_mask[:, None] & (freq_offset_cf[None, :] < Cf)),
other=1.0,
).to(tl.float32)
freqs_real_ch = tl.load(
freqs_real_ptr + freq_addr_ch,
mask=(seq_mask[:, None] & (freq_offset_ch[None, :] < Cf + Ch)),
other=1.0,
).to(tl.float32)
freqs_imag_ch = tl.load(
freqs_imag_ptr + freq_addr_ch,
mask=(seq_mask[:, None] & (freq_offset_ch[None, :] < Cf + Ch)),
other=1.0,
).to(tl.float32)
freqs_real_cw = tl.load(
freqs_real_ptr + freq_addr_cw,
mask=(seq_mask[:, None] & (freq_offset_cw[None, :] < C)),
other=1.0,
).to(tl.float32)
freqs_imag_cw = tl.load(
freqs_imag_ptr + freq_addr_cw,
mask=(seq_mask[:, None] & (freq_offset_cw[None, :] < C)),
other=1.0,
).to(tl.float32)
# 将频率值扩展维度以便与x相乘 (在head维度上广播)
freqs_real_cf = freqs_real_cf[:, None, :] # [SEQ_BLOCK, 1, Cf]
freqs_imag_cf = freqs_imag_cf[:, None, :]
freqs_real_ch = freqs_real_ch[:, None, :]
freqs_imag_ch = freqs_imag_ch[:, None, :]
freqs_real_cw = freqs_real_cw[:, None, :]
freqs_imag_cw = freqs_imag_cw[:, None, :]
# 加载输入x对应块的实部和虚部 (形状: [SEQ_BLOCK, HEADS_BLOCK, C])
seq_offset = seqlen_group_idx * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK)
head_offset = head_group_idx * HEADS_BLOCK + tl.arange(0, HEADS_BLOCK)
# 计算x_ptr偏移地址
base_offset = batch_idx * S * N * 2 * C
seq_head_offset = (
base_offset
+ seq_offset[:, None, None] * (N * 2 * C)
+ head_offset[None, :, None] * (2 * C)
)
x_mask = (seq_offset < S)[:, None, None] & (head_offset < N)[None, :, None]
# 加载输入 x 的对应通道段数据超出实际长度部分掩码为0
# 段1通道 [0, Cf-1]
chan_cf = tl.arange(0, CfM * 2)
mask_2cf_chan = chan_cf < Cf * 2
x_cf = tl.load(
x_ptr + seq_head_offset + chan_cf[None, None, :],
mask=(x_mask & mask_2cf_chan[None, None, :]),
other=0.0,
).to(tl.float32)
x_cf = x_cf.reshape(
SEQ_BLOCK, HEADS_BLOCK, CfM, 2
) # [SEQ_BLOCK, HEADS_BLOCK, CfM, 2]
x_real_cf, x_imag_cf = x_cf.split()
# 计算 RoPE 旋转段1
out_real_cf = x_real_cf * freqs_real_cf - x_imag_cf * freqs_imag_cf
out_imag_cf = x_real_cf * freqs_imag_cf + x_imag_cf * freqs_real_cf
out_cf = tl.interleave(out_real_cf, out_imag_cf) # [SEQ_BLOCK, HEADS_BLOCK, CfM, 2]
tl.store(
output_ptr + seq_head_offset + chan_cf[None, None, :],
out_cf,
mask=(x_mask & mask_2cf_chan[None, None, :]),
)
# 段2通道 [Cf, Cf+Ch-1]
chan_ch = tl.arange(0, ChM * 2) + Cf * 2
mask_2ch_chan = chan_ch < 2 * (Cf + Ch)
x_ch = tl.load(
x_ptr + seq_head_offset + chan_ch[None, None, :],
mask=(x_mask & mask_2ch_chan[None, None, :]),
other=0.0,
).to(tl.float32)
x_ch = x_ch.reshape(SEQ_BLOCK, HEADS_BLOCK, ChM, 2)
x_real_ch, x_imag_ch = x_ch.split()
out_real_ch = x_real_ch * freqs_real_ch - x_imag_ch * freqs_imag_ch
out_imag_ch = x_real_ch * freqs_imag_ch + x_imag_ch * freqs_real_ch
out_ch = tl.interleave(out_real_ch, out_imag_ch) # [SEQ_BLOCK, HEADS_BLOCK, ChM, 2]
tl.store(
output_ptr + seq_head_offset + chan_ch[None, None, :],
out_ch,
mask=(x_mask & mask_2ch_chan[None, None, :]),
)
# 段3通道 [Cf+Ch, C-1]
chan_cw = tl.arange(0, CwM * 2) + (Cf + Ch) * 2
mask_2cw_chan = chan_cw < 2 * C
x_cw = tl.load(
x_ptr + seq_head_offset + chan_cw[None, None, :],
mask=(x_mask & mask_2cw_chan[None, None, :]),
other=0.0,
).to(tl.float32)
x_cw = x_cw.reshape(SEQ_BLOCK, HEADS_BLOCK, CwM, 2)
x_real_cw, x_imag_cw = x_cw.split()
out_real_cw = x_real_cw * freqs_real_cw - x_imag_cw * freqs_imag_cw
out_imag_cw = x_real_cw * freqs_imag_cw + x_imag_cw * freqs_real_cw
out_cw = tl.interleave(out_real_cw, out_imag_cw)
tl.store(
output_ptr + seq_head_offset + chan_cw[None, None, :],
out_cw,
mask=(x_mask & mask_2cw_chan[None, None, :]),
)
@torch._dynamo.disable
def rope_apply_triton(
x: torch.tensor,
grid_sizes: torch.tensor,
freqs: Tuple[torch.tensor],
sp_size: Optional[int] = None,
sp_rank: Optional[int] = None,
) -> torch.tensor:
"""
x: [1, 9450, 40, 128]
grid_sizes: [[21, 45, 80]]
freqs_real: [1024, 64]
freqs_imag: [1024, 64]
"""
B, S, N, C = x.shape
C = C // 2
Cf = C - 2 * (C // 3) # 第一维度频率长度
Ch = C // 3 # 第二维度频率长度
Cw = C // 3 # 第三维度频率长度
M = freqs[0].shape[0]
SEQ_BLOCK = 64 # 每个线程块处理的序列长度
HEADS_BLOCK = 8 # 每个线程块处理的头数
if sp_rank is None:
sp_size = 1
sp_rank = 0
grid_sizes = grid_sizes.to(device=x.device)
output = torch.empty_like(x)
rope_kernel[(B, triton.cdiv(S, SEQ_BLOCK), triton.cdiv(N, HEADS_BLOCK))](
x,
grid_sizes,
freqs[0],
freqs[-1],
output,
sp_size,
sp_rank,
B,
S,
N=N,
C=C,
M=M,
CfM=triton.next_power_of_2(Cf),
ChM=triton.next_power_of_2(Ch),
CwM=triton.next_power_of_2(Cw),
SEQ_BLOCK=SEQ_BLOCK,
HEADS_BLOCK=HEADS_BLOCK,
num_warps=32,
num_stages=3,
)
return output.float()

View File

@ -7,7 +7,13 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from .tokenizers import HuggingfaceTokenizer
try:
import torch_musa
except ModuleNotFoundError:
torch_musa = None
from wan.modules.tokenizers import HuggingfaceTokenizer
from wan.utils.platform import get_device
__all__ = [
'T5Model',
@ -59,10 +65,8 @@ class T5LayerNorm(nn.Module):
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) +
self.eps)
if self.weight.dtype in [torch.float16, torch.bfloat16]:
x = x.type_as(self.weight)
return self.weight * x
@ -110,7 +114,7 @@ class T5Attention(nn.Module):
# compute attention (T5 does not use scaling)
attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
attn = F.softmax(attn, dim=-1)
x = torch.einsum('bnij,bjnc->binc', attn, v)
# output
@ -255,7 +259,7 @@ class T5RelativeEmbedding(nn.Module):
# embeddings for small and large positions
max_exact = num_buckets // 2
rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
rel_pos_large = max_exact + (torch.log(rel_pos / max_exact) /
math.log(self.max_dist / max_exact) *
(num_buckets - max_exact)).long()
rel_pos_large = torch.min(
@ -475,7 +479,7 @@ class T5EncoderModel:
self,
text_len,
dtype=torch.bfloat16,
device=torch.cuda.current_device(),
device=get_device(),
checkpoint_path=None,
tokenizer_path=None,
shard_fn=None,

View File

@ -4,6 +4,12 @@ import torch.cuda.amp as amp
import torch.nn as nn
from diffusers.configuration_utils import register_to_config
try:
import torch_musa
import torch_musa.core.amp as amp
except ModuleNotFoundError:
torch_musa = None
from .model import WanAttentionBlock, WanModel, sinusoidal_embedding_1d

View File

@ -1,12 +1,21 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import logging
from math import sqrt
import torch
import torch.cuda.amp as amp
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Upsample
from einops import rearrange
try:
import torch_musa
torch.backends.mudnn.allow_tf32 = True
except ModuleNotFoundError:
torch_musa = None
from wan.utils.platform import get_device_type
__all__ = [
'WanVAE',
]
@ -44,23 +53,17 @@ class RMS_norm(nn.Module):
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.scale = sqrt(dim)
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
def forward(self, x):
return F.normalize(
x, dim=(1 if self.channel_first else
-1)) * self.scale * self.gamma + self.bias
class Upsample(nn.Upsample):
def forward(self, x):
"""
Fix bfloat16 support for nearest neighbor interpolation.
"""
return super().forward(x.float()).type_as(x)
return (
F.normalize(x.float(), dim=(1 if self.channel_first else -1)).type_as(x)
* self.scale
* self.gamma
+ self.bias
)
class Resample(nn.Module):
@ -253,6 +256,10 @@ class AttentionBlock(nn.Module):
q,
k,
v,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale=None,
)
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
@ -621,8 +628,8 @@ class WanVAE:
def __init__(self,
z_dim=16,
vae_pth='cache/vae_step_411000.pth',
dtype=torch.float,
device="cuda"):
dtype=torch.bfloat16,
device=get_device_type()):
self.dtype = dtype
self.device = device
@ -648,16 +655,12 @@ class WanVAE:
"""
videos: A list of videos each with shape [C, T, H, W].
"""
with amp.autocast(dtype=self.dtype):
return [
self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
for u in videos
]
return [
self.model.encode(u.unsqueeze(0), self.scale).squeeze(0) for u in videos
]
def decode(self, zs):
with amp.autocast(dtype=self.dtype):
return [
self.model.decode(u.unsqueeze(0),
self.scale).float().clamp_(-1, 1).squeeze(0)
for u in zs
]
return [
self.model.decode(u.unsqueeze(0), self.scale).clamp_(-1, 1).squeeze(0)
for u in zs
]

View File

@ -6,24 +6,37 @@ import os
import random
import sys
import types
from time import perf_counter
from contextlib import contextmanager
from functools import partial
import torch
import torch.cuda.amp as amp
from torch.cuda import empty_cache, synchronize
import torch.distributed as dist
from tqdm import tqdm
from .distributed.fsdp import shard_model
from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE
from .utils.fm_solvers import (
try:
import torch_musa
import torch_musa.core.amp as amp
from torch_musa.core.memory import empty_cache
from torch_musa.core.device import synchronize
torch.backends.mudnn.allow_tf32 = True
except ModuleNotFoundError:
torch_musa = None
from wan.distributed.fsdp import shard_model
from wan.modules.model import WanModel
from wan.modules.t5 import T5EncoderModel
from wan.modules.vae import WanVAE
from wan.utils.fm_solvers import (
FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from wan.utils.platform import get_device
from wan.utils.memory_format import convert_conv3d_weight_memory_format
class WanT2V:
@ -38,6 +51,7 @@ class WanT2V:
dit_fsdp=False,
use_usp=False,
t5_cpu=False,
profiler=None,
):
r"""
Initializes the Wan text-to-video generation model components.
@ -60,7 +74,7 @@ class WanT2V:
t5_cpu (`bool`, *optional*, defaults to False):
Whether to place T5 model on CPU. Only works without t5_fsdp.
"""
self.device = torch.device(f"cuda:{device_id}")
self.device = get_device(device_id)
self.config = config
self.rank = rank
self.t5_cpu = t5_cpu
@ -82,6 +96,7 @@ class WanT2V:
self.vae = WanVAE(
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
device=self.device)
convert_conv3d_weight_memory_format(self.vae.model, memory_format=torch.channels_last_3d)
logging.info(f"Creating WanModel from {checkpoint_dir}")
self.model = WanModel.from_pretrained(checkpoint_dir)
@ -103,13 +118,15 @@ class WanT2V:
self.sp_size = 1
if dist.is_initialized():
dist.barrier()
# dist.barrier()
pass
if dit_fsdp:
self.model = shard_fn(self.model)
else:
self.model.to(self.device)
self.sample_neg_prompt = config.sample_neg_prompt
self.profiler = profiler
def generate(self,
input_prompt,
@ -155,6 +172,11 @@ class WanT2V:
- H: Frame height (from size)
- W: Frame width from size)
"""
start_time = 0.0
end_time = 0.0
if self.rank == 0:
start_time = perf_counter()
# preprocess
F = frame_num
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
@ -194,6 +216,10 @@ class WanT2V:
generator=seed_g)
]
if self.rank == 0:
end_time = perf_counter()
logging.info(f"[preprocess] Elapsed time: {end_time - start_time:.2f} seconds")
@contextmanager
def noop_no_sync():
yield
@ -230,13 +256,19 @@ class WanT2V:
arg_c = {'context': context, 'seq_len': seq_len}
arg_null = {'context': context_null, 'seq_len': seq_len}
if self.rank == 0:
start_time = perf_counter()
self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)):
if self.profiler and self.rank == 0:
self.profiler.step()
latent_model_input = latents
timestep = [t]
timestep = torch.stack(timestep)
self.model.to(self.device)
noise_pred_cond = self.model(
latent_model_input, t=timestep, **arg_c)[0]
noise_pred_uncond = self.model(
@ -253,19 +285,27 @@ class WanT2V:
generator=seed_g)[0]
latents = [temp_x0.squeeze(0)]
if self.rank == 0:
end_time = perf_counter()
logging.info(f"[sampling time steps] Elapsed time: {end_time - start_time:.2f} seconds")
x0 = latents
if offload_model:
self.model.cpu()
torch.cuda.empty_cache()
empty_cache()
if self.rank == 0:
start_time = perf_counter()
videos = self.vae.decode(x0)
end_time = perf_counter()
logging.info(f"[VAE decoding] Elapsed time: {end_time - start_time:.2f} seconds")
del noise, latents
del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
synchronize()
if dist.is_initialized():
dist.barrier()
# dist.barrier()
pass
return videos[0] if self.rank == 0 else None

View File

@ -5,9 +5,13 @@ from .fm_solvers import (
)
from .fm_solvers_unipc import FlowUniPCMultistepScheduler
from .vace_processor import VaceVideoProcessor
from .platform import get_device, get_device_type, get_torch_distributed_backend
from .memory_format import convert_conv3d_weight_memory_format
from .chrono_inspector import ChronoInspector
__all__ = [
'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps',
'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler',
'VaceVideoProcessor'
'VaceVideoProcessor', 'get_device', 'get_device_type', 'get_torch_distributed_backend',
'convert_conv3d_weight_memory_format', 'ChronoInspector'
]

View File

@ -0,0 +1,15 @@
from time import perf_counter
from logging import info
class ChronoInspector(object):
def __init__(self, name:str="Block"):
self.name = name
def __enter__(self):
self.start_time:float = perf_counter()
return self # 可选:返回 self 以获取更多信息
def __exit__(self, exc_type, exc_val, exc_tb):
end_time:float = perf_counter()
info(f"[{self.name}] Elapsed time: {end_time - self.start_time:.2f} seconds")

View File

@ -0,0 +1,76 @@
import torch
def convert_conv3d_weight_memory_format(module:torch.nn.Module, memory_format:torch.memory_format):
r"""Convert ``memory_format`` of ``nn.Conv3d.weight`` to ``memory_format``
The conversion recursively applies to nested ``nn.Module``, including ``module``.
Note that it only changes the memory_format, but not the semantics of each dimensions.
This function is used to facilitate the computation to adopt NHWC kernels, which
provides considerable speed up for fp16 data on CUDA devices with compute capability >= 7.0
.. note::
Calling ``model.to(memory_format=torch.channels_last_3d)`` is more aggressive
than the utility function ``convert_conv3d_weight_memory_format``. Any
layer with 4d weight will be affected by ``model.to``, which does not
necessarily benefit from conversion to specified ``memory_format``.
One place we are confident in is that NDHWC(channels_last_3d) conversion for
convolution in cuDNN, as it is beneficial to run convolution in NDHWC,
even in cases where we have to apply permutation to input tensors.
Hence our strategy here is to convert only the weight of convolution to
channels_last_3d. This ensures that;
1. Fast convolution kernels will be used, the benefit of which could
outweigh overhead of permutation (if input is not in the same format).
2. No unnecessary permutations are applied on layers that do not benefit
from memory_format conversion.
The optimal case is that, layers between convolution layers are channels
last compatible. Input tensor would be permuted to channels last when it
encounters the first convolution layer and stay in that memory format.
Hence following convolutions will not need to permute its input tensor.
In case where a channels last incompatible layer is between convolution
layers, we need to permute the input tensor back to contiguous format
for that layer. The input tensor will go through the remaining layers in
contiguous format and be permuted to channels last when it encounters
another convolution layer. There's no point in propagating that
permutation to an earlier layer, as most layers are quite agnostic to
``memory_format``.
This claim might change when PyTorch supports fusion of permutation, as
there might have been a better spot to fuse the permutation other than
immediately before a convolution.
Args:
module (nn.Module): ``nn.Conv3d`` & ``nn.ConvTranspose3d`` or container
``nn.Module``
memory_format: user specified ``memory_format``,
e.g. ``torch.channels_last`` or ``torch.contiguous_format``
Returns:
The original module with updated ``nn.Conv3d``
Example:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> # xdoctest: +REQUIRES(env:CUBLAS_WORKSPACE_CONFIG)
>>> input = torch.randint(1, 10, (2, 8, 4, 4, 4), dtype=torch.float16, device="cuda")
>>> model = nn.Sequential(
>>> nn.Conv3d(8, 4, 3)).cuda().half()
>>> # This is identical to:
>>> # nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d)
>>> model = nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d)
>>> out = model(input)
"""
# TODO: expand this to `_ConvNd` when channels_last support is extended
# beyond only 4d tensors.
if isinstance(module, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
weight_data = (
module.weight.detach().clone().contiguous(memory_format=memory_format)
)
module.weight.data = weight_data.resize_(
weight_data.size(), memory_format=memory_format
)
for child in module.children():
convert_conv3d_weight_memory_format(child, memory_format)
return module

61
wan/utils/platform.py Normal file
View File

@ -0,0 +1,61 @@
from typing import Optional, List
import torch
try:
import torch_musa
except ModuleNotFoundError:
torch_musa = None
def _is_musa() -> bool:
if torch_musa is None:
return False
else:
return True
def get_device(local_rank: Optional[int] = None) -> torch.device:
if torch.cuda.is_available():
return (
torch.cuda.current_device()
if local_rank is None
else torch.device("cuda", local_rank)
)
elif _is_musa():
return (
torch.musa.current_device()
if local_rank is None
else torch.device("musa", local_rank)
)
else:
return torch.device("cpu")
def get_device_type() -> str:
if torch.cuda.is_available():
return "cuda"
elif _is_musa():
return "musa"
else:
return "cpu"
def get_torch_distributed_backend() -> str:
if torch.cuda.is_available():
return "nccl"
elif _is_musa():
return "mccl"
else:
raise NotImplementedError("No Accelerators(NV/MTT GPU) available")
def get_torch_profiler_activities() -> List[torch.profiler.ProfilerActivity]:
activities: List[torch.profiler.ProfilerActivity] = [
torch.profiler.ProfilerActivity.CPU
]
if torch.cuda.is_available():
activities.append(torch.profiler.ProfilerActivity.CUDA)
elif _is_musa():
activities.append(torch.profiler.ProfilerActivity.MUSA)
return activities

View File

@ -8,11 +8,13 @@ import sys
import time
import traceback
import types
from time import perf_counter
from contextlib import contextmanager
from functools import partial
import torch
import torch.cuda.amp as amp
from torch.cuda import empty_cache, synchronize
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F
@ -20,8 +22,17 @@ import torchvision.transforms.functional as TF
from PIL import Image
from tqdm import tqdm
from .modules.vace_model import VaceWanModel
from .text2video import (
try:
import torch_musa
import torch_musa.core.amp as amp
from torch_musa.core.memory import empty_cache
from torch_musa.core.device import synchronize
torch.backends.mudnn.allow_tf32 = True
except ModuleNotFoundError:
torch_musa = None
from wan.modules.vace_model import VaceWanModel
from wan.text2video import (
FlowDPMSolverMultistepScheduler,
FlowUniPCMultistepScheduler,
T5EncoderModel,
@ -31,7 +42,9 @@ from .text2video import (
retrieve_timesteps,
shard_model,
)
from .utils.vace_processor import VaceVideoProcessor
from wan.utils.vace_processor import VaceVideoProcessor
from wan.utils.platform import get_device, get_torch_distributed_backend
from wan.utils.memory_format import convert_conv3d_weight_memory_format
class WanVace(WanT2V):
@ -46,6 +59,7 @@ class WanVace(WanT2V):
dit_fsdp=False,
use_usp=False,
t5_cpu=False,
profiler=None,
):
r"""
Initializes the Wan text-to-video generation model components.
@ -68,7 +82,7 @@ class WanVace(WanT2V):
t5_cpu (`bool`, *optional*, defaults to False):
Whether to place T5 model on CPU. Only works without t5_fsdp.
"""
self.device = torch.device(f"cuda:{device_id}")
self.device = get_device(device_id)
self.config = config
self.rank = rank
self.t5_cpu = t5_cpu
@ -90,6 +104,7 @@ class WanVace(WanT2V):
self.vae = WanVAE(
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
device=self.device)
convert_conv3d_weight_memory_format(self.vae.model, memory_format=torch.channels_last_3d)
logging.info(f"Creating VaceWanModel from {checkpoint_dir}")
self.model = VaceWanModel.from_pretrained(checkpoint_dir)
@ -117,7 +132,8 @@ class WanVace(WanT2V):
self.sp_size = 1
if dist.is_initialized():
dist.barrier()
# dist.barrier()
pass
if dit_fsdp:
self.model = shard_fn(self.model)
else:
@ -136,6 +152,8 @@ class WanVace(WanT2V):
seq_len=75600,
keep_last=True)
self.profiler = profiler
def vace_encode_frames(self, frames, ref_images, masks=None, vae=None):
vae = self.vae if vae is None else vae
if ref_images is None:
@ -340,6 +358,11 @@ class WanVace(WanT2V):
- H: Frame height (from size)
- W: Frame width from size)
"""
start_time = 0.0
end_time = 0.0
if self.rank == 0:
start_time = perf_counter()
# preprocess
# F = frame_num
# target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
@ -390,6 +413,10 @@ class WanVace(WanT2V):
(self.patch_size[1] * self.patch_size[2]) *
target_shape[1] / self.sp_size) * self.sp_size
if self.rank == 0:
end_time = perf_counter()
logging.info(f"[preprocess] Elapsed time: {end_time - start_time:.2f} seconds")
@contextmanager
def noop_no_sync():
yield
@ -426,13 +453,19 @@ class WanVace(WanT2V):
arg_c = {'context': context, 'seq_len': seq_len}
arg_null = {'context': context_null, 'seq_len': seq_len}
if self.rank == 0:
start_time = perf_counter()
self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)):
if self.profiler and self.rank == 0:
self.profiler.step()
latent_model_input = latents
timestep = [t]
timestep = torch.stack(timestep)
self.model.to(self.device)
noise_pred_cond = self.model(
latent_model_input,
t=timestep,
@ -457,20 +490,28 @@ class WanVace(WanT2V):
generator=seed_g)[0]
latents = [temp_x0.squeeze(0)]
if self.rank == 0:
end_time = perf_counter()
logging.info(f"[sampling time steps] Elapsed time: {end_time - start_time:.2f} seconds")
x0 = latents
if offload_model:
self.model.cpu()
torch.cuda.empty_cache()
empty_cache()
if self.rank == 0:
start_time = perf_counter()
videos = self.decode_latent(x0, input_ref_images)
end_time = perf_counter()
logging.info(f"[VAE decoding] Elapsed time: {end_time - start_time:.2f} seconds")
del noise, latents
del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
synchronize()
if dist.is_initialized():
dist.barrier()
# dist.barrier()
pass
return videos[0] if self.rank == 0 else None
@ -568,7 +609,7 @@ class WanVaceMP(WanVace):
torch.cuda.set_device(gpu)
dist.init_process_group(
backend='nccl',
backend=get_torch_distributed_backend(),
init_method='env://',
rank=rank,
world_size=world_size)
@ -629,11 +670,11 @@ class WanVaceMP(WanVace):
else:
sp_size = 1
dist.barrier()
# dist.barrier()
model = shard_fn(model)
sample_neg_prompt = self.config.sample_neg_prompt
torch.cuda.empty_cache()
empty_cache()
event = initialized_events[gpu]
in_q = in_q_list[gpu]
event.set()
@ -748,7 +789,7 @@ class WanVaceMP(WanVace):
generator=seed_g)[0]
latents = [temp_x0.squeeze(0)]
torch.cuda.empty_cache()
empty_cache()
x0 = latents
if rank == 0:
videos = self.decode_latent(
@ -758,9 +799,10 @@ class WanVaceMP(WanVace):
del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
synchronize()
if dist.is_initialized():
dist.barrier()
# dist.barrier()
pass
if rank == 0:
out_q.put(videos[0].cpu())