Compare commits

...

2 Commits

Author SHA1 Message Date
Houchen Li
618d94c564
Merge 029e421891 into 7c81b2f27d 2025-07-28 17:23:28 +08:00
Houchen Li
029e421891 [feature] adapt for Moore Threads graphics processing unit 2025-07-28 17:10:46 +08:00
15 changed files with 399 additions and 181 deletions

View File

@ -4,6 +4,7 @@ import logging
import os import os
import sys import sys
import warnings import warnings
from time import perf_counter
from datetime import datetime from datetime import datetime
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
@ -12,12 +13,20 @@ import random
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.cuda import set_device
from PIL import Image from PIL import Image
try:
import torch_musa
from torch_musa.core.device import set_device
except ModuleNotFoundError:
pass
import wan import wan
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_image, cache_video, str2bool from wan.utils.utils import cache_image, cache_video, str2bool
from wan.utils.platform import get_torch_distributed_backend
EXAMPLE_PROMPT = { EXAMPLE_PROMPT = {
@ -275,9 +284,9 @@ def generate(args):
logging.info( logging.info(
f"offload_model is not specified, set to {args.offload_model}.") f"offload_model is not specified, set to {args.offload_model}.")
if world_size > 1: if world_size > 1:
torch.cuda.set_device(local_rank) set_device(local_rank)
dist.init_process_group( dist.init_process_group(
backend="nccl", backend=get_torch_distributed_backend(),
init_method="env://", init_method="env://",
rank=rank, rank=rank,
world_size=world_size) world_size=world_size)
@ -357,6 +366,7 @@ def generate(args):
logging.info(f"Extended prompt: {args.prompt}") logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating WanT2V pipeline.") logging.info("Creating WanT2V pipeline.")
start_time = perf_counter()
wan_t2v = wan.WanT2V( wan_t2v = wan.WanT2V(
config=cfg, config=cfg,
checkpoint_dir=args.ckpt_dir, checkpoint_dir=args.ckpt_dir,
@ -367,6 +377,8 @@ def generate(args):
use_usp=(args.ulysses_size > 1 or args.ring_size > 1), use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu, t5_cpu=args.t5_cpu,
) )
end_time = perf_counter()
logging.info(f"Creating WanT2V pipeline took {end_time - start_time:.2f} seconds.")
logging.info( logging.info(
f"Generating {'image' if 't2i' in args.task else 'video'} ...") f"Generating {'image' if 't2i' in args.task else 'video'} ...")
@ -380,7 +392,6 @@ def generate(args):
guide_scale=args.sample_guide_scale, guide_scale=args.sample_guide_scale,
seed=args.base_seed, seed=args.base_seed,
offload_model=args.offload_model) offload_model=args.offload_model)
elif "i2v" in args.task: elif "i2v" in args.task:
if args.prompt is None: if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
@ -414,6 +425,7 @@ def generate(args):
logging.info(f"Extended prompt: {args.prompt}") logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating WanI2V pipeline.") logging.info("Creating WanI2V pipeline.")
start_time = perf_counter()
wan_i2v = wan.WanI2V( wan_i2v = wan.WanI2V(
config=cfg, config=cfg,
checkpoint_dir=args.ckpt_dir, checkpoint_dir=args.ckpt_dir,
@ -424,6 +436,8 @@ def generate(args):
use_usp=(args.ulysses_size > 1 or args.ring_size > 1), use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu, t5_cpu=args.t5_cpu,
) )
end_time = perf_counter()
logging.info(f"Creating WanI2V pipeline took {end_time - start_time:.2f} seconds.")
logging.info("Generating video ...") logging.info("Generating video ...")
video = wan_i2v.generate( video = wan_i2v.generate(
@ -572,6 +586,7 @@ def generate(args):
value_range=(-1, 1)) value_range=(-1, 1))
else: else:
logging.info(f"Saving generated video to {args.save_file}") logging.info(f"Saving generated video to {args.save_file}")
start_time = perf_counter()
cache_video( cache_video(
tensor=video[None], tensor=video[None],
save_file=args.save_file, save_file=args.save_file,
@ -579,6 +594,8 @@ def generate(args):
nrow=1, nrow=1,
normalize=True, normalize=True,
value_range=(-1, 1)) value_range=(-1, 1))
end_time = perf_counter()
logging.info(f"Saving Video took {end_time - start_time:.2f} seconds")
logging.info("Finished.") logging.info("Finished.")

View File

@ -3,11 +3,17 @@ import gc
from functools import partial from functools import partial
import torch import torch
from torch.cuda import empty_cache
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
from torch.distributed.utils import _free_storage from torch.distributed.utils import _free_storage
try:
import torch_musa
from torch_musa.core.memory import empty_cache
except ModuleNotFoundError:
pass
def shard_model( def shard_model(
model, model,
@ -40,4 +46,4 @@ def free_model(model):
_free_storage(m._handle.flat_param.data) _free_storage(m._handle.flat_param.data)
del model del model
gc.collect() gc.collect()
torch.cuda.empty_cache() empty_cache()

View File

@ -6,20 +6,28 @@ from xfuser.core.distributed import (
get_sequence_parallel_world_size, get_sequence_parallel_world_size,
get_sp_group, get_sp_group,
) )
from xfuser.core.long_ctx_attention import xFuserLongContextAttention from xfuser.core.long_ctx_attention import xFuserLongContextAttention, AttnType
attn_type:AttnType = AttnType.FA
try:
import torch_musa
import torch_musa.core.amp as amp
attn_type = AttnType.TORCH
except ImportError:
pass
from ..modules.model import sinusoidal_embedding_1d from ..modules.model import sinusoidal_embedding_1d
def pad_freqs(original_tensor, target_len): def pad_tensor(original_tensor, target_len, pad_value=0.0):
seq_len, s1, s2 = original_tensor.shape seq_len, s1, s2 = original_tensor.shape
pad_size = target_len - seq_len pad_size = target_len - seq_len
padding_tensor = torch.ones( padding_tensor = torch.full(
pad_size, (pad_size, s1, s2),
s1, pad_value,
s2,
dtype=original_tensor.dtype, dtype=original_tensor.dtype,
device=original_tensor.device) device=original_tensor.device,
)
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
return padded_tensor return padded_tensor
@ -32,8 +40,13 @@ def rope_apply(x, grid_sizes, freqs):
freqs: [M, C // 2]. freqs: [M, C // 2].
""" """
s, n, c = x.size(1), x.size(2), x.size(3) // 2 s, n, c = x.size(1), x.size(2), x.size(3) // 2
c0 = c - 2 * (c // 3)
c1 = c // 3
c2 = c // 3
# split freqs # split freqs
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) freqs_real = freqs[0].split([c0, c1, c2], dim=1)
freqs_imag = freqs[-1].split([c0, c1, c2], dim=1)
# loop over samples # loop over samples
output = [] output = []
@ -41,28 +54,45 @@ def rope_apply(x, grid_sizes, freqs):
seq_len = f * h * w seq_len = f * h * w
# precompute multipliers # precompute multipliers
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape( x_i = x[i, :seq_len].reshape(s, n, -1, 2)
s, n, -1, 2)) x_real = x_i[..., 0]
freqs_i = torch.cat([ x_imag = x_i[..., 1]
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), freqs_real = torch.cat(
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), [
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) freqs_real[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
], freqs_real[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
dim=-1).reshape(seq_len, 1, -1) freqs_real[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
).reshape(seq_len, 1, -1)
freqs_imag = torch.cat(
[
freqs_imag[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs_imag[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs_imag[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
).reshape(seq_len, 1, -1)
# apply rotary embedding # apply rotary embedding
sp_size = get_sequence_parallel_world_size() sp_size = get_sequence_parallel_world_size()
sp_rank = get_sequence_parallel_rank() sp_rank = get_sequence_parallel_rank()
freqs_i = pad_freqs(freqs_i, s * sp_size)
s_per_rank = s freqs_real = pad_tensor(freqs_real, s * sp_size, 1.0)
freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) * freqs_imag = pad_tensor(freqs_imag, s * sp_size, 0.0)
s_per_rank), :, :]
x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2) freqs_real_rank = freqs_real[(sp_rank * s) : ((sp_rank + 1) * s), :, :]
x_i = torch.cat([x_i, x[i, s:]]) freqs_imag_rank = freqs_imag[(sp_rank * s) : ((sp_rank + 1) * s), :, :]
out_real = x_real * freqs_real_rank - x_imag * freqs_imag_rank
out_imag = x_real * freqs_imag_rank + x_imag * freqs_real_rank
x_out = torch.stack([out_real, out_imag], dim=-1).flatten(2)
x_out = torch.cat([x_out, x[i, seq_len:]], dim=0)
# append to collection # append to collection
output.append(x_i) output.append(x_out)
return torch.stack(output).float() return torch.stack(output)
def usp_dit_forward_vace(self, x, vace_context, seq_len, kwargs): def usp_dit_forward_vace(self, x, vace_context, seq_len, kwargs):
@ -109,9 +139,13 @@ def usp_dit_forward(
if self.model_type == 'i2v': if self.model_type == 'i2v':
assert clip_fea is not None and y is not None assert clip_fea is not None and y is not None
# params # params
dtype = self.patch_embedding.weight.dtype
device = self.patch_embedding.weight.device device = self.patch_embedding.weight.device
if self.freqs.device != device: if self.freqs[0].dtype != dtype or self.freqs[0].device != device:
self.freqs = self.freqs.to(device) self.freqs = (
self.freqs[0].to(dtype=dtype, device=device),
self.freqs[-1].to(dtype=dtype, device=device)
)
if self.model_type != 'vace' and y is not None: if self.model_type != 'vace' and y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
@ -129,11 +163,9 @@ def usp_dit_forward(
]) ])
# time embeddings # time embeddings
with amp.autocast(dtype=torch.float32): e = self.time_embedding(
e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t))
sinusoidal_embedding_1d(self.freq_dim, t).float()) e0 = self.time_projection(e).unflatten(1, (6, self.dim))
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32
# context # context
context_lens = None context_lens = None
@ -177,7 +209,7 @@ def usp_dit_forward(
# unpatchify # unpatchify
x = self.unpatchify(x, grid_sizes) x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x] return x
def usp_attn_forward(self, def usp_attn_forward(self,
@ -210,7 +242,7 @@ def usp_attn_forward(self,
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0) # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0) # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
x = xFuserLongContextAttention()( x = xFuserLongContextAttention(attn_type=attn_type)(
None, None,
query=half(q), query=half(q),
key=half(k), key=half(k),

View File

@ -12,10 +12,19 @@ from functools import partial
import numpy as np import numpy as np
import torch import torch
import torch.cuda.amp as amp import torch.cuda.amp as amp
from torch.cuda import empty_cache, synchronize
import torch.distributed as dist import torch.distributed as dist
import torchvision.transforms.functional as TF import torchvision.transforms.functional as TF
from tqdm import tqdm from tqdm import tqdm
try:
import torch_musa
import torch_musa.core.amp as amp
from torch_musa.core.memory import empty_cache
from torch_musa.core.device import synchronize
except ModuleNotFoundError:
pass
from .distributed.fsdp import shard_model from .distributed.fsdp import shard_model
from .modules.clip import CLIPModel from .modules.clip import CLIPModel
from .modules.model import WanModel from .modules.model import WanModel
@ -27,6 +36,7 @@ from .utils.fm_solvers import (
retrieve_timesteps, retrieve_timesteps,
) )
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from .utils.platform import get_device
class WanFLF2V: class WanFLF2V:
@ -66,7 +76,7 @@ class WanFLF2V:
init_on_cpu (`bool`, *optional*, defaults to True): init_on_cpu (`bool`, *optional*, defaults to True):
Enable initializing Transformer Model on CPU. Only works without FSDP or USP. Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
""" """
self.device = torch.device(f"cuda:{device_id}") self.device = get_device(device_id)
self.config = config self.config = config
self.rank = rank self.rank = rank
self.use_usp = use_usp self.use_usp = use_usp
@ -323,7 +333,7 @@ class WanFLF2V:
} }
if offload_model: if offload_model:
torch.cuda.empty_cache() empty_cache()
self.model.to(self.device) self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)): for _, t in enumerate(tqdm(timesteps)):
@ -336,12 +346,12 @@ class WanFLF2V:
latent_model_input, t=timestep, **arg_c)[0].to( latent_model_input, t=timestep, **arg_c)[0].to(
torch.device('cpu') if offload_model else self.device) torch.device('cpu') if offload_model else self.device)
if offload_model: if offload_model:
torch.cuda.empty_cache() empty_cache()
noise_pred_uncond = self.model( noise_pred_uncond = self.model(
latent_model_input, t=timestep, **arg_null)[0].to( latent_model_input, t=timestep, **arg_null)[0].to(
torch.device('cpu') if offload_model else self.device) torch.device('cpu') if offload_model else self.device)
if offload_model: if offload_model:
torch.cuda.empty_cache() empty_cache()
noise_pred = noise_pred_uncond + guide_scale * ( noise_pred = noise_pred_uncond + guide_scale * (
noise_pred_cond - noise_pred_uncond) noise_pred_cond - noise_pred_uncond)
@ -361,7 +371,7 @@ class WanFLF2V:
if offload_model: if offload_model:
self.model.cpu() self.model.cpu()
torch.cuda.empty_cache() empty_cache()
if self.rank == 0: if self.rank == 0:
videos = self.vae.decode(x0) videos = self.vae.decode(x0)
@ -370,7 +380,7 @@ class WanFLF2V:
del sample_scheduler del sample_scheduler
if offload_model: if offload_model:
gc.collect() gc.collect()
torch.cuda.synchronize() synchronize()
if dist.is_initialized(): if dist.is_initialized():
dist.barrier() dist.barrier()

View File

@ -6,16 +6,26 @@ import os
import random import random
import sys import sys
import types import types
from time import perf_counter
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
import numpy as np import numpy as np
import torch import torch
import torch.cuda.amp as amp import torch.cuda.amp as amp
from torch.cuda import empty_cache, synchronize
import torch.distributed as dist import torch.distributed as dist
import torchvision.transforms.functional as TF import torchvision.transforms.functional as TF
from tqdm import tqdm from tqdm import tqdm
try:
import torch_musa
import torch_musa.core.amp as amp
from torch_musa.core.memory import empty_cache
from torch_musa.core.device import synchronize
except ModuleNotFoundError:
pass
from .distributed.fsdp import shard_model from .distributed.fsdp import shard_model
from .modules.clip import CLIPModel from .modules.clip import CLIPModel
from .modules.model import WanModel from .modules.model import WanModel
@ -27,6 +37,7 @@ from .utils.fm_solvers import (
retrieve_timesteps, retrieve_timesteps,
) )
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from .utils.platform import get_device
class WanI2V: class WanI2V:
@ -66,7 +77,7 @@ class WanI2V:
init_on_cpu (`bool`, *optional*, defaults to True): init_on_cpu (`bool`, *optional*, defaults to True):
Enable initializing Transformer Model on CPU. Only works without FSDP or USP. Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
""" """
self.device = torch.device(f"cuda:{device_id}") self.device = get_device(device_id)
self.config = config self.config = config
self.rank = rank self.rank = rank
self.use_usp = use_usp self.use_usp = use_usp
@ -220,6 +231,7 @@ class WanI2V:
n_prompt = self.sample_neg_prompt n_prompt = self.sample_neg_prompt
# preprocess # preprocess
start_time = perf_counter()
if not self.t5_cpu: if not self.t5_cpu:
self.text_encoder.model.to(self.device) self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device) context = self.text_encoder([input_prompt], self.device)
@ -231,12 +243,18 @@ class WanI2V:
context_null = self.text_encoder([n_prompt], torch.device('cpu')) context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context] context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null] context_null = [t.to(self.device) for t in context_null]
end_time = perf_counter()
logging.info(f"T5 Encoding took {end_time - start_time:.2f} seconds.")
start_time = perf_counter()
self.clip.model.to(self.device) self.clip.model.to(self.device)
clip_context = self.clip.visual([img[:, None, :, :]]) clip_context = self.clip.visual([img[:, None, :, :]])
if offload_model: if offload_model:
self.clip.model.cpu() self.clip.model.cpu()
end_time = perf_counter()
logging.info(f"CLIP took {end_time - start_time:.2f} seconds.")
start_time = perf_counter()
y = self.vae.encode([ y = self.vae.encode([
torch.concat([ torch.concat([
torch.nn.functional.interpolate( torch.nn.functional.interpolate(
@ -246,6 +264,9 @@ class WanI2V:
], ],
dim=1).to(self.device) dim=1).to(self.device)
])[0] ])[0]
end_time = perf_counter()
logging.info(f"VAE Encoding took {end_time - start_time:.2f} seconds.")
y = torch.concat([msk, y]) y = torch.concat([msk, y])
@contextmanager @contextmanager
@ -296,8 +317,9 @@ class WanI2V:
} }
if offload_model: if offload_model:
torch.cuda.empty_cache() empty_cache()
start_time = perf_counter()
self.model.to(self.device) self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)): for _, t in enumerate(tqdm(timesteps)):
latent_model_input = [latent.to(self.device)] latent_model_input = [latent.to(self.device)]
@ -309,12 +331,12 @@ class WanI2V:
latent_model_input, t=timestep, **arg_c)[0].to( latent_model_input, t=timestep, **arg_c)[0].to(
torch.device('cpu') if offload_model else self.device) torch.device('cpu') if offload_model else self.device)
if offload_model: if offload_model:
torch.cuda.empty_cache() empty_cache()
noise_pred_uncond = self.model( noise_pred_uncond = self.model(
latent_model_input, t=timestep, **arg_null)[0].to( latent_model_input, t=timestep, **arg_null)[0].to(
torch.device('cpu') if offload_model else self.device) torch.device('cpu') if offload_model else self.device)
if offload_model: if offload_model:
torch.cuda.empty_cache() empty_cache()
noise_pred = noise_pred_uncond + guide_scale * ( noise_pred = noise_pred_uncond + guide_scale * (
noise_pred_cond - noise_pred_uncond) noise_pred_cond - noise_pred_uncond)
@ -331,19 +353,24 @@ class WanI2V:
x0 = [latent.to(self.device)] x0 = [latent.to(self.device)]
del latent_model_input, timestep del latent_model_input, timestep
end_time = perf_counter()
logging.info(f"Sampling took {end_time - start_time:.2f} seconds.")
if offload_model: if offload_model:
self.model.cpu() self.model.cpu()
torch.cuda.empty_cache() empty_cache()
if self.rank == 0: if self.rank == 0:
start_time = perf_counter()
videos = self.vae.decode(x0) videos = self.vae.decode(x0)
end_time = perf_counter()
logging.info(f"VAE Decoding took {end_time - start_time:.2f} seconds.")
del noise, latent del noise, latent
del sample_scheduler del sample_scheduler
if offload_model: if offload_model:
gc.collect() gc.collect()
torch.cuda.synchronize() synchronize()
if dist.is_initialized(): if dist.is_initialized():
dist.barrier() dist.barrier()

View File

@ -1,4 +1,6 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import warnings
import torch import torch
try: try:
@ -13,7 +15,13 @@ try:
except ModuleNotFoundError: except ModuleNotFoundError:
FLASH_ATTN_2_AVAILABLE = False FLASH_ATTN_2_AVAILABLE = False
import warnings try:
import torch_musa
FLASH_ATTN_3_AVAILABLE = False
FLASH_ATTN_2_AVAILABLE = False
except ModuleNotFoundError:
pass
__all__ = [ __all__ = [
'flash_attention', 'flash_attention',
@ -51,7 +59,7 @@ def flash_attention(
""" """
half_dtypes = (torch.float16, torch.bfloat16) half_dtypes = (torch.float16, torch.bfloat16)
assert dtype in half_dtypes assert dtype in half_dtypes
assert q.device.type == 'cuda' and q.size(-1) <= 256 assert (q.device.type == "cuda" or q.device.type == "musa") and q.size(-1) <= 256
# params # params
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
@ -172,8 +180,9 @@ def attention(
k = k.transpose(1, 2).to(dtype) k = k.transpose(1, 2).to(dtype)
v = v.transpose(1, 2).to(dtype) v = v.transpose(1, 2).to(dtype)
out = torch.nn.functional.scaled_dot_product_attention( with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False):
q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p) out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=causal, scale=softmax_scale)
out = out.transpose(1, 2).contiguous() out = out.transpose(1, 2).contiguous()
return out return out

View File

@ -6,12 +6,20 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.cuda.amp as amp
import torchvision.transforms as T import torchvision.transforms as T
from .attention import flash_attention from .attention import flash_attention
from .tokenizers import HuggingfaceTokenizer from .tokenizers import HuggingfaceTokenizer
from .xlm_roberta import XLMRoberta from .xlm_roberta import XLMRoberta
try:
import torch_musa
import torch_musa.core.amp as amp
from .attention import attention as flash_attention
except ModuleNotFoundError:
pass
__all__ = [ __all__ = [
'XLMRobertaCLIP', 'XLMRobertaCLIP',
'clip_xlm_roberta_vit_h_14', 'clip_xlm_roberta_vit_h_14',
@ -29,7 +37,7 @@ def pos_interpolate(pos, seq_len):
return torch.cat([ return torch.cat([
pos[:, :n], pos[:, :n],
F.interpolate( F.interpolate(
pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute( pos[:, n:].reshape(1, src_grid, src_grid, -1).permute(
0, 3, 1, 2), 0, 3, 1, 2),
size=(tar_grid, tar_grid), size=(tar_grid, tar_grid),
mode='bicubic', mode='bicubic',
@ -44,12 +52,6 @@ class QuickGELU(nn.Module):
return x * torch.sigmoid(1.702 * x) return x * torch.sigmoid(1.702 * x)
class LayerNorm(nn.LayerNorm):
def forward(self, x):
return super().forward(x.float()).type_as(x)
class SelfAttention(nn.Module): class SelfAttention(nn.Module):
def __init__(self, def __init__(self,
@ -82,7 +84,7 @@ class SelfAttention(nn.Module):
# compute attention # compute attention
p = self.attn_dropout if self.training else 0.0 p = self.attn_dropout if self.training else 0.0
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2) x = flash_attention(q, k, v, dropout_p=p, causal=self.causal)
x = x.reshape(b, s, c) x = x.reshape(b, s, c)
# output # output
@ -131,10 +133,10 @@ class AttentionBlock(nn.Module):
self.norm_eps = norm_eps self.norm_eps = norm_eps
# layers # layers
self.norm1 = LayerNorm(dim, eps=norm_eps) self.norm1 = nn.LayerNorm(dim, eps=norm_eps)
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
proj_dropout) proj_dropout)
self.norm2 = LayerNorm(dim, eps=norm_eps) self.norm2 = nn.LayerNorm(dim, eps=norm_eps)
if activation == 'swi_glu': if activation == 'swi_glu':
self.mlp = SwiGLU(dim, int(dim * mlp_ratio)) self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
else: else:
@ -177,7 +179,7 @@ class AttentionPool(nn.Module):
self.to_q = nn.Linear(dim, dim) self.to_q = nn.Linear(dim, dim)
self.to_kv = nn.Linear(dim, dim * 2) self.to_kv = nn.Linear(dim, dim * 2)
self.proj = nn.Linear(dim, dim) self.proj = nn.Linear(dim, dim)
self.norm = LayerNorm(dim, eps=norm_eps) self.norm = nn.LayerNorm(dim, eps=norm_eps)
self.mlp = nn.Sequential( self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)), nn.Linear(dim, int(dim * mlp_ratio)),
QuickGELU() if activation == 'quick_gelu' else nn.GELU(), QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
@ -259,13 +261,13 @@ class VisionTransformer(nn.Module):
self.dropout = nn.Dropout(embedding_dropout) self.dropout = nn.Dropout(embedding_dropout)
# transformer # transformer
self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None self.pre_norm = nn.LayerNorm(dim, eps=norm_eps) if pre_norm else None
self.transformer = nn.Sequential(*[ self.transformer = nn.Sequential(*[
AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False, AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
activation, attn_dropout, proj_dropout, norm_eps) activation, attn_dropout, proj_dropout, norm_eps)
for _ in range(num_layers) for _ in range(num_layers)
]) ])
self.post_norm = LayerNorm(dim, eps=norm_eps) self.post_norm = nn.LayerNorm(dim, eps=norm_eps)
# head # head
if pool_type == 'token': if pool_type == 'token':
@ -537,6 +539,6 @@ class CLIPModel:
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5)) videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
# forward # forward
with torch.cuda.amp.autocast(dtype=self.dtype): with amp.autocast(dtype=self.dtype):
out = self.model.visual(videos, use_31_block=True) out = self.model.visual(videos, use_31_block=True)
return out return out

View File

@ -7,7 +7,15 @@ import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin from diffusers.models.modeling_utils import ModelMixin
from .attention import flash_attention from wan.utils.platform import get_device
from wan.modules.attention import flash_attention
try:
import torch_musa
import torch_musa.core.amp as amp
from wan.modules.attention import attention as flash_attention
except ModuleNotFoundError:
pass
__all__ = ['WanModel'] __all__ = ['WanModel']
@ -19,7 +27,7 @@ def sinusoidal_embedding_1d(dim, position):
# preprocess # preprocess
assert dim % 2 == 0 assert dim % 2 == 0
half = dim // 2 half = dim // 2
position = position.type(torch.float64) position = position.type(torch.bfloat16)
# calculation # calculation
sinusoid = torch.outer( sinusoid = torch.outer(
@ -29,22 +37,45 @@ def sinusoidal_embedding_1d(dim, position):
@amp.autocast(enabled=False) @amp.autocast(enabled=False)
def rope_params(max_seq_len, dim, theta=10000): def rope_params_real(
max_seq_len, dim, theta=10000, dtype=torch.float32, device=torch.device("cpu")
):
assert dim % 2 == 0 assert dim % 2 == 0
freqs = torch.outer( freqs_real = torch.outer(
torch.arange(max_seq_len), torch.arange(max_seq_len, dtype=dtype, device=device),
1.0 / torch.pow(theta, 1.0
torch.arange(0, dim, 2).to(torch.float64).div(dim))) / torch.pow(
freqs = torch.polar(torch.ones_like(freqs), freqs) theta, torch.arange(0, dim, 2, dtype=dtype, device=device).div(dim)
return freqs ),
)
return torch.cos(freqs_real)
@amp.autocast(enabled=False)
def rope_params_imag(
max_seq_len, dim, theta=10000, dtype=torch.float32, device=torch.device("cpu")
):
assert dim % 2 == 0
freqs_imag = torch.outer(
torch.arange(max_seq_len, dtype=dtype, device=device),
1.0
/ torch.pow(
theta, torch.arange(0, dim, 2, dtype=dtype, device=device).div(dim)
),
)
return torch.sin(freqs_imag)
@amp.autocast(enabled=False) @amp.autocast(enabled=False)
def rope_apply(x, grid_sizes, freqs): def rope_apply(x, grid_sizes, freqs):
n, c = x.size(2), x.size(3) // 2 n, c = x.size(2), x.size(3) // 2
c0 = c - 2 * (c // 3)
c1 = c // 3
c2 = c // 3
# split freqs # split freqs
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) freqs_real = freqs[0].split([c0, c1, c2], dim=1)
freqs_imag = freqs[-1].split([c0, c1, c2], dim=1)
# loop over samples # loop over samples
output = [] output = []
@ -52,22 +83,36 @@ def rope_apply(x, grid_sizes, freqs):
seq_len = f * h * w seq_len = f * h * w
# precompute multipliers # precompute multipliers
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape( x_i = x[i, :seq_len].reshape(seq_len, n, c, 2)
seq_len, n, -1, 2)) x_real = x_i[..., 0]
freqs_i = torch.cat([ x_imag = x_i[..., 1]
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), freqs_real = torch.cat(
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), [
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) freqs_real[0][:f].view(f, 1, 1, c0).expand(f, h, w, c0),
], freqs_real[1][:h].view(1, h, 1, c1).expand(f, h, w, c1),
dim=-1).reshape(seq_len, 1, -1) freqs_real[2][:w].view(1, 1, w, c2).expand(f, h, w, c2),
],
dim=-1,
).reshape(seq_len, 1, c)
freqs_imag = torch.cat(
[
freqs_imag[0][:f].view(f, 1, 1, c0).expand(f, h, w, c0),
freqs_imag[1][:h].view(1, h, 1, c1).expand(f, h, w, c1),
freqs_imag[2][:w].view(1, 1, w, c2).expand(f, h, w, c2),
],
dim=-1,
).reshape(seq_len, 1, c)
out_real = x_real * freqs_real - x_imag * freqs_imag
out_imag = x_real * freqs_imag + x_imag * freqs_real
# apply rotary embedding # apply rotary embedding
x_i = torch.view_as_real(x_i * freqs_i).flatten(2) x_out = torch.stack([out_real, out_imag], dim=-1).flatten(2)
x_i = torch.cat([x_i, x[i, seq_len:]]) x_out = torch.cat([x_out, x[i, seq_len:]], dim=0)
# append to collection # append to collection
output.append(x_i) output.append(x_out)
return torch.stack(output).float() return torch.stack(output)
class WanRMSNorm(nn.Module): class WanRMSNorm(nn.Module):
@ -89,19 +134,6 @@ class WanRMSNorm(nn.Module):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
class WanLayerNorm(nn.LayerNorm):
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
def forward(self, x):
r"""
Args:
x(Tensor): Shape [B, L, C]
"""
return super().forward(x.float()).type_as(x)
class WanSelfAttention(nn.Module): class WanSelfAttention(nn.Module):
def __init__(self, def __init__(self,
@ -256,10 +288,10 @@ class WanAttentionBlock(nn.Module):
self.eps = eps self.eps = eps
# layers # layers
self.norm1 = WanLayerNorm(dim, eps) self.norm1 = nn.LayerNorm(dim, eps, elementwise_affine=False)
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
eps) eps)
self.norm3 = WanLayerNorm( self.norm3 = nn.LayerNorm(
dim, eps, dim, eps,
elementwise_affine=True) if cross_attn_norm else nn.Identity() elementwise_affine=True) if cross_attn_norm else nn.Identity()
self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim, self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
@ -267,7 +299,7 @@ class WanAttentionBlock(nn.Module):
(-1, -1), (-1, -1),
qk_norm, qk_norm,
eps) eps)
self.norm2 = WanLayerNorm(dim, eps) self.norm2 = nn.LayerNorm(dim, eps, elementwise_affine=False)
self.ffn = nn.Sequential( self.ffn = nn.Sequential(
nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'), nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
nn.Linear(ffn_dim, dim)) nn.Linear(ffn_dim, dim))
@ -293,24 +325,19 @@ class WanAttentionBlock(nn.Module):
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
""" """
assert e.dtype == torch.float32 e = (self.modulation + e).chunk(6, dim=1)
with amp.autocast(dtype=torch.float32):
e = (self.modulation + e).chunk(6, dim=1)
assert e[0].dtype == torch.float32
# self-attention # self-attention
y = self.self_attn( y = self.self_attn(
self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes, self.norm1(x) * (1 + e[1]) + e[0], seq_lens, grid_sizes,
freqs) freqs)
with amp.autocast(dtype=torch.float32): x = x + y * e[2]
x = x + y * e[2]
# cross-attention & ffn function # cross-attention & ffn function
def cross_attn_ffn(x, context, context_lens, e): def cross_attn_ffn(x, context, context_lens, e):
x = x + self.cross_attn(self.norm3(x), context, context_lens) x = x + self.cross_attn(self.norm3(x), context, context_lens)
y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3]) y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
with amp.autocast(dtype=torch.float32): x = x + y * e[5]
x = x + y * e[5]
return x return x
x = cross_attn_ffn(x, context, context_lens, e) x = cross_attn_ffn(x, context, context_lens, e)
@ -328,7 +355,7 @@ class Head(nn.Module):
# layers # layers
out_dim = math.prod(patch_size) * out_dim out_dim = math.prod(patch_size) * out_dim
self.norm = WanLayerNorm(dim, eps) self.norm = nn.LayerNorm(dim, eps, elementwise_affine=False)
self.head = nn.Linear(dim, out_dim) self.head = nn.Linear(dim, out_dim)
# modulation # modulation
@ -340,10 +367,8 @@ class Head(nn.Module):
x(Tensor): Shape [B, L1, C] x(Tensor): Shape [B, L1, C]
e(Tensor): Shape [B, C] e(Tensor): Shape [B, C]
""" """
assert e.dtype == torch.float32 e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
with amp.autocast(dtype=torch.float32): x = self.head(self.norm(x) * (1 + e[1]) + e[0])
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
return x return x
@ -477,12 +502,23 @@ class WanModel(ModelMixin, ConfigMixin):
# buffers (don't use register_buffer otherwise dtype will be changed in to()) # buffers (don't use register_buffer otherwise dtype will be changed in to())
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
d = dim // num_heads d = dim // num_heads
self.freqs = torch.cat([ freqs_real = torch.cat(
rope_params(1024, d - 4 * (d // 6)), [
rope_params(1024, 2 * (d // 6)), rope_params_real(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)) rope_params_real(1024, 2 * (d // 6)),
], rope_params_real(1024, 2 * (d // 6)),
dim=1) ],
dim=1,
)
freqs_imag = torch.cat(
[
rope_params_imag(1024, d - 4 * (d // 6)),
rope_params_imag(1024, 2 * (d // 6)),
rope_params_imag(1024, 2 * (d // 6)),
],
dim=1,
)
self.freqs = (freqs_real, freqs_imag)
if model_type == 'i2v' or model_type == 'flf2v': if model_type == 'i2v' or model_type == 'flf2v':
self.img_emb = MLPProj(1280, dim, flf_pos_emb=model_type == 'flf2v') self.img_emb = MLPProj(1280, dim, flf_pos_emb=model_type == 'flf2v')
@ -523,9 +559,13 @@ class WanModel(ModelMixin, ConfigMixin):
if self.model_type == 'i2v' or self.model_type == 'flf2v': if self.model_type == 'i2v' or self.model_type == 'flf2v':
assert clip_fea is not None and y is not None assert clip_fea is not None and y is not None
# params # params
dtype = self.patch_embedding.weight.dtype
device = self.patch_embedding.weight.device device = self.patch_embedding.weight.device
if self.freqs.device != device: if self.freqs[0].dtype != dtype or self.freqs[0].device != device:
self.freqs = self.freqs.to(device) self.freqs = (
self.freqs[0].to(dtype=dtype, device=device),
self.freqs[-1].to(dtype=dtype, device=device)
)
if y is not None: if y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
@ -543,11 +583,9 @@ class WanModel(ModelMixin, ConfigMixin):
]) ])
# time embeddings # time embeddings
with amp.autocast(dtype=torch.float32): e = self.time_embedding(
e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t))
sinusoidal_embedding_1d(self.freq_dim, t).float()) e0 = self.time_projection(e).unflatten(1, (6, self.dim))
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32
# context # context
context_lens = None context_lens = None
@ -579,7 +617,7 @@ class WanModel(ModelMixin, ConfigMixin):
# unpatchify # unpatchify
x = self.unpatchify(x, grid_sizes) x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x] return x
def unpatchify(self, x, grid_sizes): def unpatchify(self, x, grid_sizes):
r""" r"""

View File

@ -7,7 +7,14 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .tokenizers import HuggingfaceTokenizer try:
import torch_musa
from torch_musa.core.device import current_device
except ModuleNotFoundError:
pass
from wan.modules.tokenizers import HuggingfaceTokenizer
from wan.utils.platform import get_device
__all__ = [ __all__ = [
'T5Model', 'T5Model',
@ -59,10 +66,8 @@ class T5LayerNorm(nn.Module):
self.weight = nn.Parameter(torch.ones(dim)) self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x): def forward(self, x):
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) +
self.eps) self.eps)
if self.weight.dtype in [torch.float16, torch.bfloat16]:
x = x.type_as(self.weight)
return self.weight * x return self.weight * x
@ -110,7 +115,7 @@ class T5Attention(nn.Module):
# compute attention (T5 does not use scaling) # compute attention (T5 does not use scaling)
attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
attn = F.softmax(attn.float(), dim=-1).type_as(attn) attn = F.softmax(attn, dim=-1)
x = torch.einsum('bnij,bjnc->binc', attn, v) x = torch.einsum('bnij,bjnc->binc', attn, v)
# output # output
@ -255,7 +260,7 @@ class T5RelativeEmbedding(nn.Module):
# embeddings for small and large positions # embeddings for small and large positions
max_exact = num_buckets // 2 max_exact = num_buckets // 2
rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / rel_pos_large = max_exact + (torch.log(rel_pos / max_exact) /
math.log(self.max_dist / max_exact) * math.log(self.max_dist / max_exact) *
(num_buckets - max_exact)).long() (num_buckets - max_exact)).long()
rel_pos_large = torch.min( rel_pos_large = torch.min(
@ -475,7 +480,7 @@ class T5EncoderModel:
self, self,
text_len, text_len,
dtype=torch.bfloat16, dtype=torch.bfloat16,
device=torch.cuda.current_device(), device=get_device(),
checkpoint_path=None, checkpoint_path=None,
tokenizer_path=None, tokenizer_path=None,
shard_fn=None, shard_fn=None,

View File

@ -4,6 +4,12 @@ import torch.cuda.amp as amp
import torch.nn as nn import torch.nn as nn
from diffusers.configuration_utils import register_to_config from diffusers.configuration_utils import register_to_config
try:
import torch_musa
import torch_musa.core.amp as amp
except ModuleNotFoundError:
pass
from .model import WanAttentionBlock, WanModel, sinusoidal_embedding_1d from .model import WanAttentionBlock, WanModel, sinusoidal_embedding_1d

View File

@ -1,12 +1,20 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import logging import logging
from math import sqrt
import torch import torch
import torch.cuda.amp as amp
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import Upsample
from einops import rearrange from einops import rearrange
try:
import torch_musa
except ModuleNotFoundError:
pass
from wan.utils.platform import get_device
__all__ = [ __all__ = [
'WanVAE', 'WanVAE',
] ]
@ -44,23 +52,17 @@ class RMS_norm(nn.Module):
shape = (dim, *broadcastable_dims) if channel_first else (dim,) shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first self.channel_first = channel_first
self.scale = dim**0.5 self.scale = sqrt(dim)
self.gamma = nn.Parameter(torch.ones(shape)) self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0. self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
def forward(self, x): def forward(self, x):
return F.normalize( return (
x, dim=(1 if self.channel_first else F.normalize(x.float(), dim=(1 if self.channel_first else -1)).type_as(x)
-1)) * self.scale * self.gamma + self.bias * self.scale
* self.gamma
+ self.bias
class Upsample(nn.Upsample): )
def forward(self, x):
"""
Fix bfloat16 support for nearest neighbor interpolation.
"""
return super().forward(x.float()).type_as(x)
class Resample(nn.Module): class Resample(nn.Module):
@ -253,6 +255,10 @@ class AttentionBlock(nn.Module):
q, q,
k, k,
v, v,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale=None,
) )
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
@ -621,8 +627,8 @@ class WanVAE:
def __init__(self, def __init__(self,
z_dim=16, z_dim=16,
vae_pth='cache/vae_step_411000.pth', vae_pth='cache/vae_step_411000.pth',
dtype=torch.float, dtype=torch.bfloat16,
device="cuda"): device=get_device()):
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
@ -648,16 +654,12 @@ class WanVAE:
""" """
videos: A list of videos each with shape [C, T, H, W]. videos: A list of videos each with shape [C, T, H, W].
""" """
with amp.autocast(dtype=self.dtype): return [
return [ self.model.encode(u.unsqueeze(0), self.scale).squeeze(0) for u in videos
self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) ]
for u in videos
]
def decode(self, zs): def decode(self, zs):
with amp.autocast(dtype=self.dtype): return [
return [ self.model.decode(u.unsqueeze(0), self.scale).clamp_(-1, 1).squeeze(0)
self.model.decode(u.unsqueeze(0), for u in zs
self.scale).float().clamp_(-1, 1).squeeze(0) ]
for u in zs
]

View File

@ -6,14 +6,24 @@ import os
import random import random
import sys import sys
import types import types
from time import perf_counter
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
import torch import torch
import torch.cuda.amp as amp import torch.cuda.amp as amp
from torch.cuda import empty_cache, synchronize
import torch.distributed as dist import torch.distributed as dist
from tqdm import tqdm from tqdm import tqdm
try:
import torch_musa
import torch_musa.core.amp as amp
from torch_musa.core.memory import empty_cache
from torch_musa.core.device import synchronize
except ModuleNotFoundError:
pass
from .distributed.fsdp import shard_model from .distributed.fsdp import shard_model
from .modules.model import WanModel from .modules.model import WanModel
from .modules.t5 import T5EncoderModel from .modules.t5 import T5EncoderModel
@ -24,7 +34,7 @@ from .utils.fm_solvers import (
retrieve_timesteps, retrieve_timesteps,
) )
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from .utils.platform import get_device
class WanT2V: class WanT2V:
@ -60,7 +70,7 @@ class WanT2V:
t5_cpu (`bool`, *optional*, defaults to False): t5_cpu (`bool`, *optional*, defaults to False):
Whether to place T5 model on CPU. Only works without t5_fsdp. Whether to place T5 model on CPU. Only works without t5_fsdp.
""" """
self.device = torch.device(f"cuda:{device_id}") self.device = get_device(device_id)
self.config = config self.config = config
self.rank = rank self.rank = rank
self.t5_cpu = t5_cpu self.t5_cpu = t5_cpu
@ -171,6 +181,7 @@ class WanT2V:
seed_g = torch.Generator(device=self.device) seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed) seed_g.manual_seed(seed)
start_time = perf_counter()
if not self.t5_cpu: if not self.t5_cpu:
self.text_encoder.model.to(self.device) self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device) context = self.text_encoder([input_prompt], self.device)
@ -182,6 +193,8 @@ class WanT2V:
context_null = self.text_encoder([n_prompt], torch.device('cpu')) context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context] context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null] context_null = [t.to(self.device) for t in context_null]
end_time = perf_counter()
logging.info(f"T5 Encoding Context took {end_time - start_time:.2f} seconds.")
noise = [ noise = [
torch.randn( torch.randn(
@ -230,13 +243,14 @@ class WanT2V:
arg_c = {'context': context, 'seq_len': seq_len} arg_c = {'context': context, 'seq_len': seq_len}
arg_null = {'context': context_null, 'seq_len': seq_len} arg_null = {'context': context_null, 'seq_len': seq_len}
start_time = perf_counter()
self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)): for _, t in enumerate(tqdm(timesteps)):
latent_model_input = latents latent_model_input = latents
timestep = [t] timestep = [t]
timestep = torch.stack(timestep) timestep = torch.stack(timestep)
self.model.to(self.device)
noise_pred_cond = self.model( noise_pred_cond = self.model(
latent_model_input, t=timestep, **arg_c)[0] latent_model_input, t=timestep, **arg_c)[0]
noise_pred_uncond = self.model( noise_pred_uncond = self.model(
@ -252,19 +266,24 @@ class WanT2V:
return_dict=False, return_dict=False,
generator=seed_g)[0] generator=seed_g)[0]
latents = [temp_x0.squeeze(0)] latents = [temp_x0.squeeze(0)]
end_time = perf_counter()
logging.info(f"Sampling took {end_time - start_time:.2f} seconds.")
x0 = latents x0 = latents
if offload_model: if offload_model:
self.model.cpu() self.model.cpu()
torch.cuda.empty_cache() empty_cache()
if self.rank == 0: if self.rank == 0:
start_time = perf_counter()
videos = self.vae.decode(x0) videos = self.vae.decode(x0)
end_time = perf_counter()
logging.info(f"VAE Decoding took {end_time - start_time:.2f} seconds.")
del noise, latents del noise, latents
del sample_scheduler del sample_scheduler
if offload_model: if offload_model:
gc.collect() gc.collect()
torch.cuda.synchronize() synchronize()
if dist.is_initialized(): if dist.is_initialized():
dist.barrier() dist.barrier()

View File

@ -5,9 +5,10 @@ from .fm_solvers import (
) )
from .fm_solvers_unipc import FlowUniPCMultistepScheduler from .fm_solvers_unipc import FlowUniPCMultistepScheduler
from .vace_processor import VaceVideoProcessor from .vace_processor import VaceVideoProcessor
from .platform import get_device, get_torch_distributed_backend
__all__ = [ __all__ = [
'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps', 'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps',
'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler', 'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler',
'VaceVideoProcessor' 'VaceVideoProcessor', 'get_device', 'get_torch_distributed_backend'
] ]

34
wan/utils/platform.py Normal file
View File

@ -0,0 +1,34 @@
from typing import Optional
import torch
try:
import torch_musa
except ModuleNotFoundError:
pass
def _is_musa():
try:
if torch.musa.is_available():
return True
except ModuleNotFoundError:
return False
def get_device(local_rank:Optional[int]=None) -> torch.device:
if torch.cuda.is_available():
return torch.cuda.current_device() if local_rank is None else torch.device("cuda", local_rank)
elif _is_musa():
return torch.musa.current_device() if local_rank is None else torch.device("musa", local_rank)
else:
return torch.device("cpu")
def get_torch_distributed_backend() -> str:
if torch.cuda.is_available():
return "nccl"
elif _is_musa():
return "mccl"
else:
raise NotImplementedError("No Accelerators(AMD/NV/MTT GPU, AMD MI instinct accelerators) available")

View File

@ -13,6 +13,7 @@ from functools import partial
import torch import torch
import torch.cuda.amp as amp import torch.cuda.amp as amp
from torch.cuda import empty_cache, synchronize
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn.functional as F import torch.nn.functional as F
@ -20,6 +21,14 @@ import torchvision.transforms.functional as TF
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
try:
import torch_musa
import torch_musa.core.amp as amp
from torch_musa.core.memory import empty_cache
from torch_musa.core.device import synchronize
except ModuleNotFoundError:
pass
from .modules.vace_model import VaceWanModel from .modules.vace_model import VaceWanModel
from .text2video import ( from .text2video import (
FlowDPMSolverMultistepScheduler, FlowDPMSolverMultistepScheduler,
@ -32,6 +41,7 @@ from .text2video import (
shard_model, shard_model,
) )
from .utils.vace_processor import VaceVideoProcessor from .utils.vace_processor import VaceVideoProcessor
from .utils.platform import get_device, get_torch_distributed_backend
class WanVace(WanT2V): class WanVace(WanT2V):
@ -68,7 +78,7 @@ class WanVace(WanT2V):
t5_cpu (`bool`, *optional*, defaults to False): t5_cpu (`bool`, *optional*, defaults to False):
Whether to place T5 model on CPU. Only works without t5_fsdp. Whether to place T5 model on CPU. Only works without t5_fsdp.
""" """
self.device = torch.device(f"cuda:{device_id}") self.device = get_device(device_id)
self.config = config self.config = config
self.rank = rank self.rank = rank
self.t5_cpu = t5_cpu self.t5_cpu = t5_cpu
@ -460,7 +470,7 @@ class WanVace(WanT2V):
x0 = latents x0 = latents
if offload_model: if offload_model:
self.model.cpu() self.model.cpu()
torch.cuda.empty_cache() empty_cache()
if self.rank == 0: if self.rank == 0:
videos = self.decode_latent(x0, input_ref_images) videos = self.decode_latent(x0, input_ref_images)
@ -468,7 +478,7 @@ class WanVace(WanT2V):
del sample_scheduler del sample_scheduler
if offload_model: if offload_model:
gc.collect() gc.collect()
torch.cuda.synchronize() synchronize()
if dist.is_initialized(): if dist.is_initialized():
dist.barrier() dist.barrier()
@ -568,7 +578,7 @@ class WanVaceMP(WanVace):
torch.cuda.set_device(gpu) torch.cuda.set_device(gpu)
dist.init_process_group( dist.init_process_group(
backend='nccl', backend=get_torch_distributed_backend(),
init_method='env://', init_method='env://',
rank=rank, rank=rank,
world_size=world_size) world_size=world_size)
@ -633,7 +643,7 @@ class WanVaceMP(WanVace):
model = shard_fn(model) model = shard_fn(model)
sample_neg_prompt = self.config.sample_neg_prompt sample_neg_prompt = self.config.sample_neg_prompt
torch.cuda.empty_cache() empty_cache()
event = initialized_events[gpu] event = initialized_events[gpu]
in_q = in_q_list[gpu] in_q = in_q_list[gpu]
event.set() event.set()
@ -748,7 +758,7 @@ class WanVaceMP(WanVace):
generator=seed_g)[0] generator=seed_g)[0]
latents = [temp_x0.squeeze(0)] latents = [temp_x0.squeeze(0)]
torch.cuda.empty_cache() empty_cache()
x0 = latents x0 = latents
if rank == 0: if rank == 0:
videos = self.decode_latent( videos = self.decode_latent(
@ -758,7 +768,7 @@ class WanVaceMP(WanVace):
del sample_scheduler del sample_scheduler
if offload_model: if offload_model:
gc.collect() gc.collect()
torch.cuda.synchronize() synchronize()
if dist.is_initialized(): if dist.is_initialized():
dist.barrier() dist.barrier()