mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-12-20 22:22:15 +00:00
Compare commits
2 Commits
618d94c564
...
43ac073411
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
43ac073411 | ||
|
|
447aa08620 |
13
generate.py
13
generate.py
@ -4,7 +4,6 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from time import perf_counter
|
||||
from datetime import datetime
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
@ -20,7 +19,7 @@ try:
|
||||
import torch_musa
|
||||
from torch_musa.core.device import set_device
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
torch_musa = None
|
||||
|
||||
import wan
|
||||
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
|
||||
@ -366,7 +365,6 @@ def generate(args):
|
||||
logging.info(f"Extended prompt: {args.prompt}")
|
||||
|
||||
logging.info("Creating WanT2V pipeline.")
|
||||
start_time = perf_counter()
|
||||
wan_t2v = wan.WanT2V(
|
||||
config=cfg,
|
||||
checkpoint_dir=args.ckpt_dir,
|
||||
@ -377,8 +375,6 @@ def generate(args):
|
||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||
t5_cpu=args.t5_cpu,
|
||||
)
|
||||
end_time = perf_counter()
|
||||
logging.info(f"Creating WanT2V pipeline took {end_time - start_time:.2f} seconds.")
|
||||
|
||||
logging.info(
|
||||
f"Generating {'image' if 't2i' in args.task else 'video'} ...")
|
||||
@ -392,6 +388,7 @@ def generate(args):
|
||||
guide_scale=args.sample_guide_scale,
|
||||
seed=args.base_seed,
|
||||
offload_model=args.offload_model)
|
||||
|
||||
elif "i2v" in args.task:
|
||||
if args.prompt is None:
|
||||
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
||||
@ -425,7 +422,6 @@ def generate(args):
|
||||
logging.info(f"Extended prompt: {args.prompt}")
|
||||
|
||||
logging.info("Creating WanI2V pipeline.")
|
||||
start_time = perf_counter()
|
||||
wan_i2v = wan.WanI2V(
|
||||
config=cfg,
|
||||
checkpoint_dir=args.ckpt_dir,
|
||||
@ -436,8 +432,6 @@ def generate(args):
|
||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||
t5_cpu=args.t5_cpu,
|
||||
)
|
||||
end_time = perf_counter()
|
||||
logging.info(f"Creating WanI2V pipeline took {end_time - start_time:.2f} seconds.")
|
||||
|
||||
logging.info("Generating video ...")
|
||||
video = wan_i2v.generate(
|
||||
@ -586,7 +580,6 @@ def generate(args):
|
||||
value_range=(-1, 1))
|
||||
else:
|
||||
logging.info(f"Saving generated video to {args.save_file}")
|
||||
start_time = perf_counter()
|
||||
cache_video(
|
||||
tensor=video[None],
|
||||
save_file=args.save_file,
|
||||
@ -594,8 +587,6 @@ def generate(args):
|
||||
nrow=1,
|
||||
normalize=True,
|
||||
value_range=(-1, 1))
|
||||
end_time = perf_counter()
|
||||
logging.info(f"Saving Video took {end_time - start_time:.2f} seconds")
|
||||
logging.info("Finished.")
|
||||
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ try:
|
||||
import torch_musa
|
||||
from torch_musa.core.memory import empty_cache
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
torch_musa = None
|
||||
|
||||
def shard_model(
|
||||
model,
|
||||
|
||||
@ -14,11 +14,24 @@ try:
|
||||
import torch_musa.core.amp as amp
|
||||
attn_type = AttnType.TORCH
|
||||
except ImportError:
|
||||
pass
|
||||
torch_musa = None
|
||||
|
||||
from ..modules.model import sinusoidal_embedding_1d
|
||||
|
||||
|
||||
def pad_freqs(original_tensor, target_len):
|
||||
seq_len, s1, s2 = original_tensor.shape
|
||||
pad_size = target_len - seq_len
|
||||
padding_tensor = torch.ones(
|
||||
pad_size,
|
||||
s1,
|
||||
s2,
|
||||
dtype=original_tensor.dtype,
|
||||
device=original_tensor.device)
|
||||
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
|
||||
return padded_tensor
|
||||
|
||||
|
||||
def pad_tensor(original_tensor, target_len, pad_value=0.0):
|
||||
seq_len, s1, s2 = original_tensor.shape
|
||||
pad_size = target_len - seq_len
|
||||
@ -34,6 +47,47 @@ def pad_tensor(original_tensor, target_len, pad_value=0.0):
|
||||
|
||||
@amp.autocast(enabled=False)
|
||||
def rope_apply(x, grid_sizes, freqs):
|
||||
"""
|
||||
x: [B, L, N, C].
|
||||
grid_sizes: [B, 3].
|
||||
freqs: [M, C // 2].
|
||||
"""
|
||||
s, n, c = x.size(1), x.size(2), x.size(3) // 2
|
||||
# split freqs
|
||||
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
||||
|
||||
# loop over samples
|
||||
output = []
|
||||
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
||||
seq_len = f * h * w
|
||||
|
||||
# precompute multipliers
|
||||
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
|
||||
s, n, -1, 2))
|
||||
freqs_i = torch.cat([
|
||||
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||
],
|
||||
dim=-1).reshape(seq_len, 1, -1)
|
||||
|
||||
# apply rotary embedding
|
||||
sp_size = get_sequence_parallel_world_size()
|
||||
sp_rank = get_sequence_parallel_rank()
|
||||
freqs_i = pad_freqs(freqs_i, s * sp_size)
|
||||
s_per_rank = s
|
||||
freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
|
||||
s_per_rank), :, :]
|
||||
x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
|
||||
x_i = torch.cat([x_i, x[i, s:]])
|
||||
|
||||
# append to collection
|
||||
output.append(x_i)
|
||||
return torch.stack(output).float()
|
||||
|
||||
|
||||
@amp.autocast(enabled=False)
|
||||
def rope_apply_musa(x, grid_sizes, freqs):
|
||||
"""
|
||||
x: [B, L, N, C].
|
||||
grid_sizes: [B, 3].
|
||||
@ -141,11 +195,15 @@ def usp_dit_forward(
|
||||
# params
|
||||
dtype = self.patch_embedding.weight.dtype
|
||||
device = self.patch_embedding.weight.device
|
||||
if self.freqs[0].dtype != dtype or self.freqs[0].device != device:
|
||||
self.freqs = (
|
||||
self.freqs[0].to(dtype=dtype, device=device),
|
||||
self.freqs[-1].to(dtype=dtype, device=device)
|
||||
)
|
||||
if torch_musa is not None:
|
||||
if self.freqs[0].dtype != dtype or self.freqs[0].device != device:
|
||||
self.freqs = (
|
||||
self.freqs[0].to(dtype=dtype, device=device),
|
||||
self.freqs[-1].to(dtype=dtype, device=device)
|
||||
)
|
||||
else:
|
||||
if self.freqs.dtype != dtype or self.freqs.device != device:
|
||||
self.freqs = self.freqs.to(dtype=dtype, device=device)
|
||||
|
||||
if self.model_type != 'vace' and y is not None:
|
||||
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
||||
@ -163,9 +221,11 @@ def usp_dit_forward(
|
||||
])
|
||||
|
||||
# time embeddings
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, t))
|
||||
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
||||
with amp.autocast(dtype=torch.float32):
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, t).float())
|
||||
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
||||
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
||||
|
||||
# context
|
||||
context_lens = None
|
||||
@ -209,7 +269,7 @@ def usp_dit_forward(
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
return [u.float() for u in x]
|
||||
|
||||
|
||||
def usp_attn_forward(self,
|
||||
@ -232,8 +292,13 @@ def usp_attn_forward(self,
|
||||
return q, k, v
|
||||
|
||||
q, k, v = qkv_fn(x)
|
||||
q = rope_apply(q, grid_sizes, freqs)
|
||||
k = rope_apply(k, grid_sizes, freqs)
|
||||
|
||||
if torch_musa is not None:
|
||||
q = rope_apply_musa(q, grid_sizes, freqs)
|
||||
k = rope_apply_musa(k, grid_sizes, freqs)
|
||||
else:
|
||||
q = rope_apply(q, grid_sizes, freqs)
|
||||
k = rope_apply(k, grid_sizes, freqs)
|
||||
|
||||
# TODO: We should use unpaded q,k,v for attention.
|
||||
# k_lens = seq_lens // get_sequence_parallel_world_size()
|
||||
|
||||
@ -23,7 +23,7 @@ try:
|
||||
from torch_musa.core.memory import empty_cache
|
||||
from torch_musa.core.device import synchronize
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
torch_musa = None
|
||||
|
||||
from .distributed.fsdp import shard_model
|
||||
from .modules.clip import CLIPModel
|
||||
|
||||
@ -6,7 +6,6 @@ import os
|
||||
import random
|
||||
import sys
|
||||
import types
|
||||
from time import perf_counter
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
|
||||
@ -24,7 +23,7 @@ try:
|
||||
from torch_musa.core.memory import empty_cache
|
||||
from torch_musa.core.device import synchronize
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
torch_musa = None
|
||||
|
||||
from .distributed.fsdp import shard_model
|
||||
from .modules.clip import CLIPModel
|
||||
@ -231,7 +230,6 @@ class WanI2V:
|
||||
n_prompt = self.sample_neg_prompt
|
||||
|
||||
# preprocess
|
||||
start_time = perf_counter()
|
||||
if not self.t5_cpu:
|
||||
self.text_encoder.model.to(self.device)
|
||||
context = self.text_encoder([input_prompt], self.device)
|
||||
@ -243,18 +241,12 @@ class WanI2V:
|
||||
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
|
||||
context = [t.to(self.device) for t in context]
|
||||
context_null = [t.to(self.device) for t in context_null]
|
||||
end_time = perf_counter()
|
||||
logging.info(f"T5 Encoding took {end_time - start_time:.2f} seconds.")
|
||||
|
||||
start_time = perf_counter()
|
||||
self.clip.model.to(self.device)
|
||||
clip_context = self.clip.visual([img[:, None, :, :]])
|
||||
if offload_model:
|
||||
self.clip.model.cpu()
|
||||
end_time = perf_counter()
|
||||
logging.info(f"CLIP took {end_time - start_time:.2f} seconds.")
|
||||
|
||||
start_time = perf_counter()
|
||||
y = self.vae.encode([
|
||||
torch.concat([
|
||||
torch.nn.functional.interpolate(
|
||||
@ -264,9 +256,6 @@ class WanI2V:
|
||||
],
|
||||
dim=1).to(self.device)
|
||||
])[0]
|
||||
end_time = perf_counter()
|
||||
logging.info(f"VAE Encoding took {end_time - start_time:.2f} seconds.")
|
||||
|
||||
y = torch.concat([msk, y])
|
||||
|
||||
@contextmanager
|
||||
@ -319,7 +308,6 @@ class WanI2V:
|
||||
if offload_model:
|
||||
empty_cache()
|
||||
|
||||
start_time = perf_counter()
|
||||
self.model.to(self.device)
|
||||
for _, t in enumerate(tqdm(timesteps)):
|
||||
latent_model_input = [latent.to(self.device)]
|
||||
@ -353,18 +341,13 @@ class WanI2V:
|
||||
|
||||
x0 = [latent.to(self.device)]
|
||||
del latent_model_input, timestep
|
||||
end_time = perf_counter()
|
||||
logging.info(f"Sampling took {end_time - start_time:.2f} seconds.")
|
||||
|
||||
if offload_model:
|
||||
self.model.cpu()
|
||||
empty_cache()
|
||||
|
||||
if self.rank == 0:
|
||||
start_time = perf_counter()
|
||||
videos = self.vae.decode(x0)
|
||||
end_time = perf_counter()
|
||||
logging.info(f"VAE Decoding took {end_time - start_time:.2f} seconds.")
|
||||
|
||||
del noise, latent
|
||||
del sample_scheduler
|
||||
|
||||
@ -20,7 +20,7 @@ try:
|
||||
FLASH_ATTN_3_AVAILABLE = False
|
||||
FLASH_ATTN_2_AVAILABLE = False
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
torch_musa = None
|
||||
|
||||
|
||||
__all__ = [
|
||||
@ -180,9 +180,8 @@ def attention(
|
||||
k = k.transpose(1, 2).to(dtype)
|
||||
v = v.transpose(1, 2).to(dtype)
|
||||
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False):
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=causal, scale=softmax_scale)
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=causal, scale=softmax_scale)
|
||||
|
||||
out = out.transpose(1, 2).contiguous()
|
||||
return out
|
||||
|
||||
@ -18,7 +18,7 @@ try:
|
||||
import torch_musa.core.amp as amp
|
||||
from .attention import attention as flash_attention
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
torch_musa = None
|
||||
|
||||
__all__ = [
|
||||
'XLMRobertaCLIP',
|
||||
@ -37,7 +37,7 @@ def pos_interpolate(pos, seq_len):
|
||||
return torch.cat([
|
||||
pos[:, :n],
|
||||
F.interpolate(
|
||||
pos[:, n:].reshape(1, src_grid, src_grid, -1).permute(
|
||||
pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
|
||||
0, 3, 1, 2),
|
||||
size=(tar_grid, tar_grid),
|
||||
mode='bicubic',
|
||||
@ -52,6 +52,12 @@ class QuickGELU(nn.Module):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type_as(x)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
@ -84,7 +90,10 @@ class SelfAttention(nn.Module):
|
||||
|
||||
# compute attention
|
||||
p = self.attn_dropout if self.training else 0.0
|
||||
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal)
|
||||
if torch_musa is not None:
|
||||
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal)
|
||||
else:
|
||||
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
|
||||
x = x.reshape(b, s, c)
|
||||
|
||||
# output
|
||||
@ -133,10 +142,10 @@ class AttentionBlock(nn.Module):
|
||||
self.norm_eps = norm_eps
|
||||
|
||||
# layers
|
||||
self.norm1 = nn.LayerNorm(dim, eps=norm_eps)
|
||||
self.norm1 = LayerNorm(dim, eps=norm_eps)
|
||||
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
|
||||
proj_dropout)
|
||||
self.norm2 = nn.LayerNorm(dim, eps=norm_eps)
|
||||
self.norm2 = LayerNorm(dim, eps=norm_eps)
|
||||
if activation == 'swi_glu':
|
||||
self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
|
||||
else:
|
||||
@ -179,7 +188,7 @@ class AttentionPool(nn.Module):
|
||||
self.to_q = nn.Linear(dim, dim)
|
||||
self.to_kv = nn.Linear(dim, dim * 2)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.norm = nn.LayerNorm(dim, eps=norm_eps)
|
||||
self.norm = LayerNorm(dim, eps=norm_eps)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(dim, int(dim * mlp_ratio)),
|
||||
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
|
||||
@ -196,7 +205,10 @@ class AttentionPool(nn.Module):
|
||||
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
|
||||
|
||||
# compute attention
|
||||
x = flash_attention(q, k, v, version=2)
|
||||
if torch_musa is not None:
|
||||
x = flash_attention(q, k, v)
|
||||
else:
|
||||
x = flash_attention(q, k, v, version=2)
|
||||
x = x.reshape(b, 1, c)
|
||||
|
||||
# output
|
||||
@ -261,13 +273,13 @@ class VisionTransformer(nn.Module):
|
||||
self.dropout = nn.Dropout(embedding_dropout)
|
||||
|
||||
# transformer
|
||||
self.pre_norm = nn.LayerNorm(dim, eps=norm_eps) if pre_norm else None
|
||||
self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
|
||||
self.transformer = nn.Sequential(*[
|
||||
AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
|
||||
activation, attn_dropout, proj_dropout, norm_eps)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
self.post_norm = nn.LayerNorm(dim, eps=norm_eps)
|
||||
self.post_norm = LayerNorm(dim, eps=norm_eps)
|
||||
|
||||
# head
|
||||
if pool_type == 'token':
|
||||
|
||||
@ -7,7 +7,6 @@ import torch.nn as nn
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
|
||||
from wan.utils.platform import get_device
|
||||
from wan.modules.attention import flash_attention
|
||||
|
||||
try:
|
||||
@ -15,7 +14,7 @@ try:
|
||||
import torch_musa.core.amp as amp
|
||||
from wan.modules.attention import attention as flash_attention
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
torch_musa = None
|
||||
|
||||
__all__ = ['WanModel']
|
||||
|
||||
@ -27,7 +26,7 @@ def sinusoidal_embedding_1d(dim, position):
|
||||
# preprocess
|
||||
assert dim % 2 == 0
|
||||
half = dim // 2
|
||||
position = position.type(torch.bfloat16)
|
||||
position = position.type(torch.float32)
|
||||
|
||||
# calculation
|
||||
sinusoid = torch.outer(
|
||||
@ -36,6 +35,17 @@ def sinusoidal_embedding_1d(dim, position):
|
||||
return x
|
||||
|
||||
|
||||
@amp.autocast(enabled=False)
|
||||
def rope_params(max_seq_len, dim, theta=10000):
|
||||
assert dim % 2 == 0
|
||||
freqs = torch.outer(
|
||||
torch.arange(max_seq_len),
|
||||
1.0 / torch.pow(theta,
|
||||
torch.arange(0, dim, 2).to(torch.float64).div(dim)))
|
||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs
|
||||
|
||||
|
||||
@amp.autocast(enabled=False)
|
||||
def rope_params_real(
|
||||
max_seq_len, dim, theta=10000, dtype=torch.float32, device=torch.device("cpu")
|
||||
@ -69,6 +79,37 @@ def rope_params_imag(
|
||||
@amp.autocast(enabled=False)
|
||||
def rope_apply(x, grid_sizes, freqs):
|
||||
n, c = x.size(2), x.size(3) // 2
|
||||
|
||||
# split freqs
|
||||
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
||||
|
||||
# loop over samples
|
||||
output = []
|
||||
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
||||
seq_len = f * h * w
|
||||
|
||||
# precompute multipliers
|
||||
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
|
||||
seq_len, n, -1, 2))
|
||||
freqs_i = torch.cat([
|
||||
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||
],
|
||||
dim=-1).reshape(seq_len, 1, -1)
|
||||
|
||||
# apply rotary embedding
|
||||
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
|
||||
x_i = torch.cat([x_i, x[i, seq_len:]])
|
||||
|
||||
# append to collection
|
||||
output.append(x_i)
|
||||
return torch.stack(output).float()
|
||||
|
||||
|
||||
@amp.autocast(enabled=False)
|
||||
def rope_apply_musa(x, grid_sizes, freqs):
|
||||
n, c = x.size(2), x.size(3) // 2
|
||||
c0 = c - 2 * (c // 3)
|
||||
c1 = c // 3
|
||||
c2 = c // 3
|
||||
@ -134,6 +175,19 @@ class WanRMSNorm(nn.Module):
|
||||
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
||||
|
||||
|
||||
class WanLayerNorm(nn.LayerNorm):
|
||||
|
||||
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
|
||||
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
|
||||
|
||||
def forward(self, x):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L, C]
|
||||
"""
|
||||
return super().forward(x.float()).type_as(x)
|
||||
|
||||
|
||||
class WanSelfAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
@ -178,12 +232,22 @@ class WanSelfAttention(nn.Module):
|
||||
|
||||
q, k, v = qkv_fn(x)
|
||||
|
||||
x = flash_attention(
|
||||
q=rope_apply(q, grid_sizes, freqs),
|
||||
k=rope_apply(k, grid_sizes, freqs),
|
||||
v=v,
|
||||
k_lens=seq_lens,
|
||||
window_size=self.window_size)
|
||||
if torch_musa is not None:
|
||||
x = flash_attention(
|
||||
q=rope_apply_musa(q, grid_sizes, freqs),
|
||||
k=rope_apply_musa(k, grid_sizes, freqs),
|
||||
v=v,
|
||||
k_lens=seq_lens,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
else:
|
||||
x = flash_attention(
|
||||
q=rope_apply(q, grid_sizes, freqs),
|
||||
k=rope_apply(k, grid_sizes, freqs),
|
||||
v=v,
|
||||
k_lens=seq_lens,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
|
||||
# output
|
||||
x = x.flatten(2)
|
||||
@ -288,10 +352,10 @@ class WanAttentionBlock(nn.Module):
|
||||
self.eps = eps
|
||||
|
||||
# layers
|
||||
self.norm1 = nn.LayerNorm(dim, eps, elementwise_affine=False)
|
||||
self.norm1 = WanLayerNorm(dim, eps)
|
||||
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
|
||||
eps)
|
||||
self.norm3 = nn.LayerNorm(
|
||||
self.norm3 = WanLayerNorm(
|
||||
dim, eps,
|
||||
elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
||||
self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
|
||||
@ -299,7 +363,7 @@ class WanAttentionBlock(nn.Module):
|
||||
(-1, -1),
|
||||
qk_norm,
|
||||
eps)
|
||||
self.norm2 = nn.LayerNorm(dim, eps, elementwise_affine=False)
|
||||
self.norm2 = WanLayerNorm(dim, eps)
|
||||
self.ffn = nn.Sequential(
|
||||
nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
|
||||
nn.Linear(ffn_dim, dim))
|
||||
@ -325,19 +389,24 @@ class WanAttentionBlock(nn.Module):
|
||||
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
||||
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
||||
"""
|
||||
e = (self.modulation + e).chunk(6, dim=1)
|
||||
assert e.dtype == torch.float32
|
||||
with amp.autocast(dtype=torch.float32):
|
||||
e = (self.modulation + e).chunk(6, dim=1)
|
||||
assert e[0].dtype == torch.float32
|
||||
|
||||
# self-attention
|
||||
y = self.self_attn(
|
||||
self.norm1(x) * (1 + e[1]) + e[0], seq_lens, grid_sizes,
|
||||
self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes,
|
||||
freqs)
|
||||
x = x + y * e[2]
|
||||
with amp.autocast(dtype=torch.float32):
|
||||
x = x + y * e[2]
|
||||
|
||||
# cross-attention & ffn function
|
||||
def cross_attn_ffn(x, context, context_lens, e):
|
||||
x = x + self.cross_attn(self.norm3(x), context, context_lens)
|
||||
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
|
||||
x = x + y * e[5]
|
||||
y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
|
||||
with amp.autocast(dtype=torch.float32):
|
||||
x = x + y * e[5]
|
||||
return x
|
||||
|
||||
x = cross_attn_ffn(x, context, context_lens, e)
|
||||
@ -355,7 +424,7 @@ class Head(nn.Module):
|
||||
|
||||
# layers
|
||||
out_dim = math.prod(patch_size) * out_dim
|
||||
self.norm = nn.LayerNorm(dim, eps, elementwise_affine=False)
|
||||
self.norm = WanLayerNorm(dim, eps)
|
||||
self.head = nn.Linear(dim, out_dim)
|
||||
|
||||
# modulation
|
||||
@ -367,8 +436,10 @@ class Head(nn.Module):
|
||||
x(Tensor): Shape [B, L1, C]
|
||||
e(Tensor): Shape [B, C]
|
||||
"""
|
||||
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
|
||||
x = self.head(self.norm(x) * (1 + e[1]) + e[0])
|
||||
assert e.dtype == torch.float32
|
||||
with amp.autocast(dtype=torch.float32):
|
||||
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
|
||||
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
||||
return x
|
||||
|
||||
|
||||
@ -502,23 +573,33 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
# buffers (don't use register_buffer otherwise dtype will be changed in to())
|
||||
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
||||
d = dim // num_heads
|
||||
freqs_real = torch.cat(
|
||||
[
|
||||
rope_params_real(1024, d - 4 * (d // 6)),
|
||||
rope_params_real(1024, 2 * (d // 6)),
|
||||
rope_params_real(1024, 2 * (d // 6)),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
freqs_imag = torch.cat(
|
||||
[
|
||||
rope_params_imag(1024, d - 4 * (d // 6)),
|
||||
rope_params_imag(1024, 2 * (d // 6)),
|
||||
rope_params_imag(1024, 2 * (d // 6)),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
self.freqs = (freqs_real, freqs_imag)
|
||||
if torch_musa is not None:
|
||||
freqs_real = torch.cat(
|
||||
[
|
||||
rope_params_real(1024, d - 4 * (d // 6)),
|
||||
rope_params_real(1024, 2 * (d // 6)),
|
||||
rope_params_real(1024, 2 * (d // 6)),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
freqs_imag = torch.cat(
|
||||
[
|
||||
rope_params_imag(1024, d - 4 * (d // 6)),
|
||||
rope_params_imag(1024, 2 * (d // 6)),
|
||||
rope_params_imag(1024, 2 * (d // 6)),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
self.freqs = (freqs_real, freqs_imag)
|
||||
else:
|
||||
self.freqs = torch.cat(
|
||||
[
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
if model_type == 'i2v' or model_type == 'flf2v':
|
||||
self.img_emb = MLPProj(1280, dim, flf_pos_emb=model_type == 'flf2v')
|
||||
@ -561,11 +642,15 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
# params
|
||||
dtype = self.patch_embedding.weight.dtype
|
||||
device = self.patch_embedding.weight.device
|
||||
if self.freqs[0].dtype != dtype or self.freqs[0].device != device:
|
||||
self.freqs = (
|
||||
self.freqs[0].to(dtype=dtype, device=device),
|
||||
self.freqs[-1].to(dtype=dtype, device=device)
|
||||
)
|
||||
if torch_musa is not None:
|
||||
if self.freqs[0].dtype != dtype or self.freqs[0].device != device:
|
||||
self.freqs = (
|
||||
self.freqs[0].to(dtype=dtype, device=device),
|
||||
self.freqs[-1].to(dtype=dtype, device=device)
|
||||
)
|
||||
else:
|
||||
if self.freqs.dtype != dtype or self.freqs.device != device:
|
||||
self.freqs = self.freqs.to(dtype=dtype, device=device)
|
||||
|
||||
if y is not None:
|
||||
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
||||
@ -583,9 +668,11 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
])
|
||||
|
||||
# time embeddings
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, t))
|
||||
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
||||
with amp.autocast(dtype=torch.float32):
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, t).float())
|
||||
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
||||
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
||||
|
||||
# context
|
||||
context_lens = None
|
||||
@ -617,7 +704,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
return [u.float() for u in x]
|
||||
|
||||
def unpatchify(self, x, grid_sizes):
|
||||
r"""
|
||||
|
||||
@ -6,15 +6,15 @@ import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.cuda import current_device
|
||||
|
||||
try:
|
||||
import torch_musa
|
||||
from torch_musa.core.device import current_device
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
torch_musa = None
|
||||
|
||||
from wan.modules.tokenizers import HuggingfaceTokenizer
|
||||
from wan.utils.platform import get_device
|
||||
from .tokenizers import HuggingfaceTokenizer
|
||||
|
||||
__all__ = [
|
||||
'T5Model',
|
||||
@ -66,8 +66,10 @@ class T5LayerNorm(nn.Module):
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) +
|
||||
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
|
||||
self.eps)
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
x = x.type_as(self.weight)
|
||||
return self.weight * x
|
||||
|
||||
|
||||
@ -115,7 +117,7 @@ class T5Attention(nn.Module):
|
||||
|
||||
# compute attention (T5 does not use scaling)
|
||||
attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
|
||||
attn = F.softmax(attn, dim=-1)
|
||||
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
|
||||
x = torch.einsum('bnij,bjnc->binc', attn, v)
|
||||
|
||||
# output
|
||||
@ -260,7 +262,7 @@ class T5RelativeEmbedding(nn.Module):
|
||||
|
||||
# embeddings for small and large positions
|
||||
max_exact = num_buckets // 2
|
||||
rel_pos_large = max_exact + (torch.log(rel_pos / max_exact) /
|
||||
rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
|
||||
math.log(self.max_dist / max_exact) *
|
||||
(num_buckets - max_exact)).long()
|
||||
rel_pos_large = torch.min(
|
||||
@ -480,7 +482,7 @@ class T5EncoderModel:
|
||||
self,
|
||||
text_len,
|
||||
dtype=torch.bfloat16,
|
||||
device=get_device(),
|
||||
device=current_device(),
|
||||
checkpoint_path=None,
|
||||
tokenizer_path=None,
|
||||
shard_fn=None,
|
||||
|
||||
@ -8,7 +8,7 @@ try:
|
||||
import torch_musa
|
||||
import torch_musa.core.amp as amp
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
torch_musa = None
|
||||
|
||||
from .model import WanAttentionBlock, WanModel, sinusoidal_embedding_1d
|
||||
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import logging
|
||||
from math import sqrt
|
||||
|
||||
import torch
|
||||
import torch.cuda.amp as amp
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Upsample
|
||||
@ -10,8 +10,9 @@ from einops import rearrange
|
||||
|
||||
try:
|
||||
import torch_musa
|
||||
import torch_musa.core.amp as amp
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
torch_musa = None
|
||||
|
||||
from wan.utils.platform import get_device
|
||||
|
||||
@ -52,17 +53,23 @@ class RMS_norm(nn.Module):
|
||||
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
||||
|
||||
self.channel_first = channel_first
|
||||
self.scale = sqrt(dim)
|
||||
self.scale = dim**0.5
|
||||
self.gamma = nn.Parameter(torch.ones(shape))
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
|
||||
|
||||
def forward(self, x):
|
||||
return (
|
||||
F.normalize(x.float(), dim=(1 if self.channel_first else -1)).type_as(x)
|
||||
* self.scale
|
||||
* self.gamma
|
||||
+ self.bias
|
||||
)
|
||||
return F.normalize(
|
||||
x, dim=(1 if self.channel_first else
|
||||
-1)) * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
class Upsample(nn.Upsample):
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Fix bfloat16 support for nearest neighbor interpolation.
|
||||
"""
|
||||
return super().forward(x.float()).type_as(x)
|
||||
|
||||
|
||||
class Resample(nn.Module):
|
||||
@ -255,10 +262,6 @@ class AttentionBlock(nn.Module):
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
scale=None,
|
||||
)
|
||||
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
|
||||
|
||||
@ -627,7 +630,7 @@ class WanVAE:
|
||||
def __init__(self,
|
||||
z_dim=16,
|
||||
vae_pth='cache/vae_step_411000.pth',
|
||||
dtype=torch.bfloat16,
|
||||
dtype=torch.float,
|
||||
device=get_device()):
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
@ -654,12 +657,16 @@ class WanVAE:
|
||||
"""
|
||||
videos: A list of videos each with shape [C, T, H, W].
|
||||
"""
|
||||
return [
|
||||
self.model.encode(u.unsqueeze(0), self.scale).squeeze(0) for u in videos
|
||||
]
|
||||
with amp.autocast(dtype=self.dtype):
|
||||
return [
|
||||
self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
|
||||
for u in videos
|
||||
]
|
||||
|
||||
def decode(self, zs):
|
||||
return [
|
||||
self.model.decode(u.unsqueeze(0), self.scale).clamp_(-1, 1).squeeze(0)
|
||||
for u in zs
|
||||
]
|
||||
with amp.autocast(dtype=self.dtype):
|
||||
return [
|
||||
self.model.decode(u.unsqueeze(0),
|
||||
self.scale).float().clamp_(-1, 1).squeeze(0)
|
||||
for u in zs
|
||||
]
|
||||
|
||||
@ -6,7 +6,6 @@ import os
|
||||
import random
|
||||
import sys
|
||||
import types
|
||||
from time import perf_counter
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
|
||||
@ -22,7 +21,7 @@ try:
|
||||
from torch_musa.core.memory import empty_cache
|
||||
from torch_musa.core.device import synchronize
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
torch_musa = None
|
||||
|
||||
from .distributed.fsdp import shard_model
|
||||
from .modules.model import WanModel
|
||||
@ -181,7 +180,6 @@ class WanT2V:
|
||||
seed_g = torch.Generator(device=self.device)
|
||||
seed_g.manual_seed(seed)
|
||||
|
||||
start_time = perf_counter()
|
||||
if not self.t5_cpu:
|
||||
self.text_encoder.model.to(self.device)
|
||||
context = self.text_encoder([input_prompt], self.device)
|
||||
@ -193,8 +191,6 @@ class WanT2V:
|
||||
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
|
||||
context = [t.to(self.device) for t in context]
|
||||
context_null = [t.to(self.device) for t in context_null]
|
||||
end_time = perf_counter()
|
||||
logging.info(f"T5 Encoding Context took {end_time - start_time:.2f} seconds.")
|
||||
|
||||
noise = [
|
||||
torch.randn(
|
||||
@ -243,14 +239,13 @@ class WanT2V:
|
||||
arg_c = {'context': context, 'seq_len': seq_len}
|
||||
arg_null = {'context': context_null, 'seq_len': seq_len}
|
||||
|
||||
start_time = perf_counter()
|
||||
self.model.to(self.device)
|
||||
for _, t in enumerate(tqdm(timesteps)):
|
||||
latent_model_input = latents
|
||||
timestep = [t]
|
||||
|
||||
timestep = torch.stack(timestep)
|
||||
|
||||
self.model.to(self.device)
|
||||
noise_pred_cond = self.model(
|
||||
latent_model_input, t=timestep, **arg_c)[0]
|
||||
noise_pred_uncond = self.model(
|
||||
@ -266,18 +261,13 @@ class WanT2V:
|
||||
return_dict=False,
|
||||
generator=seed_g)[0]
|
||||
latents = [temp_x0.squeeze(0)]
|
||||
end_time = perf_counter()
|
||||
logging.info(f"Sampling took {end_time - start_time:.2f} seconds.")
|
||||
|
||||
x0 = latents
|
||||
if offload_model:
|
||||
self.model.cpu()
|
||||
empty_cache()
|
||||
if self.rank == 0:
|
||||
start_time = perf_counter()
|
||||
videos = self.vae.decode(x0)
|
||||
end_time = perf_counter()
|
||||
logging.info(f"VAE Decoding took {end_time - start_time:.2f} seconds.")
|
||||
|
||||
del noise, latents
|
||||
del sample_scheduler
|
||||
|
||||
@ -5,7 +5,7 @@ import torch
|
||||
try:
|
||||
import torch_musa
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
torch_musa = None
|
||||
|
||||
|
||||
def _is_musa():
|
||||
@ -31,4 +31,4 @@ def get_torch_distributed_backend() -> str:
|
||||
elif _is_musa():
|
||||
return "mccl"
|
||||
else:
|
||||
raise NotImplementedError("No Accelerators(AMD/NV/MTT GPU, AMD MI instinct accelerators) available")
|
||||
raise NotImplementedError("No Accelerators(NV/MTT GPU) available")
|
||||
|
||||
@ -27,7 +27,7 @@ try:
|
||||
from torch_musa.core.memory import empty_cache
|
||||
from torch_musa.core.device import synchronize
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
torch_musa = None
|
||||
|
||||
from .modules.vace_model import VaceWanModel
|
||||
from .text2video import (
|
||||
|
||||
Loading…
Reference in New Issue
Block a user