From 447aa08620cb54b51071f1eecf5e6e90a5121756 Mon Sep 17 00:00:00 2001 From: Houchen Li Date: Mon, 28 Jul 2025 19:07:56 +0800 Subject: [PATCH] [feature] adapt for Moore Threads GPU family --- generate.py | 12 +- wan/distributed/fsdp.py | 8 +- wan/distributed/xdit_context_parallel.py | 109 +++++++++++++++- wan/first_last_frame2video.py | 22 +++- wan/image2video.py | 22 +++- wan/modules/attention.py | 14 +- wan/modules/clip.py | 20 ++- wan/modules/model.py | 157 ++++++++++++++++++++--- wan/modules/t5.py | 9 +- wan/modules/vace_model.py | 6 + wan/modules/vae.py | 11 +- wan/text2video.py | 17 ++- wan/utils/__init__.py | 3 +- wan/utils/platform.py | 34 +++++ wan/vace.py | 24 +++- 15 files changed, 411 insertions(+), 57 deletions(-) create mode 100644 wan/utils/platform.py diff --git a/generate.py b/generate.py index c841c19..ca9b2cc 100644 --- a/generate.py +++ b/generate.py @@ -12,12 +12,20 @@ import random import torch import torch.distributed as dist +from torch.cuda import set_device from PIL import Image +try: + import torch_musa + from torch_musa.core.device import set_device +except ModuleNotFoundError: + torch_musa = None + import wan from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander from wan.utils.utils import cache_image, cache_video, str2bool +from wan.utils.platform import get_torch_distributed_backend EXAMPLE_PROMPT = { @@ -275,9 +283,9 @@ def generate(args): logging.info( f"offload_model is not specified, set to {args.offload_model}.") if world_size > 1: - torch.cuda.set_device(local_rank) + set_device(local_rank) dist.init_process_group( - backend="nccl", + backend=get_torch_distributed_backend(), init_method="env://", rank=rank, world_size=world_size) diff --git a/wan/distributed/fsdp.py b/wan/distributed/fsdp.py index 6bb496d..ca12048 100644 --- a/wan/distributed/fsdp.py +++ b/wan/distributed/fsdp.py @@ -3,11 +3,17 @@ import gc from functools import partial import torch +from torch.cuda import empty_cache from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import MixedPrecision, ShardingStrategy from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy from torch.distributed.utils import _free_storage +try: + import torch_musa + from torch_musa.core.memory import empty_cache +except ModuleNotFoundError: + torch_musa = None def shard_model( model, @@ -40,4 +46,4 @@ def free_model(model): _free_storage(m._handle.flat_param.data) del model gc.collect() - torch.cuda.empty_cache() + empty_cache() diff --git a/wan/distributed/xdit_context_parallel.py b/wan/distributed/xdit_context_parallel.py index 4718577..546f783 100644 --- a/wan/distributed/xdit_context_parallel.py +++ b/wan/distributed/xdit_context_parallel.py @@ -6,7 +6,15 @@ from xfuser.core.distributed import ( get_sequence_parallel_world_size, get_sp_group, ) -from xfuser.core.long_ctx_attention import xFuserLongContextAttention +from xfuser.core.long_ctx_attention import xFuserLongContextAttention, AttnType +attn_type:AttnType = AttnType.FA + +try: + import torch_musa + import torch_musa.core.amp as amp + attn_type = AttnType.TORCH +except ImportError: + torch_musa = None from ..modules.model import sinusoidal_embedding_1d @@ -24,6 +32,19 @@ def pad_freqs(original_tensor, target_len): return padded_tensor +def pad_tensor(original_tensor, target_len, pad_value=0.0): + seq_len, s1, s2 = original_tensor.shape + pad_size = target_len - seq_len + padding_tensor = torch.full( + (pad_size, s1, s2), + pad_value, + dtype=original_tensor.dtype, + device=original_tensor.device, + ) + padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) + return padded_tensor + + @amp.autocast(enabled=False) def rope_apply(x, grid_sizes, freqs): """ @@ -65,6 +86,69 @@ def rope_apply(x, grid_sizes, freqs): return torch.stack(output).float() +@amp.autocast(enabled=False) +def rope_apply_musa(x, grid_sizes, freqs): + """ + x: [B, L, N, C]. + grid_sizes: [B, 3]. + freqs: [M, C // 2]. + """ + s, n, c = x.size(1), x.size(2), x.size(3) // 2 + c0 = c - 2 * (c // 3) + c1 = c // 3 + c2 = c // 3 + + # split freqs + freqs_real = freqs[0].split([c0, c1, c2], dim=1) + freqs_imag = freqs[-1].split([c0, c1, c2], dim=1) + + # loop over samples + output = [] + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + seq_len = f * h * w + + # precompute multipliers + x_i = x[i, :seq_len].reshape(s, n, -1, 2) + x_real = x_i[..., 0] + x_imag = x_i[..., 1] + freqs_real = torch.cat( + [ + freqs_real[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs_real[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs_real[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), + ], + dim=-1, + ).reshape(seq_len, 1, -1) + freqs_imag = torch.cat( + [ + freqs_imag[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs_imag[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs_imag[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), + ], + dim=-1, + ).reshape(seq_len, 1, -1) + + # apply rotary embedding + sp_size = get_sequence_parallel_world_size() + sp_rank = get_sequence_parallel_rank() + + freqs_real = pad_tensor(freqs_real, s * sp_size, 1.0) + freqs_imag = pad_tensor(freqs_imag, s * sp_size, 0.0) + + freqs_real_rank = freqs_real[(sp_rank * s) : ((sp_rank + 1) * s), :, :] + freqs_imag_rank = freqs_imag[(sp_rank * s) : ((sp_rank + 1) * s), :, :] + + out_real = x_real * freqs_real_rank - x_imag * freqs_imag_rank + out_imag = x_real * freqs_imag_rank + x_imag * freqs_real_rank + + x_out = torch.stack([out_real, out_imag], dim=-1).flatten(2) + x_out = torch.cat([x_out, x[i, seq_len:]], dim=0) + + # append to collection + output.append(x_out) + return torch.stack(output) + + def usp_dit_forward_vace(self, x, vace_context, seq_len, kwargs): # embeddings c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] @@ -109,9 +193,17 @@ def usp_dit_forward( if self.model_type == 'i2v': assert clip_fea is not None and y is not None # params + dtype = self.patch_embedding.weight.dtype device = self.patch_embedding.weight.device - if self.freqs.device != device: - self.freqs = self.freqs.to(device) + if torch_musa is not None: + if self.freqs[0].dtype != dtype or self.freqs[0].device != device: + self.freqs = ( + self.freqs[0].to(dtype=dtype, device=device), + self.freqs[-1].to(dtype=dtype, device=device) + ) + else: + if self.freqs.dtype != dtype or self.freqs.device != device: + self.freqs = self.freqs.to(dtype=dtype, device=device) if self.model_type != 'vace' and y is not None: x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] @@ -200,8 +292,13 @@ def usp_attn_forward(self, return q, k, v q, k, v = qkv_fn(x) - q = rope_apply(q, grid_sizes, freqs) - k = rope_apply(k, grid_sizes, freqs) + + if torch_musa is not None: + q = rope_apply_musa(q, grid_sizes, freqs) + k = rope_apply_musa(k, grid_sizes, freqs) + else: + q = rope_apply(q, grid_sizes, freqs) + k = rope_apply(k, grid_sizes, freqs) # TODO: We should use unpaded q,k,v for attention. # k_lens = seq_lens // get_sequence_parallel_world_size() @@ -210,7 +307,7 @@ def usp_attn_forward(self, # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0) # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0) - x = xFuserLongContextAttention()( + x = xFuserLongContextAttention(attn_type=attn_type)( None, query=half(q), key=half(k), diff --git a/wan/first_last_frame2video.py b/wan/first_last_frame2video.py index 232950f..85bb2c7 100644 --- a/wan/first_last_frame2video.py +++ b/wan/first_last_frame2video.py @@ -12,10 +12,19 @@ from functools import partial import numpy as np import torch import torch.cuda.amp as amp +from torch.cuda import empty_cache, synchronize import torch.distributed as dist import torchvision.transforms.functional as TF from tqdm import tqdm +try: + import torch_musa + import torch_musa.core.amp as amp + from torch_musa.core.memory import empty_cache + from torch_musa.core.device import synchronize +except ModuleNotFoundError: + torch_musa = None + from .distributed.fsdp import shard_model from .modules.clip import CLIPModel from .modules.model import WanModel @@ -27,6 +36,7 @@ from .utils.fm_solvers import ( retrieve_timesteps, ) from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from .utils.platform import get_device class WanFLF2V: @@ -66,7 +76,7 @@ class WanFLF2V: init_on_cpu (`bool`, *optional*, defaults to True): Enable initializing Transformer Model on CPU. Only works without FSDP or USP. """ - self.device = torch.device(f"cuda:{device_id}") + self.device = get_device(device_id) self.config = config self.rank = rank self.use_usp = use_usp @@ -323,7 +333,7 @@ class WanFLF2V: } if offload_model: - torch.cuda.empty_cache() + empty_cache() self.model.to(self.device) for _, t in enumerate(tqdm(timesteps)): @@ -336,12 +346,12 @@ class WanFLF2V: latent_model_input, t=timestep, **arg_c)[0].to( torch.device('cpu') if offload_model else self.device) if offload_model: - torch.cuda.empty_cache() + empty_cache() noise_pred_uncond = self.model( latent_model_input, t=timestep, **arg_null)[0].to( torch.device('cpu') if offload_model else self.device) if offload_model: - torch.cuda.empty_cache() + empty_cache() noise_pred = noise_pred_uncond + guide_scale * ( noise_pred_cond - noise_pred_uncond) @@ -361,7 +371,7 @@ class WanFLF2V: if offload_model: self.model.cpu() - torch.cuda.empty_cache() + empty_cache() if self.rank == 0: videos = self.vae.decode(x0) @@ -370,7 +380,7 @@ class WanFLF2V: del sample_scheduler if offload_model: gc.collect() - torch.cuda.synchronize() + synchronize() if dist.is_initialized(): dist.barrier() diff --git a/wan/image2video.py b/wan/image2video.py index 6882c53..04f4cd0 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -12,10 +12,19 @@ from functools import partial import numpy as np import torch import torch.cuda.amp as amp +from torch.cuda import empty_cache, synchronize import torch.distributed as dist import torchvision.transforms.functional as TF from tqdm import tqdm +try: + import torch_musa + import torch_musa.core.amp as amp + from torch_musa.core.memory import empty_cache + from torch_musa.core.device import synchronize +except ModuleNotFoundError: + torch_musa = None + from .distributed.fsdp import shard_model from .modules.clip import CLIPModel from .modules.model import WanModel @@ -27,6 +36,7 @@ from .utils.fm_solvers import ( retrieve_timesteps, ) from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from .utils.platform import get_device class WanI2V: @@ -66,7 +76,7 @@ class WanI2V: init_on_cpu (`bool`, *optional*, defaults to True): Enable initializing Transformer Model on CPU. Only works without FSDP or USP. """ - self.device = torch.device(f"cuda:{device_id}") + self.device = get_device(device_id) self.config = config self.rank = rank self.use_usp = use_usp @@ -296,7 +306,7 @@ class WanI2V: } if offload_model: - torch.cuda.empty_cache() + empty_cache() self.model.to(self.device) for _, t in enumerate(tqdm(timesteps)): @@ -309,12 +319,12 @@ class WanI2V: latent_model_input, t=timestep, **arg_c)[0].to( torch.device('cpu') if offload_model else self.device) if offload_model: - torch.cuda.empty_cache() + empty_cache() noise_pred_uncond = self.model( latent_model_input, t=timestep, **arg_null)[0].to( torch.device('cpu') if offload_model else self.device) if offload_model: - torch.cuda.empty_cache() + empty_cache() noise_pred = noise_pred_uncond + guide_scale * ( noise_pred_cond - noise_pred_uncond) @@ -334,7 +344,7 @@ class WanI2V: if offload_model: self.model.cpu() - torch.cuda.empty_cache() + empty_cache() if self.rank == 0: videos = self.vae.decode(x0) @@ -343,7 +353,7 @@ class WanI2V: del sample_scheduler if offload_model: gc.collect() - torch.cuda.synchronize() + synchronize() if dist.is_initialized(): dist.barrier() diff --git a/wan/modules/attention.py b/wan/modules/attention.py index 4dbbe03..c8ac269 100644 --- a/wan/modules/attention.py +++ b/wan/modules/attention.py @@ -1,4 +1,6 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import warnings + import torch try: @@ -13,7 +15,13 @@ try: except ModuleNotFoundError: FLASH_ATTN_2_AVAILABLE = False -import warnings +try: + import torch_musa + FLASH_ATTN_3_AVAILABLE = False + FLASH_ATTN_2_AVAILABLE = False +except ModuleNotFoundError: + torch_musa = None + __all__ = [ 'flash_attention', @@ -51,7 +59,7 @@ def flash_attention( """ half_dtypes = (torch.float16, torch.bfloat16) assert dtype in half_dtypes - assert q.device.type == 'cuda' and q.size(-1) <= 256 + assert (q.device.type == "cuda" or q.device.type == "musa") and q.size(-1) <= 256 # params b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype @@ -173,7 +181,7 @@ def attention( v = v.transpose(1, 2).to(dtype) out = torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p) + q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=causal, scale=softmax_scale) out = out.transpose(1, 2).contiguous() return out diff --git a/wan/modules/clip.py b/wan/modules/clip.py index 42dda04..f74d364 100644 --- a/wan/modules/clip.py +++ b/wan/modules/clip.py @@ -6,12 +6,20 @@ import math import torch import torch.nn as nn import torch.nn.functional as F +import torch.cuda.amp as amp import torchvision.transforms as T from .attention import flash_attention from .tokenizers import HuggingfaceTokenizer from .xlm_roberta import XLMRoberta +try: + import torch_musa + import torch_musa.core.amp as amp + from .attention import attention as flash_attention +except ModuleNotFoundError: + torch_musa = None + __all__ = [ 'XLMRobertaCLIP', 'clip_xlm_roberta_vit_h_14', @@ -82,7 +90,10 @@ class SelfAttention(nn.Module): # compute attention p = self.attn_dropout if self.training else 0.0 - x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2) + if torch_musa is not None: + x = flash_attention(q, k, v, dropout_p=p, causal=self.causal) + else: + x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2) x = x.reshape(b, s, c) # output @@ -194,7 +205,10 @@ class AttentionPool(nn.Module): k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2) # compute attention - x = flash_attention(q, k, v, version=2) + if torch_musa is not None: + x = flash_attention(q, k, v) + else: + x = flash_attention(q, k, v, version=2) x = x.reshape(b, 1, c) # output @@ -537,6 +551,6 @@ class CLIPModel: videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5)) # forward - with torch.cuda.amp.autocast(dtype=self.dtype): + with amp.autocast(dtype=self.dtype): out = self.model.visual(videos, use_31_block=True) return out diff --git a/wan/modules/model.py b/wan/modules/model.py index a5425da..60dd532 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -7,7 +7,14 @@ import torch.nn as nn from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin -from .attention import flash_attention +from wan.modules.attention import flash_attention + +try: + import torch_musa + import torch_musa.core.amp as amp + from wan.modules.attention import attention as flash_attention +except ModuleNotFoundError: + torch_musa = None __all__ = ['WanModel'] @@ -19,7 +26,7 @@ def sinusoidal_embedding_1d(dim, position): # preprocess assert dim % 2 == 0 half = dim // 2 - position = position.type(torch.float64) + position = position.type(torch.float32) # calculation sinusoid = torch.outer( @@ -39,6 +46,36 @@ def rope_params(max_seq_len, dim, theta=10000): return freqs +@amp.autocast(enabled=False) +def rope_params_real( + max_seq_len, dim, theta=10000, dtype=torch.float32, device=torch.device("cpu") +): + assert dim % 2 == 0 + freqs_real = torch.outer( + torch.arange(max_seq_len, dtype=dtype, device=device), + 1.0 + / torch.pow( + theta, torch.arange(0, dim, 2, dtype=dtype, device=device).div(dim) + ), + ) + return torch.cos(freqs_real) + + +@amp.autocast(enabled=False) +def rope_params_imag( + max_seq_len, dim, theta=10000, dtype=torch.float32, device=torch.device("cpu") +): + assert dim % 2 == 0 + freqs_imag = torch.outer( + torch.arange(max_seq_len, dtype=dtype, device=device), + 1.0 + / torch.pow( + theta, torch.arange(0, dim, 2, dtype=dtype, device=device).div(dim) + ), + ) + return torch.sin(freqs_imag) + + @amp.autocast(enabled=False) def rope_apply(x, grid_sizes, freqs): n, c = x.size(2), x.size(3) // 2 @@ -70,6 +107,55 @@ def rope_apply(x, grid_sizes, freqs): return torch.stack(output).float() +@amp.autocast(enabled=False) +def rope_apply_musa(x, grid_sizes, freqs): + n, c = x.size(2), x.size(3) // 2 + c0 = c - 2 * (c // 3) + c1 = c // 3 + c2 = c // 3 + + # split freqs + freqs_real = freqs[0].split([c0, c1, c2], dim=1) + freqs_imag = freqs[-1].split([c0, c1, c2], dim=1) + + # loop over samples + output = [] + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + seq_len = f * h * w + + # precompute multipliers + x_i = x[i, :seq_len].reshape(seq_len, n, c, 2) + x_real = x_i[..., 0] + x_imag = x_i[..., 1] + freqs_real = torch.cat( + [ + freqs_real[0][:f].view(f, 1, 1, c0).expand(f, h, w, c0), + freqs_real[1][:h].view(1, h, 1, c1).expand(f, h, w, c1), + freqs_real[2][:w].view(1, 1, w, c2).expand(f, h, w, c2), + ], + dim=-1, + ).reshape(seq_len, 1, c) + freqs_imag = torch.cat( + [ + freqs_imag[0][:f].view(f, 1, 1, c0).expand(f, h, w, c0), + freqs_imag[1][:h].view(1, h, 1, c1).expand(f, h, w, c1), + freqs_imag[2][:w].view(1, 1, w, c2).expand(f, h, w, c2), + ], + dim=-1, + ).reshape(seq_len, 1, c) + + out_real = x_real * freqs_real - x_imag * freqs_imag + out_imag = x_real * freqs_imag + x_imag * freqs_real + + # apply rotary embedding + x_out = torch.stack([out_real, out_imag], dim=-1).flatten(2) + x_out = torch.cat([x_out, x[i, seq_len:]], dim=0) + + # append to collection + output.append(x_out) + return torch.stack(output) + + class WanRMSNorm(nn.Module): def __init__(self, dim, eps=1e-5): @@ -146,12 +232,22 @@ class WanSelfAttention(nn.Module): q, k, v = qkv_fn(x) - x = flash_attention( - q=rope_apply(q, grid_sizes, freqs), - k=rope_apply(k, grid_sizes, freqs), - v=v, - k_lens=seq_lens, - window_size=self.window_size) + if torch_musa is not None: + x = flash_attention( + q=rope_apply_musa(q, grid_sizes, freqs), + k=rope_apply_musa(k, grid_sizes, freqs), + v=v, + k_lens=seq_lens, + window_size=self.window_size, + ) + else: + x = flash_attention( + q=rope_apply(q, grid_sizes, freqs), + k=rope_apply(k, grid_sizes, freqs), + v=v, + k_lens=seq_lens, + window_size=self.window_size, + ) # output x = x.flatten(2) @@ -477,12 +573,33 @@ class WanModel(ModelMixin, ConfigMixin): # buffers (don't use register_buffer otherwise dtype will be changed in to()) assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 d = dim // num_heads - self.freqs = torch.cat([ - rope_params(1024, d - 4 * (d // 6)), - rope_params(1024, 2 * (d // 6)), - rope_params(1024, 2 * (d // 6)) - ], - dim=1) + if torch_musa is not None: + freqs_real = torch.cat( + [ + rope_params_real(1024, d - 4 * (d // 6)), + rope_params_real(1024, 2 * (d // 6)), + rope_params_real(1024, 2 * (d // 6)), + ], + dim=1, + ) + freqs_imag = torch.cat( + [ + rope_params_imag(1024, d - 4 * (d // 6)), + rope_params_imag(1024, 2 * (d // 6)), + rope_params_imag(1024, 2 * (d // 6)), + ], + dim=1, + ) + self.freqs = (freqs_real, freqs_imag) + else: + self.freqs = torch.cat( + [ + rope_params(1024, d - 4 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + ], + dim=1, + ) if model_type == 'i2v' or model_type == 'flf2v': self.img_emb = MLPProj(1280, dim, flf_pos_emb=model_type == 'flf2v') @@ -523,9 +640,17 @@ class WanModel(ModelMixin, ConfigMixin): if self.model_type == 'i2v' or self.model_type == 'flf2v': assert clip_fea is not None and y is not None # params + dtype = self.patch_embedding.weight.dtype device = self.patch_embedding.weight.device - if self.freqs.device != device: - self.freqs = self.freqs.to(device) + if torch_musa is not None: + if self.freqs[0].dtype != dtype or self.freqs[0].device != device: + self.freqs = ( + self.freqs[0].to(dtype=dtype, device=device), + self.freqs[-1].to(dtype=dtype, device=device) + ) + else: + if self.freqs.dtype != dtype or self.freqs.device != device: + self.freqs = self.freqs.to(dtype=dtype, device=device) if y is not None: x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] diff --git a/wan/modules/t5.py b/wan/modules/t5.py index c841b04..6f8ef44 100644 --- a/wan/modules/t5.py +++ b/wan/modules/t5.py @@ -6,6 +6,13 @@ import math import torch import torch.nn as nn import torch.nn.functional as F +from torch.cuda import current_device + +try: + import torch_musa + from torch_musa.core.device import current_device +except ModuleNotFoundError: + torch_musa = None from .tokenizers import HuggingfaceTokenizer @@ -475,7 +482,7 @@ class T5EncoderModel: self, text_len, dtype=torch.bfloat16, - device=torch.cuda.current_device(), + device=current_device(), checkpoint_path=None, tokenizer_path=None, shard_fn=None, diff --git a/wan/modules/vace_model.py b/wan/modules/vace_model.py index a12d1dd..8913c2e 100644 --- a/wan/modules/vace_model.py +++ b/wan/modules/vace_model.py @@ -4,6 +4,12 @@ import torch.cuda.amp as amp import torch.nn as nn from diffusers.configuration_utils import register_to_config +try: + import torch_musa + import torch_musa.core.amp as amp +except ModuleNotFoundError: + torch_musa = None + from .model import WanAttentionBlock, WanModel, sinusoidal_embedding_1d diff --git a/wan/modules/vae.py b/wan/modules/vae.py index 5c6da57..a6b2576 100644 --- a/wan/modules/vae.py +++ b/wan/modules/vae.py @@ -5,8 +5,17 @@ import torch import torch.cuda.amp as amp import torch.nn as nn import torch.nn.functional as F +from torch.nn import Upsample from einops import rearrange +try: + import torch_musa + import torch_musa.core.amp as amp +except ModuleNotFoundError: + torch_musa = None + +from wan.utils.platform import get_device + __all__ = [ 'WanVAE', ] @@ -622,7 +631,7 @@ class WanVAE: z_dim=16, vae_pth='cache/vae_step_411000.pth', dtype=torch.float, - device="cuda"): + device=get_device()): self.dtype = dtype self.device = device diff --git a/wan/text2video.py b/wan/text2video.py index c518b61..e1f79a0 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -11,9 +11,18 @@ from functools import partial import torch import torch.cuda.amp as amp +from torch.cuda import empty_cache, synchronize import torch.distributed as dist from tqdm import tqdm +try: + import torch_musa + import torch_musa.core.amp as amp + from torch_musa.core.memory import empty_cache + from torch_musa.core.device import synchronize +except ModuleNotFoundError: + torch_musa = None + from .distributed.fsdp import shard_model from .modules.model import WanModel from .modules.t5 import T5EncoderModel @@ -24,7 +33,7 @@ from .utils.fm_solvers import ( retrieve_timesteps, ) from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler - +from .utils.platform import get_device class WanT2V: @@ -60,7 +69,7 @@ class WanT2V: t5_cpu (`bool`, *optional*, defaults to False): Whether to place T5 model on CPU. Only works without t5_fsdp. """ - self.device = torch.device(f"cuda:{device_id}") + self.device = get_device(device_id) self.config = config self.rank = rank self.t5_cpu = t5_cpu @@ -256,7 +265,7 @@ class WanT2V: x0 = latents if offload_model: self.model.cpu() - torch.cuda.empty_cache() + empty_cache() if self.rank == 0: videos = self.vae.decode(x0) @@ -264,7 +273,7 @@ class WanT2V: del sample_scheduler if offload_model: gc.collect() - torch.cuda.synchronize() + synchronize() if dist.is_initialized(): dist.barrier() diff --git a/wan/utils/__init__.py b/wan/utils/__init__.py index 2e9b33d..02d6687 100644 --- a/wan/utils/__init__.py +++ b/wan/utils/__init__.py @@ -5,9 +5,10 @@ from .fm_solvers import ( ) from .fm_solvers_unipc import FlowUniPCMultistepScheduler from .vace_processor import VaceVideoProcessor +from .platform import get_device, get_torch_distributed_backend __all__ = [ 'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps', 'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler', - 'VaceVideoProcessor' + 'VaceVideoProcessor', 'get_device', 'get_torch_distributed_backend' ] diff --git a/wan/utils/platform.py b/wan/utils/platform.py new file mode 100644 index 0000000..45d004b --- /dev/null +++ b/wan/utils/platform.py @@ -0,0 +1,34 @@ +from typing import Optional + +import torch + +try: + import torch_musa +except ModuleNotFoundError: + torch_musa = None + + +def _is_musa(): + try: + if torch.musa.is_available(): + return True + except ModuleNotFoundError: + return False + + +def get_device(local_rank:Optional[int]=None) -> torch.device: + if torch.cuda.is_available(): + return torch.cuda.current_device() if local_rank is None else torch.device("cuda", local_rank) + elif _is_musa(): + return torch.musa.current_device() if local_rank is None else torch.device("musa", local_rank) + else: + return torch.device("cpu") + + +def get_torch_distributed_backend() -> str: + if torch.cuda.is_available(): + return "nccl" + elif _is_musa(): + return "mccl" + else: + raise NotImplementedError("No Accelerators(NV/MTT GPU) available") diff --git a/wan/vace.py b/wan/vace.py index 8a4f744..3a07cc5 100644 --- a/wan/vace.py +++ b/wan/vace.py @@ -13,6 +13,7 @@ from functools import partial import torch import torch.cuda.amp as amp +from torch.cuda import empty_cache, synchronize import torch.distributed as dist import torch.multiprocessing as mp import torch.nn.functional as F @@ -20,6 +21,14 @@ import torchvision.transforms.functional as TF from PIL import Image from tqdm import tqdm +try: + import torch_musa + import torch_musa.core.amp as amp + from torch_musa.core.memory import empty_cache + from torch_musa.core.device import synchronize +except ModuleNotFoundError: + torch_musa = None + from .modules.vace_model import VaceWanModel from .text2video import ( FlowDPMSolverMultistepScheduler, @@ -32,6 +41,7 @@ from .text2video import ( shard_model, ) from .utils.vace_processor import VaceVideoProcessor +from .utils.platform import get_device, get_torch_distributed_backend class WanVace(WanT2V): @@ -68,7 +78,7 @@ class WanVace(WanT2V): t5_cpu (`bool`, *optional*, defaults to False): Whether to place T5 model on CPU. Only works without t5_fsdp. """ - self.device = torch.device(f"cuda:{device_id}") + self.device = get_device(device_id) self.config = config self.rank = rank self.t5_cpu = t5_cpu @@ -460,7 +470,7 @@ class WanVace(WanT2V): x0 = latents if offload_model: self.model.cpu() - torch.cuda.empty_cache() + empty_cache() if self.rank == 0: videos = self.decode_latent(x0, input_ref_images) @@ -468,7 +478,7 @@ class WanVace(WanT2V): del sample_scheduler if offload_model: gc.collect() - torch.cuda.synchronize() + synchronize() if dist.is_initialized(): dist.barrier() @@ -568,7 +578,7 @@ class WanVaceMP(WanVace): torch.cuda.set_device(gpu) dist.init_process_group( - backend='nccl', + backend=get_torch_distributed_backend(), init_method='env://', rank=rank, world_size=world_size) @@ -633,7 +643,7 @@ class WanVaceMP(WanVace): model = shard_fn(model) sample_neg_prompt = self.config.sample_neg_prompt - torch.cuda.empty_cache() + empty_cache() event = initialized_events[gpu] in_q = in_q_list[gpu] event.set() @@ -748,7 +758,7 @@ class WanVaceMP(WanVace): generator=seed_g)[0] latents = [temp_x0.squeeze(0)] - torch.cuda.empty_cache() + empty_cache() x0 = latents if rank == 0: videos = self.decode_latent( @@ -758,7 +768,7 @@ class WanVaceMP(WanVace): del sample_scheduler if offload_model: gc.collect() - torch.cuda.synchronize() + synchronize() if dist.is_initialized(): dist.barrier()