From 28919e5663ade36cd91534c1b784b91509c7aec6 Mon Sep 17 00:00:00 2001 From: Houchen Li Date: Wed, 6 Aug 2025 20:05:57 +0800 Subject: [PATCH] [feature] adapt for Moore Threads GPU family --- .gitignore | 2 +- generate.py | 282 ++++++++++++-------- wan/distributed/fsdp.py | 8 +- wan/distributed/xdit_context_parallel.py | 45 ++-- wan/first_last_frame2video.py | 45 ++-- wan/image2video.py | 45 ++-- wan/modules/attention.py | 18 +- wan/modules/clip.py | 30 ++- wan/modules/model.py | 153 +++++++---- wan/modules/rope.py | 317 +++++++++++++++++++++++ wan/modules/t5.py | 18 +- wan/modules/vace_model.py | 6 + wan/modules/vae.py | 57 ++-- wan/text2video.py | 39 ++- wan/utils/__init__.py | 6 +- wan/utils/chrono_inspector.py | 15 ++ wan/utils/memory_format.py | 76 ++++++ wan/utils/platform.py | 50 ++++ wan/vace.py | 44 +++- 19 files changed, 971 insertions(+), 285 deletions(-) create mode 100644 wan/modules/rope.py create mode 100644 wan/utils/chrono_inspector.py create mode 100644 wan/utils/memory_format.py create mode 100644 wan/utils/platform.py diff --git a/.gitignore b/.gitignore index deeef7a..daa41f1 100644 --- a/.gitignore +++ b/.gitignore @@ -34,4 +34,4 @@ Wan2.1-T2V-14B/ Wan2.1-T2V-1.3B/ Wan2.1-I2V-14B-480P/ Wan2.1-I2V-14B-720P/ -poetry.lock \ No newline at end of file +poetry.lock diff --git a/generate.py b/generate.py index c841c19..9764ab0 100644 --- a/generate.py +++ b/generate.py @@ -12,12 +12,21 @@ import random import torch import torch.distributed as dist +from torch.cuda import set_device from PIL import Image +try: + import torch_musa + from torch_musa.core.device import set_device +except ModuleNotFoundError: + torch_musa = None + import wan from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander from wan.utils.utils import cache_image, cache_video, str2bool +from wan.utils.platform import get_torch_distributed_backend +from wan.utils.chrono_inspector import ChronoInspector EXAMPLE_PROMPT = { @@ -275,9 +284,9 @@ def generate(args): logging.info( f"offload_model is not specified, set to {args.offload_model}.") if world_size > 1: - torch.cuda.set_device(local_rank) + set_device(local_rank) dist.init_process_group( - backend="nccl", + backend=get_torch_distributed_backend(), init_method="env://", rank=rank, world_size=world_size) @@ -357,29 +366,43 @@ def generate(args): logging.info(f"Extended prompt: {args.prompt}") logging.info("Creating WanT2V pipeline.") - wan_t2v = wan.WanT2V( - config=cfg, - checkpoint_dir=args.ckpt_dir, - device_id=device, - rank=rank, - t5_fsdp=args.t5_fsdp, - dit_fsdp=args.dit_fsdp, - use_usp=(args.ulysses_size > 1 or args.ring_size > 1), - t5_cpu=args.t5_cpu, - ) + with ChronoInspector("Creating WanT2V pipeline"): + wan_t2v = wan.WanT2V( + config=cfg, + checkpoint_dir=args.ckpt_dir, + device_id=device, + rank=rank, + t5_fsdp=args.t5_fsdp, + dit_fsdp=args.dit_fsdp, + use_usp=(args.ulysses_size > 1 or args.ring_size > 1), + t5_cpu=args.t5_cpu, + ) - logging.info( - f"Generating {'image' if 't2i' in args.task else 'video'} ...") - video = wan_t2v.generate( - args.prompt, - size=SIZE_CONFIGS[args.size], - frame_num=args.frame_num, - shift=args.sample_shift, - sample_solver=args.sample_solver, - sampling_steps=args.sample_steps, - guide_scale=args.sample_guide_scale, - seed=args.base_seed, - offload_model=args.offload_model) + logging.info("Warming up WanT2V pipeline ...") + with ChronoInspector("Warming up WanT2V pipeline"): + _ = wan_t2v.generate( + args.prompt, + size=SIZE_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=3, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model) + + logging.info(f"Generating {'image' if 't2i' in args.task else 'video'} ...") + with ChronoInspector(f"Generating {'image' if 't2i' in args.task else 'video'}"): + video = wan_t2v.generate( + args.prompt, + size=SIZE_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=args.sample_steps, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model) elif "i2v" in args.task: if args.prompt is None: @@ -414,29 +437,45 @@ def generate(args): logging.info(f"Extended prompt: {args.prompt}") logging.info("Creating WanI2V pipeline.") - wan_i2v = wan.WanI2V( - config=cfg, - checkpoint_dir=args.ckpt_dir, - device_id=device, - rank=rank, - t5_fsdp=args.t5_fsdp, - dit_fsdp=args.dit_fsdp, - use_usp=(args.ulysses_size > 1 or args.ring_size > 1), - t5_cpu=args.t5_cpu, - ) + with ChronoInspector("Creating WanI2V pipeline"): + wan_i2v = wan.WanI2V( + config=cfg, + checkpoint_dir=args.ckpt_dir, + device_id=device, + rank=rank, + t5_fsdp=args.t5_fsdp, + dit_fsdp=args.dit_fsdp, + use_usp=(args.ulysses_size > 1 or args.ring_size > 1), + t5_cpu=args.t5_cpu, + ) + + logging.info("Warming up WanI2V pipeline ...") + with ChronoInspector("Warming up WanI2V pipeline"): + _ = wan_i2v.generate( + args.prompt, + img, + max_area=MAX_AREA_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=3, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model) logging.info("Generating video ...") - video = wan_i2v.generate( - args.prompt, - img, - max_area=MAX_AREA_CONFIGS[args.size], - frame_num=args.frame_num, - shift=args.sample_shift, - sample_solver=args.sample_solver, - sampling_steps=args.sample_steps, - guide_scale=args.sample_guide_scale, - seed=args.base_seed, - offload_model=args.offload_model) + with ChronoInspector("Generating video"): + video = wan_i2v.generate( + args.prompt, + img, + max_area=MAX_AREA_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=args.sample_steps, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model) elif "flf2v" in args.task: if args.prompt is None: args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] @@ -472,30 +511,47 @@ def generate(args): logging.info(f"Extended prompt: {args.prompt}") logging.info("Creating WanFLF2V pipeline.") - wan_flf2v = wan.WanFLF2V( - config=cfg, - checkpoint_dir=args.ckpt_dir, - device_id=device, - rank=rank, - t5_fsdp=args.t5_fsdp, - dit_fsdp=args.dit_fsdp, - use_usp=(args.ulysses_size > 1 or args.ring_size > 1), - t5_cpu=args.t5_cpu, - ) + with ChronoInspector("Creating WanFLF2V pipeline"): + wan_flf2v = wan.WanFLF2V( + config=cfg, + checkpoint_dir=args.ckpt_dir, + device_id=device, + rank=rank, + t5_fsdp=args.t5_fsdp, + dit_fsdp=args.dit_fsdp, + use_usp=(args.ulysses_size > 1 or args.ring_size > 1), + t5_cpu=args.t5_cpu, + ) + + logging.info("Warming up WanFLF2V pipeline ...") + with ChronoInspector("Warming up WanFLF2V pipeline"): + _ = wan_flf2v.generate( + args.prompt, + first_frame, + last_frame, + max_area=MAX_AREA_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=3, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model) logging.info("Generating video ...") - video = wan_flf2v.generate( - args.prompt, - first_frame, - last_frame, - max_area=MAX_AREA_CONFIGS[args.size], - frame_num=args.frame_num, - shift=args.sample_shift, - sample_solver=args.sample_solver, - sampling_steps=args.sample_steps, - guide_scale=args.sample_guide_scale, - seed=args.base_seed, - offload_model=args.offload_model) + with ChronoInspector("Generating video"): + video = wan_flf2v.generate( + args.prompt, + first_frame, + last_frame, + max_area=MAX_AREA_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=args.sample_steps, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model) elif "vace" in args.task: if args.prompt is None: args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] @@ -520,16 +576,17 @@ def generate(args): logging.info(f"Extended prompt: {args.prompt}") logging.info("Creating VACE pipeline.") - wan_vace = wan.WanVace( - config=cfg, - checkpoint_dir=args.ckpt_dir, - device_id=device, - rank=rank, - t5_fsdp=args.t5_fsdp, - dit_fsdp=args.dit_fsdp, - use_usp=(args.ulysses_size > 1 or args.ring_size > 1), - t5_cpu=args.t5_cpu, - ) + with ChronoInspector("Creating VACE pipeline"): + wan_vace = wan.WanVace( + config=cfg, + checkpoint_dir=args.ckpt_dir, + device_id=device, + rank=rank, + t5_fsdp=args.t5_fsdp, + dit_fsdp=args.dit_fsdp, + use_usp=(args.ulysses_size > 1 or args.ring_size > 1), + t5_cpu=args.t5_cpu, + ) src_video, src_mask, src_ref_images = wan_vace.prepare_source( [args.src_video], [args.src_mask], [ @@ -537,20 +594,37 @@ def generate(args): args.src_ref_images.split(',') ], args.frame_num, SIZE_CONFIGS[args.size], device) + logging.info("Warming up VACE pipeline ...") + with ChronoInspector("Warming up VACE pipeline"): + video = wan_vace.generate( + args.prompt, + src_video, + src_mask, + src_ref_images, + size=SIZE_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=3, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model) + logging.info(f"Generating video...") - video = wan_vace.generate( - args.prompt, - src_video, - src_mask, - src_ref_images, - size=SIZE_CONFIGS[args.size], - frame_num=args.frame_num, - shift=args.sample_shift, - sample_solver=args.sample_solver, - sampling_steps=args.sample_steps, - guide_scale=args.sample_guide_scale, - seed=args.base_seed, - offload_model=args.offload_model) + with ChronoInspector("Generating video"): + video = wan_vace.generate( + args.prompt, + src_video, + src_mask, + src_ref_images, + size=SIZE_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=args.sample_steps, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model) else: raise ValueError(f"Unkown task type: {args.task}") @@ -564,21 +638,23 @@ def generate(args): if "t2i" in args.task: logging.info(f"Saving generated image to {args.save_file}") - cache_image( - tensor=video.squeeze(1)[None], - save_file=args.save_file, - nrow=1, - normalize=True, - value_range=(-1, 1)) + with ChronoInspector("Saving generated image"): + cache_image( + tensor=video.squeeze(1)[None], + save_file=args.save_file, + nrow=1, + normalize=True, + value_range=(-1, 1)) else: logging.info(f"Saving generated video to {args.save_file}") - cache_video( - tensor=video[None], - save_file=args.save_file, - fps=cfg.sample_fps, - nrow=1, - normalize=True, - value_range=(-1, 1)) + with ChronoInspector("Saving generated video"): + cache_video( + tensor=video[None], + save_file=args.save_file, + fps=cfg.sample_fps, + nrow=1, + normalize=True, + value_range=(-1, 1)) logging.info("Finished.") diff --git a/wan/distributed/fsdp.py b/wan/distributed/fsdp.py index 6bb496d..ca12048 100644 --- a/wan/distributed/fsdp.py +++ b/wan/distributed/fsdp.py @@ -3,11 +3,17 @@ import gc from functools import partial import torch +from torch.cuda import empty_cache from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import MixedPrecision, ShardingStrategy from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy from torch.distributed.utils import _free_storage +try: + import torch_musa + from torch_musa.core.memory import empty_cache +except ModuleNotFoundError: + torch_musa = None def shard_model( model, @@ -40,4 +46,4 @@ def free_model(model): _free_storage(m._handle.flat_param.data) del model gc.collect() - torch.cuda.empty_cache() + empty_cache() diff --git a/wan/distributed/xdit_context_parallel.py b/wan/distributed/xdit_context_parallel.py index 4718577..765c95e 100644 --- a/wan/distributed/xdit_context_parallel.py +++ b/wan/distributed/xdit_context_parallel.py @@ -6,7 +6,18 @@ from xfuser.core.distributed import ( get_sequence_parallel_world_size, get_sp_group, ) -from xfuser.core.long_ctx_attention import xFuserLongContextAttention +from xfuser.core.long_ctx_attention import xFuserLongContextAttention, AttnType +attn_type:AttnType = AttnType.FA + +from wan.modules.rope import rope_apply_pytorch, rope_apply_triton + +try: + import torch_musa + import torch_musa.core.amp as amp + attn_type = AttnType.TORCH + torch.backends.mudnn.allow_tf32 = True +except ImportError: + torch_musa = None from ..modules.model import sinusoidal_embedding_1d @@ -25,7 +36,7 @@ def pad_freqs(original_tensor, target_len): @amp.autocast(enabled=False) -def rope_apply(x, grid_sizes, freqs): +def rope_apply(x, grid_sizes, freqs, sp_size, sp_rank): """ x: [B, L, N, C]. grid_sizes: [B, 3]. @@ -51,8 +62,6 @@ def rope_apply(x, grid_sizes, freqs): dim=-1).reshape(seq_len, 1, -1) # apply rotary embedding - sp_size = get_sequence_parallel_world_size() - sp_rank = get_sequence_parallel_rank() freqs_i = pad_freqs(freqs_i, s * sp_size) s_per_rank = s freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) * @@ -109,9 +118,13 @@ def usp_dit_forward( if self.model_type == 'i2v': assert clip_fea is not None and y is not None # params + dtype = self.patch_embedding.weight.dtype device = self.patch_embedding.weight.device - if self.freqs.device != device: - self.freqs = self.freqs.to(device) + if self.freqs[0].dtype != dtype or self.freqs[0].device != device: + self.freqs = ( + self.freqs[0].to(dtype=dtype, device=device), + self.freqs[-1].to(dtype=dtype, device=device) + ) if self.model_type != 'vace' and y is not None: x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] @@ -129,11 +142,9 @@ def usp_dit_forward( ]) # time embeddings - with amp.autocast(dtype=torch.float32): - e = self.time_embedding( - sinusoidal_embedding_1d(self.freq_dim, t).float()) - e0 = self.time_projection(e).unflatten(1, (6, self.dim)) - assert e.dtype == torch.float32 and e0.dtype == torch.float32 + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t)) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # context context_lens = None @@ -177,7 +188,7 @@ def usp_dit_forward( # unpatchify x = self.unpatchify(x, grid_sizes) - return [u.float() for u in x] + return x def usp_attn_forward(self, @@ -200,8 +211,12 @@ def usp_attn_forward(self, return q, k, v q, k, v = qkv_fn(x) - q = rope_apply(q, grid_sizes, freqs) - k = rope_apply(k, grid_sizes, freqs) + if torch_musa is None: + q = rope_apply(q, grid_sizes, freqs, get_sequence_parallel_world_size(), get_sequence_parallel_rank()) + k = rope_apply(k, grid_sizes, freqs, get_sequence_parallel_world_size(), get_sequence_parallel_rank()) + else: + q = rope_apply_pytorch(q, grid_sizes, freqs, get_sequence_parallel_world_size(), get_sequence_parallel_rank()) + k = rope_apply_pytorch(k, grid_sizes, freqs, get_sequence_parallel_world_size(), get_sequence_parallel_rank()) # TODO: We should use unpaded q,k,v for attention. # k_lens = seq_lens // get_sequence_parallel_world_size() @@ -210,7 +225,7 @@ def usp_attn_forward(self, # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0) # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0) - x = xFuserLongContextAttention()( + x = xFuserLongContextAttention(attn_type=attn_type)( None, query=half(q), key=half(k), diff --git a/wan/first_last_frame2video.py b/wan/first_last_frame2video.py index 232950f..8e9b64d 100644 --- a/wan/first_last_frame2video.py +++ b/wan/first_last_frame2video.py @@ -12,21 +12,33 @@ from functools import partial import numpy as np import torch import torch.cuda.amp as amp +from torch.cuda import empty_cache, synchronize import torch.distributed as dist import torchvision.transforms.functional as TF from tqdm import tqdm -from .distributed.fsdp import shard_model -from .modules.clip import CLIPModel -from .modules.model import WanModel -from .modules.t5 import T5EncoderModel -from .modules.vae import WanVAE -from .utils.fm_solvers import ( +try: + import torch_musa + import torch_musa.core.amp as amp + from torch_musa.core.memory import empty_cache + from torch_musa.core.device import synchronize + torch.backends.mudnn.allow_tf32 = True +except ModuleNotFoundError: + torch_musa = None + +from wan.distributed.fsdp import shard_model +from wan.modules.clip import CLIPModel +from wan.modules.model import WanModel +from wan.modules.t5 import T5EncoderModel +from wan.modules.vae import WanVAE +from wan.utils.fm_solvers import ( FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps, ) -from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from wan.utils.platform import get_device +from wan.utils.memory_format import convert_conv3d_weight_memory_format class WanFLF2V: @@ -66,7 +78,7 @@ class WanFLF2V: init_on_cpu (`bool`, *optional*, defaults to True): Enable initializing Transformer Model on CPU. Only works without FSDP or USP. """ - self.device = torch.device(f"cuda:{device_id}") + self.device = get_device(device_id) self.config = config self.rank = rank self.use_usp = use_usp @@ -90,6 +102,7 @@ class WanFLF2V: self.vae = WanVAE( vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), device=self.device) + convert_conv3d_weight_memory_format(self.vae.model, memory_format=torch.channels_last_3d) self.clip = CLIPModel( dtype=config.clip_dtype, @@ -121,7 +134,8 @@ class WanFLF2V: self.sp_size = 1 if dist.is_initialized(): - dist.barrier() + # dist.barrier() + pass if dit_fsdp: self.model = shard_fn(self.model) else: @@ -323,7 +337,7 @@ class WanFLF2V: } if offload_model: - torch.cuda.empty_cache() + empty_cache() self.model.to(self.device) for _, t in enumerate(tqdm(timesteps)): @@ -336,12 +350,12 @@ class WanFLF2V: latent_model_input, t=timestep, **arg_c)[0].to( torch.device('cpu') if offload_model else self.device) if offload_model: - torch.cuda.empty_cache() + empty_cache() noise_pred_uncond = self.model( latent_model_input, t=timestep, **arg_null)[0].to( torch.device('cpu') if offload_model else self.device) if offload_model: - torch.cuda.empty_cache() + empty_cache() noise_pred = noise_pred_uncond + guide_scale * ( noise_pred_cond - noise_pred_uncond) @@ -361,7 +375,7 @@ class WanFLF2V: if offload_model: self.model.cpu() - torch.cuda.empty_cache() + empty_cache() if self.rank == 0: videos = self.vae.decode(x0) @@ -370,8 +384,9 @@ class WanFLF2V: del sample_scheduler if offload_model: gc.collect() - torch.cuda.synchronize() + synchronize() if dist.is_initialized(): - dist.barrier() + # dist.barrier() + pass return videos[0] if self.rank == 0 else None diff --git a/wan/image2video.py b/wan/image2video.py index 6882c53..c96c0e3 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -12,21 +12,33 @@ from functools import partial import numpy as np import torch import torch.cuda.amp as amp +from torch.cuda import empty_cache, synchronize import torch.distributed as dist import torchvision.transforms.functional as TF from tqdm import tqdm -from .distributed.fsdp import shard_model -from .modules.clip import CLIPModel -from .modules.model import WanModel -from .modules.t5 import T5EncoderModel -from .modules.vae import WanVAE -from .utils.fm_solvers import ( +try: + import torch_musa + import torch_musa.core.amp as amp + from torch_musa.core.memory import empty_cache + from torch_musa.core.device import synchronize + torch.backends.mudnn.allow_tf32 = True +except ModuleNotFoundError: + torch_musa = None + +from wan.distributed.fsdp import shard_model +from wan.modules.clip import CLIPModel +from wan.modules.model import WanModel +from wan.modules.t5 import T5EncoderModel +from wan.modules.vae import WanVAE +from wan.utils.fm_solvers import ( FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps, ) -from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from wan.utils.platform import get_device +from wan.utils.memory_format import convert_conv3d_weight_memory_format class WanI2V: @@ -66,7 +78,7 @@ class WanI2V: init_on_cpu (`bool`, *optional*, defaults to True): Enable initializing Transformer Model on CPU. Only works without FSDP or USP. """ - self.device = torch.device(f"cuda:{device_id}") + self.device = get_device(device_id) self.config = config self.rank = rank self.use_usp = use_usp @@ -90,6 +102,7 @@ class WanI2V: self.vae = WanVAE( vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), device=self.device) + convert_conv3d_weight_memory_format(self.vae.model, memory_format=torch.channels_last_3d) self.clip = CLIPModel( dtype=config.clip_dtype, @@ -121,7 +134,8 @@ class WanI2V: self.sp_size = 1 if dist.is_initialized(): - dist.barrier() + # dist.barrier() + pass if dit_fsdp: self.model = shard_fn(self.model) else: @@ -296,7 +310,7 @@ class WanI2V: } if offload_model: - torch.cuda.empty_cache() + empty_cache() self.model.to(self.device) for _, t in enumerate(tqdm(timesteps)): @@ -309,12 +323,12 @@ class WanI2V: latent_model_input, t=timestep, **arg_c)[0].to( torch.device('cpu') if offload_model else self.device) if offload_model: - torch.cuda.empty_cache() + empty_cache() noise_pred_uncond = self.model( latent_model_input, t=timestep, **arg_null)[0].to( torch.device('cpu') if offload_model else self.device) if offload_model: - torch.cuda.empty_cache() + empty_cache() noise_pred = noise_pred_uncond + guide_scale * ( noise_pred_cond - noise_pred_uncond) @@ -334,7 +348,7 @@ class WanI2V: if offload_model: self.model.cpu() - torch.cuda.empty_cache() + empty_cache() if self.rank == 0: videos = self.vae.decode(x0) @@ -343,8 +357,9 @@ class WanI2V: del sample_scheduler if offload_model: gc.collect() - torch.cuda.synchronize() + synchronize() if dist.is_initialized(): - dist.barrier() + # dist.barrier() + pass return videos[0] if self.rank == 0 else None diff --git a/wan/modules/attention.py b/wan/modules/attention.py index 4dbbe03..796c633 100644 --- a/wan/modules/attention.py +++ b/wan/modules/attention.py @@ -1,4 +1,6 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import warnings + import torch try: @@ -13,7 +15,14 @@ try: except ModuleNotFoundError: FLASH_ATTN_2_AVAILABLE = False -import warnings +try: + import torch_musa + FLASH_ATTN_3_AVAILABLE = False + FLASH_ATTN_2_AVAILABLE = False + torch.backends.mudnn.allow_tf32 = True +except ModuleNotFoundError: + torch_musa = None + __all__ = [ 'flash_attention', @@ -51,7 +60,7 @@ def flash_attention( """ half_dtypes = (torch.float16, torch.bfloat16) assert dtype in half_dtypes - assert q.device.type == 'cuda' and q.size(-1) <= 256 + assert q.device.type in ("cuda", "musa") and q.size(-1) <= 256 # params b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype @@ -172,8 +181,9 @@ def attention( k = k.transpose(1, 2).to(dtype) v = v.transpose(1, 2).to(dtype) - out = torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p) + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False): + out = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=causal, scale=softmax_scale) out = out.transpose(1, 2).contiguous() return out diff --git a/wan/modules/clip.py b/wan/modules/clip.py index 42dda04..496ed58 100644 --- a/wan/modules/clip.py +++ b/wan/modules/clip.py @@ -6,12 +6,20 @@ import math import torch import torch.nn as nn import torch.nn.functional as F +import torch.cuda.amp as amp import torchvision.transforms as T from .attention import flash_attention from .tokenizers import HuggingfaceTokenizer from .xlm_roberta import XLMRoberta +try: + import torch_musa + import torch_musa.core.amp as amp + from .attention import attention as flash_attention +except ModuleNotFoundError: + torch_musa = None + __all__ = [ 'XLMRobertaCLIP', 'clip_xlm_roberta_vit_h_14', @@ -29,7 +37,7 @@ def pos_interpolate(pos, seq_len): return torch.cat([ pos[:, :n], F.interpolate( - pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute( + pos[:, n:].reshape(1, src_grid, src_grid, -1).permute( 0, 3, 1, 2), size=(tar_grid, tar_grid), mode='bicubic', @@ -44,12 +52,6 @@ class QuickGELU(nn.Module): return x * torch.sigmoid(1.702 * x) -class LayerNorm(nn.LayerNorm): - - def forward(self, x): - return super().forward(x.float()).type_as(x) - - class SelfAttention(nn.Module): def __init__(self, @@ -82,7 +84,7 @@ class SelfAttention(nn.Module): # compute attention p = self.attn_dropout if self.training else 0.0 - x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2) + x = flash_attention(q, k, v, dropout_p=p, causal=self.causal) x = x.reshape(b, s, c) # output @@ -131,10 +133,10 @@ class AttentionBlock(nn.Module): self.norm_eps = norm_eps # layers - self.norm1 = LayerNorm(dim, eps=norm_eps) + self.norm1 = nn.LayerNorm(dim, eps=norm_eps) self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout) - self.norm2 = LayerNorm(dim, eps=norm_eps) + self.norm2 = nn.LayerNorm(dim, eps=norm_eps) if activation == 'swi_glu': self.mlp = SwiGLU(dim, int(dim * mlp_ratio)) else: @@ -177,7 +179,7 @@ class AttentionPool(nn.Module): self.to_q = nn.Linear(dim, dim) self.to_kv = nn.Linear(dim, dim * 2) self.proj = nn.Linear(dim, dim) - self.norm = LayerNorm(dim, eps=norm_eps) + self.norm = nn.LayerNorm(dim, eps=norm_eps) self.mlp = nn.Sequential( nn.Linear(dim, int(dim * mlp_ratio)), QuickGELU() if activation == 'quick_gelu' else nn.GELU(), @@ -259,13 +261,13 @@ class VisionTransformer(nn.Module): self.dropout = nn.Dropout(embedding_dropout) # transformer - self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None + self.pre_norm = nn.LayerNorm(dim, eps=norm_eps) if pre_norm else None self.transformer = nn.Sequential(*[ AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False, activation, attn_dropout, proj_dropout, norm_eps) for _ in range(num_layers) ]) - self.post_norm = LayerNorm(dim, eps=norm_eps) + self.post_norm = nn.LayerNorm(dim, eps=norm_eps) # head if pool_type == 'token': @@ -537,6 +539,6 @@ class CLIPModel: videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5)) # forward - with torch.cuda.amp.autocast(dtype=self.dtype): + with amp.autocast(dtype=self.dtype): out = self.model.visual(videos, use_31_block=True) return out diff --git a/wan/modules/model.py b/wan/modules/model.py index a5425da..467b557 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -7,7 +7,16 @@ import torch.nn as nn from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin -from .attention import flash_attention +from wan.modules.attention import flash_attention +from wan.modules.rope import rope_apply_pytorch + +try: + import torch_musa + import torch_musa.core.amp as amp + from wan.modules.attention import attention as flash_attention + torch.backends.mudnn.allow_tf32 = True +except ModuleNotFoundError: + pass __all__ = ['WanModel'] @@ -19,7 +28,7 @@ def sinusoidal_embedding_1d(dim, position): # preprocess assert dim % 2 == 0 half = dim // 2 - position = position.type(torch.float64) + position = position.type(torch.bfloat16) # calculation sinusoid = torch.outer( @@ -29,14 +38,33 @@ def sinusoidal_embedding_1d(dim, position): @amp.autocast(enabled=False) -def rope_params(max_seq_len, dim, theta=10000): +def rope_params_real( + max_seq_len, dim, theta=10000, dtype=torch.float32, device=torch.device("cpu") +): assert dim % 2 == 0 - freqs = torch.outer( - torch.arange(max_seq_len), - 1.0 / torch.pow(theta, - torch.arange(0, dim, 2).to(torch.float64).div(dim))) - freqs = torch.polar(torch.ones_like(freqs), freqs) - return freqs + freqs_real = torch.outer( + torch.arange(max_seq_len, dtype=dtype, device=device), + 1.0 + / torch.pow( + theta, torch.arange(0, dim, 2, dtype=dtype, device=device).div(dim) + ), + ) + return torch.cos(freqs_real) + + +@amp.autocast(enabled=False) +def rope_params_imag( + max_seq_len, dim, theta=10000, dtype=torch.float32, device=torch.device("cpu") +): + assert dim % 2 == 0 + freqs_imag = torch.outer( + torch.arange(max_seq_len, dtype=dtype, device=device), + 1.0 + / torch.pow( + theta, torch.arange(0, dim, 2, dtype=dtype, device=device).div(dim) + ), + ) + return torch.sin(freqs_imag) @amp.autocast(enabled=False) @@ -89,19 +117,6 @@ class WanRMSNorm(nn.Module): return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) -class WanLayerNorm(nn.LayerNorm): - - def __init__(self, dim, eps=1e-6, elementwise_affine=False): - super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) - - def forward(self, x): - r""" - Args: - x(Tensor): Shape [B, L, C] - """ - return super().forward(x.float()).type_as(x) - - class WanSelfAttention(nn.Module): def __init__(self, @@ -145,10 +160,16 @@ class WanSelfAttention(nn.Module): return q, k, v q, k, v = qkv_fn(x) + if torch_musa is None: + q = rope_apply(q, grid_sizes, freqs) + k = rope_apply(k, grid_sizes, freqs) + else: + q = rope_apply_pytorch(q, grid_sizes, freqs) + k = rope_apply_pytorch(k, grid_sizes, freqs) x = flash_attention( - q=rope_apply(q, grid_sizes, freqs), - k=rope_apply(k, grid_sizes, freqs), + q=q, + k=k, v=v, k_lens=seq_lens, window_size=self.window_size) @@ -256,10 +277,10 @@ class WanAttentionBlock(nn.Module): self.eps = eps # layers - self.norm1 = WanLayerNorm(dim, eps) + self.norm1 = nn.LayerNorm(dim, eps, elementwise_affine=False) self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps) - self.norm3 = WanLayerNorm( + self.norm3 = nn.LayerNorm( dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim, @@ -267,7 +288,7 @@ class WanAttentionBlock(nn.Module): (-1, -1), qk_norm, eps) - self.norm2 = WanLayerNorm(dim, eps) + self.norm2 = nn.LayerNorm(dim, eps, elementwise_affine=False) self.ffn = nn.Sequential( nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'), nn.Linear(ffn_dim, dim)) @@ -293,24 +314,19 @@ class WanAttentionBlock(nn.Module): grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] """ - assert e.dtype == torch.float32 - with amp.autocast(dtype=torch.float32): - e = (self.modulation + e).chunk(6, dim=1) - assert e[0].dtype == torch.float32 + e = (self.modulation + e).chunk(6, dim=1) # self-attention y = self.self_attn( - self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes, + self.norm1(x) * (1 + e[1]) + e[0], seq_lens, grid_sizes, freqs) - with amp.autocast(dtype=torch.float32): - x = x + y * e[2] + x = x + y * e[2] # cross-attention & ffn function def cross_attn_ffn(x, context, context_lens, e): x = x + self.cross_attn(self.norm3(x), context, context_lens) - y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3]) - with amp.autocast(dtype=torch.float32): - x = x + y * e[5] + y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3]) + x = x + y * e[5] return x x = cross_attn_ffn(x, context, context_lens, e) @@ -328,7 +344,7 @@ class Head(nn.Module): # layers out_dim = math.prod(patch_size) * out_dim - self.norm = WanLayerNorm(dim, eps) + self.norm = nn.LayerNorm(dim, eps, elementwise_affine=False) self.head = nn.Linear(dim, out_dim) # modulation @@ -340,10 +356,8 @@ class Head(nn.Module): x(Tensor): Shape [B, L1, C] e(Tensor): Shape [B, C] """ - assert e.dtype == torch.float32 - with amp.autocast(dtype=torch.float32): - e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) - x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) + e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) + x = self.head(self.norm(x) * (1 + e[1]) + e[0]) return x @@ -477,12 +491,33 @@ class WanModel(ModelMixin, ConfigMixin): # buffers (don't use register_buffer otherwise dtype will be changed in to()) assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 d = dim // num_heads - self.freqs = torch.cat([ - rope_params(1024, d - 4 * (d // 6)), - rope_params(1024, 2 * (d // 6)), - rope_params(1024, 2 * (d // 6)) - ], - dim=1) + if torch_musa is None: + self.freqs = torch.cat( + [ + rope_params(1024, d - 4 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + ], + dim=1, + ) + else: + freqs_real = torch.cat( + [ + rope_params_real(1024, d - 4 * (d // 6)), + rope_params_real(1024, 2 * (d // 6)), + rope_params_real(1024, 2 * (d // 6)), + ], + dim=1, + ) + freqs_imag = torch.cat( + [ + rope_params_imag(1024, d - 4 * (d // 6)), + rope_params_imag(1024, 2 * (d // 6)), + rope_params_imag(1024, 2 * (d // 6)), + ], + dim=1, + ) + self.freqs = (freqs_real, freqs_imag) if model_type == 'i2v' or model_type == 'flf2v': self.img_emb = MLPProj(1280, dim, flf_pos_emb=model_type == 'flf2v') @@ -523,9 +558,17 @@ class WanModel(ModelMixin, ConfigMixin): if self.model_type == 'i2v' or self.model_type == 'flf2v': assert clip_fea is not None and y is not None # params + dtype = self.patch_embedding.weight.dtype device = self.patch_embedding.weight.device - if self.freqs.device != device: - self.freqs = self.freqs.to(device) + if torch_musa is None: + if self.freqs.dtype != dtype or self.freqs.device != device: + self.freqs = self.freqs.to(dtype=dtype, device=device) + else: + if self.freqs[0].dtype != dtype or self.freqs[0].device != device: + self.freqs = ( + self.freqs[0].to(dtype=dtype, device=device), + self.freqs[-1].to(dtype=dtype, device=device), + ) if y is not None: x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] @@ -543,11 +586,9 @@ class WanModel(ModelMixin, ConfigMixin): ]) # time embeddings - with amp.autocast(dtype=torch.float32): - e = self.time_embedding( - sinusoidal_embedding_1d(self.freq_dim, t).float()) - e0 = self.time_projection(e).unflatten(1, (6, self.dim)) - assert e.dtype == torch.float32 and e0.dtype == torch.float32 + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t)) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # context context_lens = None @@ -579,7 +620,7 @@ class WanModel(ModelMixin, ConfigMixin): # unpatchify x = self.unpatchify(x, grid_sizes) - return [u.float() for u in x] + return x def unpatchify(self, x, grid_sizes): r""" diff --git a/wan/modules/rope.py b/wan/modules/rope.py new file mode 100644 index 0000000..ca204ab --- /dev/null +++ b/wan/modules/rope.py @@ -0,0 +1,317 @@ +from typing import Optional, Tuple + +import triton +import triton.language as tl +import torch + + +def pad_tensor( + original_tensor: torch.tensor, target_len: int, pad_value: float = 0.0 +) -> torch.tensor: + seq_len, s1, s2 = original_tensor.shape + pad_size = target_len - seq_len + padding_tensor = torch.full( + (pad_size, s1, s2), + pad_value, + dtype=original_tensor.dtype, + device=original_tensor.device, + ) + padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) + return padded_tensor + + +def rope_apply_pytorch( + x: torch.tensor, + grid_sizes: torch.tensor, + freqs: Tuple[torch.tensor], + sp_size: Optional[int] = None, + sp_rank: Optional[int] = None, +) -> torch.tensor: + """ + x: [B, L, N, C]. + grid_sizes: [B, 3]. + freqs: [M, C // 2]. + """ + s, n, c = x.size(1), x.size(2), x.size(3) // 2 + c0 = c - 2 * (c // 3) + c1 = c // 3 + c2 = c // 3 + + # split freqs + freqs_real = freqs[0].split([c0, c1, c2], dim=1) + freqs_imag = freqs[-1].split([c0, c1, c2], dim=1) + + # loop over samples + output = [] + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + seq_len = f * h * w + + # precompute multipliers + x_i = x[i, :seq_len].reshape(s, n, -1, 2) + x_real = x_i[..., 0] + x_imag = x_i[..., 1] + freqs_real = torch.cat( + [ + freqs_real[0][:f].view(f, 1, 1, c0).expand(f, h, w, c0), + freqs_real[1][:h].view(1, h, 1, c1).expand(f, h, w, c1), + freqs_real[2][:w].view(1, 1, w, c2).expand(f, h, w, c2), + ], + dim=-1, + ).reshape(seq_len, 1, -1) + freqs_imag = torch.cat( + [ + freqs_imag[0][:f].view(f, 1, 1, c0).expand(f, h, w, c0), + freqs_imag[1][:h].view(1, h, 1, c1).expand(f, h, w, c1), + freqs_imag[2][:w].view(1, 1, w, c2).expand(f, h, w, c2), + ], + dim=-1, + ).reshape(seq_len, 1, -1) + + if sp_rank is None: + freqs_real_rank = freqs_real + freqs_imag_rank = freqs_imag + else: + freqs_real = pad_tensor(freqs_real, s * sp_size, 1.0) + freqs_imag = pad_tensor(freqs_imag, s * sp_size, 0.0) + freqs_real_rank = freqs_real[(sp_rank * s) : ((sp_rank + 1) * s), :, :] + freqs_imag_rank = freqs_imag[(sp_rank * s) : ((sp_rank + 1) * s), :, :] + + out_real = x_real * freqs_real_rank - x_imag * freqs_imag_rank + out_imag = x_real * freqs_imag_rank + x_imag * freqs_real_rank + + x_out = torch.stack([out_real, out_imag], dim=-1).flatten(2) + x_out = torch.cat([x_out, x[i, seq_len:]], dim=0) + + # append to collection + output.append(x_out) + + return torch.stack(output) + + +@triton.jit +def rope_kernel( + x_ptr, # [B, S, N, 2C] + grid_sizes_ptr, # [B, 3] + freqs_real_ptr, # [M, C] + freqs_imag_ptr, # [M, C] + output_ptr, # [B, S, N, 2C] + sp_size, # SP world size + sp_rank, # SP rank + B, + S, + N: tl.constexpr, + C: tl.constexpr, + M: tl.constexpr, + CfM: tl.constexpr, + ChM: tl.constexpr, + CwM: tl.constexpr, + SEQ_BLOCK: tl.constexpr, + HEADS_BLOCK: tl.constexpr, +): + Cf = C - 2 * (C // 3) + Ch = C // 3 + Cw = C // 3 + + batch_idx = tl.program_id(0) + seqlen_group_idx = tl.program_id(1) + head_group_idx = tl.program_id(2) + + base = batch_idx * 3 + F = tl.load(grid_sizes_ptr + base + 0) + H = tl.load(grid_sizes_ptr + base + 1) + W = tl.load(grid_sizes_ptr + base + 2) + seq_len = F * H * W + + global_offset = sp_rank * S + seqlen_group_idx * SEQ_BLOCK + seq_indices = global_offset + tl.arange(0, SEQ_BLOCK) + + limit = tl.minimum(seq_len, S * sp_size) + seq_mask = seq_indices < limit + seq_indices = tl.where(seq_mask, seq_indices, 0) + + HW = H * W + f_idx = seq_indices // HW + rem = seq_indices - f_idx * HW + h_idx = rem // W + w_idx = rem - h_idx * W + + freq_offset_cf = tl.arange(0, CfM) # 第1段列偏移 [0, Cf) + freq_offset_ch = Cf + tl.arange(0, ChM) # 第2段列偏移 [Cf, Cf+Ch) + freq_offset_cw = Cf + Ch + tl.arange(0, CwM) # 第3段列偏移 [Cf+Ch, C) + # 按照每个序列位置取对应频率值 (利用广播计算每个位置不同行的值) + # 频率表取值地址 = idx * C + col_offset + freq_addr_cf = f_idx[:, None] * C + freq_offset_cf[None, :] + freq_addr_ch = h_idx[:, None] * C + freq_offset_ch[None, :] + freq_addr_cw = w_idx[:, None] * C + freq_offset_cw[None, :] + + freqs_real_cf = tl.load( + freqs_real_ptr + freq_addr_cf, + mask=(seq_mask[:, None] & (freq_offset_cf[None, :] < Cf)), + other=1.0, + ).to(tl.float32) + freqs_imag_cf = tl.load( + freqs_imag_ptr + freq_addr_cf, + mask=(seq_mask[:, None] & (freq_offset_cf[None, :] < Cf)), + other=1.0, + ).to(tl.float32) + freqs_real_ch = tl.load( + freqs_real_ptr + freq_addr_ch, + mask=(seq_mask[:, None] & (freq_offset_ch[None, :] < Cf + Ch)), + other=1.0, + ).to(tl.float32) + freqs_imag_ch = tl.load( + freqs_imag_ptr + freq_addr_ch, + mask=(seq_mask[:, None] & (freq_offset_ch[None, :] < Cf + Ch)), + other=1.0, + ).to(tl.float32) + freqs_real_cw = tl.load( + freqs_real_ptr + freq_addr_cw, + mask=(seq_mask[:, None] & (freq_offset_cw[None, :] < C)), + other=1.0, + ).to(tl.float32) + freqs_imag_cw = tl.load( + freqs_imag_ptr + freq_addr_cw, + mask=(seq_mask[:, None] & (freq_offset_cw[None, :] < C)), + other=1.0, + ).to(tl.float32) + + # 将频率值扩展维度以便与x相乘 (在head维度上广播) + freqs_real_cf = freqs_real_cf[:, None, :] # [SEQ_BLOCK, 1, Cf] + freqs_imag_cf = freqs_imag_cf[:, None, :] + freqs_real_ch = freqs_real_ch[:, None, :] + freqs_imag_ch = freqs_imag_ch[:, None, :] + freqs_real_cw = freqs_real_cw[:, None, :] + freqs_imag_cw = freqs_imag_cw[:, None, :] + + # 加载输入x对应块的实部和虚部 (形状: [SEQ_BLOCK, HEADS_BLOCK, C]) + seq_offset = seqlen_group_idx * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK) + head_offset = head_group_idx * HEADS_BLOCK + tl.arange(0, HEADS_BLOCK) + # 计算x_ptr偏移地址 + base_offset = batch_idx * S * N * 2 * C + seq_head_offset = ( + base_offset + + seq_offset[:, None, None] * (N * 2 * C) + + head_offset[None, :, None] * (2 * C) + ) + x_mask = (seq_offset < S)[:, None, None] & (head_offset < N)[None, :, None] + + # 加载输入 x 的对应通道段数据,超出实际长度部分掩码为0 + # 段1:通道 [0, Cf-1] + chan_cf = tl.arange(0, CfM * 2) + mask_2cf_chan = chan_cf < Cf * 2 + x_cf = tl.load( + x_ptr + seq_head_offset + chan_cf[None, None, :], + mask=(x_mask & mask_2cf_chan[None, None, :]), + other=0.0, + ).to(tl.float32) + x_cf = x_cf.reshape( + SEQ_BLOCK, HEADS_BLOCK, CfM, 2 + ) # [SEQ_BLOCK, HEADS_BLOCK, CfM, 2] + x_real_cf, x_imag_cf = x_cf.split() + + # 计算 RoPE 旋转(段1) + out_real_cf = x_real_cf * freqs_real_cf - x_imag_cf * freqs_imag_cf + out_imag_cf = x_real_cf * freqs_imag_cf + x_imag_cf * freqs_real_cf + + out_cf = tl.interleave(out_real_cf, out_imag_cf) # [SEQ_BLOCK, HEADS_BLOCK, CfM, 2] + tl.store( + output_ptr + seq_head_offset + chan_cf[None, None, :], + out_cf, + mask=(x_mask & mask_2cf_chan[None, None, :]), + ) + + # 段2:通道 [Cf, Cf+Ch-1] + chan_ch = tl.arange(0, ChM * 2) + Cf * 2 + mask_2ch_chan = chan_ch < 2 * (Cf + Ch) + x_ch = tl.load( + x_ptr + seq_head_offset + chan_ch[None, None, :], + mask=(x_mask & mask_2ch_chan[None, None, :]), + other=0.0, + ).to(tl.float32) + x_ch = x_ch.reshape(SEQ_BLOCK, HEADS_BLOCK, ChM, 2) + x_real_ch, x_imag_ch = x_ch.split() + out_real_ch = x_real_ch * freqs_real_ch - x_imag_ch * freqs_imag_ch + out_imag_ch = x_real_ch * freqs_imag_ch + x_imag_ch * freqs_real_ch + + out_ch = tl.interleave(out_real_ch, out_imag_ch) # [SEQ_BLOCK, HEADS_BLOCK, ChM, 2] + tl.store( + output_ptr + seq_head_offset + chan_ch[None, None, :], + out_ch, + mask=(x_mask & mask_2ch_chan[None, None, :]), + ) + + # 段3:通道 [Cf+Ch, C-1] + chan_cw = tl.arange(0, CwM * 2) + (Cf + Ch) * 2 + mask_2cw_chan = chan_cw < 2 * C + x_cw = tl.load( + x_ptr + seq_head_offset + chan_cw[None, None, :], + mask=(x_mask & mask_2cw_chan[None, None, :]), + other=0.0, + ).to(tl.float32) + x_cw = x_cw.reshape(SEQ_BLOCK, HEADS_BLOCK, CwM, 2) + x_real_cw, x_imag_cw = x_cw.split() + out_real_cw = x_real_cw * freqs_real_cw - x_imag_cw * freqs_imag_cw + out_imag_cw = x_real_cw * freqs_imag_cw + x_imag_cw * freqs_real_cw + + out_cw = tl.interleave(out_real_cw, out_imag_cw) + tl.store( + output_ptr + seq_head_offset + chan_cw[None, None, :], + out_cw, + mask=(x_mask & mask_2cw_chan[None, None, :]), + ) + + +@torch._dynamo.disable +def rope_apply_triton( + x: torch.tensor, + grid_sizes: torch.tensor, + freqs: Tuple[torch.tensor], + sp_size: Optional[int] = None, + sp_rank: Optional[int] = None, +) -> torch.tensor: + """ + x: [1, 9450, 40, 128] + grid_sizes: [[21, 45, 80]] + freqs_real: [1024, 64] + freqs_imag: [1024, 64] + """ + B, S, N, C = x.shape + C = C // 2 + Cf = C - 2 * (C // 3) # 第一维度频率长度 + Ch = C // 3 # 第二维度频率长度 + Cw = C // 3 # 第三维度频率长度 + M = freqs[0].shape[0] + + SEQ_BLOCK = 64 # 每个线程块处理的序列长度 + HEADS_BLOCK = 8 # 每个线程块处理的头数 + + if sp_rank is None: + sp_size = 1 + sp_rank = 0 + + grid_sizes = grid_sizes.to(device=x.device) + output = torch.empty_like(x) + + rope_kernel[(B, triton.cdiv(S, SEQ_BLOCK), triton.cdiv(N, HEADS_BLOCK))]( + x, + grid_sizes, + freqs[0], + freqs[-1], + output, + sp_size, + sp_rank, + B, + S, + N=N, + C=C, + M=M, + CfM=triton.next_power_of_2(Cf), + ChM=triton.next_power_of_2(Ch), + CwM=triton.next_power_of_2(Cw), + SEQ_BLOCK=SEQ_BLOCK, + HEADS_BLOCK=HEADS_BLOCK, + num_warps=32, + num_stages=3, + ) + + return output.float() diff --git a/wan/modules/t5.py b/wan/modules/t5.py index c841b04..80ded5a 100644 --- a/wan/modules/t5.py +++ b/wan/modules/t5.py @@ -7,7 +7,13 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .tokenizers import HuggingfaceTokenizer +try: + import torch_musa +except ModuleNotFoundError: + torch_musa = None + +from wan.modules.tokenizers import HuggingfaceTokenizer +from wan.utils.platform import get_device __all__ = [ 'T5Model', @@ -59,10 +65,8 @@ class T5LayerNorm(nn.Module): self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): - x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + + x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) - if self.weight.dtype in [torch.float16, torch.bfloat16]: - x = x.type_as(self.weight) return self.weight * x @@ -110,7 +114,7 @@ class T5Attention(nn.Module): # compute attention (T5 does not use scaling) attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias - attn = F.softmax(attn.float(), dim=-1).type_as(attn) + attn = F.softmax(attn, dim=-1) x = torch.einsum('bnij,bjnc->binc', attn, v) # output @@ -255,7 +259,7 @@ class T5RelativeEmbedding(nn.Module): # embeddings for small and large positions max_exact = num_buckets // 2 - rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / + rel_pos_large = max_exact + (torch.log(rel_pos / max_exact) / math.log(self.max_dist / max_exact) * (num_buckets - max_exact)).long() rel_pos_large = torch.min( @@ -475,7 +479,7 @@ class T5EncoderModel: self, text_len, dtype=torch.bfloat16, - device=torch.cuda.current_device(), + device=get_device(), checkpoint_path=None, tokenizer_path=None, shard_fn=None, diff --git a/wan/modules/vace_model.py b/wan/modules/vace_model.py index a12d1dd..8913c2e 100644 --- a/wan/modules/vace_model.py +++ b/wan/modules/vace_model.py @@ -4,6 +4,12 @@ import torch.cuda.amp as amp import torch.nn as nn from diffusers.configuration_utils import register_to_config +try: + import torch_musa + import torch_musa.core.amp as amp +except ModuleNotFoundError: + torch_musa = None + from .model import WanAttentionBlock, WanModel, sinusoidal_embedding_1d diff --git a/wan/modules/vae.py b/wan/modules/vae.py index 5c6da57..d9713cb 100644 --- a/wan/modules/vae.py +++ b/wan/modules/vae.py @@ -1,12 +1,21 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import logging +from math import sqrt import torch -import torch.cuda.amp as amp import torch.nn as nn import torch.nn.functional as F +from torch.nn import Upsample from einops import rearrange +try: + import torch_musa + torch.backends.mudnn.allow_tf32 = True +except ModuleNotFoundError: + torch_musa = None + +from wan.utils.platform import get_device_type + __all__ = [ 'WanVAE', ] @@ -44,23 +53,17 @@ class RMS_norm(nn.Module): shape = (dim, *broadcastable_dims) if channel_first else (dim,) self.channel_first = channel_first - self.scale = dim**0.5 + self.scale = sqrt(dim) self.gamma = nn.Parameter(torch.ones(shape)) self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0. def forward(self, x): - return F.normalize( - x, dim=(1 if self.channel_first else - -1)) * self.scale * self.gamma + self.bias - - -class Upsample(nn.Upsample): - - def forward(self, x): - """ - Fix bfloat16 support for nearest neighbor interpolation. - """ - return super().forward(x.float()).type_as(x) + return ( + F.normalize(x.float(), dim=(1 if self.channel_first else -1)).type_as(x) + * self.scale + * self.gamma + + self.bias + ) class Resample(nn.Module): @@ -253,6 +256,10 @@ class AttentionBlock(nn.Module): q, k, v, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, ) x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) @@ -621,8 +628,8 @@ class WanVAE: def __init__(self, z_dim=16, vae_pth='cache/vae_step_411000.pth', - dtype=torch.float, - device="cuda"): + dtype=torch.bfloat16, + device=get_device_type()): self.dtype = dtype self.device = device @@ -648,16 +655,12 @@ class WanVAE: """ videos: A list of videos each with shape [C, T, H, W]. """ - with amp.autocast(dtype=self.dtype): - return [ - self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) - for u in videos - ] + return [ + self.model.encode(u.unsqueeze(0), self.scale).squeeze(0) for u in videos + ] def decode(self, zs): - with amp.autocast(dtype=self.dtype): - return [ - self.model.decode(u.unsqueeze(0), - self.scale).float().clamp_(-1, 1).squeeze(0) - for u in zs - ] + return [ + self.model.decode(u.unsqueeze(0), self.scale).clamp_(-1, 1).squeeze(0) + for u in zs + ] diff --git a/wan/text2video.py b/wan/text2video.py index c518b61..cdaa3d7 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -11,19 +11,31 @@ from functools import partial import torch import torch.cuda.amp as amp +from torch.cuda import empty_cache, synchronize import torch.distributed as dist from tqdm import tqdm -from .distributed.fsdp import shard_model -from .modules.model import WanModel -from .modules.t5 import T5EncoderModel -from .modules.vae import WanVAE -from .utils.fm_solvers import ( +try: + import torch_musa + import torch_musa.core.amp as amp + from torch_musa.core.memory import empty_cache + from torch_musa.core.device import synchronize + torch.backends.mudnn.allow_tf32 = True +except ModuleNotFoundError: + torch_musa = None + +from wan.distributed.fsdp import shard_model +from wan.modules.model import WanModel +from wan.modules.t5 import T5EncoderModel +from wan.modules.vae import WanVAE +from wan.utils.fm_solvers import ( FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps, ) -from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from wan.utils.platform import get_device +from wan.utils.memory_format import convert_conv3d_weight_memory_format class WanT2V: @@ -60,7 +72,7 @@ class WanT2V: t5_cpu (`bool`, *optional*, defaults to False): Whether to place T5 model on CPU. Only works without t5_fsdp. """ - self.device = torch.device(f"cuda:{device_id}") + self.device = get_device(device_id) self.config = config self.rank = rank self.t5_cpu = t5_cpu @@ -82,6 +94,7 @@ class WanT2V: self.vae = WanVAE( vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), device=self.device) + convert_conv3d_weight_memory_format(self.vae.model, memory_format=torch.channels_last_3d) logging.info(f"Creating WanModel from {checkpoint_dir}") self.model = WanModel.from_pretrained(checkpoint_dir) @@ -103,7 +116,8 @@ class WanT2V: self.sp_size = 1 if dist.is_initialized(): - dist.barrier() + # dist.barrier() + pass if dit_fsdp: self.model = shard_fn(self.model) else: @@ -230,13 +244,13 @@ class WanT2V: arg_c = {'context': context, 'seq_len': seq_len} arg_null = {'context': context_null, 'seq_len': seq_len} + self.model.to(self.device) for _, t in enumerate(tqdm(timesteps)): latent_model_input = latents timestep = [t] timestep = torch.stack(timestep) - self.model.to(self.device) noise_pred_cond = self.model( latent_model_input, t=timestep, **arg_c)[0] noise_pred_uncond = self.model( @@ -256,7 +270,7 @@ class WanT2V: x0 = latents if offload_model: self.model.cpu() - torch.cuda.empty_cache() + empty_cache() if self.rank == 0: videos = self.vae.decode(x0) @@ -264,8 +278,9 @@ class WanT2V: del sample_scheduler if offload_model: gc.collect() - torch.cuda.synchronize() + synchronize() if dist.is_initialized(): - dist.barrier() + # dist.barrier() + pass return videos[0] if self.rank == 0 else None diff --git a/wan/utils/__init__.py b/wan/utils/__init__.py index 2e9b33d..179648c 100644 --- a/wan/utils/__init__.py +++ b/wan/utils/__init__.py @@ -5,9 +5,13 @@ from .fm_solvers import ( ) from .fm_solvers_unipc import FlowUniPCMultistepScheduler from .vace_processor import VaceVideoProcessor +from .platform import get_device, get_device_type, get_torch_distributed_backend +from .memory_format import convert_conv3d_weight_memory_format +from .chrono_inspector import ChronoInspector __all__ = [ 'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps', 'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler', - 'VaceVideoProcessor' + 'VaceVideoProcessor', 'get_device', 'get_device_type', 'get_torch_distributed_backend', + 'convert_conv3d_weight_memory_format', 'ChronoInspector' ] diff --git a/wan/utils/chrono_inspector.py b/wan/utils/chrono_inspector.py new file mode 100644 index 0000000..f3f096e --- /dev/null +++ b/wan/utils/chrono_inspector.py @@ -0,0 +1,15 @@ +from time import perf_counter +from logging import info + + +class ChronoInspector(object): + def __init__(self, name:str="Block"): + self.name = name + + def __enter__(self): + self.start_time:float = perf_counter() + return self # 可选:返回 self 以获取更多信息 + + def __exit__(self, exc_type, exc_val, exc_tb): + end_time:float = perf_counter() + info(f"[{self.name}] Elapsed time: {end_time - self.start_time:.2f} seconds") diff --git a/wan/utils/memory_format.py b/wan/utils/memory_format.py new file mode 100644 index 0000000..a4b19ba --- /dev/null +++ b/wan/utils/memory_format.py @@ -0,0 +1,76 @@ +import torch + + +def convert_conv3d_weight_memory_format(module:torch.nn.Module, memory_format:torch.memory_format): + r"""Convert ``memory_format`` of ``nn.Conv3d.weight`` to ``memory_format`` + The conversion recursively applies to nested ``nn.Module``, including ``module``. + Note that it only changes the memory_format, but not the semantics of each dimensions. + This function is used to facilitate the computation to adopt NHWC kernels, which + provides considerable speed up for fp16 data on CUDA devices with compute capability >= 7.0 + + .. note:: + Calling ``model.to(memory_format=torch.channels_last_3d)`` is more aggressive + than the utility function ``convert_conv3d_weight_memory_format``. Any + layer with 4d weight will be affected by ``model.to``, which does not + necessarily benefit from conversion to specified ``memory_format``. + One place we are confident in is that NDHWC(channels_last_3d) conversion for + convolution in cuDNN, as it is beneficial to run convolution in NDHWC, + even in cases where we have to apply permutation to input tensors. + + Hence our strategy here is to convert only the weight of convolution to + channels_last_3d. This ensures that; + 1. Fast convolution kernels will be used, the benefit of which could + outweigh overhead of permutation (if input is not in the same format). + 2. No unnecessary permutations are applied on layers that do not benefit + from memory_format conversion. + + The optimal case is that, layers between convolution layers are channels + last compatible. Input tensor would be permuted to channels last when it + encounters the first convolution layer and stay in that memory format. + Hence following convolutions will not need to permute its input tensor. + + In case where a channels last incompatible layer is between convolution + layers, we need to permute the input tensor back to contiguous format + for that layer. The input tensor will go through the remaining layers in + contiguous format and be permuted to channels last when it encounters + another convolution layer. There's no point in propagating that + permutation to an earlier layer, as most layers are quite agnostic to + ``memory_format``. + + This claim might change when PyTorch supports fusion of permutation, as + there might have been a better spot to fuse the permutation other than + immediately before a convolution. + + Args: + module (nn.Module): ``nn.Conv3d`` & ``nn.ConvTranspose3d`` or container + ``nn.Module`` + memory_format: user specified ``memory_format``, + e.g. ``torch.channels_last`` or ``torch.contiguous_format`` + + Returns: + The original module with updated ``nn.Conv3d`` + + Example: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> # xdoctest: +REQUIRES(env:CUBLAS_WORKSPACE_CONFIG) + >>> input = torch.randint(1, 10, (2, 8, 4, 4, 4), dtype=torch.float16, device="cuda") + >>> model = nn.Sequential( + >>> nn.Conv3d(8, 4, 3)).cuda().half() + >>> # This is identical to: + >>> # nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d) + >>> model = nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d) + >>> out = model(input) + """ + + # TODO: expand this to `_ConvNd` when channels_last support is extended + # beyond only 4d tensors. + if isinstance(module, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)): + weight_data = ( + module.weight.detach().clone().contiguous(memory_format=memory_format) + ) + module.weight.data = weight_data.resize_( + weight_data.size(), memory_format=memory_format + ) + for child in module.children(): + convert_conv3d_weight_memory_format(child, memory_format) + return module diff --git a/wan/utils/platform.py b/wan/utils/platform.py new file mode 100644 index 0000000..1f08530 --- /dev/null +++ b/wan/utils/platform.py @@ -0,0 +1,50 @@ +from typing import Optional + +import torch + +try: + import torch_musa +except ModuleNotFoundError: + torch_musa = None + + +def _is_musa(): + if torch_musa is None: + return False + else: + return True + + +def get_device(local_rank: Optional[int] = None) -> torch.device: + if torch.cuda.is_available(): + return ( + torch.cuda.current_device() + if local_rank is None + else torch.device("cuda", local_rank) + ) + elif _is_musa(): + return ( + torch.musa.current_device() + if local_rank is None + else torch.device("musa", local_rank) + ) + else: + return torch.device("cpu") + + +def get_device_type() -> str: + if torch.cuda.is_available(): + return "cuda" + elif _is_musa(): + return "musa" + else: + return "cpu" + + +def get_torch_distributed_backend() -> str: + if torch.cuda.is_available(): + return "nccl" + elif _is_musa(): + return "mccl" + else: + raise NotImplementedError("No Accelerators(NV/MTT GPU) available") diff --git a/wan/vace.py b/wan/vace.py index 8a4f744..6c5be34 100644 --- a/wan/vace.py +++ b/wan/vace.py @@ -13,6 +13,7 @@ from functools import partial import torch import torch.cuda.amp as amp +from torch.cuda import empty_cache, synchronize import torch.distributed as dist import torch.multiprocessing as mp import torch.nn.functional as F @@ -20,8 +21,17 @@ import torchvision.transforms.functional as TF from PIL import Image from tqdm import tqdm -from .modules.vace_model import VaceWanModel -from .text2video import ( +try: + import torch_musa + import torch_musa.core.amp as amp + from torch_musa.core.memory import empty_cache + from torch_musa.core.device import synchronize + torch.backends.mudnn.allow_tf32 = True +except ModuleNotFoundError: + torch_musa = None + +from wan.modules.vace_model import VaceWanModel +from wan.text2video import ( FlowDPMSolverMultistepScheduler, FlowUniPCMultistepScheduler, T5EncoderModel, @@ -31,7 +41,9 @@ from .text2video import ( retrieve_timesteps, shard_model, ) -from .utils.vace_processor import VaceVideoProcessor +from wan.utils.vace_processor import VaceVideoProcessor +from wan.utils.platform import get_device, get_torch_distributed_backend +from wan.utils.memory_format import convert_conv3d_weight_memory_format class WanVace(WanT2V): @@ -68,7 +80,7 @@ class WanVace(WanT2V): t5_cpu (`bool`, *optional*, defaults to False): Whether to place T5 model on CPU. Only works without t5_fsdp. """ - self.device = torch.device(f"cuda:{device_id}") + self.device = get_device(device_id) self.config = config self.rank = rank self.t5_cpu = t5_cpu @@ -90,6 +102,7 @@ class WanVace(WanT2V): self.vae = WanVAE( vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), device=self.device) + convert_conv3d_weight_memory_format(self.vae.model, memory_format=torch.channels_last_3d) logging.info(f"Creating VaceWanModel from {checkpoint_dir}") self.model = VaceWanModel.from_pretrained(checkpoint_dir) @@ -117,7 +130,8 @@ class WanVace(WanT2V): self.sp_size = 1 if dist.is_initialized(): - dist.barrier() + # dist.barrier() + pass if dit_fsdp: self.model = shard_fn(self.model) else: @@ -460,7 +474,7 @@ class WanVace(WanT2V): x0 = latents if offload_model: self.model.cpu() - torch.cuda.empty_cache() + empty_cache() if self.rank == 0: videos = self.decode_latent(x0, input_ref_images) @@ -468,9 +482,10 @@ class WanVace(WanT2V): del sample_scheduler if offload_model: gc.collect() - torch.cuda.synchronize() + synchronize() if dist.is_initialized(): - dist.barrier() + # dist.barrier() + pass return videos[0] if self.rank == 0 else None @@ -568,7 +583,7 @@ class WanVaceMP(WanVace): torch.cuda.set_device(gpu) dist.init_process_group( - backend='nccl', + backend=get_torch_distributed_backend(), init_method='env://', rank=rank, world_size=world_size) @@ -629,11 +644,11 @@ class WanVaceMP(WanVace): else: sp_size = 1 - dist.barrier() + # dist.barrier() model = shard_fn(model) sample_neg_prompt = self.config.sample_neg_prompt - torch.cuda.empty_cache() + empty_cache() event = initialized_events[gpu] in_q = in_q_list[gpu] event.set() @@ -748,7 +763,7 @@ class WanVaceMP(WanVace): generator=seed_g)[0] latents = [temp_x0.squeeze(0)] - torch.cuda.empty_cache() + empty_cache() x0 = latents if rank == 0: videos = self.decode_latent( @@ -758,9 +773,10 @@ class WanVaceMP(WanVace): del sample_scheduler if offload_model: gc.collect() - torch.cuda.synchronize() + synchronize() if dist.is_initialized(): - dist.barrier() + # dist.barrier() + pass if rank == 0: out_q.put(videos[0].cpu())