diff --git a/generate.py b/generate.py index c841c19..3e293e6 100644 --- a/generate.py +++ b/generate.py @@ -4,6 +4,7 @@ import logging import os import sys import warnings +from time import perf_counter from datetime import datetime warnings.filterwarnings('ignore') @@ -12,12 +13,20 @@ import random import torch import torch.distributed as dist +from torch.cuda import set_device from PIL import Image +try: + import torch_musa + from torch_musa.core.device import set_device +except ModuleNotFoundError: + pass + import wan from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander from wan.utils.utils import cache_image, cache_video, str2bool +from wan.utils.platform import get_torch_distributed_backend EXAMPLE_PROMPT = { @@ -275,9 +284,9 @@ def generate(args): logging.info( f"offload_model is not specified, set to {args.offload_model}.") if world_size > 1: - torch.cuda.set_device(local_rank) + set_device(local_rank) dist.init_process_group( - backend="nccl", + backend=get_torch_distributed_backend(), init_method="env://", rank=rank, world_size=world_size) @@ -357,6 +366,7 @@ def generate(args): logging.info(f"Extended prompt: {args.prompt}") logging.info("Creating WanT2V pipeline.") + start_time = perf_counter() wan_t2v = wan.WanT2V( config=cfg, checkpoint_dir=args.ckpt_dir, @@ -367,6 +377,8 @@ def generate(args): use_usp=(args.ulysses_size > 1 or args.ring_size > 1), t5_cpu=args.t5_cpu, ) + end_time = perf_counter() + logging.info(f"Creating WanT2V pipeline took {end_time - start_time:.2f} seconds.") logging.info( f"Generating {'image' if 't2i' in args.task else 'video'} ...") @@ -380,7 +392,6 @@ def generate(args): guide_scale=args.sample_guide_scale, seed=args.base_seed, offload_model=args.offload_model) - elif "i2v" in args.task: if args.prompt is None: args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] @@ -414,6 +425,7 @@ def generate(args): logging.info(f"Extended prompt: {args.prompt}") logging.info("Creating WanI2V pipeline.") + start_time = perf_counter() wan_i2v = wan.WanI2V( config=cfg, checkpoint_dir=args.ckpt_dir, @@ -424,6 +436,8 @@ def generate(args): use_usp=(args.ulysses_size > 1 or args.ring_size > 1), t5_cpu=args.t5_cpu, ) + end_time = perf_counter() + logging.info(f"Creating WanI2V pipeline took {end_time - start_time:.2f} seconds.") logging.info("Generating video ...") video = wan_i2v.generate( @@ -572,6 +586,7 @@ def generate(args): value_range=(-1, 1)) else: logging.info(f"Saving generated video to {args.save_file}") + start_time = perf_counter() cache_video( tensor=video[None], save_file=args.save_file, @@ -579,6 +594,8 @@ def generate(args): nrow=1, normalize=True, value_range=(-1, 1)) + end_time = perf_counter() + logging.info(f"Saving Video took {end_time - start_time:.2f} seconds") logging.info("Finished.") diff --git a/wan/distributed/fsdp.py b/wan/distributed/fsdp.py index 6bb496d..2b327ba 100644 --- a/wan/distributed/fsdp.py +++ b/wan/distributed/fsdp.py @@ -3,11 +3,17 @@ import gc from functools import partial import torch +from torch.cuda import empty_cache from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import MixedPrecision, ShardingStrategy from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy from torch.distributed.utils import _free_storage +try: + import torch_musa + from torch_musa.core.memory import empty_cache +except ModuleNotFoundError: + pass def shard_model( model, @@ -40,4 +46,4 @@ def free_model(model): _free_storage(m._handle.flat_param.data) del model gc.collect() - torch.cuda.empty_cache() + empty_cache() diff --git a/wan/distributed/xdit_context_parallel.py b/wan/distributed/xdit_context_parallel.py index 4718577..9bc13fc 100644 --- a/wan/distributed/xdit_context_parallel.py +++ b/wan/distributed/xdit_context_parallel.py @@ -6,20 +6,28 @@ from xfuser.core.distributed import ( get_sequence_parallel_world_size, get_sp_group, ) -from xfuser.core.long_ctx_attention import xFuserLongContextAttention +from xfuser.core.long_ctx_attention import xFuserLongContextAttention, AttnType +attn_type:AttnType = AttnType.FA + +try: + import torch_musa + import torch_musa.core.amp as amp + attn_type = AttnType.TORCH +except ImportError: + pass from ..modules.model import sinusoidal_embedding_1d -def pad_freqs(original_tensor, target_len): +def pad_tensor(original_tensor, target_len, pad_value=0.0): seq_len, s1, s2 = original_tensor.shape pad_size = target_len - seq_len - padding_tensor = torch.ones( - pad_size, - s1, - s2, + padding_tensor = torch.full( + (pad_size, s1, s2), + pad_value, dtype=original_tensor.dtype, - device=original_tensor.device) + device=original_tensor.device, + ) padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) return padded_tensor @@ -32,8 +40,13 @@ def rope_apply(x, grid_sizes, freqs): freqs: [M, C // 2]. """ s, n, c = x.size(1), x.size(2), x.size(3) // 2 + c0 = c - 2 * (c // 3) + c1 = c // 3 + c2 = c // 3 + # split freqs - freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + freqs_real = freqs[0].split([c0, c1, c2], dim=1) + freqs_imag = freqs[-1].split([c0, c1, c2], dim=1) # loop over samples output = [] @@ -41,28 +54,45 @@ def rope_apply(x, grid_sizes, freqs): seq_len = f * h * w # precompute multipliers - x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape( - s, n, -1, 2)) - freqs_i = torch.cat([ - freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), - freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), - freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) - ], - dim=-1).reshape(seq_len, 1, -1) + x_i = x[i, :seq_len].reshape(s, n, -1, 2) + x_real = x_i[..., 0] + x_imag = x_i[..., 1] + freqs_real = torch.cat( + [ + freqs_real[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs_real[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs_real[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), + ], + dim=-1, + ).reshape(seq_len, 1, -1) + freqs_imag = torch.cat( + [ + freqs_imag[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs_imag[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs_imag[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), + ], + dim=-1, + ).reshape(seq_len, 1, -1) # apply rotary embedding sp_size = get_sequence_parallel_world_size() sp_rank = get_sequence_parallel_rank() - freqs_i = pad_freqs(freqs_i, s * sp_size) - s_per_rank = s - freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) * - s_per_rank), :, :] - x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2) - x_i = torch.cat([x_i, x[i, s:]]) + + freqs_real = pad_tensor(freqs_real, s * sp_size, 1.0) + freqs_imag = pad_tensor(freqs_imag, s * sp_size, 0.0) + + freqs_real_rank = freqs_real[(sp_rank * s) : ((sp_rank + 1) * s), :, :] + freqs_imag_rank = freqs_imag[(sp_rank * s) : ((sp_rank + 1) * s), :, :] + + out_real = x_real * freqs_real_rank - x_imag * freqs_imag_rank + out_imag = x_real * freqs_imag_rank + x_imag * freqs_real_rank + + x_out = torch.stack([out_real, out_imag], dim=-1).flatten(2) + x_out = torch.cat([x_out, x[i, seq_len:]], dim=0) # append to collection - output.append(x_i) - return torch.stack(output).float() + output.append(x_out) + return torch.stack(output) def usp_dit_forward_vace(self, x, vace_context, seq_len, kwargs): @@ -109,9 +139,13 @@ def usp_dit_forward( if self.model_type == 'i2v': assert clip_fea is not None and y is not None # params + dtype = self.patch_embedding.weight.dtype device = self.patch_embedding.weight.device - if self.freqs.device != device: - self.freqs = self.freqs.to(device) + if self.freqs[0].dtype != dtype or self.freqs[0].device != device: + self.freqs = ( + self.freqs[0].to(dtype=dtype, device=device), + self.freqs[-1].to(dtype=dtype, device=device) + ) if self.model_type != 'vace' and y is not None: x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] @@ -129,11 +163,9 @@ def usp_dit_forward( ]) # time embeddings - with amp.autocast(dtype=torch.float32): - e = self.time_embedding( - sinusoidal_embedding_1d(self.freq_dim, t).float()) - e0 = self.time_projection(e).unflatten(1, (6, self.dim)) - assert e.dtype == torch.float32 and e0.dtype == torch.float32 + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t)) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # context context_lens = None @@ -177,7 +209,7 @@ def usp_dit_forward( # unpatchify x = self.unpatchify(x, grid_sizes) - return [u.float() for u in x] + return x def usp_attn_forward(self, @@ -210,7 +242,7 @@ def usp_attn_forward(self, # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0) # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0) - x = xFuserLongContextAttention()( + x = xFuserLongContextAttention(attn_type=attn_type)( None, query=half(q), key=half(k), diff --git a/wan/first_last_frame2video.py b/wan/first_last_frame2video.py index 232950f..58c9594 100644 --- a/wan/first_last_frame2video.py +++ b/wan/first_last_frame2video.py @@ -12,10 +12,19 @@ from functools import partial import numpy as np import torch import torch.cuda.amp as amp +from torch.cuda import empty_cache, synchronize import torch.distributed as dist import torchvision.transforms.functional as TF from tqdm import tqdm +try: + import torch_musa + import torch_musa.core.amp as amp + from torch_musa.core.memory import empty_cache + from torch_musa.core.device import synchronize +except ModuleNotFoundError: + pass + from .distributed.fsdp import shard_model from .modules.clip import CLIPModel from .modules.model import WanModel @@ -27,6 +36,7 @@ from .utils.fm_solvers import ( retrieve_timesteps, ) from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from .utils.platform import get_device class WanFLF2V: @@ -66,7 +76,7 @@ class WanFLF2V: init_on_cpu (`bool`, *optional*, defaults to True): Enable initializing Transformer Model on CPU. Only works without FSDP or USP. """ - self.device = torch.device(f"cuda:{device_id}") + self.device = get_device(device_id) self.config = config self.rank = rank self.use_usp = use_usp @@ -323,7 +333,7 @@ class WanFLF2V: } if offload_model: - torch.cuda.empty_cache() + empty_cache() self.model.to(self.device) for _, t in enumerate(tqdm(timesteps)): @@ -336,12 +346,12 @@ class WanFLF2V: latent_model_input, t=timestep, **arg_c)[0].to( torch.device('cpu') if offload_model else self.device) if offload_model: - torch.cuda.empty_cache() + empty_cache() noise_pred_uncond = self.model( latent_model_input, t=timestep, **arg_null)[0].to( torch.device('cpu') if offload_model else self.device) if offload_model: - torch.cuda.empty_cache() + empty_cache() noise_pred = noise_pred_uncond + guide_scale * ( noise_pred_cond - noise_pred_uncond) @@ -361,7 +371,7 @@ class WanFLF2V: if offload_model: self.model.cpu() - torch.cuda.empty_cache() + empty_cache() if self.rank == 0: videos = self.vae.decode(x0) @@ -370,7 +380,7 @@ class WanFLF2V: del sample_scheduler if offload_model: gc.collect() - torch.cuda.synchronize() + synchronize() if dist.is_initialized(): dist.barrier() diff --git a/wan/image2video.py b/wan/image2video.py index 6882c53..2167fda 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -6,16 +6,26 @@ import os import random import sys import types +from time import perf_counter from contextlib import contextmanager from functools import partial import numpy as np import torch import torch.cuda.amp as amp +from torch.cuda import empty_cache, synchronize import torch.distributed as dist import torchvision.transforms.functional as TF from tqdm import tqdm +try: + import torch_musa + import torch_musa.core.amp as amp + from torch_musa.core.memory import empty_cache + from torch_musa.core.device import synchronize +except ModuleNotFoundError: + pass + from .distributed.fsdp import shard_model from .modules.clip import CLIPModel from .modules.model import WanModel @@ -27,6 +37,7 @@ from .utils.fm_solvers import ( retrieve_timesteps, ) from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from .utils.platform import get_device class WanI2V: @@ -66,7 +77,7 @@ class WanI2V: init_on_cpu (`bool`, *optional*, defaults to True): Enable initializing Transformer Model on CPU. Only works without FSDP or USP. """ - self.device = torch.device(f"cuda:{device_id}") + self.device = get_device(device_id) self.config = config self.rank = rank self.use_usp = use_usp @@ -220,6 +231,7 @@ class WanI2V: n_prompt = self.sample_neg_prompt # preprocess + start_time = perf_counter() if not self.t5_cpu: self.text_encoder.model.to(self.device) context = self.text_encoder([input_prompt], self.device) @@ -231,12 +243,18 @@ class WanI2V: context_null = self.text_encoder([n_prompt], torch.device('cpu')) context = [t.to(self.device) for t in context] context_null = [t.to(self.device) for t in context_null] + end_time = perf_counter() + logging.info(f"T5 Encoding took {end_time - start_time:.2f} seconds.") + start_time = perf_counter() self.clip.model.to(self.device) clip_context = self.clip.visual([img[:, None, :, :]]) if offload_model: self.clip.model.cpu() + end_time = perf_counter() + logging.info(f"CLIP took {end_time - start_time:.2f} seconds.") + start_time = perf_counter() y = self.vae.encode([ torch.concat([ torch.nn.functional.interpolate( @@ -246,6 +264,9 @@ class WanI2V: ], dim=1).to(self.device) ])[0] + end_time = perf_counter() + logging.info(f"VAE Encoding took {end_time - start_time:.2f} seconds.") + y = torch.concat([msk, y]) @contextmanager @@ -296,8 +317,9 @@ class WanI2V: } if offload_model: - torch.cuda.empty_cache() + empty_cache() + start_time = perf_counter() self.model.to(self.device) for _, t in enumerate(tqdm(timesteps)): latent_model_input = [latent.to(self.device)] @@ -309,12 +331,12 @@ class WanI2V: latent_model_input, t=timestep, **arg_c)[0].to( torch.device('cpu') if offload_model else self.device) if offload_model: - torch.cuda.empty_cache() + empty_cache() noise_pred_uncond = self.model( latent_model_input, t=timestep, **arg_null)[0].to( torch.device('cpu') if offload_model else self.device) if offload_model: - torch.cuda.empty_cache() + empty_cache() noise_pred = noise_pred_uncond + guide_scale * ( noise_pred_cond - noise_pred_uncond) @@ -331,19 +353,24 @@ class WanI2V: x0 = [latent.to(self.device)] del latent_model_input, timestep + end_time = perf_counter() + logging.info(f"Sampling took {end_time - start_time:.2f} seconds.") if offload_model: self.model.cpu() - torch.cuda.empty_cache() + empty_cache() if self.rank == 0: + start_time = perf_counter() videos = self.vae.decode(x0) + end_time = perf_counter() + logging.info(f"VAE Decoding took {end_time - start_time:.2f} seconds.") del noise, latent del sample_scheduler if offload_model: gc.collect() - torch.cuda.synchronize() + synchronize() if dist.is_initialized(): dist.barrier() diff --git a/wan/modules/attention.py b/wan/modules/attention.py index 4dbbe03..74fb83e 100644 --- a/wan/modules/attention.py +++ b/wan/modules/attention.py @@ -1,4 +1,6 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import warnings + import torch try: @@ -13,7 +15,13 @@ try: except ModuleNotFoundError: FLASH_ATTN_2_AVAILABLE = False -import warnings +try: + import torch_musa + FLASH_ATTN_3_AVAILABLE = False + FLASH_ATTN_2_AVAILABLE = False +except ModuleNotFoundError: + pass + __all__ = [ 'flash_attention', @@ -51,7 +59,7 @@ def flash_attention( """ half_dtypes = (torch.float16, torch.bfloat16) assert dtype in half_dtypes - assert q.device.type == 'cuda' and q.size(-1) <= 256 + assert (q.device.type == "cuda" or q.device.type == "musa") and q.size(-1) <= 256 # params b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype @@ -172,8 +180,9 @@ def attention( k = k.transpose(1, 2).to(dtype) v = v.transpose(1, 2).to(dtype) - out = torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p) + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False): + out = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=causal, scale=softmax_scale) out = out.transpose(1, 2).contiguous() return out diff --git a/wan/modules/clip.py b/wan/modules/clip.py index 42dda04..6a4c513 100644 --- a/wan/modules/clip.py +++ b/wan/modules/clip.py @@ -6,12 +6,20 @@ import math import torch import torch.nn as nn import torch.nn.functional as F +import torch.cuda.amp as amp import torchvision.transforms as T from .attention import flash_attention from .tokenizers import HuggingfaceTokenizer from .xlm_roberta import XLMRoberta +try: + import torch_musa + import torch_musa.core.amp as amp + from .attention import attention as flash_attention +except ModuleNotFoundError: + pass + __all__ = [ 'XLMRobertaCLIP', 'clip_xlm_roberta_vit_h_14', @@ -29,7 +37,7 @@ def pos_interpolate(pos, seq_len): return torch.cat([ pos[:, :n], F.interpolate( - pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute( + pos[:, n:].reshape(1, src_grid, src_grid, -1).permute( 0, 3, 1, 2), size=(tar_grid, tar_grid), mode='bicubic', @@ -44,12 +52,6 @@ class QuickGELU(nn.Module): return x * torch.sigmoid(1.702 * x) -class LayerNorm(nn.LayerNorm): - - def forward(self, x): - return super().forward(x.float()).type_as(x) - - class SelfAttention(nn.Module): def __init__(self, @@ -82,7 +84,7 @@ class SelfAttention(nn.Module): # compute attention p = self.attn_dropout if self.training else 0.0 - x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2) + x = flash_attention(q, k, v, dropout_p=p, causal=self.causal) x = x.reshape(b, s, c) # output @@ -131,10 +133,10 @@ class AttentionBlock(nn.Module): self.norm_eps = norm_eps # layers - self.norm1 = LayerNorm(dim, eps=norm_eps) + self.norm1 = nn.LayerNorm(dim, eps=norm_eps) self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout) - self.norm2 = LayerNorm(dim, eps=norm_eps) + self.norm2 = nn.LayerNorm(dim, eps=norm_eps) if activation == 'swi_glu': self.mlp = SwiGLU(dim, int(dim * mlp_ratio)) else: @@ -177,7 +179,7 @@ class AttentionPool(nn.Module): self.to_q = nn.Linear(dim, dim) self.to_kv = nn.Linear(dim, dim * 2) self.proj = nn.Linear(dim, dim) - self.norm = LayerNorm(dim, eps=norm_eps) + self.norm = nn.LayerNorm(dim, eps=norm_eps) self.mlp = nn.Sequential( nn.Linear(dim, int(dim * mlp_ratio)), QuickGELU() if activation == 'quick_gelu' else nn.GELU(), @@ -259,13 +261,13 @@ class VisionTransformer(nn.Module): self.dropout = nn.Dropout(embedding_dropout) # transformer - self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None + self.pre_norm = nn.LayerNorm(dim, eps=norm_eps) if pre_norm else None self.transformer = nn.Sequential(*[ AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False, activation, attn_dropout, proj_dropout, norm_eps) for _ in range(num_layers) ]) - self.post_norm = LayerNorm(dim, eps=norm_eps) + self.post_norm = nn.LayerNorm(dim, eps=norm_eps) # head if pool_type == 'token': @@ -537,6 +539,6 @@ class CLIPModel: videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5)) # forward - with torch.cuda.amp.autocast(dtype=self.dtype): + with amp.autocast(dtype=self.dtype): out = self.model.visual(videos, use_31_block=True) return out diff --git a/wan/modules/model.py b/wan/modules/model.py index a5425da..5d0b098 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -7,7 +7,15 @@ import torch.nn as nn from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin -from .attention import flash_attention +from wan.utils.platform import get_device +from wan.modules.attention import flash_attention + +try: + import torch_musa + import torch_musa.core.amp as amp + from wan.modules.attention import attention as flash_attention +except ModuleNotFoundError: + pass __all__ = ['WanModel'] @@ -19,7 +27,7 @@ def sinusoidal_embedding_1d(dim, position): # preprocess assert dim % 2 == 0 half = dim // 2 - position = position.type(torch.float64) + position = position.type(torch.bfloat16) # calculation sinusoid = torch.outer( @@ -29,22 +37,45 @@ def sinusoidal_embedding_1d(dim, position): @amp.autocast(enabled=False) -def rope_params(max_seq_len, dim, theta=10000): +def rope_params_real( + max_seq_len, dim, theta=10000, dtype=torch.float32, device=torch.device("cpu") +): assert dim % 2 == 0 - freqs = torch.outer( - torch.arange(max_seq_len), - 1.0 / torch.pow(theta, - torch.arange(0, dim, 2).to(torch.float64).div(dim))) - freqs = torch.polar(torch.ones_like(freqs), freqs) - return freqs + freqs_real = torch.outer( + torch.arange(max_seq_len, dtype=dtype, device=device), + 1.0 + / torch.pow( + theta, torch.arange(0, dim, 2, dtype=dtype, device=device).div(dim) + ), + ) + return torch.cos(freqs_real) + + +@amp.autocast(enabled=False) +def rope_params_imag( + max_seq_len, dim, theta=10000, dtype=torch.float32, device=torch.device("cpu") +): + assert dim % 2 == 0 + freqs_imag = torch.outer( + torch.arange(max_seq_len, dtype=dtype, device=device), + 1.0 + / torch.pow( + theta, torch.arange(0, dim, 2, dtype=dtype, device=device).div(dim) + ), + ) + return torch.sin(freqs_imag) @amp.autocast(enabled=False) def rope_apply(x, grid_sizes, freqs): n, c = x.size(2), x.size(3) // 2 + c0 = c - 2 * (c // 3) + c1 = c // 3 + c2 = c // 3 # split freqs - freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + freqs_real = freqs[0].split([c0, c1, c2], dim=1) + freqs_imag = freqs[-1].split([c0, c1, c2], dim=1) # loop over samples output = [] @@ -52,22 +83,36 @@ def rope_apply(x, grid_sizes, freqs): seq_len = f * h * w # precompute multipliers - x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape( - seq_len, n, -1, 2)) - freqs_i = torch.cat([ - freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), - freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), - freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) - ], - dim=-1).reshape(seq_len, 1, -1) + x_i = x[i, :seq_len].reshape(seq_len, n, c, 2) + x_real = x_i[..., 0] + x_imag = x_i[..., 1] + freqs_real = torch.cat( + [ + freqs_real[0][:f].view(f, 1, 1, c0).expand(f, h, w, c0), + freqs_real[1][:h].view(1, h, 1, c1).expand(f, h, w, c1), + freqs_real[2][:w].view(1, 1, w, c2).expand(f, h, w, c2), + ], + dim=-1, + ).reshape(seq_len, 1, c) + freqs_imag = torch.cat( + [ + freqs_imag[0][:f].view(f, 1, 1, c0).expand(f, h, w, c0), + freqs_imag[1][:h].view(1, h, 1, c1).expand(f, h, w, c1), + freqs_imag[2][:w].view(1, 1, w, c2).expand(f, h, w, c2), + ], + dim=-1, + ).reshape(seq_len, 1, c) + + out_real = x_real * freqs_real - x_imag * freqs_imag + out_imag = x_real * freqs_imag + x_imag * freqs_real # apply rotary embedding - x_i = torch.view_as_real(x_i * freqs_i).flatten(2) - x_i = torch.cat([x_i, x[i, seq_len:]]) + x_out = torch.stack([out_real, out_imag], dim=-1).flatten(2) + x_out = torch.cat([x_out, x[i, seq_len:]], dim=0) # append to collection - output.append(x_i) - return torch.stack(output).float() + output.append(x_out) + return torch.stack(output) class WanRMSNorm(nn.Module): @@ -89,19 +134,6 @@ class WanRMSNorm(nn.Module): return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) -class WanLayerNorm(nn.LayerNorm): - - def __init__(self, dim, eps=1e-6, elementwise_affine=False): - super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) - - def forward(self, x): - r""" - Args: - x(Tensor): Shape [B, L, C] - """ - return super().forward(x.float()).type_as(x) - - class WanSelfAttention(nn.Module): def __init__(self, @@ -256,10 +288,10 @@ class WanAttentionBlock(nn.Module): self.eps = eps # layers - self.norm1 = WanLayerNorm(dim, eps) + self.norm1 = nn.LayerNorm(dim, eps, elementwise_affine=False) self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps) - self.norm3 = WanLayerNorm( + self.norm3 = nn.LayerNorm( dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim, @@ -267,7 +299,7 @@ class WanAttentionBlock(nn.Module): (-1, -1), qk_norm, eps) - self.norm2 = WanLayerNorm(dim, eps) + self.norm2 = nn.LayerNorm(dim, eps, elementwise_affine=False) self.ffn = nn.Sequential( nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'), nn.Linear(ffn_dim, dim)) @@ -293,24 +325,19 @@ class WanAttentionBlock(nn.Module): grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] """ - assert e.dtype == torch.float32 - with amp.autocast(dtype=torch.float32): - e = (self.modulation + e).chunk(6, dim=1) - assert e[0].dtype == torch.float32 + e = (self.modulation + e).chunk(6, dim=1) # self-attention y = self.self_attn( - self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes, + self.norm1(x) * (1 + e[1]) + e[0], seq_lens, grid_sizes, freqs) - with amp.autocast(dtype=torch.float32): - x = x + y * e[2] + x = x + y * e[2] # cross-attention & ffn function def cross_attn_ffn(x, context, context_lens, e): x = x + self.cross_attn(self.norm3(x), context, context_lens) - y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3]) - with amp.autocast(dtype=torch.float32): - x = x + y * e[5] + y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3]) + x = x + y * e[5] return x x = cross_attn_ffn(x, context, context_lens, e) @@ -328,7 +355,7 @@ class Head(nn.Module): # layers out_dim = math.prod(patch_size) * out_dim - self.norm = WanLayerNorm(dim, eps) + self.norm = nn.LayerNorm(dim, eps, elementwise_affine=False) self.head = nn.Linear(dim, out_dim) # modulation @@ -340,10 +367,8 @@ class Head(nn.Module): x(Tensor): Shape [B, L1, C] e(Tensor): Shape [B, C] """ - assert e.dtype == torch.float32 - with amp.autocast(dtype=torch.float32): - e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) - x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) + e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) + x = self.head(self.norm(x) * (1 + e[1]) + e[0]) return x @@ -477,12 +502,23 @@ class WanModel(ModelMixin, ConfigMixin): # buffers (don't use register_buffer otherwise dtype will be changed in to()) assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 d = dim // num_heads - self.freqs = torch.cat([ - rope_params(1024, d - 4 * (d // 6)), - rope_params(1024, 2 * (d // 6)), - rope_params(1024, 2 * (d // 6)) - ], - dim=1) + freqs_real = torch.cat( + [ + rope_params_real(1024, d - 4 * (d // 6)), + rope_params_real(1024, 2 * (d // 6)), + rope_params_real(1024, 2 * (d // 6)), + ], + dim=1, + ) + freqs_imag = torch.cat( + [ + rope_params_imag(1024, d - 4 * (d // 6)), + rope_params_imag(1024, 2 * (d // 6)), + rope_params_imag(1024, 2 * (d // 6)), + ], + dim=1, + ) + self.freqs = (freqs_real, freqs_imag) if model_type == 'i2v' or model_type == 'flf2v': self.img_emb = MLPProj(1280, dim, flf_pos_emb=model_type == 'flf2v') @@ -523,9 +559,13 @@ class WanModel(ModelMixin, ConfigMixin): if self.model_type == 'i2v' or self.model_type == 'flf2v': assert clip_fea is not None and y is not None # params + dtype = self.patch_embedding.weight.dtype device = self.patch_embedding.weight.device - if self.freqs.device != device: - self.freqs = self.freqs.to(device) + if self.freqs[0].dtype != dtype or self.freqs[0].device != device: + self.freqs = ( + self.freqs[0].to(dtype=dtype, device=device), + self.freqs[-1].to(dtype=dtype, device=device) + ) if y is not None: x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] @@ -543,11 +583,9 @@ class WanModel(ModelMixin, ConfigMixin): ]) # time embeddings - with amp.autocast(dtype=torch.float32): - e = self.time_embedding( - sinusoidal_embedding_1d(self.freq_dim, t).float()) - e0 = self.time_projection(e).unflatten(1, (6, self.dim)) - assert e.dtype == torch.float32 and e0.dtype == torch.float32 + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t)) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # context context_lens = None @@ -579,7 +617,7 @@ class WanModel(ModelMixin, ConfigMixin): # unpatchify x = self.unpatchify(x, grid_sizes) - return [u.float() for u in x] + return x def unpatchify(self, x, grid_sizes): r""" diff --git a/wan/modules/t5.py b/wan/modules/t5.py index c841b04..d19cd01 100644 --- a/wan/modules/t5.py +++ b/wan/modules/t5.py @@ -7,7 +7,14 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .tokenizers import HuggingfaceTokenizer +try: + import torch_musa + from torch_musa.core.device import current_device +except ModuleNotFoundError: + pass + +from wan.modules.tokenizers import HuggingfaceTokenizer +from wan.utils.platform import get_device __all__ = [ 'T5Model', @@ -59,10 +66,8 @@ class T5LayerNorm(nn.Module): self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): - x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + + x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) - if self.weight.dtype in [torch.float16, torch.bfloat16]: - x = x.type_as(self.weight) return self.weight * x @@ -110,7 +115,7 @@ class T5Attention(nn.Module): # compute attention (T5 does not use scaling) attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias - attn = F.softmax(attn.float(), dim=-1).type_as(attn) + attn = F.softmax(attn, dim=-1) x = torch.einsum('bnij,bjnc->binc', attn, v) # output @@ -255,7 +260,7 @@ class T5RelativeEmbedding(nn.Module): # embeddings for small and large positions max_exact = num_buckets // 2 - rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / + rel_pos_large = max_exact + (torch.log(rel_pos / max_exact) / math.log(self.max_dist / max_exact) * (num_buckets - max_exact)).long() rel_pos_large = torch.min( @@ -475,7 +480,7 @@ class T5EncoderModel: self, text_len, dtype=torch.bfloat16, - device=torch.cuda.current_device(), + device=get_device(), checkpoint_path=None, tokenizer_path=None, shard_fn=None, diff --git a/wan/modules/vace_model.py b/wan/modules/vace_model.py index a12d1dd..440fb60 100644 --- a/wan/modules/vace_model.py +++ b/wan/modules/vace_model.py @@ -4,6 +4,12 @@ import torch.cuda.amp as amp import torch.nn as nn from diffusers.configuration_utils import register_to_config +try: + import torch_musa + import torch_musa.core.amp as amp +except ModuleNotFoundError: + pass + from .model import WanAttentionBlock, WanModel, sinusoidal_embedding_1d diff --git a/wan/modules/vae.py b/wan/modules/vae.py index 5c6da57..b4e30cc 100644 --- a/wan/modules/vae.py +++ b/wan/modules/vae.py @@ -1,12 +1,20 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import logging +from math import sqrt import torch -import torch.cuda.amp as amp import torch.nn as nn import torch.nn.functional as F +from torch.nn import Upsample from einops import rearrange +try: + import torch_musa +except ModuleNotFoundError: + pass + +from wan.utils.platform import get_device + __all__ = [ 'WanVAE', ] @@ -44,23 +52,17 @@ class RMS_norm(nn.Module): shape = (dim, *broadcastable_dims) if channel_first else (dim,) self.channel_first = channel_first - self.scale = dim**0.5 + self.scale = sqrt(dim) self.gamma = nn.Parameter(torch.ones(shape)) self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0. def forward(self, x): - return F.normalize( - x, dim=(1 if self.channel_first else - -1)) * self.scale * self.gamma + self.bias - - -class Upsample(nn.Upsample): - - def forward(self, x): - """ - Fix bfloat16 support for nearest neighbor interpolation. - """ - return super().forward(x.float()).type_as(x) + return ( + F.normalize(x.float(), dim=(1 if self.channel_first else -1)).type_as(x) + * self.scale + * self.gamma + + self.bias + ) class Resample(nn.Module): @@ -253,6 +255,10 @@ class AttentionBlock(nn.Module): q, k, v, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, ) x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) @@ -621,8 +627,8 @@ class WanVAE: def __init__(self, z_dim=16, vae_pth='cache/vae_step_411000.pth', - dtype=torch.float, - device="cuda"): + dtype=torch.bfloat16, + device=get_device()): self.dtype = dtype self.device = device @@ -648,16 +654,12 @@ class WanVAE: """ videos: A list of videos each with shape [C, T, H, W]. """ - with amp.autocast(dtype=self.dtype): - return [ - self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) - for u in videos - ] + return [ + self.model.encode(u.unsqueeze(0), self.scale).squeeze(0) for u in videos + ] def decode(self, zs): - with amp.autocast(dtype=self.dtype): - return [ - self.model.decode(u.unsqueeze(0), - self.scale).float().clamp_(-1, 1).squeeze(0) - for u in zs - ] + return [ + self.model.decode(u.unsqueeze(0), self.scale).clamp_(-1, 1).squeeze(0) + for u in zs + ] diff --git a/wan/text2video.py b/wan/text2video.py index c518b61..25ffff9 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -6,14 +6,24 @@ import os import random import sys import types +from time import perf_counter from contextlib import contextmanager from functools import partial import torch import torch.cuda.amp as amp +from torch.cuda import empty_cache, synchronize import torch.distributed as dist from tqdm import tqdm +try: + import torch_musa + import torch_musa.core.amp as amp + from torch_musa.core.memory import empty_cache + from torch_musa.core.device import synchronize +except ModuleNotFoundError: + pass + from .distributed.fsdp import shard_model from .modules.model import WanModel from .modules.t5 import T5EncoderModel @@ -24,7 +34,7 @@ from .utils.fm_solvers import ( retrieve_timesteps, ) from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler - +from .utils.platform import get_device class WanT2V: @@ -60,7 +70,7 @@ class WanT2V: t5_cpu (`bool`, *optional*, defaults to False): Whether to place T5 model on CPU. Only works without t5_fsdp. """ - self.device = torch.device(f"cuda:{device_id}") + self.device = get_device(device_id) self.config = config self.rank = rank self.t5_cpu = t5_cpu @@ -171,6 +181,7 @@ class WanT2V: seed_g = torch.Generator(device=self.device) seed_g.manual_seed(seed) + start_time = perf_counter() if not self.t5_cpu: self.text_encoder.model.to(self.device) context = self.text_encoder([input_prompt], self.device) @@ -182,6 +193,8 @@ class WanT2V: context_null = self.text_encoder([n_prompt], torch.device('cpu')) context = [t.to(self.device) for t in context] context_null = [t.to(self.device) for t in context_null] + end_time = perf_counter() + logging.info(f"T5 Encoding Context took {end_time - start_time:.2f} seconds.") noise = [ torch.randn( @@ -230,13 +243,14 @@ class WanT2V: arg_c = {'context': context, 'seq_len': seq_len} arg_null = {'context': context_null, 'seq_len': seq_len} + start_time = perf_counter() + self.model.to(self.device) for _, t in enumerate(tqdm(timesteps)): latent_model_input = latents timestep = [t] timestep = torch.stack(timestep) - self.model.to(self.device) noise_pred_cond = self.model( latent_model_input, t=timestep, **arg_c)[0] noise_pred_uncond = self.model( @@ -252,19 +266,24 @@ class WanT2V: return_dict=False, generator=seed_g)[0] latents = [temp_x0.squeeze(0)] + end_time = perf_counter() + logging.info(f"Sampling took {end_time - start_time:.2f} seconds.") x0 = latents if offload_model: self.model.cpu() - torch.cuda.empty_cache() + empty_cache() if self.rank == 0: + start_time = perf_counter() videos = self.vae.decode(x0) + end_time = perf_counter() + logging.info(f"VAE Decoding took {end_time - start_time:.2f} seconds.") del noise, latents del sample_scheduler if offload_model: gc.collect() - torch.cuda.synchronize() + synchronize() if dist.is_initialized(): dist.barrier() diff --git a/wan/utils/__init__.py b/wan/utils/__init__.py index 2e9b33d..02d6687 100644 --- a/wan/utils/__init__.py +++ b/wan/utils/__init__.py @@ -5,9 +5,10 @@ from .fm_solvers import ( ) from .fm_solvers_unipc import FlowUniPCMultistepScheduler from .vace_processor import VaceVideoProcessor +from .platform import get_device, get_torch_distributed_backend __all__ = [ 'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps', 'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler', - 'VaceVideoProcessor' + 'VaceVideoProcessor', 'get_device', 'get_torch_distributed_backend' ] diff --git a/wan/utils/platform.py b/wan/utils/platform.py new file mode 100644 index 0000000..8731a49 --- /dev/null +++ b/wan/utils/platform.py @@ -0,0 +1,34 @@ +from typing import Optional + +import torch + +try: + import torch_musa +except ModuleNotFoundError: + pass + + +def _is_musa(): + try: + if torch.musa.is_available(): + return True + except ModuleNotFoundError: + return False + + +def get_device(local_rank:Optional[int]=None) -> torch.device: + if torch.cuda.is_available(): + return torch.cuda.current_device() if local_rank is None else torch.device("cuda", local_rank) + elif _is_musa(): + return torch.musa.current_device() if local_rank is None else torch.device("musa", local_rank) + else: + return torch.device("cpu") + + +def get_torch_distributed_backend() -> str: + if torch.cuda.is_available(): + return "nccl" + elif _is_musa(): + return "mccl" + else: + raise NotImplementedError("No Accelerators(AMD/NV/MTT GPU, AMD MI instinct accelerators) available") diff --git a/wan/vace.py b/wan/vace.py index 8a4f744..06f9b51 100644 --- a/wan/vace.py +++ b/wan/vace.py @@ -13,6 +13,7 @@ from functools import partial import torch import torch.cuda.amp as amp +from torch.cuda import empty_cache, synchronize import torch.distributed as dist import torch.multiprocessing as mp import torch.nn.functional as F @@ -20,6 +21,14 @@ import torchvision.transforms.functional as TF from PIL import Image from tqdm import tqdm +try: + import torch_musa + import torch_musa.core.amp as amp + from torch_musa.core.memory import empty_cache + from torch_musa.core.device import synchronize +except ModuleNotFoundError: + pass + from .modules.vace_model import VaceWanModel from .text2video import ( FlowDPMSolverMultistepScheduler, @@ -32,6 +41,7 @@ from .text2video import ( shard_model, ) from .utils.vace_processor import VaceVideoProcessor +from .utils.platform import get_device, get_torch_distributed_backend class WanVace(WanT2V): @@ -68,7 +78,7 @@ class WanVace(WanT2V): t5_cpu (`bool`, *optional*, defaults to False): Whether to place T5 model on CPU. Only works without t5_fsdp. """ - self.device = torch.device(f"cuda:{device_id}") + self.device = get_device(device_id) self.config = config self.rank = rank self.t5_cpu = t5_cpu @@ -460,7 +470,7 @@ class WanVace(WanT2V): x0 = latents if offload_model: self.model.cpu() - torch.cuda.empty_cache() + empty_cache() if self.rank == 0: videos = self.decode_latent(x0, input_ref_images) @@ -468,7 +478,7 @@ class WanVace(WanT2V): del sample_scheduler if offload_model: gc.collect() - torch.cuda.synchronize() + synchronize() if dist.is_initialized(): dist.barrier() @@ -568,7 +578,7 @@ class WanVaceMP(WanVace): torch.cuda.set_device(gpu) dist.init_process_group( - backend='nccl', + backend=get_torch_distributed_backend(), init_method='env://', rank=rank, world_size=world_size) @@ -633,7 +643,7 @@ class WanVaceMP(WanVace): model = shard_fn(model) sample_neg_prompt = self.config.sample_neg_prompt - torch.cuda.empty_cache() + empty_cache() event = initialized_events[gpu] in_q = in_q_list[gpu] event.set() @@ -748,7 +758,7 @@ class WanVaceMP(WanVace): generator=seed_g)[0] latents = [temp_x0.squeeze(0)] - torch.cuda.empty_cache() + empty_cache() x0 = latents if rank == 0: videos = self.decode_latent( @@ -758,7 +768,7 @@ class WanVaceMP(WanVace): del sample_scheduler if offload_model: gc.collect() - torch.cuda.synchronize() + synchronize() if dist.is_initialized(): dist.barrier()