diff --git a/.gitignore b/.gitignore index deeef7a..38c5296 100644 --- a/.gitignore +++ b/.gitignore @@ -21,7 +21,7 @@ *.html *.pdf *.whl -cache +*cache/ __pycache__/ storage/ samples/ @@ -29,9 +29,11 @@ samples/ !requirements.txt .DS_Store *DS_Store +.vscode google/ Wan2.1-T2V-14B/ Wan2.1-T2V-1.3B/ Wan2.1-I2V-14B-480P/ Wan2.1-I2V-14B-720P/ -poetry.lock \ No newline at end of file +poetry.lock +logs/ diff --git a/generate.py b/generate.py index c841c19..aab7d7d 100644 --- a/generate.py +++ b/generate.py @@ -12,12 +12,25 @@ import random import torch import torch.distributed as dist +from torch.cuda import set_device from PIL import Image +try: + import torch_musa + from torch_musa.core.device import set_device +except ModuleNotFoundError: + torch_musa = None + import wan from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander from wan.utils.utils import cache_image, cache_video, str2bool +from wan.utils.platform import ( + get_device_type, + get_torch_distributed_backend, + get_torch_profiler_activities, +) + EXAMPLE_PROMPT = { @@ -243,6 +256,11 @@ def _parse_args(): type=float, default=5.0, help="Classifier free guidance scale.") + parser.add_argument( + "--profile", + action="store_true", + default=False, + help="profile the generating procedure.") args = parser.parse_args() @@ -263,6 +281,30 @@ def _init_logging(rank): logging.basicConfig(level=logging.ERROR) +def _init_profiler(): + profiler = torch.profiler.profile( + activities=get_torch_profiler_activities(), + on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs'), + record_shapes=True, + profile_memory=True, + with_stack=True, + ) + profiler.start() + return profiler + + +def _finalize_profiler(profiler): + profiler.stop() + table = profiler.key_averages().table( + sort_by=f"{get_device_type()}_time_total", + row_limit=20, + ) + file_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + with open(f"logs/profiling-{file_name}.txt", "w") as f: + f.write(table) + del file_name + + def generate(args): rank = int(os.getenv("RANK", 0)) world_size = int(os.getenv("WORLD_SIZE", 1)) @@ -275,9 +317,9 @@ def generate(args): logging.info( f"offload_model is not specified, set to {args.offload_model}.") if world_size > 1: - torch.cuda.set_device(local_rank) + set_device(local_rank) dist.init_process_group( - backend="nccl", + backend=get_torch_distributed_backend(), init_method="env://", rank=rank, world_size=world_size) @@ -329,6 +371,10 @@ def generate(args): base_seed = [args.base_seed] if rank == 0 else [None] dist.broadcast_object_list(base_seed, src=0) args.base_seed = base_seed[0] + + profiler = None + if args.profile and rank == 0: + profiler = _init_profiler() if "t2v" in args.task or "t2i" in args.task: if args.prompt is None: @@ -366,10 +412,23 @@ def generate(args): dit_fsdp=args.dit_fsdp, use_usp=(args.ulysses_size > 1 or args.ring_size > 1), t5_cpu=args.t5_cpu, + profiler=profiler, ) - logging.info( - f"Generating {'image' if 't2i' in args.task else 'video'} ...") + logging.info("Warming up WanT2V pipeline ...") + with torch.no_grad(): + _ = wan_t2v.generate( + args.prompt, + size=SIZE_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=3, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model) + + logging.info(f"Generating {'image' if 't2i' in args.task else 'video'} ...") video = wan_t2v.generate( args.prompt, size=SIZE_CONFIGS[args.size], @@ -423,8 +482,23 @@ def generate(args): dit_fsdp=args.dit_fsdp, use_usp=(args.ulysses_size > 1 or args.ring_size > 1), t5_cpu=args.t5_cpu, + profiler=profiler, ) + logging.info("Warming up WanI2V pipeline ...") + with torch.no_grad(): + _ = wan_i2v.generate( + args.prompt, + img, + max_area=MAX_AREA_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=3, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model) + logging.info("Generating video ...") video = wan_i2v.generate( args.prompt, @@ -481,8 +555,24 @@ def generate(args): dit_fsdp=args.dit_fsdp, use_usp=(args.ulysses_size > 1 or args.ring_size > 1), t5_cpu=args.t5_cpu, + profiler=profiler ) + logging.info("Warming up WanFLF2V pipeline ...") + with torch.no_grad(): + _ = wan_flf2v.generate( + args.prompt, + first_frame, + last_frame, + max_area=MAX_AREA_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=3, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model) + logging.info("Generating video ...") video = wan_flf2v.generate( args.prompt, @@ -529,6 +619,7 @@ def generate(args): dit_fsdp=args.dit_fsdp, use_usp=(args.ulysses_size > 1 or args.ring_size > 1), t5_cpu=args.t5_cpu, + profiler=profiler ) src_video, src_mask, src_ref_images = wan_vace.prepare_source( @@ -537,6 +628,22 @@ def generate(args): args.src_ref_images.split(',') ], args.frame_num, SIZE_CONFIGS[args.size], device) + logging.info("Warming up VACE pipeline ...") + with torch.no_grad(): + _ = wan_vace.generate( + args.prompt, + src_video, + src_mask, + src_ref_images, + size=SIZE_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=3, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model) + logging.info(f"Generating video...") video = wan_vace.generate( args.prompt, @@ -554,6 +661,9 @@ def generate(args): else: raise ValueError(f"Unkown task type: {args.task}") + if args.profile and rank == 0: + _finalize_profiler(profiler) + if rank == 0: if args.save_file is None: formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S") diff --git a/wan/distributed/fsdp.py b/wan/distributed/fsdp.py index 6bb496d..ca12048 100644 --- a/wan/distributed/fsdp.py +++ b/wan/distributed/fsdp.py @@ -3,11 +3,17 @@ import gc from functools import partial import torch +from torch.cuda import empty_cache from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import MixedPrecision, ShardingStrategy from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy from torch.distributed.utils import _free_storage +try: + import torch_musa + from torch_musa.core.memory import empty_cache +except ModuleNotFoundError: + torch_musa = None def shard_model( model, @@ -40,4 +46,4 @@ def free_model(model): _free_storage(m._handle.flat_param.data) del model gc.collect() - torch.cuda.empty_cache() + empty_cache() diff --git a/wan/distributed/xdit_context_parallel.py b/wan/distributed/xdit_context_parallel.py index 4718577..765c95e 100644 --- a/wan/distributed/xdit_context_parallel.py +++ b/wan/distributed/xdit_context_parallel.py @@ -6,7 +6,18 @@ from xfuser.core.distributed import ( get_sequence_parallel_world_size, get_sp_group, ) -from xfuser.core.long_ctx_attention import xFuserLongContextAttention +from xfuser.core.long_ctx_attention import xFuserLongContextAttention, AttnType +attn_type:AttnType = AttnType.FA + +from wan.modules.rope import rope_apply_pytorch, rope_apply_triton + +try: + import torch_musa + import torch_musa.core.amp as amp + attn_type = AttnType.TORCH + torch.backends.mudnn.allow_tf32 = True +except ImportError: + torch_musa = None from ..modules.model import sinusoidal_embedding_1d @@ -25,7 +36,7 @@ def pad_freqs(original_tensor, target_len): @amp.autocast(enabled=False) -def rope_apply(x, grid_sizes, freqs): +def rope_apply(x, grid_sizes, freqs, sp_size, sp_rank): """ x: [B, L, N, C]. grid_sizes: [B, 3]. @@ -51,8 +62,6 @@ def rope_apply(x, grid_sizes, freqs): dim=-1).reshape(seq_len, 1, -1) # apply rotary embedding - sp_size = get_sequence_parallel_world_size() - sp_rank = get_sequence_parallel_rank() freqs_i = pad_freqs(freqs_i, s * sp_size) s_per_rank = s freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) * @@ -109,9 +118,13 @@ def usp_dit_forward( if self.model_type == 'i2v': assert clip_fea is not None and y is not None # params + dtype = self.patch_embedding.weight.dtype device = self.patch_embedding.weight.device - if self.freqs.device != device: - self.freqs = self.freqs.to(device) + if self.freqs[0].dtype != dtype or self.freqs[0].device != device: + self.freqs = ( + self.freqs[0].to(dtype=dtype, device=device), + self.freqs[-1].to(dtype=dtype, device=device) + ) if self.model_type != 'vace' and y is not None: x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] @@ -129,11 +142,9 @@ def usp_dit_forward( ]) # time embeddings - with amp.autocast(dtype=torch.float32): - e = self.time_embedding( - sinusoidal_embedding_1d(self.freq_dim, t).float()) - e0 = self.time_projection(e).unflatten(1, (6, self.dim)) - assert e.dtype == torch.float32 and e0.dtype == torch.float32 + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t)) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # context context_lens = None @@ -177,7 +188,7 @@ def usp_dit_forward( # unpatchify x = self.unpatchify(x, grid_sizes) - return [u.float() for u in x] + return x def usp_attn_forward(self, @@ -200,8 +211,12 @@ def usp_attn_forward(self, return q, k, v q, k, v = qkv_fn(x) - q = rope_apply(q, grid_sizes, freqs) - k = rope_apply(k, grid_sizes, freqs) + if torch_musa is None: + q = rope_apply(q, grid_sizes, freqs, get_sequence_parallel_world_size(), get_sequence_parallel_rank()) + k = rope_apply(k, grid_sizes, freqs, get_sequence_parallel_world_size(), get_sequence_parallel_rank()) + else: + q = rope_apply_pytorch(q, grid_sizes, freqs, get_sequence_parallel_world_size(), get_sequence_parallel_rank()) + k = rope_apply_pytorch(k, grid_sizes, freqs, get_sequence_parallel_world_size(), get_sequence_parallel_rank()) # TODO: We should use unpaded q,k,v for attention. # k_lens = seq_lens // get_sequence_parallel_world_size() @@ -210,7 +225,7 @@ def usp_attn_forward(self, # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0) # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0) - x = xFuserLongContextAttention()( + x = xFuserLongContextAttention(attn_type=attn_type)( None, query=half(q), key=half(k), diff --git a/wan/first_last_frame2video.py b/wan/first_last_frame2video.py index 232950f..1f42101 100644 --- a/wan/first_last_frame2video.py +++ b/wan/first_last_frame2video.py @@ -6,27 +6,40 @@ import os import random import sys import types +from time import perf_counter from contextlib import contextmanager from functools import partial import numpy as np import torch import torch.cuda.amp as amp +from torch.cuda import empty_cache, synchronize import torch.distributed as dist import torchvision.transforms.functional as TF from tqdm import tqdm -from .distributed.fsdp import shard_model -from .modules.clip import CLIPModel -from .modules.model import WanModel -from .modules.t5 import T5EncoderModel -from .modules.vae import WanVAE -from .utils.fm_solvers import ( +try: + import torch_musa + import torch_musa.core.amp as amp + from torch_musa.core.memory import empty_cache + from torch_musa.core.device import synchronize + torch.backends.mudnn.allow_tf32 = True +except ModuleNotFoundError: + torch_musa = None + +from wan.distributed.fsdp import shard_model +from wan.modules.clip import CLIPModel +from wan.modules.model import WanModel +from wan.modules.t5 import T5EncoderModel +from wan.modules.vae import WanVAE +from wan.utils.fm_solvers import ( FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps, ) -from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from wan.utils.platform import get_device +from wan.utils.memory_format import convert_conv3d_weight_memory_format class WanFLF2V: @@ -42,6 +55,7 @@ class WanFLF2V: use_usp=False, t5_cpu=False, init_on_cpu=True, + profiler=None, ): r""" Initializes the image-to-video generation model components. @@ -66,7 +80,7 @@ class WanFLF2V: init_on_cpu (`bool`, *optional*, defaults to True): Enable initializing Transformer Model on CPU. Only works without FSDP or USP. """ - self.device = torch.device(f"cuda:{device_id}") + self.device = get_device(device_id) self.config = config self.rank = rank self.use_usp = use_usp @@ -90,6 +104,7 @@ class WanFLF2V: self.vae = WanVAE( vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), device=self.device) + convert_conv3d_weight_memory_format(self.vae.model, memory_format=torch.channels_last_3d) self.clip = CLIPModel( dtype=config.clip_dtype, @@ -121,7 +136,8 @@ class WanFLF2V: self.sp_size = 1 if dist.is_initialized(): - dist.barrier() + # dist.barrier() + pass if dit_fsdp: self.model = shard_fn(self.model) else: @@ -129,6 +145,7 @@ class WanFLF2V: self.model.to(self.device) self.sample_neg_prompt = config.sample_neg_prompt + self.profiler = profiler def generate(self, input_prompt, @@ -183,6 +200,11 @@ class WanFLF2V: - H: Frame height (from max_area) - W: Frame width from max_area) """ + start_time = 0.0 + end_time = 0.0 + if self.rank == 0: + start_time = perf_counter() + first_frame_size = first_frame.size last_frame_size = last_frame.size first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).to( @@ -275,6 +297,10 @@ class WanFLF2V: ])[0] y = torch.concat([msk, y]) + if self.rank == 0: + end_time = perf_counter() + logging.info(f"[preprocess and VAE encode] Elapsed time: {end_time - start_time:.2f} seconds") + @contextmanager def noop_no_sync(): yield @@ -323,10 +349,16 @@ class WanFLF2V: } if offload_model: - torch.cuda.empty_cache() + empty_cache() + + if self.rank == 0: + start_time = perf_counter() self.model.to(self.device) for _, t in enumerate(tqdm(timesteps)): + if self.profiler and self.rank == 0: + self.profiler.step() + latent_model_input = [latent.to(self.device)] timestep = [t] @@ -336,12 +368,12 @@ class WanFLF2V: latent_model_input, t=timestep, **arg_c)[0].to( torch.device('cpu') if offload_model else self.device) if offload_model: - torch.cuda.empty_cache() + empty_cache() noise_pred_uncond = self.model( latent_model_input, t=timestep, **arg_null)[0].to( torch.device('cpu') if offload_model else self.device) if offload_model: - torch.cuda.empty_cache() + empty_cache() noise_pred = noise_pred_uncond + guide_scale * ( noise_pred_cond - noise_pred_uncond) @@ -356,22 +388,30 @@ class WanFLF2V: generator=seed_g)[0] latent = temp_x0.squeeze(0) - x0 = [latent.to(self.device)] - del latent_model_input, timestep + if self.rank == 0: + end_time = perf_counter() + logging.info(f"[sampling time steps] Elapsed time: {end_time - start_time:.2f} seconds") + + x0 = [latent.to(self.device)] + del latent_model_input, timestep if offload_model: self.model.cpu() - torch.cuda.empty_cache() + empty_cache() if self.rank == 0: + start_time = perf_counter() videos = self.vae.decode(x0) + end_time = perf_counter() + logging.info(f"[VAE decoding] Elapsed time: {end_time - start_time:.2f} seconds") del noise, latent del sample_scheduler if offload_model: gc.collect() - torch.cuda.synchronize() + synchronize() if dist.is_initialized(): - dist.barrier() + # dist.barrier() + pass return videos[0] if self.rank == 0 else None diff --git a/wan/image2video.py b/wan/image2video.py index 6882c53..c385e8a 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -6,27 +6,40 @@ import os import random import sys import types +from time import perf_counter from contextlib import contextmanager from functools import partial import numpy as np import torch import torch.cuda.amp as amp +from torch.cuda import empty_cache, synchronize import torch.distributed as dist import torchvision.transforms.functional as TF from tqdm import tqdm -from .distributed.fsdp import shard_model -from .modules.clip import CLIPModel -from .modules.model import WanModel -from .modules.t5 import T5EncoderModel -from .modules.vae import WanVAE -from .utils.fm_solvers import ( +try: + import torch_musa + import torch_musa.core.amp as amp + from torch_musa.core.memory import empty_cache + from torch_musa.core.device import synchronize + torch.backends.mudnn.allow_tf32 = True +except ModuleNotFoundError: + torch_musa = None + +from wan.distributed.fsdp import shard_model +from wan.modules.clip import CLIPModel +from wan.modules.model import WanModel +from wan.modules.t5 import T5EncoderModel +from wan.modules.vae import WanVAE +from wan.utils.fm_solvers import ( FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps, ) -from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from wan.utils.platform import get_device +from wan.utils.memory_format import convert_conv3d_weight_memory_format class WanI2V: @@ -42,6 +55,7 @@ class WanI2V: use_usp=False, t5_cpu=False, init_on_cpu=True, + profiler=None, ): r""" Initializes the image-to-video generation model components. @@ -66,7 +80,7 @@ class WanI2V: init_on_cpu (`bool`, *optional*, defaults to True): Enable initializing Transformer Model on CPU. Only works without FSDP or USP. """ - self.device = torch.device(f"cuda:{device_id}") + self.device = get_device(device_id) self.config = config self.rank = rank self.use_usp = use_usp @@ -90,6 +104,7 @@ class WanI2V: self.vae = WanVAE( vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), device=self.device) + convert_conv3d_weight_memory_format(self.vae.model, memory_format=torch.channels_last_3d) self.clip = CLIPModel( dtype=config.clip_dtype, @@ -121,7 +136,8 @@ class WanI2V: self.sp_size = 1 if dist.is_initialized(): - dist.barrier() + # dist.barrier() + pass if dit_fsdp: self.model = shard_fn(self.model) else: @@ -129,6 +145,7 @@ class WanI2V: self.model.to(self.device) self.sample_neg_prompt = config.sample_neg_prompt + self.profiler = profiler def generate(self, input_prompt, @@ -178,6 +195,11 @@ class WanI2V: - H: Frame height (from max_area) - W: Frame width from max_area) """ + start_time = 0.0 + end_time = 0.0 + if self.rank == 0: + start_time = perf_counter() + img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) F = frame_num @@ -248,6 +270,10 @@ class WanI2V: ])[0] y = torch.concat([msk, y]) + if self.rank == 0: + end_time = perf_counter() + logging.info(f"[preprocess and VAE encode] Elapsed time: {end_time - start_time:.2f} seconds") + @contextmanager def noop_no_sync(): yield @@ -296,10 +322,16 @@ class WanI2V: } if offload_model: - torch.cuda.empty_cache() + empty_cache() + + if self.rank == 0: + start_time = perf_counter() self.model.to(self.device) for _, t in enumerate(tqdm(timesteps)): + if self.profiler and self.rank == 0: + self.profiler.step() + latent_model_input = [latent.to(self.device)] timestep = [t] @@ -309,12 +341,12 @@ class WanI2V: latent_model_input, t=timestep, **arg_c)[0].to( torch.device('cpu') if offload_model else self.device) if offload_model: - torch.cuda.empty_cache() + empty_cache() noise_pred_uncond = self.model( latent_model_input, t=timestep, **arg_null)[0].to( torch.device('cpu') if offload_model else self.device) if offload_model: - torch.cuda.empty_cache() + empty_cache() noise_pred = noise_pred_uncond + guide_scale * ( noise_pred_cond - noise_pred_uncond) @@ -329,22 +361,30 @@ class WanI2V: generator=seed_g)[0] latent = temp_x0.squeeze(0) - x0 = [latent.to(self.device)] - del latent_model_input, timestep + if self.rank == 0: + end_time = perf_counter() + logging.info(f"[sampling time steps] Elapsed time: {end_time - start_time:.2f} seconds") + + x0 = [latent.to(self.device)] + del latent_model_input, timestep if offload_model: self.model.cpu() - torch.cuda.empty_cache() + empty_cache() if self.rank == 0: + start_time = perf_counter() videos = self.vae.decode(x0) + end_time = perf_counter() + logging.info(f"[VAE decoding] Elapsed time: {end_time - start_time:.2f} seconds") del noise, latent del sample_scheduler if offload_model: gc.collect() - torch.cuda.synchronize() + synchronize() if dist.is_initialized(): - dist.barrier() + # dist.barrier() + pass return videos[0] if self.rank == 0 else None diff --git a/wan/modules/attention.py b/wan/modules/attention.py index 4dbbe03..796c633 100644 --- a/wan/modules/attention.py +++ b/wan/modules/attention.py @@ -1,4 +1,6 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import warnings + import torch try: @@ -13,7 +15,14 @@ try: except ModuleNotFoundError: FLASH_ATTN_2_AVAILABLE = False -import warnings +try: + import torch_musa + FLASH_ATTN_3_AVAILABLE = False + FLASH_ATTN_2_AVAILABLE = False + torch.backends.mudnn.allow_tf32 = True +except ModuleNotFoundError: + torch_musa = None + __all__ = [ 'flash_attention', @@ -51,7 +60,7 @@ def flash_attention( """ half_dtypes = (torch.float16, torch.bfloat16) assert dtype in half_dtypes - assert q.device.type == 'cuda' and q.size(-1) <= 256 + assert q.device.type in ("cuda", "musa") and q.size(-1) <= 256 # params b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype @@ -172,8 +181,9 @@ def attention( k = k.transpose(1, 2).to(dtype) v = v.transpose(1, 2).to(dtype) - out = torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p) + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False): + out = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=causal, scale=softmax_scale) out = out.transpose(1, 2).contiguous() return out diff --git a/wan/modules/clip.py b/wan/modules/clip.py index 42dda04..496ed58 100644 --- a/wan/modules/clip.py +++ b/wan/modules/clip.py @@ -6,12 +6,20 @@ import math import torch import torch.nn as nn import torch.nn.functional as F +import torch.cuda.amp as amp import torchvision.transforms as T from .attention import flash_attention from .tokenizers import HuggingfaceTokenizer from .xlm_roberta import XLMRoberta +try: + import torch_musa + import torch_musa.core.amp as amp + from .attention import attention as flash_attention +except ModuleNotFoundError: + torch_musa = None + __all__ = [ 'XLMRobertaCLIP', 'clip_xlm_roberta_vit_h_14', @@ -29,7 +37,7 @@ def pos_interpolate(pos, seq_len): return torch.cat([ pos[:, :n], F.interpolate( - pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute( + pos[:, n:].reshape(1, src_grid, src_grid, -1).permute( 0, 3, 1, 2), size=(tar_grid, tar_grid), mode='bicubic', @@ -44,12 +52,6 @@ class QuickGELU(nn.Module): return x * torch.sigmoid(1.702 * x) -class LayerNorm(nn.LayerNorm): - - def forward(self, x): - return super().forward(x.float()).type_as(x) - - class SelfAttention(nn.Module): def __init__(self, @@ -82,7 +84,7 @@ class SelfAttention(nn.Module): # compute attention p = self.attn_dropout if self.training else 0.0 - x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2) + x = flash_attention(q, k, v, dropout_p=p, causal=self.causal) x = x.reshape(b, s, c) # output @@ -131,10 +133,10 @@ class AttentionBlock(nn.Module): self.norm_eps = norm_eps # layers - self.norm1 = LayerNorm(dim, eps=norm_eps) + self.norm1 = nn.LayerNorm(dim, eps=norm_eps) self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout) - self.norm2 = LayerNorm(dim, eps=norm_eps) + self.norm2 = nn.LayerNorm(dim, eps=norm_eps) if activation == 'swi_glu': self.mlp = SwiGLU(dim, int(dim * mlp_ratio)) else: @@ -177,7 +179,7 @@ class AttentionPool(nn.Module): self.to_q = nn.Linear(dim, dim) self.to_kv = nn.Linear(dim, dim * 2) self.proj = nn.Linear(dim, dim) - self.norm = LayerNorm(dim, eps=norm_eps) + self.norm = nn.LayerNorm(dim, eps=norm_eps) self.mlp = nn.Sequential( nn.Linear(dim, int(dim * mlp_ratio)), QuickGELU() if activation == 'quick_gelu' else nn.GELU(), @@ -259,13 +261,13 @@ class VisionTransformer(nn.Module): self.dropout = nn.Dropout(embedding_dropout) # transformer - self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None + self.pre_norm = nn.LayerNorm(dim, eps=norm_eps) if pre_norm else None self.transformer = nn.Sequential(*[ AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False, activation, attn_dropout, proj_dropout, norm_eps) for _ in range(num_layers) ]) - self.post_norm = LayerNorm(dim, eps=norm_eps) + self.post_norm = nn.LayerNorm(dim, eps=norm_eps) # head if pool_type == 'token': @@ -537,6 +539,6 @@ class CLIPModel: videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5)) # forward - with torch.cuda.amp.autocast(dtype=self.dtype): + with amp.autocast(dtype=self.dtype): out = self.model.visual(videos, use_31_block=True) return out diff --git a/wan/modules/model.py b/wan/modules/model.py index a5425da..7327c6c 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -7,7 +7,16 @@ import torch.nn as nn from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin -from .attention import flash_attention +from wan.modules.attention import flash_attention +from wan.modules.rope import rope_apply_pytorch + +try: + import torch_musa + import torch_musa.core.amp as amp + from wan.modules.attention import attention as flash_attention + torch.backends.mudnn.allow_tf32 = True +except ModuleNotFoundError: + torch_musa = None __all__ = ['WanModel'] @@ -19,7 +28,7 @@ def sinusoidal_embedding_1d(dim, position): # preprocess assert dim % 2 == 0 half = dim // 2 - position = position.type(torch.float64) + position = position.type(torch.bfloat16) # calculation sinusoid = torch.outer( @@ -29,14 +38,33 @@ def sinusoidal_embedding_1d(dim, position): @amp.autocast(enabled=False) -def rope_params(max_seq_len, dim, theta=10000): +def rope_params_real( + max_seq_len, dim, theta=10000, dtype=torch.float32, device=torch.device("cpu") +): assert dim % 2 == 0 - freqs = torch.outer( - torch.arange(max_seq_len), - 1.0 / torch.pow(theta, - torch.arange(0, dim, 2).to(torch.float64).div(dim))) - freqs = torch.polar(torch.ones_like(freqs), freqs) - return freqs + freqs_real = torch.outer( + torch.arange(max_seq_len, dtype=dtype, device=device), + 1.0 + / torch.pow( + theta, torch.arange(0, dim, 2, dtype=dtype, device=device).div(dim) + ), + ) + return torch.cos(freqs_real) + + +@amp.autocast(enabled=False) +def rope_params_imag( + max_seq_len, dim, theta=10000, dtype=torch.float32, device=torch.device("cpu") +): + assert dim % 2 == 0 + freqs_imag = torch.outer( + torch.arange(max_seq_len, dtype=dtype, device=device), + 1.0 + / torch.pow( + theta, torch.arange(0, dim, 2, dtype=dtype, device=device).div(dim) + ), + ) + return torch.sin(freqs_imag) @amp.autocast(enabled=False) @@ -89,19 +117,6 @@ class WanRMSNorm(nn.Module): return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) -class WanLayerNorm(nn.LayerNorm): - - def __init__(self, dim, eps=1e-6, elementwise_affine=False): - super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) - - def forward(self, x): - r""" - Args: - x(Tensor): Shape [B, L, C] - """ - return super().forward(x.float()).type_as(x) - - class WanSelfAttention(nn.Module): def __init__(self, @@ -145,10 +160,16 @@ class WanSelfAttention(nn.Module): return q, k, v q, k, v = qkv_fn(x) + if torch_musa is None: + q = rope_apply(q, grid_sizes, freqs) + k = rope_apply(k, grid_sizes, freqs) + else: + q = rope_apply_pytorch(q, grid_sizes, freqs) + k = rope_apply_pytorch(k, grid_sizes, freqs) x = flash_attention( - q=rope_apply(q, grid_sizes, freqs), - k=rope_apply(k, grid_sizes, freqs), + q=q, + k=k, v=v, k_lens=seq_lens, window_size=self.window_size) @@ -256,10 +277,10 @@ class WanAttentionBlock(nn.Module): self.eps = eps # layers - self.norm1 = WanLayerNorm(dim, eps) + self.norm1 = nn.LayerNorm(dim, eps, elementwise_affine=False) self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps) - self.norm3 = WanLayerNorm( + self.norm3 = nn.LayerNorm( dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim, @@ -267,7 +288,7 @@ class WanAttentionBlock(nn.Module): (-1, -1), qk_norm, eps) - self.norm2 = WanLayerNorm(dim, eps) + self.norm2 = nn.LayerNorm(dim, eps, elementwise_affine=False) self.ffn = nn.Sequential( nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'), nn.Linear(ffn_dim, dim)) @@ -293,24 +314,19 @@ class WanAttentionBlock(nn.Module): grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] """ - assert e.dtype == torch.float32 - with amp.autocast(dtype=torch.float32): - e = (self.modulation + e).chunk(6, dim=1) - assert e[0].dtype == torch.float32 + e = (self.modulation + e).chunk(6, dim=1) # self-attention y = self.self_attn( - self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes, + self.norm1(x) * (1 + e[1]) + e[0], seq_lens, grid_sizes, freqs) - with amp.autocast(dtype=torch.float32): - x = x + y * e[2] + x = x + y * e[2] # cross-attention & ffn function def cross_attn_ffn(x, context, context_lens, e): x = x + self.cross_attn(self.norm3(x), context, context_lens) - y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3]) - with amp.autocast(dtype=torch.float32): - x = x + y * e[5] + y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3]) + x = x + y * e[5] return x x = cross_attn_ffn(x, context, context_lens, e) @@ -328,7 +344,7 @@ class Head(nn.Module): # layers out_dim = math.prod(patch_size) * out_dim - self.norm = WanLayerNorm(dim, eps) + self.norm = nn.LayerNorm(dim, eps, elementwise_affine=False) self.head = nn.Linear(dim, out_dim) # modulation @@ -340,10 +356,8 @@ class Head(nn.Module): x(Tensor): Shape [B, L1, C] e(Tensor): Shape [B, C] """ - assert e.dtype == torch.float32 - with amp.autocast(dtype=torch.float32): - e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) - x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) + e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) + x = self.head(self.norm(x) * (1 + e[1]) + e[0]) return x @@ -477,12 +491,33 @@ class WanModel(ModelMixin, ConfigMixin): # buffers (don't use register_buffer otherwise dtype will be changed in to()) assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 d = dim // num_heads - self.freqs = torch.cat([ - rope_params(1024, d - 4 * (d // 6)), - rope_params(1024, 2 * (d // 6)), - rope_params(1024, 2 * (d // 6)) - ], - dim=1) + if torch_musa is None: + self.freqs = torch.cat( + [ + rope_params(1024, d - 4 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + ], + dim=1, + ) + else: + freqs_real = torch.cat( + [ + rope_params_real(1024, d - 4 * (d // 6)), + rope_params_real(1024, 2 * (d // 6)), + rope_params_real(1024, 2 * (d // 6)), + ], + dim=1, + ) + freqs_imag = torch.cat( + [ + rope_params_imag(1024, d - 4 * (d // 6)), + rope_params_imag(1024, 2 * (d // 6)), + rope_params_imag(1024, 2 * (d // 6)), + ], + dim=1, + ) + self.freqs = (freqs_real, freqs_imag) if model_type == 'i2v' or model_type == 'flf2v': self.img_emb = MLPProj(1280, dim, flf_pos_emb=model_type == 'flf2v') @@ -523,9 +558,17 @@ class WanModel(ModelMixin, ConfigMixin): if self.model_type == 'i2v' or self.model_type == 'flf2v': assert clip_fea is not None and y is not None # params + dtype = self.patch_embedding.weight.dtype device = self.patch_embedding.weight.device - if self.freqs.device != device: - self.freqs = self.freqs.to(device) + if torch_musa is None: + if self.freqs.dtype != dtype or self.freqs.device != device: + self.freqs = self.freqs.to(dtype=dtype, device=device) + else: + if self.freqs[0].dtype != dtype or self.freqs[0].device != device: + self.freqs = ( + self.freqs[0].to(dtype=dtype, device=device), + self.freqs[-1].to(dtype=dtype, device=device), + ) if y is not None: x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] @@ -543,11 +586,9 @@ class WanModel(ModelMixin, ConfigMixin): ]) # time embeddings - with amp.autocast(dtype=torch.float32): - e = self.time_embedding( - sinusoidal_embedding_1d(self.freq_dim, t).float()) - e0 = self.time_projection(e).unflatten(1, (6, self.dim)) - assert e.dtype == torch.float32 and e0.dtype == torch.float32 + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t)) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # context context_lens = None @@ -579,7 +620,7 @@ class WanModel(ModelMixin, ConfigMixin): # unpatchify x = self.unpatchify(x, grid_sizes) - return [u.float() for u in x] + return x def unpatchify(self, x, grid_sizes): r""" diff --git a/wan/modules/rope.py b/wan/modules/rope.py new file mode 100644 index 0000000..ca204ab --- /dev/null +++ b/wan/modules/rope.py @@ -0,0 +1,317 @@ +from typing import Optional, Tuple + +import triton +import triton.language as tl +import torch + + +def pad_tensor( + original_tensor: torch.tensor, target_len: int, pad_value: float = 0.0 +) -> torch.tensor: + seq_len, s1, s2 = original_tensor.shape + pad_size = target_len - seq_len + padding_tensor = torch.full( + (pad_size, s1, s2), + pad_value, + dtype=original_tensor.dtype, + device=original_tensor.device, + ) + padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) + return padded_tensor + + +def rope_apply_pytorch( + x: torch.tensor, + grid_sizes: torch.tensor, + freqs: Tuple[torch.tensor], + sp_size: Optional[int] = None, + sp_rank: Optional[int] = None, +) -> torch.tensor: + """ + x: [B, L, N, C]. + grid_sizes: [B, 3]. + freqs: [M, C // 2]. + """ + s, n, c = x.size(1), x.size(2), x.size(3) // 2 + c0 = c - 2 * (c // 3) + c1 = c // 3 + c2 = c // 3 + + # split freqs + freqs_real = freqs[0].split([c0, c1, c2], dim=1) + freqs_imag = freqs[-1].split([c0, c1, c2], dim=1) + + # loop over samples + output = [] + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + seq_len = f * h * w + + # precompute multipliers + x_i = x[i, :seq_len].reshape(s, n, -1, 2) + x_real = x_i[..., 0] + x_imag = x_i[..., 1] + freqs_real = torch.cat( + [ + freqs_real[0][:f].view(f, 1, 1, c0).expand(f, h, w, c0), + freqs_real[1][:h].view(1, h, 1, c1).expand(f, h, w, c1), + freqs_real[2][:w].view(1, 1, w, c2).expand(f, h, w, c2), + ], + dim=-1, + ).reshape(seq_len, 1, -1) + freqs_imag = torch.cat( + [ + freqs_imag[0][:f].view(f, 1, 1, c0).expand(f, h, w, c0), + freqs_imag[1][:h].view(1, h, 1, c1).expand(f, h, w, c1), + freqs_imag[2][:w].view(1, 1, w, c2).expand(f, h, w, c2), + ], + dim=-1, + ).reshape(seq_len, 1, -1) + + if sp_rank is None: + freqs_real_rank = freqs_real + freqs_imag_rank = freqs_imag + else: + freqs_real = pad_tensor(freqs_real, s * sp_size, 1.0) + freqs_imag = pad_tensor(freqs_imag, s * sp_size, 0.0) + freqs_real_rank = freqs_real[(sp_rank * s) : ((sp_rank + 1) * s), :, :] + freqs_imag_rank = freqs_imag[(sp_rank * s) : ((sp_rank + 1) * s), :, :] + + out_real = x_real * freqs_real_rank - x_imag * freqs_imag_rank + out_imag = x_real * freqs_imag_rank + x_imag * freqs_real_rank + + x_out = torch.stack([out_real, out_imag], dim=-1).flatten(2) + x_out = torch.cat([x_out, x[i, seq_len:]], dim=0) + + # append to collection + output.append(x_out) + + return torch.stack(output) + + +@triton.jit +def rope_kernel( + x_ptr, # [B, S, N, 2C] + grid_sizes_ptr, # [B, 3] + freqs_real_ptr, # [M, C] + freqs_imag_ptr, # [M, C] + output_ptr, # [B, S, N, 2C] + sp_size, # SP world size + sp_rank, # SP rank + B, + S, + N: tl.constexpr, + C: tl.constexpr, + M: tl.constexpr, + CfM: tl.constexpr, + ChM: tl.constexpr, + CwM: tl.constexpr, + SEQ_BLOCK: tl.constexpr, + HEADS_BLOCK: tl.constexpr, +): + Cf = C - 2 * (C // 3) + Ch = C // 3 + Cw = C // 3 + + batch_idx = tl.program_id(0) + seqlen_group_idx = tl.program_id(1) + head_group_idx = tl.program_id(2) + + base = batch_idx * 3 + F = tl.load(grid_sizes_ptr + base + 0) + H = tl.load(grid_sizes_ptr + base + 1) + W = tl.load(grid_sizes_ptr + base + 2) + seq_len = F * H * W + + global_offset = sp_rank * S + seqlen_group_idx * SEQ_BLOCK + seq_indices = global_offset + tl.arange(0, SEQ_BLOCK) + + limit = tl.minimum(seq_len, S * sp_size) + seq_mask = seq_indices < limit + seq_indices = tl.where(seq_mask, seq_indices, 0) + + HW = H * W + f_idx = seq_indices // HW + rem = seq_indices - f_idx * HW + h_idx = rem // W + w_idx = rem - h_idx * W + + freq_offset_cf = tl.arange(0, CfM) # 第1段列偏移 [0, Cf) + freq_offset_ch = Cf + tl.arange(0, ChM) # 第2段列偏移 [Cf, Cf+Ch) + freq_offset_cw = Cf + Ch + tl.arange(0, CwM) # 第3段列偏移 [Cf+Ch, C) + # 按照每个序列位置取对应频率值 (利用广播计算每个位置不同行的值) + # 频率表取值地址 = idx * C + col_offset + freq_addr_cf = f_idx[:, None] * C + freq_offset_cf[None, :] + freq_addr_ch = h_idx[:, None] * C + freq_offset_ch[None, :] + freq_addr_cw = w_idx[:, None] * C + freq_offset_cw[None, :] + + freqs_real_cf = tl.load( + freqs_real_ptr + freq_addr_cf, + mask=(seq_mask[:, None] & (freq_offset_cf[None, :] < Cf)), + other=1.0, + ).to(tl.float32) + freqs_imag_cf = tl.load( + freqs_imag_ptr + freq_addr_cf, + mask=(seq_mask[:, None] & (freq_offset_cf[None, :] < Cf)), + other=1.0, + ).to(tl.float32) + freqs_real_ch = tl.load( + freqs_real_ptr + freq_addr_ch, + mask=(seq_mask[:, None] & (freq_offset_ch[None, :] < Cf + Ch)), + other=1.0, + ).to(tl.float32) + freqs_imag_ch = tl.load( + freqs_imag_ptr + freq_addr_ch, + mask=(seq_mask[:, None] & (freq_offset_ch[None, :] < Cf + Ch)), + other=1.0, + ).to(tl.float32) + freqs_real_cw = tl.load( + freqs_real_ptr + freq_addr_cw, + mask=(seq_mask[:, None] & (freq_offset_cw[None, :] < C)), + other=1.0, + ).to(tl.float32) + freqs_imag_cw = tl.load( + freqs_imag_ptr + freq_addr_cw, + mask=(seq_mask[:, None] & (freq_offset_cw[None, :] < C)), + other=1.0, + ).to(tl.float32) + + # 将频率值扩展维度以便与x相乘 (在head维度上广播) + freqs_real_cf = freqs_real_cf[:, None, :] # [SEQ_BLOCK, 1, Cf] + freqs_imag_cf = freqs_imag_cf[:, None, :] + freqs_real_ch = freqs_real_ch[:, None, :] + freqs_imag_ch = freqs_imag_ch[:, None, :] + freqs_real_cw = freqs_real_cw[:, None, :] + freqs_imag_cw = freqs_imag_cw[:, None, :] + + # 加载输入x对应块的实部和虚部 (形状: [SEQ_BLOCK, HEADS_BLOCK, C]) + seq_offset = seqlen_group_idx * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK) + head_offset = head_group_idx * HEADS_BLOCK + tl.arange(0, HEADS_BLOCK) + # 计算x_ptr偏移地址 + base_offset = batch_idx * S * N * 2 * C + seq_head_offset = ( + base_offset + + seq_offset[:, None, None] * (N * 2 * C) + + head_offset[None, :, None] * (2 * C) + ) + x_mask = (seq_offset < S)[:, None, None] & (head_offset < N)[None, :, None] + + # 加载输入 x 的对应通道段数据,超出实际长度部分掩码为0 + # 段1:通道 [0, Cf-1] + chan_cf = tl.arange(0, CfM * 2) + mask_2cf_chan = chan_cf < Cf * 2 + x_cf = tl.load( + x_ptr + seq_head_offset + chan_cf[None, None, :], + mask=(x_mask & mask_2cf_chan[None, None, :]), + other=0.0, + ).to(tl.float32) + x_cf = x_cf.reshape( + SEQ_BLOCK, HEADS_BLOCK, CfM, 2 + ) # [SEQ_BLOCK, HEADS_BLOCK, CfM, 2] + x_real_cf, x_imag_cf = x_cf.split() + + # 计算 RoPE 旋转(段1) + out_real_cf = x_real_cf * freqs_real_cf - x_imag_cf * freqs_imag_cf + out_imag_cf = x_real_cf * freqs_imag_cf + x_imag_cf * freqs_real_cf + + out_cf = tl.interleave(out_real_cf, out_imag_cf) # [SEQ_BLOCK, HEADS_BLOCK, CfM, 2] + tl.store( + output_ptr + seq_head_offset + chan_cf[None, None, :], + out_cf, + mask=(x_mask & mask_2cf_chan[None, None, :]), + ) + + # 段2:通道 [Cf, Cf+Ch-1] + chan_ch = tl.arange(0, ChM * 2) + Cf * 2 + mask_2ch_chan = chan_ch < 2 * (Cf + Ch) + x_ch = tl.load( + x_ptr + seq_head_offset + chan_ch[None, None, :], + mask=(x_mask & mask_2ch_chan[None, None, :]), + other=0.0, + ).to(tl.float32) + x_ch = x_ch.reshape(SEQ_BLOCK, HEADS_BLOCK, ChM, 2) + x_real_ch, x_imag_ch = x_ch.split() + out_real_ch = x_real_ch * freqs_real_ch - x_imag_ch * freqs_imag_ch + out_imag_ch = x_real_ch * freqs_imag_ch + x_imag_ch * freqs_real_ch + + out_ch = tl.interleave(out_real_ch, out_imag_ch) # [SEQ_BLOCK, HEADS_BLOCK, ChM, 2] + tl.store( + output_ptr + seq_head_offset + chan_ch[None, None, :], + out_ch, + mask=(x_mask & mask_2ch_chan[None, None, :]), + ) + + # 段3:通道 [Cf+Ch, C-1] + chan_cw = tl.arange(0, CwM * 2) + (Cf + Ch) * 2 + mask_2cw_chan = chan_cw < 2 * C + x_cw = tl.load( + x_ptr + seq_head_offset + chan_cw[None, None, :], + mask=(x_mask & mask_2cw_chan[None, None, :]), + other=0.0, + ).to(tl.float32) + x_cw = x_cw.reshape(SEQ_BLOCK, HEADS_BLOCK, CwM, 2) + x_real_cw, x_imag_cw = x_cw.split() + out_real_cw = x_real_cw * freqs_real_cw - x_imag_cw * freqs_imag_cw + out_imag_cw = x_real_cw * freqs_imag_cw + x_imag_cw * freqs_real_cw + + out_cw = tl.interleave(out_real_cw, out_imag_cw) + tl.store( + output_ptr + seq_head_offset + chan_cw[None, None, :], + out_cw, + mask=(x_mask & mask_2cw_chan[None, None, :]), + ) + + +@torch._dynamo.disable +def rope_apply_triton( + x: torch.tensor, + grid_sizes: torch.tensor, + freqs: Tuple[torch.tensor], + sp_size: Optional[int] = None, + sp_rank: Optional[int] = None, +) -> torch.tensor: + """ + x: [1, 9450, 40, 128] + grid_sizes: [[21, 45, 80]] + freqs_real: [1024, 64] + freqs_imag: [1024, 64] + """ + B, S, N, C = x.shape + C = C // 2 + Cf = C - 2 * (C // 3) # 第一维度频率长度 + Ch = C // 3 # 第二维度频率长度 + Cw = C // 3 # 第三维度频率长度 + M = freqs[0].shape[0] + + SEQ_BLOCK = 64 # 每个线程块处理的序列长度 + HEADS_BLOCK = 8 # 每个线程块处理的头数 + + if sp_rank is None: + sp_size = 1 + sp_rank = 0 + + grid_sizes = grid_sizes.to(device=x.device) + output = torch.empty_like(x) + + rope_kernel[(B, triton.cdiv(S, SEQ_BLOCK), triton.cdiv(N, HEADS_BLOCK))]( + x, + grid_sizes, + freqs[0], + freqs[-1], + output, + sp_size, + sp_rank, + B, + S, + N=N, + C=C, + M=M, + CfM=triton.next_power_of_2(Cf), + ChM=triton.next_power_of_2(Ch), + CwM=triton.next_power_of_2(Cw), + SEQ_BLOCK=SEQ_BLOCK, + HEADS_BLOCK=HEADS_BLOCK, + num_warps=32, + num_stages=3, + ) + + return output.float() diff --git a/wan/modules/t5.py b/wan/modules/t5.py index c841b04..80ded5a 100644 --- a/wan/modules/t5.py +++ b/wan/modules/t5.py @@ -7,7 +7,13 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .tokenizers import HuggingfaceTokenizer +try: + import torch_musa +except ModuleNotFoundError: + torch_musa = None + +from wan.modules.tokenizers import HuggingfaceTokenizer +from wan.utils.platform import get_device __all__ = [ 'T5Model', @@ -59,10 +65,8 @@ class T5LayerNorm(nn.Module): self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): - x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + + x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) - if self.weight.dtype in [torch.float16, torch.bfloat16]: - x = x.type_as(self.weight) return self.weight * x @@ -110,7 +114,7 @@ class T5Attention(nn.Module): # compute attention (T5 does not use scaling) attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias - attn = F.softmax(attn.float(), dim=-1).type_as(attn) + attn = F.softmax(attn, dim=-1) x = torch.einsum('bnij,bjnc->binc', attn, v) # output @@ -255,7 +259,7 @@ class T5RelativeEmbedding(nn.Module): # embeddings for small and large positions max_exact = num_buckets // 2 - rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / + rel_pos_large = max_exact + (torch.log(rel_pos / max_exact) / math.log(self.max_dist / max_exact) * (num_buckets - max_exact)).long() rel_pos_large = torch.min( @@ -475,7 +479,7 @@ class T5EncoderModel: self, text_len, dtype=torch.bfloat16, - device=torch.cuda.current_device(), + device=get_device(), checkpoint_path=None, tokenizer_path=None, shard_fn=None, diff --git a/wan/modules/vace_model.py b/wan/modules/vace_model.py index a12d1dd..8913c2e 100644 --- a/wan/modules/vace_model.py +++ b/wan/modules/vace_model.py @@ -4,6 +4,12 @@ import torch.cuda.amp as amp import torch.nn as nn from diffusers.configuration_utils import register_to_config +try: + import torch_musa + import torch_musa.core.amp as amp +except ModuleNotFoundError: + torch_musa = None + from .model import WanAttentionBlock, WanModel, sinusoidal_embedding_1d diff --git a/wan/modules/vae.py b/wan/modules/vae.py index 5c6da57..d9713cb 100644 --- a/wan/modules/vae.py +++ b/wan/modules/vae.py @@ -1,12 +1,21 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import logging +from math import sqrt import torch -import torch.cuda.amp as amp import torch.nn as nn import torch.nn.functional as F +from torch.nn import Upsample from einops import rearrange +try: + import torch_musa + torch.backends.mudnn.allow_tf32 = True +except ModuleNotFoundError: + torch_musa = None + +from wan.utils.platform import get_device_type + __all__ = [ 'WanVAE', ] @@ -44,23 +53,17 @@ class RMS_norm(nn.Module): shape = (dim, *broadcastable_dims) if channel_first else (dim,) self.channel_first = channel_first - self.scale = dim**0.5 + self.scale = sqrt(dim) self.gamma = nn.Parameter(torch.ones(shape)) self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0. def forward(self, x): - return F.normalize( - x, dim=(1 if self.channel_first else - -1)) * self.scale * self.gamma + self.bias - - -class Upsample(nn.Upsample): - - def forward(self, x): - """ - Fix bfloat16 support for nearest neighbor interpolation. - """ - return super().forward(x.float()).type_as(x) + return ( + F.normalize(x.float(), dim=(1 if self.channel_first else -1)).type_as(x) + * self.scale + * self.gamma + + self.bias + ) class Resample(nn.Module): @@ -253,6 +256,10 @@ class AttentionBlock(nn.Module): q, k, v, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, ) x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) @@ -621,8 +628,8 @@ class WanVAE: def __init__(self, z_dim=16, vae_pth='cache/vae_step_411000.pth', - dtype=torch.float, - device="cuda"): + dtype=torch.bfloat16, + device=get_device_type()): self.dtype = dtype self.device = device @@ -648,16 +655,12 @@ class WanVAE: """ videos: A list of videos each with shape [C, T, H, W]. """ - with amp.autocast(dtype=self.dtype): - return [ - self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) - for u in videos - ] + return [ + self.model.encode(u.unsqueeze(0), self.scale).squeeze(0) for u in videos + ] def decode(self, zs): - with amp.autocast(dtype=self.dtype): - return [ - self.model.decode(u.unsqueeze(0), - self.scale).float().clamp_(-1, 1).squeeze(0) - for u in zs - ] + return [ + self.model.decode(u.unsqueeze(0), self.scale).clamp_(-1, 1).squeeze(0) + for u in zs + ] diff --git a/wan/text2video.py b/wan/text2video.py index c518b61..acf5bd4 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -6,24 +6,37 @@ import os import random import sys import types +from time import perf_counter from contextlib import contextmanager from functools import partial import torch import torch.cuda.amp as amp +from torch.cuda import empty_cache, synchronize import torch.distributed as dist from tqdm import tqdm -from .distributed.fsdp import shard_model -from .modules.model import WanModel -from .modules.t5 import T5EncoderModel -from .modules.vae import WanVAE -from .utils.fm_solvers import ( +try: + import torch_musa + import torch_musa.core.amp as amp + from torch_musa.core.memory import empty_cache + from torch_musa.core.device import synchronize + torch.backends.mudnn.allow_tf32 = True +except ModuleNotFoundError: + torch_musa = None + +from wan.distributed.fsdp import shard_model +from wan.modules.model import WanModel +from wan.modules.t5 import T5EncoderModel +from wan.modules.vae import WanVAE +from wan.utils.fm_solvers import ( FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps, ) -from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from wan.utils.platform import get_device +from wan.utils.memory_format import convert_conv3d_weight_memory_format class WanT2V: @@ -38,6 +51,7 @@ class WanT2V: dit_fsdp=False, use_usp=False, t5_cpu=False, + profiler=None, ): r""" Initializes the Wan text-to-video generation model components. @@ -60,7 +74,7 @@ class WanT2V: t5_cpu (`bool`, *optional*, defaults to False): Whether to place T5 model on CPU. Only works without t5_fsdp. """ - self.device = torch.device(f"cuda:{device_id}") + self.device = get_device(device_id) self.config = config self.rank = rank self.t5_cpu = t5_cpu @@ -82,6 +96,7 @@ class WanT2V: self.vae = WanVAE( vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), device=self.device) + convert_conv3d_weight_memory_format(self.vae.model, memory_format=torch.channels_last_3d) logging.info(f"Creating WanModel from {checkpoint_dir}") self.model = WanModel.from_pretrained(checkpoint_dir) @@ -103,13 +118,15 @@ class WanT2V: self.sp_size = 1 if dist.is_initialized(): - dist.barrier() + # dist.barrier() + pass if dit_fsdp: self.model = shard_fn(self.model) else: self.model.to(self.device) self.sample_neg_prompt = config.sample_neg_prompt + self.profiler = profiler def generate(self, input_prompt, @@ -155,6 +172,11 @@ class WanT2V: - H: Frame height (from size) - W: Frame width from size) """ + start_time = 0.0 + end_time = 0.0 + if self.rank == 0: + start_time = perf_counter() + # preprocess F = frame_num target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, @@ -194,6 +216,10 @@ class WanT2V: generator=seed_g) ] + if self.rank == 0: + end_time = perf_counter() + logging.info(f"[preprocess] Elapsed time: {end_time - start_time:.2f} seconds") + @contextmanager def noop_no_sync(): yield @@ -230,13 +256,19 @@ class WanT2V: arg_c = {'context': context, 'seq_len': seq_len} arg_null = {'context': context_null, 'seq_len': seq_len} + if self.rank == 0: + start_time = perf_counter() + + self.model.to(self.device) for _, t in enumerate(tqdm(timesteps)): + if self.profiler and self.rank == 0: + self.profiler.step() + latent_model_input = latents timestep = [t] timestep = torch.stack(timestep) - self.model.to(self.device) noise_pred_cond = self.model( latent_model_input, t=timestep, **arg_c)[0] noise_pred_uncond = self.model( @@ -253,19 +285,27 @@ class WanT2V: generator=seed_g)[0] latents = [temp_x0.squeeze(0)] + if self.rank == 0: + end_time = perf_counter() + logging.info(f"[sampling time steps] Elapsed time: {end_time - start_time:.2f} seconds") + x0 = latents if offload_model: self.model.cpu() - torch.cuda.empty_cache() + empty_cache() if self.rank == 0: + start_time = perf_counter() videos = self.vae.decode(x0) + end_time = perf_counter() + logging.info(f"[VAE decoding] Elapsed time: {end_time - start_time:.2f} seconds") del noise, latents del sample_scheduler if offload_model: gc.collect() - torch.cuda.synchronize() + synchronize() if dist.is_initialized(): - dist.barrier() + # dist.barrier() + pass return videos[0] if self.rank == 0 else None diff --git a/wan/utils/__init__.py b/wan/utils/__init__.py index 2e9b33d..179648c 100644 --- a/wan/utils/__init__.py +++ b/wan/utils/__init__.py @@ -5,9 +5,13 @@ from .fm_solvers import ( ) from .fm_solvers_unipc import FlowUniPCMultistepScheduler from .vace_processor import VaceVideoProcessor +from .platform import get_device, get_device_type, get_torch_distributed_backend +from .memory_format import convert_conv3d_weight_memory_format +from .chrono_inspector import ChronoInspector __all__ = [ 'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps', 'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler', - 'VaceVideoProcessor' + 'VaceVideoProcessor', 'get_device', 'get_device_type', 'get_torch_distributed_backend', + 'convert_conv3d_weight_memory_format', 'ChronoInspector' ] diff --git a/wan/utils/chrono_inspector.py b/wan/utils/chrono_inspector.py new file mode 100644 index 0000000..f3f096e --- /dev/null +++ b/wan/utils/chrono_inspector.py @@ -0,0 +1,15 @@ +from time import perf_counter +from logging import info + + +class ChronoInspector(object): + def __init__(self, name:str="Block"): + self.name = name + + def __enter__(self): + self.start_time:float = perf_counter() + return self # 可选:返回 self 以获取更多信息 + + def __exit__(self, exc_type, exc_val, exc_tb): + end_time:float = perf_counter() + info(f"[{self.name}] Elapsed time: {end_time - self.start_time:.2f} seconds") diff --git a/wan/utils/memory_format.py b/wan/utils/memory_format.py new file mode 100644 index 0000000..a4b19ba --- /dev/null +++ b/wan/utils/memory_format.py @@ -0,0 +1,76 @@ +import torch + + +def convert_conv3d_weight_memory_format(module:torch.nn.Module, memory_format:torch.memory_format): + r"""Convert ``memory_format`` of ``nn.Conv3d.weight`` to ``memory_format`` + The conversion recursively applies to nested ``nn.Module``, including ``module``. + Note that it only changes the memory_format, but not the semantics of each dimensions. + This function is used to facilitate the computation to adopt NHWC kernels, which + provides considerable speed up for fp16 data on CUDA devices with compute capability >= 7.0 + + .. note:: + Calling ``model.to(memory_format=torch.channels_last_3d)`` is more aggressive + than the utility function ``convert_conv3d_weight_memory_format``. Any + layer with 4d weight will be affected by ``model.to``, which does not + necessarily benefit from conversion to specified ``memory_format``. + One place we are confident in is that NDHWC(channels_last_3d) conversion for + convolution in cuDNN, as it is beneficial to run convolution in NDHWC, + even in cases where we have to apply permutation to input tensors. + + Hence our strategy here is to convert only the weight of convolution to + channels_last_3d. This ensures that; + 1. Fast convolution kernels will be used, the benefit of which could + outweigh overhead of permutation (if input is not in the same format). + 2. No unnecessary permutations are applied on layers that do not benefit + from memory_format conversion. + + The optimal case is that, layers between convolution layers are channels + last compatible. Input tensor would be permuted to channels last when it + encounters the first convolution layer and stay in that memory format. + Hence following convolutions will not need to permute its input tensor. + + In case where a channels last incompatible layer is between convolution + layers, we need to permute the input tensor back to contiguous format + for that layer. The input tensor will go through the remaining layers in + contiguous format and be permuted to channels last when it encounters + another convolution layer. There's no point in propagating that + permutation to an earlier layer, as most layers are quite agnostic to + ``memory_format``. + + This claim might change when PyTorch supports fusion of permutation, as + there might have been a better spot to fuse the permutation other than + immediately before a convolution. + + Args: + module (nn.Module): ``nn.Conv3d`` & ``nn.ConvTranspose3d`` or container + ``nn.Module`` + memory_format: user specified ``memory_format``, + e.g. ``torch.channels_last`` or ``torch.contiguous_format`` + + Returns: + The original module with updated ``nn.Conv3d`` + + Example: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> # xdoctest: +REQUIRES(env:CUBLAS_WORKSPACE_CONFIG) + >>> input = torch.randint(1, 10, (2, 8, 4, 4, 4), dtype=torch.float16, device="cuda") + >>> model = nn.Sequential( + >>> nn.Conv3d(8, 4, 3)).cuda().half() + >>> # This is identical to: + >>> # nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d) + >>> model = nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d) + >>> out = model(input) + """ + + # TODO: expand this to `_ConvNd` when channels_last support is extended + # beyond only 4d tensors. + if isinstance(module, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)): + weight_data = ( + module.weight.detach().clone().contiguous(memory_format=memory_format) + ) + module.weight.data = weight_data.resize_( + weight_data.size(), memory_format=memory_format + ) + for child in module.children(): + convert_conv3d_weight_memory_format(child, memory_format) + return module diff --git a/wan/utils/platform.py b/wan/utils/platform.py new file mode 100644 index 0000000..8eec972 --- /dev/null +++ b/wan/utils/platform.py @@ -0,0 +1,61 @@ +from typing import Optional, List + +import torch + +try: + import torch_musa +except ModuleNotFoundError: + torch_musa = None + + +def _is_musa() -> bool: + if torch_musa is None: + return False + else: + return True + + +def get_device(local_rank: Optional[int] = None) -> torch.device: + if torch.cuda.is_available(): + return ( + torch.cuda.current_device() + if local_rank is None + else torch.device("cuda", local_rank) + ) + elif _is_musa(): + return ( + torch.musa.current_device() + if local_rank is None + else torch.device("musa", local_rank) + ) + else: + return torch.device("cpu") + + +def get_device_type() -> str: + if torch.cuda.is_available(): + return "cuda" + elif _is_musa(): + return "musa" + else: + return "cpu" + + +def get_torch_distributed_backend() -> str: + if torch.cuda.is_available(): + return "nccl" + elif _is_musa(): + return "mccl" + else: + raise NotImplementedError("No Accelerators(NV/MTT GPU) available") + + +def get_torch_profiler_activities() -> List[torch.profiler.ProfilerActivity]: + activities: List[torch.profiler.ProfilerActivity] = [ + torch.profiler.ProfilerActivity.CPU + ] + if torch.cuda.is_available(): + activities.append(torch.profiler.ProfilerActivity.CUDA) + elif _is_musa(): + activities.append(torch.profiler.ProfilerActivity.MUSA) + return activities diff --git a/wan/vace.py b/wan/vace.py index 8a4f744..98529b0 100644 --- a/wan/vace.py +++ b/wan/vace.py @@ -8,11 +8,13 @@ import sys import time import traceback import types +from time import perf_counter from contextlib import contextmanager from functools import partial import torch import torch.cuda.amp as amp +from torch.cuda import empty_cache, synchronize import torch.distributed as dist import torch.multiprocessing as mp import torch.nn.functional as F @@ -20,8 +22,17 @@ import torchvision.transforms.functional as TF from PIL import Image from tqdm import tqdm -from .modules.vace_model import VaceWanModel -from .text2video import ( +try: + import torch_musa + import torch_musa.core.amp as amp + from torch_musa.core.memory import empty_cache + from torch_musa.core.device import synchronize + torch.backends.mudnn.allow_tf32 = True +except ModuleNotFoundError: + torch_musa = None + +from wan.modules.vace_model import VaceWanModel +from wan.text2video import ( FlowDPMSolverMultistepScheduler, FlowUniPCMultistepScheduler, T5EncoderModel, @@ -31,7 +42,9 @@ from .text2video import ( retrieve_timesteps, shard_model, ) -from .utils.vace_processor import VaceVideoProcessor +from wan.utils.vace_processor import VaceVideoProcessor +from wan.utils.platform import get_device, get_torch_distributed_backend +from wan.utils.memory_format import convert_conv3d_weight_memory_format class WanVace(WanT2V): @@ -46,6 +59,7 @@ class WanVace(WanT2V): dit_fsdp=False, use_usp=False, t5_cpu=False, + profiler=None, ): r""" Initializes the Wan text-to-video generation model components. @@ -68,7 +82,7 @@ class WanVace(WanT2V): t5_cpu (`bool`, *optional*, defaults to False): Whether to place T5 model on CPU. Only works without t5_fsdp. """ - self.device = torch.device(f"cuda:{device_id}") + self.device = get_device(device_id) self.config = config self.rank = rank self.t5_cpu = t5_cpu @@ -90,6 +104,7 @@ class WanVace(WanT2V): self.vae = WanVAE( vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), device=self.device) + convert_conv3d_weight_memory_format(self.vae.model, memory_format=torch.channels_last_3d) logging.info(f"Creating VaceWanModel from {checkpoint_dir}") self.model = VaceWanModel.from_pretrained(checkpoint_dir) @@ -117,7 +132,8 @@ class WanVace(WanT2V): self.sp_size = 1 if dist.is_initialized(): - dist.barrier() + # dist.barrier() + pass if dit_fsdp: self.model = shard_fn(self.model) else: @@ -136,6 +152,8 @@ class WanVace(WanT2V): seq_len=75600, keep_last=True) + self.profiler = profiler + def vace_encode_frames(self, frames, ref_images, masks=None, vae=None): vae = self.vae if vae is None else vae if ref_images is None: @@ -340,6 +358,11 @@ class WanVace(WanT2V): - H: Frame height (from size) - W: Frame width from size) """ + start_time = 0.0 + end_time = 0.0 + if self.rank == 0: + start_time = perf_counter() + # preprocess # F = frame_num # target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, @@ -390,6 +413,10 @@ class WanVace(WanT2V): (self.patch_size[1] * self.patch_size[2]) * target_shape[1] / self.sp_size) * self.sp_size + if self.rank == 0: + end_time = perf_counter() + logging.info(f"[preprocess] Elapsed time: {end_time - start_time:.2f} seconds") + @contextmanager def noop_no_sync(): yield @@ -426,13 +453,19 @@ class WanVace(WanT2V): arg_c = {'context': context, 'seq_len': seq_len} arg_null = {'context': context_null, 'seq_len': seq_len} + if self.rank == 0: + start_time = perf_counter() + + self.model.to(self.device) for _, t in enumerate(tqdm(timesteps)): + if self.profiler and self.rank == 0: + self.profiler.step() + latent_model_input = latents timestep = [t] timestep = torch.stack(timestep) - self.model.to(self.device) noise_pred_cond = self.model( latent_model_input, t=timestep, @@ -457,20 +490,28 @@ class WanVace(WanT2V): generator=seed_g)[0] latents = [temp_x0.squeeze(0)] + if self.rank == 0: + end_time = perf_counter() + logging.info(f"[sampling time steps] Elapsed time: {end_time - start_time:.2f} seconds") + x0 = latents if offload_model: self.model.cpu() - torch.cuda.empty_cache() + empty_cache() if self.rank == 0: + start_time = perf_counter() videos = self.decode_latent(x0, input_ref_images) + end_time = perf_counter() + logging.info(f"[VAE decoding] Elapsed time: {end_time - start_time:.2f} seconds") del noise, latents del sample_scheduler if offload_model: gc.collect() - torch.cuda.synchronize() + synchronize() if dist.is_initialized(): - dist.barrier() + # dist.barrier() + pass return videos[0] if self.rank == 0 else None @@ -568,7 +609,7 @@ class WanVaceMP(WanVace): torch.cuda.set_device(gpu) dist.init_process_group( - backend='nccl', + backend=get_torch_distributed_backend(), init_method='env://', rank=rank, world_size=world_size) @@ -629,11 +670,11 @@ class WanVaceMP(WanVace): else: sp_size = 1 - dist.barrier() + # dist.barrier() model = shard_fn(model) sample_neg_prompt = self.config.sample_neg_prompt - torch.cuda.empty_cache() + empty_cache() event = initialized_events[gpu] in_q = in_q_list[gpu] event.set() @@ -748,7 +789,7 @@ class WanVaceMP(WanVace): generator=seed_g)[0] latents = [temp_x0.squeeze(0)] - torch.cuda.empty_cache() + empty_cache() x0 = latents if rank == 0: videos = self.decode_latent( @@ -758,9 +799,10 @@ class WanVaceMP(WanVace): del sample_scheduler if offload_model: gc.collect() - torch.cuda.synchronize() + synchronize() if dist.is_initialized(): - dist.barrier() + # dist.barrier() + pass if rank == 0: out_q.put(videos[0].cpu())