mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-12-21 06:32:07 +00:00
Compare commits
2 Commits
22f4714dc1
...
43ac073411
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
43ac073411 | ||
|
|
447aa08620 |
72
generate.py
72
generate.py
@ -26,7 +26,6 @@ from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CON
|
|||||||
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
||||||
from wan.utils.utils import cache_image, cache_video, str2bool
|
from wan.utils.utils import cache_image, cache_video, str2bool
|
||||||
from wan.utils.platform import get_torch_distributed_backend
|
from wan.utils.platform import get_torch_distributed_backend
|
||||||
from wan.utils.chrono_inspector import ChronoInspector
|
|
||||||
|
|
||||||
|
|
||||||
EXAMPLE_PROMPT = {
|
EXAMPLE_PROMPT = {
|
||||||
@ -366,7 +365,6 @@ def generate(args):
|
|||||||
logging.info(f"Extended prompt: {args.prompt}")
|
logging.info(f"Extended prompt: {args.prompt}")
|
||||||
|
|
||||||
logging.info("Creating WanT2V pipeline.")
|
logging.info("Creating WanT2V pipeline.")
|
||||||
with ChronoInspector("Creating WanT2V pipeline"):
|
|
||||||
wan_t2v = wan.WanT2V(
|
wan_t2v = wan.WanT2V(
|
||||||
config=cfg,
|
config=cfg,
|
||||||
checkpoint_dir=args.ckpt_dir,
|
checkpoint_dir=args.ckpt_dir,
|
||||||
@ -378,21 +376,8 @@ def generate(args):
|
|||||||
t5_cpu=args.t5_cpu,
|
t5_cpu=args.t5_cpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Warming up WanT2V pipeline ...")
|
logging.info(
|
||||||
with ChronoInspector("Warming up WanT2V pipeline"):
|
f"Generating {'image' if 't2i' in args.task else 'video'} ...")
|
||||||
_ = wan_t2v.generate(
|
|
||||||
args.prompt,
|
|
||||||
size=SIZE_CONFIGS[args.size],
|
|
||||||
frame_num=args.frame_num,
|
|
||||||
shift=args.sample_shift,
|
|
||||||
sample_solver=args.sample_solver,
|
|
||||||
sampling_steps=3,
|
|
||||||
guide_scale=args.sample_guide_scale,
|
|
||||||
seed=args.base_seed,
|
|
||||||
offload_model=args.offload_model)
|
|
||||||
|
|
||||||
logging.info(f"Generating {'image' if 't2i' in args.task else 'video'} ...")
|
|
||||||
with ChronoInspector(f"Generating {'image' if 't2i' in args.task else 'video'}"):
|
|
||||||
video = wan_t2v.generate(
|
video = wan_t2v.generate(
|
||||||
args.prompt,
|
args.prompt,
|
||||||
size=SIZE_CONFIGS[args.size],
|
size=SIZE_CONFIGS[args.size],
|
||||||
@ -437,7 +422,6 @@ def generate(args):
|
|||||||
logging.info(f"Extended prompt: {args.prompt}")
|
logging.info(f"Extended prompt: {args.prompt}")
|
||||||
|
|
||||||
logging.info("Creating WanI2V pipeline.")
|
logging.info("Creating WanI2V pipeline.")
|
||||||
with ChronoInspector("Creating WanI2V pipeline"):
|
|
||||||
wan_i2v = wan.WanI2V(
|
wan_i2v = wan.WanI2V(
|
||||||
config=cfg,
|
config=cfg,
|
||||||
checkpoint_dir=args.ckpt_dir,
|
checkpoint_dir=args.ckpt_dir,
|
||||||
@ -449,22 +433,7 @@ def generate(args):
|
|||||||
t5_cpu=args.t5_cpu,
|
t5_cpu=args.t5_cpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Warming up WanI2V pipeline ...")
|
|
||||||
with ChronoInspector("Warming up WanI2V pipeline"):
|
|
||||||
_ = wan_i2v.generate(
|
|
||||||
args.prompt,
|
|
||||||
img,
|
|
||||||
max_area=MAX_AREA_CONFIGS[args.size],
|
|
||||||
frame_num=args.frame_num,
|
|
||||||
shift=args.sample_shift,
|
|
||||||
sample_solver=args.sample_solver,
|
|
||||||
sampling_steps=3,
|
|
||||||
guide_scale=args.sample_guide_scale,
|
|
||||||
seed=args.base_seed,
|
|
||||||
offload_model=args.offload_model)
|
|
||||||
|
|
||||||
logging.info("Generating video ...")
|
logging.info("Generating video ...")
|
||||||
with ChronoInspector("Generating video"):
|
|
||||||
video = wan_i2v.generate(
|
video = wan_i2v.generate(
|
||||||
args.prompt,
|
args.prompt,
|
||||||
img,
|
img,
|
||||||
@ -511,7 +480,6 @@ def generate(args):
|
|||||||
logging.info(f"Extended prompt: {args.prompt}")
|
logging.info(f"Extended prompt: {args.prompt}")
|
||||||
|
|
||||||
logging.info("Creating WanFLF2V pipeline.")
|
logging.info("Creating WanFLF2V pipeline.")
|
||||||
with ChronoInspector("Creating WanFLF2V pipeline"):
|
|
||||||
wan_flf2v = wan.WanFLF2V(
|
wan_flf2v = wan.WanFLF2V(
|
||||||
config=cfg,
|
config=cfg,
|
||||||
checkpoint_dir=args.ckpt_dir,
|
checkpoint_dir=args.ckpt_dir,
|
||||||
@ -523,23 +491,7 @@ def generate(args):
|
|||||||
t5_cpu=args.t5_cpu,
|
t5_cpu=args.t5_cpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Warming up WanFLF2V pipeline ...")
|
|
||||||
with ChronoInspector("Warming up WanFLF2V pipeline"):
|
|
||||||
_ = wan_flf2v.generate(
|
|
||||||
args.prompt,
|
|
||||||
first_frame,
|
|
||||||
last_frame,
|
|
||||||
max_area=MAX_AREA_CONFIGS[args.size],
|
|
||||||
frame_num=args.frame_num,
|
|
||||||
shift=args.sample_shift,
|
|
||||||
sample_solver=args.sample_solver,
|
|
||||||
sampling_steps=3,
|
|
||||||
guide_scale=args.sample_guide_scale,
|
|
||||||
seed=args.base_seed,
|
|
||||||
offload_model=args.offload_model)
|
|
||||||
|
|
||||||
logging.info("Generating video ...")
|
logging.info("Generating video ...")
|
||||||
with ChronoInspector("Generating video"):
|
|
||||||
video = wan_flf2v.generate(
|
video = wan_flf2v.generate(
|
||||||
args.prompt,
|
args.prompt,
|
||||||
first_frame,
|
first_frame,
|
||||||
@ -576,7 +528,6 @@ def generate(args):
|
|||||||
logging.info(f"Extended prompt: {args.prompt}")
|
logging.info(f"Extended prompt: {args.prompt}")
|
||||||
|
|
||||||
logging.info("Creating VACE pipeline.")
|
logging.info("Creating VACE pipeline.")
|
||||||
with ChronoInspector("Creating VACE pipeline"):
|
|
||||||
wan_vace = wan.WanVace(
|
wan_vace = wan.WanVace(
|
||||||
config=cfg,
|
config=cfg,
|
||||||
checkpoint_dir=args.ckpt_dir,
|
checkpoint_dir=args.ckpt_dir,
|
||||||
@ -594,24 +545,7 @@ def generate(args):
|
|||||||
args.src_ref_images.split(',')
|
args.src_ref_images.split(',')
|
||||||
], args.frame_num, SIZE_CONFIGS[args.size], device)
|
], args.frame_num, SIZE_CONFIGS[args.size], device)
|
||||||
|
|
||||||
logging.info("Warming up VACE pipeline ...")
|
|
||||||
with ChronoInspector("Warming up VACE pipeline"):
|
|
||||||
video = wan_vace.generate(
|
|
||||||
args.prompt,
|
|
||||||
src_video,
|
|
||||||
src_mask,
|
|
||||||
src_ref_images,
|
|
||||||
size=SIZE_CONFIGS[args.size],
|
|
||||||
frame_num=args.frame_num,
|
|
||||||
shift=args.sample_shift,
|
|
||||||
sample_solver=args.sample_solver,
|
|
||||||
sampling_steps=3,
|
|
||||||
guide_scale=args.sample_guide_scale,
|
|
||||||
seed=args.base_seed,
|
|
||||||
offload_model=args.offload_model)
|
|
||||||
|
|
||||||
logging.info(f"Generating video...")
|
logging.info(f"Generating video...")
|
||||||
with ChronoInspector("Generating video"):
|
|
||||||
video = wan_vace.generate(
|
video = wan_vace.generate(
|
||||||
args.prompt,
|
args.prompt,
|
||||||
src_video,
|
src_video,
|
||||||
@ -638,7 +572,6 @@ def generate(args):
|
|||||||
|
|
||||||
if "t2i" in args.task:
|
if "t2i" in args.task:
|
||||||
logging.info(f"Saving generated image to {args.save_file}")
|
logging.info(f"Saving generated image to {args.save_file}")
|
||||||
with ChronoInspector("Saving generated image"):
|
|
||||||
cache_image(
|
cache_image(
|
||||||
tensor=video.squeeze(1)[None],
|
tensor=video.squeeze(1)[None],
|
||||||
save_file=args.save_file,
|
save_file=args.save_file,
|
||||||
@ -647,7 +580,6 @@ def generate(args):
|
|||||||
value_range=(-1, 1))
|
value_range=(-1, 1))
|
||||||
else:
|
else:
|
||||||
logging.info(f"Saving generated video to {args.save_file}")
|
logging.info(f"Saving generated video to {args.save_file}")
|
||||||
with ChronoInspector("Saving generated video"):
|
|
||||||
cache_video(
|
cache_video(
|
||||||
tensor=video[None],
|
tensor=video[None],
|
||||||
save_file=args.save_file,
|
save_file=args.save_file,
|
||||||
|
|||||||
@ -9,13 +9,10 @@ from xfuser.core.distributed import (
|
|||||||
from xfuser.core.long_ctx_attention import xFuserLongContextAttention, AttnType
|
from xfuser.core.long_ctx_attention import xFuserLongContextAttention, AttnType
|
||||||
attn_type:AttnType = AttnType.FA
|
attn_type:AttnType = AttnType.FA
|
||||||
|
|
||||||
from wan.modules.rope import rope_apply_pytorch, rope_apply_triton
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torch_musa
|
import torch_musa
|
||||||
import torch_musa.core.amp as amp
|
import torch_musa.core.amp as amp
|
||||||
attn_type = AttnType.TORCH
|
attn_type = AttnType.TORCH
|
||||||
torch.backends.mudnn.allow_tf32 = True
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
torch_musa = None
|
torch_musa = None
|
||||||
|
|
||||||
@ -35,8 +32,21 @@ def pad_freqs(original_tensor, target_len):
|
|||||||
return padded_tensor
|
return padded_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def pad_tensor(original_tensor, target_len, pad_value=0.0):
|
||||||
|
seq_len, s1, s2 = original_tensor.shape
|
||||||
|
pad_size = target_len - seq_len
|
||||||
|
padding_tensor = torch.full(
|
||||||
|
(pad_size, s1, s2),
|
||||||
|
pad_value,
|
||||||
|
dtype=original_tensor.dtype,
|
||||||
|
device=original_tensor.device,
|
||||||
|
)
|
||||||
|
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
|
||||||
|
return padded_tensor
|
||||||
|
|
||||||
|
|
||||||
@amp.autocast(enabled=False)
|
@amp.autocast(enabled=False)
|
||||||
def rope_apply(x, grid_sizes, freqs, sp_size, sp_rank):
|
def rope_apply(x, grid_sizes, freqs):
|
||||||
"""
|
"""
|
||||||
x: [B, L, N, C].
|
x: [B, L, N, C].
|
||||||
grid_sizes: [B, 3].
|
grid_sizes: [B, 3].
|
||||||
@ -62,6 +72,8 @@ def rope_apply(x, grid_sizes, freqs, sp_size, sp_rank):
|
|||||||
dim=-1).reshape(seq_len, 1, -1)
|
dim=-1).reshape(seq_len, 1, -1)
|
||||||
|
|
||||||
# apply rotary embedding
|
# apply rotary embedding
|
||||||
|
sp_size = get_sequence_parallel_world_size()
|
||||||
|
sp_rank = get_sequence_parallel_rank()
|
||||||
freqs_i = pad_freqs(freqs_i, s * sp_size)
|
freqs_i = pad_freqs(freqs_i, s * sp_size)
|
||||||
s_per_rank = s
|
s_per_rank = s
|
||||||
freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
|
freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
|
||||||
@ -74,6 +86,69 @@ def rope_apply(x, grid_sizes, freqs, sp_size, sp_rank):
|
|||||||
return torch.stack(output).float()
|
return torch.stack(output).float()
|
||||||
|
|
||||||
|
|
||||||
|
@amp.autocast(enabled=False)
|
||||||
|
def rope_apply_musa(x, grid_sizes, freqs):
|
||||||
|
"""
|
||||||
|
x: [B, L, N, C].
|
||||||
|
grid_sizes: [B, 3].
|
||||||
|
freqs: [M, C // 2].
|
||||||
|
"""
|
||||||
|
s, n, c = x.size(1), x.size(2), x.size(3) // 2
|
||||||
|
c0 = c - 2 * (c // 3)
|
||||||
|
c1 = c // 3
|
||||||
|
c2 = c // 3
|
||||||
|
|
||||||
|
# split freqs
|
||||||
|
freqs_real = freqs[0].split([c0, c1, c2], dim=1)
|
||||||
|
freqs_imag = freqs[-1].split([c0, c1, c2], dim=1)
|
||||||
|
|
||||||
|
# loop over samples
|
||||||
|
output = []
|
||||||
|
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
||||||
|
seq_len = f * h * w
|
||||||
|
|
||||||
|
# precompute multipliers
|
||||||
|
x_i = x[i, :seq_len].reshape(s, n, -1, 2)
|
||||||
|
x_real = x_i[..., 0]
|
||||||
|
x_imag = x_i[..., 1]
|
||||||
|
freqs_real = torch.cat(
|
||||||
|
[
|
||||||
|
freqs_real[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||||
|
freqs_real[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||||
|
freqs_real[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
).reshape(seq_len, 1, -1)
|
||||||
|
freqs_imag = torch.cat(
|
||||||
|
[
|
||||||
|
freqs_imag[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||||
|
freqs_imag[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||||
|
freqs_imag[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
).reshape(seq_len, 1, -1)
|
||||||
|
|
||||||
|
# apply rotary embedding
|
||||||
|
sp_size = get_sequence_parallel_world_size()
|
||||||
|
sp_rank = get_sequence_parallel_rank()
|
||||||
|
|
||||||
|
freqs_real = pad_tensor(freqs_real, s * sp_size, 1.0)
|
||||||
|
freqs_imag = pad_tensor(freqs_imag, s * sp_size, 0.0)
|
||||||
|
|
||||||
|
freqs_real_rank = freqs_real[(sp_rank * s) : ((sp_rank + 1) * s), :, :]
|
||||||
|
freqs_imag_rank = freqs_imag[(sp_rank * s) : ((sp_rank + 1) * s), :, :]
|
||||||
|
|
||||||
|
out_real = x_real * freqs_real_rank - x_imag * freqs_imag_rank
|
||||||
|
out_imag = x_real * freqs_imag_rank + x_imag * freqs_real_rank
|
||||||
|
|
||||||
|
x_out = torch.stack([out_real, out_imag], dim=-1).flatten(2)
|
||||||
|
x_out = torch.cat([x_out, x[i, seq_len:]], dim=0)
|
||||||
|
|
||||||
|
# append to collection
|
||||||
|
output.append(x_out)
|
||||||
|
return torch.stack(output)
|
||||||
|
|
||||||
|
|
||||||
def usp_dit_forward_vace(self, x, vace_context, seq_len, kwargs):
|
def usp_dit_forward_vace(self, x, vace_context, seq_len, kwargs):
|
||||||
# embeddings
|
# embeddings
|
||||||
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
||||||
@ -120,11 +195,15 @@ def usp_dit_forward(
|
|||||||
# params
|
# params
|
||||||
dtype = self.patch_embedding.weight.dtype
|
dtype = self.patch_embedding.weight.dtype
|
||||||
device = self.patch_embedding.weight.device
|
device = self.patch_embedding.weight.device
|
||||||
|
if torch_musa is not None:
|
||||||
if self.freqs[0].dtype != dtype or self.freqs[0].device != device:
|
if self.freqs[0].dtype != dtype or self.freqs[0].device != device:
|
||||||
self.freqs = (
|
self.freqs = (
|
||||||
self.freqs[0].to(dtype=dtype, device=device),
|
self.freqs[0].to(dtype=dtype, device=device),
|
||||||
self.freqs[-1].to(dtype=dtype, device=device)
|
self.freqs[-1].to(dtype=dtype, device=device)
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
if self.freqs.dtype != dtype or self.freqs.device != device:
|
||||||
|
self.freqs = self.freqs.to(dtype=dtype, device=device)
|
||||||
|
|
||||||
if self.model_type != 'vace' and y is not None:
|
if self.model_type != 'vace' and y is not None:
|
||||||
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
||||||
@ -142,9 +221,11 @@ def usp_dit_forward(
|
|||||||
])
|
])
|
||||||
|
|
||||||
# time embeddings
|
# time embeddings
|
||||||
|
with amp.autocast(dtype=torch.float32):
|
||||||
e = self.time_embedding(
|
e = self.time_embedding(
|
||||||
sinusoidal_embedding_1d(self.freq_dim, t))
|
sinusoidal_embedding_1d(self.freq_dim, t).float())
|
||||||
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
||||||
|
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
||||||
|
|
||||||
# context
|
# context
|
||||||
context_lens = None
|
context_lens = None
|
||||||
@ -188,7 +269,7 @@ def usp_dit_forward(
|
|||||||
|
|
||||||
# unpatchify
|
# unpatchify
|
||||||
x = self.unpatchify(x, grid_sizes)
|
x = self.unpatchify(x, grid_sizes)
|
||||||
return x
|
return [u.float() for u in x]
|
||||||
|
|
||||||
|
|
||||||
def usp_attn_forward(self,
|
def usp_attn_forward(self,
|
||||||
@ -211,12 +292,13 @@ def usp_attn_forward(self,
|
|||||||
return q, k, v
|
return q, k, v
|
||||||
|
|
||||||
q, k, v = qkv_fn(x)
|
q, k, v = qkv_fn(x)
|
||||||
if torch_musa is None:
|
|
||||||
q = rope_apply(q, grid_sizes, freqs, get_sequence_parallel_world_size(), get_sequence_parallel_rank())
|
if torch_musa is not None:
|
||||||
k = rope_apply(k, grid_sizes, freqs, get_sequence_parallel_world_size(), get_sequence_parallel_rank())
|
q = rope_apply_musa(q, grid_sizes, freqs)
|
||||||
|
k = rope_apply_musa(k, grid_sizes, freqs)
|
||||||
else:
|
else:
|
||||||
q = rope_apply_pytorch(q, grid_sizes, freqs, get_sequence_parallel_world_size(), get_sequence_parallel_rank())
|
q = rope_apply(q, grid_sizes, freqs)
|
||||||
k = rope_apply_pytorch(k, grid_sizes, freqs, get_sequence_parallel_world_size(), get_sequence_parallel_rank())
|
k = rope_apply(k, grid_sizes, freqs)
|
||||||
|
|
||||||
# TODO: We should use unpaded q,k,v for attention.
|
# TODO: We should use unpaded q,k,v for attention.
|
||||||
# k_lens = seq_lens // get_sequence_parallel_world_size()
|
# k_lens = seq_lens // get_sequence_parallel_world_size()
|
||||||
|
|||||||
@ -22,23 +22,21 @@ try:
|
|||||||
import torch_musa.core.amp as amp
|
import torch_musa.core.amp as amp
|
||||||
from torch_musa.core.memory import empty_cache
|
from torch_musa.core.memory import empty_cache
|
||||||
from torch_musa.core.device import synchronize
|
from torch_musa.core.device import synchronize
|
||||||
torch.backends.mudnn.allow_tf32 = True
|
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
torch_musa = None
|
torch_musa = None
|
||||||
|
|
||||||
from wan.distributed.fsdp import shard_model
|
from .distributed.fsdp import shard_model
|
||||||
from wan.modules.clip import CLIPModel
|
from .modules.clip import CLIPModel
|
||||||
from wan.modules.model import WanModel
|
from .modules.model import WanModel
|
||||||
from wan.modules.t5 import T5EncoderModel
|
from .modules.t5 import T5EncoderModel
|
||||||
from wan.modules.vae import WanVAE
|
from .modules.vae import WanVAE
|
||||||
from wan.utils.fm_solvers import (
|
from .utils.fm_solvers import (
|
||||||
FlowDPMSolverMultistepScheduler,
|
FlowDPMSolverMultistepScheduler,
|
||||||
get_sampling_sigmas,
|
get_sampling_sigmas,
|
||||||
retrieve_timesteps,
|
retrieve_timesteps,
|
||||||
)
|
)
|
||||||
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||||
from wan.utils.platform import get_device
|
from .utils.platform import get_device
|
||||||
from wan.utils.memory_format import convert_conv3d_weight_memory_format
|
|
||||||
|
|
||||||
|
|
||||||
class WanFLF2V:
|
class WanFLF2V:
|
||||||
@ -102,7 +100,6 @@ class WanFLF2V:
|
|||||||
self.vae = WanVAE(
|
self.vae = WanVAE(
|
||||||
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
|
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
|
||||||
device=self.device)
|
device=self.device)
|
||||||
convert_conv3d_weight_memory_format(self.vae.model, memory_format=torch.channels_last_3d)
|
|
||||||
|
|
||||||
self.clip = CLIPModel(
|
self.clip = CLIPModel(
|
||||||
dtype=config.clip_dtype,
|
dtype=config.clip_dtype,
|
||||||
@ -134,8 +131,7 @@ class WanFLF2V:
|
|||||||
self.sp_size = 1
|
self.sp_size = 1
|
||||||
|
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
# dist.barrier()
|
dist.barrier()
|
||||||
pass
|
|
||||||
if dit_fsdp:
|
if dit_fsdp:
|
||||||
self.model = shard_fn(self.model)
|
self.model = shard_fn(self.model)
|
||||||
else:
|
else:
|
||||||
@ -386,7 +382,6 @@ class WanFLF2V:
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
synchronize()
|
synchronize()
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
# dist.barrier()
|
dist.barrier()
|
||||||
pass
|
|
||||||
|
|
||||||
return videos[0] if self.rank == 0 else None
|
return videos[0] if self.rank == 0 else None
|
||||||
|
|||||||
@ -22,23 +22,21 @@ try:
|
|||||||
import torch_musa.core.amp as amp
|
import torch_musa.core.amp as amp
|
||||||
from torch_musa.core.memory import empty_cache
|
from torch_musa.core.memory import empty_cache
|
||||||
from torch_musa.core.device import synchronize
|
from torch_musa.core.device import synchronize
|
||||||
torch.backends.mudnn.allow_tf32 = True
|
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
torch_musa = None
|
torch_musa = None
|
||||||
|
|
||||||
from wan.distributed.fsdp import shard_model
|
from .distributed.fsdp import shard_model
|
||||||
from wan.modules.clip import CLIPModel
|
from .modules.clip import CLIPModel
|
||||||
from wan.modules.model import WanModel
|
from .modules.model import WanModel
|
||||||
from wan.modules.t5 import T5EncoderModel
|
from .modules.t5 import T5EncoderModel
|
||||||
from wan.modules.vae import WanVAE
|
from .modules.vae import WanVAE
|
||||||
from wan.utils.fm_solvers import (
|
from .utils.fm_solvers import (
|
||||||
FlowDPMSolverMultistepScheduler,
|
FlowDPMSolverMultistepScheduler,
|
||||||
get_sampling_sigmas,
|
get_sampling_sigmas,
|
||||||
retrieve_timesteps,
|
retrieve_timesteps,
|
||||||
)
|
)
|
||||||
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||||
from wan.utils.platform import get_device
|
from .utils.platform import get_device
|
||||||
from wan.utils.memory_format import convert_conv3d_weight_memory_format
|
|
||||||
|
|
||||||
|
|
||||||
class WanI2V:
|
class WanI2V:
|
||||||
@ -102,7 +100,6 @@ class WanI2V:
|
|||||||
self.vae = WanVAE(
|
self.vae = WanVAE(
|
||||||
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
|
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
|
||||||
device=self.device)
|
device=self.device)
|
||||||
convert_conv3d_weight_memory_format(self.vae.model, memory_format=torch.channels_last_3d)
|
|
||||||
|
|
||||||
self.clip = CLIPModel(
|
self.clip = CLIPModel(
|
||||||
dtype=config.clip_dtype,
|
dtype=config.clip_dtype,
|
||||||
@ -134,8 +131,7 @@ class WanI2V:
|
|||||||
self.sp_size = 1
|
self.sp_size = 1
|
||||||
|
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
# dist.barrier()
|
dist.barrier()
|
||||||
pass
|
|
||||||
if dit_fsdp:
|
if dit_fsdp:
|
||||||
self.model = shard_fn(self.model)
|
self.model = shard_fn(self.model)
|
||||||
else:
|
else:
|
||||||
@ -359,7 +355,6 @@ class WanI2V:
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
synchronize()
|
synchronize()
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
# dist.barrier()
|
dist.barrier()
|
||||||
pass
|
|
||||||
|
|
||||||
return videos[0] if self.rank == 0 else None
|
return videos[0] if self.rank == 0 else None
|
||||||
|
|||||||
@ -19,7 +19,6 @@ try:
|
|||||||
import torch_musa
|
import torch_musa
|
||||||
FLASH_ATTN_3_AVAILABLE = False
|
FLASH_ATTN_3_AVAILABLE = False
|
||||||
FLASH_ATTN_2_AVAILABLE = False
|
FLASH_ATTN_2_AVAILABLE = False
|
||||||
torch.backends.mudnn.allow_tf32 = True
|
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
torch_musa = None
|
torch_musa = None
|
||||||
|
|
||||||
@ -60,7 +59,7 @@ def flash_attention(
|
|||||||
"""
|
"""
|
||||||
half_dtypes = (torch.float16, torch.bfloat16)
|
half_dtypes = (torch.float16, torch.bfloat16)
|
||||||
assert dtype in half_dtypes
|
assert dtype in half_dtypes
|
||||||
assert q.device.type in ("cuda", "musa") and q.size(-1) <= 256
|
assert (q.device.type == "cuda" or q.device.type == "musa") and q.size(-1) <= 256
|
||||||
|
|
||||||
# params
|
# params
|
||||||
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
|
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
|
||||||
@ -181,7 +180,6 @@ def attention(
|
|||||||
k = k.transpose(1, 2).to(dtype)
|
k = k.transpose(1, 2).to(dtype)
|
||||||
v = v.transpose(1, 2).to(dtype)
|
v = v.transpose(1, 2).to(dtype)
|
||||||
|
|
||||||
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False):
|
|
||||||
out = torch.nn.functional.scaled_dot_product_attention(
|
out = torch.nn.functional.scaled_dot_product_attention(
|
||||||
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=causal, scale=softmax_scale)
|
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=causal, scale=softmax_scale)
|
||||||
|
|
||||||
|
|||||||
@ -37,7 +37,7 @@ def pos_interpolate(pos, seq_len):
|
|||||||
return torch.cat([
|
return torch.cat([
|
||||||
pos[:, :n],
|
pos[:, :n],
|
||||||
F.interpolate(
|
F.interpolate(
|
||||||
pos[:, n:].reshape(1, src_grid, src_grid, -1).permute(
|
pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
|
||||||
0, 3, 1, 2),
|
0, 3, 1, 2),
|
||||||
size=(tar_grid, tar_grid),
|
size=(tar_grid, tar_grid),
|
||||||
mode='bicubic',
|
mode='bicubic',
|
||||||
@ -52,6 +52,12 @@ class QuickGELU(nn.Module):
|
|||||||
return x * torch.sigmoid(1.702 * x)
|
return x * torch.sigmoid(1.702 * x)
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.LayerNorm):
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return super().forward(x.float()).type_as(x)
|
||||||
|
|
||||||
|
|
||||||
class SelfAttention(nn.Module):
|
class SelfAttention(nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -84,7 +90,10 @@ class SelfAttention(nn.Module):
|
|||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
p = self.attn_dropout if self.training else 0.0
|
p = self.attn_dropout if self.training else 0.0
|
||||||
|
if torch_musa is not None:
|
||||||
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal)
|
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal)
|
||||||
|
else:
|
||||||
|
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
|
||||||
x = x.reshape(b, s, c)
|
x = x.reshape(b, s, c)
|
||||||
|
|
||||||
# output
|
# output
|
||||||
@ -133,10 +142,10 @@ class AttentionBlock(nn.Module):
|
|||||||
self.norm_eps = norm_eps
|
self.norm_eps = norm_eps
|
||||||
|
|
||||||
# layers
|
# layers
|
||||||
self.norm1 = nn.LayerNorm(dim, eps=norm_eps)
|
self.norm1 = LayerNorm(dim, eps=norm_eps)
|
||||||
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
|
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
|
||||||
proj_dropout)
|
proj_dropout)
|
||||||
self.norm2 = nn.LayerNorm(dim, eps=norm_eps)
|
self.norm2 = LayerNorm(dim, eps=norm_eps)
|
||||||
if activation == 'swi_glu':
|
if activation == 'swi_glu':
|
||||||
self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
|
self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
|
||||||
else:
|
else:
|
||||||
@ -179,7 +188,7 @@ class AttentionPool(nn.Module):
|
|||||||
self.to_q = nn.Linear(dim, dim)
|
self.to_q = nn.Linear(dim, dim)
|
||||||
self.to_kv = nn.Linear(dim, dim * 2)
|
self.to_kv = nn.Linear(dim, dim * 2)
|
||||||
self.proj = nn.Linear(dim, dim)
|
self.proj = nn.Linear(dim, dim)
|
||||||
self.norm = nn.LayerNorm(dim, eps=norm_eps)
|
self.norm = LayerNorm(dim, eps=norm_eps)
|
||||||
self.mlp = nn.Sequential(
|
self.mlp = nn.Sequential(
|
||||||
nn.Linear(dim, int(dim * mlp_ratio)),
|
nn.Linear(dim, int(dim * mlp_ratio)),
|
||||||
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
|
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
|
||||||
@ -196,6 +205,9 @@ class AttentionPool(nn.Module):
|
|||||||
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
|
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
|
||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
|
if torch_musa is not None:
|
||||||
|
x = flash_attention(q, k, v)
|
||||||
|
else:
|
||||||
x = flash_attention(q, k, v, version=2)
|
x = flash_attention(q, k, v, version=2)
|
||||||
x = x.reshape(b, 1, c)
|
x = x.reshape(b, 1, c)
|
||||||
|
|
||||||
@ -261,13 +273,13 @@ class VisionTransformer(nn.Module):
|
|||||||
self.dropout = nn.Dropout(embedding_dropout)
|
self.dropout = nn.Dropout(embedding_dropout)
|
||||||
|
|
||||||
# transformer
|
# transformer
|
||||||
self.pre_norm = nn.LayerNorm(dim, eps=norm_eps) if pre_norm else None
|
self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
|
||||||
self.transformer = nn.Sequential(*[
|
self.transformer = nn.Sequential(*[
|
||||||
AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
|
AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
|
||||||
activation, attn_dropout, proj_dropout, norm_eps)
|
activation, attn_dropout, proj_dropout, norm_eps)
|
||||||
for _ in range(num_layers)
|
for _ in range(num_layers)
|
||||||
])
|
])
|
||||||
self.post_norm = nn.LayerNorm(dim, eps=norm_eps)
|
self.post_norm = LayerNorm(dim, eps=norm_eps)
|
||||||
|
|
||||||
# head
|
# head
|
||||||
if pool_type == 'token':
|
if pool_type == 'token':
|
||||||
|
|||||||
@ -8,15 +8,13 @@ from diffusers.configuration_utils import ConfigMixin, register_to_config
|
|||||||
from diffusers.models.modeling_utils import ModelMixin
|
from diffusers.models.modeling_utils import ModelMixin
|
||||||
|
|
||||||
from wan.modules.attention import flash_attention
|
from wan.modules.attention import flash_attention
|
||||||
from wan.modules.rope import rope_apply_pytorch
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torch_musa
|
import torch_musa
|
||||||
import torch_musa.core.amp as amp
|
import torch_musa.core.amp as amp
|
||||||
from wan.modules.attention import attention as flash_attention
|
from wan.modules.attention import attention as flash_attention
|
||||||
torch.backends.mudnn.allow_tf32 = True
|
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
pass
|
torch_musa = None
|
||||||
|
|
||||||
__all__ = ['WanModel']
|
__all__ = ['WanModel']
|
||||||
|
|
||||||
@ -28,7 +26,7 @@ def sinusoidal_embedding_1d(dim, position):
|
|||||||
# preprocess
|
# preprocess
|
||||||
assert dim % 2 == 0
|
assert dim % 2 == 0
|
||||||
half = dim // 2
|
half = dim // 2
|
||||||
position = position.type(torch.bfloat16)
|
position = position.type(torch.float32)
|
||||||
|
|
||||||
# calculation
|
# calculation
|
||||||
sinusoid = torch.outer(
|
sinusoid = torch.outer(
|
||||||
@ -37,6 +35,17 @@ def sinusoidal_embedding_1d(dim, position):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@amp.autocast(enabled=False)
|
||||||
|
def rope_params(max_seq_len, dim, theta=10000):
|
||||||
|
assert dim % 2 == 0
|
||||||
|
freqs = torch.outer(
|
||||||
|
torch.arange(max_seq_len),
|
||||||
|
1.0 / torch.pow(theta,
|
||||||
|
torch.arange(0, dim, 2).to(torch.float64).div(dim)))
|
||||||
|
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||||
|
return freqs
|
||||||
|
|
||||||
|
|
||||||
@amp.autocast(enabled=False)
|
@amp.autocast(enabled=False)
|
||||||
def rope_params_real(
|
def rope_params_real(
|
||||||
max_seq_len, dim, theta=10000, dtype=torch.float32, device=torch.device("cpu")
|
max_seq_len, dim, theta=10000, dtype=torch.float32, device=torch.device("cpu")
|
||||||
@ -98,6 +107,55 @@ def rope_apply(x, grid_sizes, freqs):
|
|||||||
return torch.stack(output).float()
|
return torch.stack(output).float()
|
||||||
|
|
||||||
|
|
||||||
|
@amp.autocast(enabled=False)
|
||||||
|
def rope_apply_musa(x, grid_sizes, freqs):
|
||||||
|
n, c = x.size(2), x.size(3) // 2
|
||||||
|
c0 = c - 2 * (c // 3)
|
||||||
|
c1 = c // 3
|
||||||
|
c2 = c // 3
|
||||||
|
|
||||||
|
# split freqs
|
||||||
|
freqs_real = freqs[0].split([c0, c1, c2], dim=1)
|
||||||
|
freqs_imag = freqs[-1].split([c0, c1, c2], dim=1)
|
||||||
|
|
||||||
|
# loop over samples
|
||||||
|
output = []
|
||||||
|
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
||||||
|
seq_len = f * h * w
|
||||||
|
|
||||||
|
# precompute multipliers
|
||||||
|
x_i = x[i, :seq_len].reshape(seq_len, n, c, 2)
|
||||||
|
x_real = x_i[..., 0]
|
||||||
|
x_imag = x_i[..., 1]
|
||||||
|
freqs_real = torch.cat(
|
||||||
|
[
|
||||||
|
freqs_real[0][:f].view(f, 1, 1, c0).expand(f, h, w, c0),
|
||||||
|
freqs_real[1][:h].view(1, h, 1, c1).expand(f, h, w, c1),
|
||||||
|
freqs_real[2][:w].view(1, 1, w, c2).expand(f, h, w, c2),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
).reshape(seq_len, 1, c)
|
||||||
|
freqs_imag = torch.cat(
|
||||||
|
[
|
||||||
|
freqs_imag[0][:f].view(f, 1, 1, c0).expand(f, h, w, c0),
|
||||||
|
freqs_imag[1][:h].view(1, h, 1, c1).expand(f, h, w, c1),
|
||||||
|
freqs_imag[2][:w].view(1, 1, w, c2).expand(f, h, w, c2),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
).reshape(seq_len, 1, c)
|
||||||
|
|
||||||
|
out_real = x_real * freqs_real - x_imag * freqs_imag
|
||||||
|
out_imag = x_real * freqs_imag + x_imag * freqs_real
|
||||||
|
|
||||||
|
# apply rotary embedding
|
||||||
|
x_out = torch.stack([out_real, out_imag], dim=-1).flatten(2)
|
||||||
|
x_out = torch.cat([x_out, x[i, seq_len:]], dim=0)
|
||||||
|
|
||||||
|
# append to collection
|
||||||
|
output.append(x_out)
|
||||||
|
return torch.stack(output)
|
||||||
|
|
||||||
|
|
||||||
class WanRMSNorm(nn.Module):
|
class WanRMSNorm(nn.Module):
|
||||||
|
|
||||||
def __init__(self, dim, eps=1e-5):
|
def __init__(self, dim, eps=1e-5):
|
||||||
@ -117,6 +175,19 @@ class WanRMSNorm(nn.Module):
|
|||||||
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
class WanLayerNorm(nn.LayerNorm):
|
||||||
|
|
||||||
|
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
|
||||||
|
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
x(Tensor): Shape [B, L, C]
|
||||||
|
"""
|
||||||
|
return super().forward(x.float()).type_as(x)
|
||||||
|
|
||||||
|
|
||||||
class WanSelfAttention(nn.Module):
|
class WanSelfAttention(nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -160,19 +231,23 @@ class WanSelfAttention(nn.Module):
|
|||||||
return q, k, v
|
return q, k, v
|
||||||
|
|
||||||
q, k, v = qkv_fn(x)
|
q, k, v = qkv_fn(x)
|
||||||
if torch_musa is None:
|
|
||||||
q = rope_apply(q, grid_sizes, freqs)
|
|
||||||
k = rope_apply(k, grid_sizes, freqs)
|
|
||||||
else:
|
|
||||||
q = rope_apply_pytorch(q, grid_sizes, freqs)
|
|
||||||
k = rope_apply_pytorch(k, grid_sizes, freqs)
|
|
||||||
|
|
||||||
|
if torch_musa is not None:
|
||||||
x = flash_attention(
|
x = flash_attention(
|
||||||
q=q,
|
q=rope_apply_musa(q, grid_sizes, freqs),
|
||||||
k=k,
|
k=rope_apply_musa(k, grid_sizes, freqs),
|
||||||
v=v,
|
v=v,
|
||||||
k_lens=seq_lens,
|
k_lens=seq_lens,
|
||||||
window_size=self.window_size)
|
window_size=self.window_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x = flash_attention(
|
||||||
|
q=rope_apply(q, grid_sizes, freqs),
|
||||||
|
k=rope_apply(k, grid_sizes, freqs),
|
||||||
|
v=v,
|
||||||
|
k_lens=seq_lens,
|
||||||
|
window_size=self.window_size,
|
||||||
|
)
|
||||||
|
|
||||||
# output
|
# output
|
||||||
x = x.flatten(2)
|
x = x.flatten(2)
|
||||||
@ -277,10 +352,10 @@ class WanAttentionBlock(nn.Module):
|
|||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
|
||||||
# layers
|
# layers
|
||||||
self.norm1 = nn.LayerNorm(dim, eps, elementwise_affine=False)
|
self.norm1 = WanLayerNorm(dim, eps)
|
||||||
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
|
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
|
||||||
eps)
|
eps)
|
||||||
self.norm3 = nn.LayerNorm(
|
self.norm3 = WanLayerNorm(
|
||||||
dim, eps,
|
dim, eps,
|
||||||
elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
||||||
self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
|
self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
|
||||||
@ -288,7 +363,7 @@ class WanAttentionBlock(nn.Module):
|
|||||||
(-1, -1),
|
(-1, -1),
|
||||||
qk_norm,
|
qk_norm,
|
||||||
eps)
|
eps)
|
||||||
self.norm2 = nn.LayerNorm(dim, eps, elementwise_affine=False)
|
self.norm2 = WanLayerNorm(dim, eps)
|
||||||
self.ffn = nn.Sequential(
|
self.ffn = nn.Sequential(
|
||||||
nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
|
nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
|
||||||
nn.Linear(ffn_dim, dim))
|
nn.Linear(ffn_dim, dim))
|
||||||
@ -314,18 +389,23 @@ class WanAttentionBlock(nn.Module):
|
|||||||
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
||||||
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
||||||
"""
|
"""
|
||||||
|
assert e.dtype == torch.float32
|
||||||
|
with amp.autocast(dtype=torch.float32):
|
||||||
e = (self.modulation + e).chunk(6, dim=1)
|
e = (self.modulation + e).chunk(6, dim=1)
|
||||||
|
assert e[0].dtype == torch.float32
|
||||||
|
|
||||||
# self-attention
|
# self-attention
|
||||||
y = self.self_attn(
|
y = self.self_attn(
|
||||||
self.norm1(x) * (1 + e[1]) + e[0], seq_lens, grid_sizes,
|
self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes,
|
||||||
freqs)
|
freqs)
|
||||||
|
with amp.autocast(dtype=torch.float32):
|
||||||
x = x + y * e[2]
|
x = x + y * e[2]
|
||||||
|
|
||||||
# cross-attention & ffn function
|
# cross-attention & ffn function
|
||||||
def cross_attn_ffn(x, context, context_lens, e):
|
def cross_attn_ffn(x, context, context_lens, e):
|
||||||
x = x + self.cross_attn(self.norm3(x), context, context_lens)
|
x = x + self.cross_attn(self.norm3(x), context, context_lens)
|
||||||
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
|
y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
|
||||||
|
with amp.autocast(dtype=torch.float32):
|
||||||
x = x + y * e[5]
|
x = x + y * e[5]
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -344,7 +424,7 @@ class Head(nn.Module):
|
|||||||
|
|
||||||
# layers
|
# layers
|
||||||
out_dim = math.prod(patch_size) * out_dim
|
out_dim = math.prod(patch_size) * out_dim
|
||||||
self.norm = nn.LayerNorm(dim, eps, elementwise_affine=False)
|
self.norm = WanLayerNorm(dim, eps)
|
||||||
self.head = nn.Linear(dim, out_dim)
|
self.head = nn.Linear(dim, out_dim)
|
||||||
|
|
||||||
# modulation
|
# modulation
|
||||||
@ -356,8 +436,10 @@ class Head(nn.Module):
|
|||||||
x(Tensor): Shape [B, L1, C]
|
x(Tensor): Shape [B, L1, C]
|
||||||
e(Tensor): Shape [B, C]
|
e(Tensor): Shape [B, C]
|
||||||
"""
|
"""
|
||||||
|
assert e.dtype == torch.float32
|
||||||
|
with amp.autocast(dtype=torch.float32):
|
||||||
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
|
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
|
||||||
x = self.head(self.norm(x) * (1 + e[1]) + e[0])
|
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -491,16 +573,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
# buffers (don't use register_buffer otherwise dtype will be changed in to())
|
# buffers (don't use register_buffer otherwise dtype will be changed in to())
|
||||||
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
||||||
d = dim // num_heads
|
d = dim // num_heads
|
||||||
if torch_musa is None:
|
if torch_musa is not None:
|
||||||
self.freqs = torch.cat(
|
|
||||||
[
|
|
||||||
rope_params(1024, d - 4 * (d // 6)),
|
|
||||||
rope_params(1024, 2 * (d // 6)),
|
|
||||||
rope_params(1024, 2 * (d // 6)),
|
|
||||||
],
|
|
||||||
dim=1,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
freqs_real = torch.cat(
|
freqs_real = torch.cat(
|
||||||
[
|
[
|
||||||
rope_params_real(1024, d - 4 * (d // 6)),
|
rope_params_real(1024, d - 4 * (d // 6)),
|
||||||
@ -518,6 +591,15 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
dim=1,
|
dim=1,
|
||||||
)
|
)
|
||||||
self.freqs = (freqs_real, freqs_imag)
|
self.freqs = (freqs_real, freqs_imag)
|
||||||
|
else:
|
||||||
|
self.freqs = torch.cat(
|
||||||
|
[
|
||||||
|
rope_params(1024, d - 4 * (d // 6)),
|
||||||
|
rope_params(1024, 2 * (d // 6)),
|
||||||
|
rope_params(1024, 2 * (d // 6)),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
if model_type == 'i2v' or model_type == 'flf2v':
|
if model_type == 'i2v' or model_type == 'flf2v':
|
||||||
self.img_emb = MLPProj(1280, dim, flf_pos_emb=model_type == 'flf2v')
|
self.img_emb = MLPProj(1280, dim, flf_pos_emb=model_type == 'flf2v')
|
||||||
@ -560,15 +642,15 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
# params
|
# params
|
||||||
dtype = self.patch_embedding.weight.dtype
|
dtype = self.patch_embedding.weight.dtype
|
||||||
device = self.patch_embedding.weight.device
|
device = self.patch_embedding.weight.device
|
||||||
if torch_musa is None:
|
if torch_musa is not None:
|
||||||
if self.freqs.dtype != dtype or self.freqs.device != device:
|
|
||||||
self.freqs = self.freqs.to(dtype=dtype, device=device)
|
|
||||||
else:
|
|
||||||
if self.freqs[0].dtype != dtype or self.freqs[0].device != device:
|
if self.freqs[0].dtype != dtype or self.freqs[0].device != device:
|
||||||
self.freqs = (
|
self.freqs = (
|
||||||
self.freqs[0].to(dtype=dtype, device=device),
|
self.freqs[0].to(dtype=dtype, device=device),
|
||||||
self.freqs[-1].to(dtype=dtype, device=device),
|
self.freqs[-1].to(dtype=dtype, device=device)
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
if self.freqs.dtype != dtype or self.freqs.device != device:
|
||||||
|
self.freqs = self.freqs.to(dtype=dtype, device=device)
|
||||||
|
|
||||||
if y is not None:
|
if y is not None:
|
||||||
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
||||||
@ -586,9 +668,11 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
])
|
])
|
||||||
|
|
||||||
# time embeddings
|
# time embeddings
|
||||||
|
with amp.autocast(dtype=torch.float32):
|
||||||
e = self.time_embedding(
|
e = self.time_embedding(
|
||||||
sinusoidal_embedding_1d(self.freq_dim, t))
|
sinusoidal_embedding_1d(self.freq_dim, t).float())
|
||||||
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
||||||
|
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
||||||
|
|
||||||
# context
|
# context
|
||||||
context_lens = None
|
context_lens = None
|
||||||
@ -620,7 +704,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
|
|
||||||
# unpatchify
|
# unpatchify
|
||||||
x = self.unpatchify(x, grid_sizes)
|
x = self.unpatchify(x, grid_sizes)
|
||||||
return x
|
return [u.float() for u in x]
|
||||||
|
|
||||||
def unpatchify(self, x, grid_sizes):
|
def unpatchify(self, x, grid_sizes):
|
||||||
r"""
|
r"""
|
||||||
|
|||||||
@ -1,317 +0,0 @@
|
|||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def pad_tensor(
|
|
||||||
original_tensor: torch.tensor, target_len: int, pad_value: float = 0.0
|
|
||||||
) -> torch.tensor:
|
|
||||||
seq_len, s1, s2 = original_tensor.shape
|
|
||||||
pad_size = target_len - seq_len
|
|
||||||
padding_tensor = torch.full(
|
|
||||||
(pad_size, s1, s2),
|
|
||||||
pad_value,
|
|
||||||
dtype=original_tensor.dtype,
|
|
||||||
device=original_tensor.device,
|
|
||||||
)
|
|
||||||
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
|
|
||||||
return padded_tensor
|
|
||||||
|
|
||||||
|
|
||||||
def rope_apply_pytorch(
|
|
||||||
x: torch.tensor,
|
|
||||||
grid_sizes: torch.tensor,
|
|
||||||
freqs: Tuple[torch.tensor],
|
|
||||||
sp_size: Optional[int] = None,
|
|
||||||
sp_rank: Optional[int] = None,
|
|
||||||
) -> torch.tensor:
|
|
||||||
"""
|
|
||||||
x: [B, L, N, C].
|
|
||||||
grid_sizes: [B, 3].
|
|
||||||
freqs: [M, C // 2].
|
|
||||||
"""
|
|
||||||
s, n, c = x.size(1), x.size(2), x.size(3) // 2
|
|
||||||
c0 = c - 2 * (c // 3)
|
|
||||||
c1 = c // 3
|
|
||||||
c2 = c // 3
|
|
||||||
|
|
||||||
# split freqs
|
|
||||||
freqs_real = freqs[0].split([c0, c1, c2], dim=1)
|
|
||||||
freqs_imag = freqs[-1].split([c0, c1, c2], dim=1)
|
|
||||||
|
|
||||||
# loop over samples
|
|
||||||
output = []
|
|
||||||
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
|
||||||
seq_len = f * h * w
|
|
||||||
|
|
||||||
# precompute multipliers
|
|
||||||
x_i = x[i, :seq_len].reshape(s, n, -1, 2)
|
|
||||||
x_real = x_i[..., 0]
|
|
||||||
x_imag = x_i[..., 1]
|
|
||||||
freqs_real = torch.cat(
|
|
||||||
[
|
|
||||||
freqs_real[0][:f].view(f, 1, 1, c0).expand(f, h, w, c0),
|
|
||||||
freqs_real[1][:h].view(1, h, 1, c1).expand(f, h, w, c1),
|
|
||||||
freqs_real[2][:w].view(1, 1, w, c2).expand(f, h, w, c2),
|
|
||||||
],
|
|
||||||
dim=-1,
|
|
||||||
).reshape(seq_len, 1, -1)
|
|
||||||
freqs_imag = torch.cat(
|
|
||||||
[
|
|
||||||
freqs_imag[0][:f].view(f, 1, 1, c0).expand(f, h, w, c0),
|
|
||||||
freqs_imag[1][:h].view(1, h, 1, c1).expand(f, h, w, c1),
|
|
||||||
freqs_imag[2][:w].view(1, 1, w, c2).expand(f, h, w, c2),
|
|
||||||
],
|
|
||||||
dim=-1,
|
|
||||||
).reshape(seq_len, 1, -1)
|
|
||||||
|
|
||||||
if sp_rank is None:
|
|
||||||
freqs_real_rank = freqs_real
|
|
||||||
freqs_imag_rank = freqs_imag
|
|
||||||
else:
|
|
||||||
freqs_real = pad_tensor(freqs_real, s * sp_size, 1.0)
|
|
||||||
freqs_imag = pad_tensor(freqs_imag, s * sp_size, 0.0)
|
|
||||||
freqs_real_rank = freqs_real[(sp_rank * s) : ((sp_rank + 1) * s), :, :]
|
|
||||||
freqs_imag_rank = freqs_imag[(sp_rank * s) : ((sp_rank + 1) * s), :, :]
|
|
||||||
|
|
||||||
out_real = x_real * freqs_real_rank - x_imag * freqs_imag_rank
|
|
||||||
out_imag = x_real * freqs_imag_rank + x_imag * freqs_real_rank
|
|
||||||
|
|
||||||
x_out = torch.stack([out_real, out_imag], dim=-1).flatten(2)
|
|
||||||
x_out = torch.cat([x_out, x[i, seq_len:]], dim=0)
|
|
||||||
|
|
||||||
# append to collection
|
|
||||||
output.append(x_out)
|
|
||||||
|
|
||||||
return torch.stack(output)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def rope_kernel(
|
|
||||||
x_ptr, # [B, S, N, 2C]
|
|
||||||
grid_sizes_ptr, # [B, 3]
|
|
||||||
freqs_real_ptr, # [M, C]
|
|
||||||
freqs_imag_ptr, # [M, C]
|
|
||||||
output_ptr, # [B, S, N, 2C]
|
|
||||||
sp_size, # SP world size
|
|
||||||
sp_rank, # SP rank
|
|
||||||
B,
|
|
||||||
S,
|
|
||||||
N: tl.constexpr,
|
|
||||||
C: tl.constexpr,
|
|
||||||
M: tl.constexpr,
|
|
||||||
CfM: tl.constexpr,
|
|
||||||
ChM: tl.constexpr,
|
|
||||||
CwM: tl.constexpr,
|
|
||||||
SEQ_BLOCK: tl.constexpr,
|
|
||||||
HEADS_BLOCK: tl.constexpr,
|
|
||||||
):
|
|
||||||
Cf = C - 2 * (C // 3)
|
|
||||||
Ch = C // 3
|
|
||||||
Cw = C // 3
|
|
||||||
|
|
||||||
batch_idx = tl.program_id(0)
|
|
||||||
seqlen_group_idx = tl.program_id(1)
|
|
||||||
head_group_idx = tl.program_id(2)
|
|
||||||
|
|
||||||
base = batch_idx * 3
|
|
||||||
F = tl.load(grid_sizes_ptr + base + 0)
|
|
||||||
H = tl.load(grid_sizes_ptr + base + 1)
|
|
||||||
W = tl.load(grid_sizes_ptr + base + 2)
|
|
||||||
seq_len = F * H * W
|
|
||||||
|
|
||||||
global_offset = sp_rank * S + seqlen_group_idx * SEQ_BLOCK
|
|
||||||
seq_indices = global_offset + tl.arange(0, SEQ_BLOCK)
|
|
||||||
|
|
||||||
limit = tl.minimum(seq_len, S * sp_size)
|
|
||||||
seq_mask = seq_indices < limit
|
|
||||||
seq_indices = tl.where(seq_mask, seq_indices, 0)
|
|
||||||
|
|
||||||
HW = H * W
|
|
||||||
f_idx = seq_indices // HW
|
|
||||||
rem = seq_indices - f_idx * HW
|
|
||||||
h_idx = rem // W
|
|
||||||
w_idx = rem - h_idx * W
|
|
||||||
|
|
||||||
freq_offset_cf = tl.arange(0, CfM) # 第1段列偏移 [0, Cf)
|
|
||||||
freq_offset_ch = Cf + tl.arange(0, ChM) # 第2段列偏移 [Cf, Cf+Ch)
|
|
||||||
freq_offset_cw = Cf + Ch + tl.arange(0, CwM) # 第3段列偏移 [Cf+Ch, C)
|
|
||||||
# 按照每个序列位置取对应频率值 (利用广播计算每个位置不同行的值)
|
|
||||||
# 频率表取值地址 = idx * C + col_offset
|
|
||||||
freq_addr_cf = f_idx[:, None] * C + freq_offset_cf[None, :]
|
|
||||||
freq_addr_ch = h_idx[:, None] * C + freq_offset_ch[None, :]
|
|
||||||
freq_addr_cw = w_idx[:, None] * C + freq_offset_cw[None, :]
|
|
||||||
|
|
||||||
freqs_real_cf = tl.load(
|
|
||||||
freqs_real_ptr + freq_addr_cf,
|
|
||||||
mask=(seq_mask[:, None] & (freq_offset_cf[None, :] < Cf)),
|
|
||||||
other=1.0,
|
|
||||||
).to(tl.float32)
|
|
||||||
freqs_imag_cf = tl.load(
|
|
||||||
freqs_imag_ptr + freq_addr_cf,
|
|
||||||
mask=(seq_mask[:, None] & (freq_offset_cf[None, :] < Cf)),
|
|
||||||
other=1.0,
|
|
||||||
).to(tl.float32)
|
|
||||||
freqs_real_ch = tl.load(
|
|
||||||
freqs_real_ptr + freq_addr_ch,
|
|
||||||
mask=(seq_mask[:, None] & (freq_offset_ch[None, :] < Cf + Ch)),
|
|
||||||
other=1.0,
|
|
||||||
).to(tl.float32)
|
|
||||||
freqs_imag_ch = tl.load(
|
|
||||||
freqs_imag_ptr + freq_addr_ch,
|
|
||||||
mask=(seq_mask[:, None] & (freq_offset_ch[None, :] < Cf + Ch)),
|
|
||||||
other=1.0,
|
|
||||||
).to(tl.float32)
|
|
||||||
freqs_real_cw = tl.load(
|
|
||||||
freqs_real_ptr + freq_addr_cw,
|
|
||||||
mask=(seq_mask[:, None] & (freq_offset_cw[None, :] < C)),
|
|
||||||
other=1.0,
|
|
||||||
).to(tl.float32)
|
|
||||||
freqs_imag_cw = tl.load(
|
|
||||||
freqs_imag_ptr + freq_addr_cw,
|
|
||||||
mask=(seq_mask[:, None] & (freq_offset_cw[None, :] < C)),
|
|
||||||
other=1.0,
|
|
||||||
).to(tl.float32)
|
|
||||||
|
|
||||||
# 将频率值扩展维度以便与x相乘 (在head维度上广播)
|
|
||||||
freqs_real_cf = freqs_real_cf[:, None, :] # [SEQ_BLOCK, 1, Cf]
|
|
||||||
freqs_imag_cf = freqs_imag_cf[:, None, :]
|
|
||||||
freqs_real_ch = freqs_real_ch[:, None, :]
|
|
||||||
freqs_imag_ch = freqs_imag_ch[:, None, :]
|
|
||||||
freqs_real_cw = freqs_real_cw[:, None, :]
|
|
||||||
freqs_imag_cw = freqs_imag_cw[:, None, :]
|
|
||||||
|
|
||||||
# 加载输入x对应块的实部和虚部 (形状: [SEQ_BLOCK, HEADS_BLOCK, C])
|
|
||||||
seq_offset = seqlen_group_idx * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK)
|
|
||||||
head_offset = head_group_idx * HEADS_BLOCK + tl.arange(0, HEADS_BLOCK)
|
|
||||||
# 计算x_ptr偏移地址
|
|
||||||
base_offset = batch_idx * S * N * 2 * C
|
|
||||||
seq_head_offset = (
|
|
||||||
base_offset
|
|
||||||
+ seq_offset[:, None, None] * (N * 2 * C)
|
|
||||||
+ head_offset[None, :, None] * (2 * C)
|
|
||||||
)
|
|
||||||
x_mask = (seq_offset < S)[:, None, None] & (head_offset < N)[None, :, None]
|
|
||||||
|
|
||||||
# 加载输入 x 的对应通道段数据,超出实际长度部分掩码为0
|
|
||||||
# 段1:通道 [0, Cf-1]
|
|
||||||
chan_cf = tl.arange(0, CfM * 2)
|
|
||||||
mask_2cf_chan = chan_cf < Cf * 2
|
|
||||||
x_cf = tl.load(
|
|
||||||
x_ptr + seq_head_offset + chan_cf[None, None, :],
|
|
||||||
mask=(x_mask & mask_2cf_chan[None, None, :]),
|
|
||||||
other=0.0,
|
|
||||||
).to(tl.float32)
|
|
||||||
x_cf = x_cf.reshape(
|
|
||||||
SEQ_BLOCK, HEADS_BLOCK, CfM, 2
|
|
||||||
) # [SEQ_BLOCK, HEADS_BLOCK, CfM, 2]
|
|
||||||
x_real_cf, x_imag_cf = x_cf.split()
|
|
||||||
|
|
||||||
# 计算 RoPE 旋转(段1)
|
|
||||||
out_real_cf = x_real_cf * freqs_real_cf - x_imag_cf * freqs_imag_cf
|
|
||||||
out_imag_cf = x_real_cf * freqs_imag_cf + x_imag_cf * freqs_real_cf
|
|
||||||
|
|
||||||
out_cf = tl.interleave(out_real_cf, out_imag_cf) # [SEQ_BLOCK, HEADS_BLOCK, CfM, 2]
|
|
||||||
tl.store(
|
|
||||||
output_ptr + seq_head_offset + chan_cf[None, None, :],
|
|
||||||
out_cf,
|
|
||||||
mask=(x_mask & mask_2cf_chan[None, None, :]),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 段2:通道 [Cf, Cf+Ch-1]
|
|
||||||
chan_ch = tl.arange(0, ChM * 2) + Cf * 2
|
|
||||||
mask_2ch_chan = chan_ch < 2 * (Cf + Ch)
|
|
||||||
x_ch = tl.load(
|
|
||||||
x_ptr + seq_head_offset + chan_ch[None, None, :],
|
|
||||||
mask=(x_mask & mask_2ch_chan[None, None, :]),
|
|
||||||
other=0.0,
|
|
||||||
).to(tl.float32)
|
|
||||||
x_ch = x_ch.reshape(SEQ_BLOCK, HEADS_BLOCK, ChM, 2)
|
|
||||||
x_real_ch, x_imag_ch = x_ch.split()
|
|
||||||
out_real_ch = x_real_ch * freqs_real_ch - x_imag_ch * freqs_imag_ch
|
|
||||||
out_imag_ch = x_real_ch * freqs_imag_ch + x_imag_ch * freqs_real_ch
|
|
||||||
|
|
||||||
out_ch = tl.interleave(out_real_ch, out_imag_ch) # [SEQ_BLOCK, HEADS_BLOCK, ChM, 2]
|
|
||||||
tl.store(
|
|
||||||
output_ptr + seq_head_offset + chan_ch[None, None, :],
|
|
||||||
out_ch,
|
|
||||||
mask=(x_mask & mask_2ch_chan[None, None, :]),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 段3:通道 [Cf+Ch, C-1]
|
|
||||||
chan_cw = tl.arange(0, CwM * 2) + (Cf + Ch) * 2
|
|
||||||
mask_2cw_chan = chan_cw < 2 * C
|
|
||||||
x_cw = tl.load(
|
|
||||||
x_ptr + seq_head_offset + chan_cw[None, None, :],
|
|
||||||
mask=(x_mask & mask_2cw_chan[None, None, :]),
|
|
||||||
other=0.0,
|
|
||||||
).to(tl.float32)
|
|
||||||
x_cw = x_cw.reshape(SEQ_BLOCK, HEADS_BLOCK, CwM, 2)
|
|
||||||
x_real_cw, x_imag_cw = x_cw.split()
|
|
||||||
out_real_cw = x_real_cw * freqs_real_cw - x_imag_cw * freqs_imag_cw
|
|
||||||
out_imag_cw = x_real_cw * freqs_imag_cw + x_imag_cw * freqs_real_cw
|
|
||||||
|
|
||||||
out_cw = tl.interleave(out_real_cw, out_imag_cw)
|
|
||||||
tl.store(
|
|
||||||
output_ptr + seq_head_offset + chan_cw[None, None, :],
|
|
||||||
out_cw,
|
|
||||||
mask=(x_mask & mask_2cw_chan[None, None, :]),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@torch._dynamo.disable
|
|
||||||
def rope_apply_triton(
|
|
||||||
x: torch.tensor,
|
|
||||||
grid_sizes: torch.tensor,
|
|
||||||
freqs: Tuple[torch.tensor],
|
|
||||||
sp_size: Optional[int] = None,
|
|
||||||
sp_rank: Optional[int] = None,
|
|
||||||
) -> torch.tensor:
|
|
||||||
"""
|
|
||||||
x: [1, 9450, 40, 128]
|
|
||||||
grid_sizes: [[21, 45, 80]]
|
|
||||||
freqs_real: [1024, 64]
|
|
||||||
freqs_imag: [1024, 64]
|
|
||||||
"""
|
|
||||||
B, S, N, C = x.shape
|
|
||||||
C = C // 2
|
|
||||||
Cf = C - 2 * (C // 3) # 第一维度频率长度
|
|
||||||
Ch = C // 3 # 第二维度频率长度
|
|
||||||
Cw = C // 3 # 第三维度频率长度
|
|
||||||
M = freqs[0].shape[0]
|
|
||||||
|
|
||||||
SEQ_BLOCK = 64 # 每个线程块处理的序列长度
|
|
||||||
HEADS_BLOCK = 8 # 每个线程块处理的头数
|
|
||||||
|
|
||||||
if sp_rank is None:
|
|
||||||
sp_size = 1
|
|
||||||
sp_rank = 0
|
|
||||||
|
|
||||||
grid_sizes = grid_sizes.to(device=x.device)
|
|
||||||
output = torch.empty_like(x)
|
|
||||||
|
|
||||||
rope_kernel[(B, triton.cdiv(S, SEQ_BLOCK), triton.cdiv(N, HEADS_BLOCK))](
|
|
||||||
x,
|
|
||||||
grid_sizes,
|
|
||||||
freqs[0],
|
|
||||||
freqs[-1],
|
|
||||||
output,
|
|
||||||
sp_size,
|
|
||||||
sp_rank,
|
|
||||||
B,
|
|
||||||
S,
|
|
||||||
N=N,
|
|
||||||
C=C,
|
|
||||||
M=M,
|
|
||||||
CfM=triton.next_power_of_2(Cf),
|
|
||||||
ChM=triton.next_power_of_2(Ch),
|
|
||||||
CwM=triton.next_power_of_2(Cw),
|
|
||||||
SEQ_BLOCK=SEQ_BLOCK,
|
|
||||||
HEADS_BLOCK=HEADS_BLOCK,
|
|
||||||
num_warps=32,
|
|
||||||
num_stages=3,
|
|
||||||
)
|
|
||||||
|
|
||||||
return output.float()
|
|
||||||
@ -6,14 +6,15 @@ import math
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch.cuda import current_device
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torch_musa
|
import torch_musa
|
||||||
|
from torch_musa.core.device import current_device
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
torch_musa = None
|
torch_musa = None
|
||||||
|
|
||||||
from wan.modules.tokenizers import HuggingfaceTokenizer
|
from .tokenizers import HuggingfaceTokenizer
|
||||||
from wan.utils.platform import get_device
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'T5Model',
|
'T5Model',
|
||||||
@ -65,8 +66,10 @@ class T5LayerNorm(nn.Module):
|
|||||||
self.weight = nn.Parameter(torch.ones(dim))
|
self.weight = nn.Parameter(torch.ones(dim))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) +
|
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
|
||||||
self.eps)
|
self.eps)
|
||||||
|
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||||
|
x = x.type_as(self.weight)
|
||||||
return self.weight * x
|
return self.weight * x
|
||||||
|
|
||||||
|
|
||||||
@ -114,7 +117,7 @@ class T5Attention(nn.Module):
|
|||||||
|
|
||||||
# compute attention (T5 does not use scaling)
|
# compute attention (T5 does not use scaling)
|
||||||
attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
|
attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
|
||||||
attn = F.softmax(attn, dim=-1)
|
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
|
||||||
x = torch.einsum('bnij,bjnc->binc', attn, v)
|
x = torch.einsum('bnij,bjnc->binc', attn, v)
|
||||||
|
|
||||||
# output
|
# output
|
||||||
@ -259,7 +262,7 @@ class T5RelativeEmbedding(nn.Module):
|
|||||||
|
|
||||||
# embeddings for small and large positions
|
# embeddings for small and large positions
|
||||||
max_exact = num_buckets // 2
|
max_exact = num_buckets // 2
|
||||||
rel_pos_large = max_exact + (torch.log(rel_pos / max_exact) /
|
rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
|
||||||
math.log(self.max_dist / max_exact) *
|
math.log(self.max_dist / max_exact) *
|
||||||
(num_buckets - max_exact)).long()
|
(num_buckets - max_exact)).long()
|
||||||
rel_pos_large = torch.min(
|
rel_pos_large = torch.min(
|
||||||
@ -479,7 +482,7 @@ class T5EncoderModel:
|
|||||||
self,
|
self,
|
||||||
text_len,
|
text_len,
|
||||||
dtype=torch.bfloat16,
|
dtype=torch.bfloat16,
|
||||||
device=get_device(),
|
device=current_device(),
|
||||||
checkpoint_path=None,
|
checkpoint_path=None,
|
||||||
tokenizer_path=None,
|
tokenizer_path=None,
|
||||||
shard_fn=None,
|
shard_fn=None,
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
import logging
|
import logging
|
||||||
from math import sqrt
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.cuda.amp as amp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn import Upsample
|
from torch.nn import Upsample
|
||||||
@ -10,11 +10,11 @@ from einops import rearrange
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import torch_musa
|
import torch_musa
|
||||||
torch.backends.mudnn.allow_tf32 = True
|
import torch_musa.core.amp as amp
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
torch_musa = None
|
torch_musa = None
|
||||||
|
|
||||||
from wan.utils.platform import get_device_type
|
from wan.utils.platform import get_device
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'WanVAE',
|
'WanVAE',
|
||||||
@ -53,17 +53,23 @@ class RMS_norm(nn.Module):
|
|||||||
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
||||||
|
|
||||||
self.channel_first = channel_first
|
self.channel_first = channel_first
|
||||||
self.scale = sqrt(dim)
|
self.scale = dim**0.5
|
||||||
self.gamma = nn.Parameter(torch.ones(shape))
|
self.gamma = nn.Parameter(torch.ones(shape))
|
||||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
|
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return (
|
return F.normalize(
|
||||||
F.normalize(x.float(), dim=(1 if self.channel_first else -1)).type_as(x)
|
x, dim=(1 if self.channel_first else
|
||||||
* self.scale
|
-1)) * self.scale * self.gamma + self.bias
|
||||||
* self.gamma
|
|
||||||
+ self.bias
|
|
||||||
)
|
class Upsample(nn.Upsample):
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Fix bfloat16 support for nearest neighbor interpolation.
|
||||||
|
"""
|
||||||
|
return super().forward(x.float()).type_as(x)
|
||||||
|
|
||||||
|
|
||||||
class Resample(nn.Module):
|
class Resample(nn.Module):
|
||||||
@ -256,10 +262,6 @@ class AttentionBlock(nn.Module):
|
|||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
attn_mask=None,
|
|
||||||
dropout_p=0.0,
|
|
||||||
is_causal=False,
|
|
||||||
scale=None,
|
|
||||||
)
|
)
|
||||||
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
|
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
|
||||||
|
|
||||||
@ -628,8 +630,8 @@ class WanVAE:
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
z_dim=16,
|
z_dim=16,
|
||||||
vae_pth='cache/vae_step_411000.pth',
|
vae_pth='cache/vae_step_411000.pth',
|
||||||
dtype=torch.bfloat16,
|
dtype=torch.float,
|
||||||
device=get_device_type()):
|
device=get_device()):
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
@ -655,12 +657,16 @@ class WanVAE:
|
|||||||
"""
|
"""
|
||||||
videos: A list of videos each with shape [C, T, H, W].
|
videos: A list of videos each with shape [C, T, H, W].
|
||||||
"""
|
"""
|
||||||
|
with amp.autocast(dtype=self.dtype):
|
||||||
return [
|
return [
|
||||||
self.model.encode(u.unsqueeze(0), self.scale).squeeze(0) for u in videos
|
self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
|
||||||
|
for u in videos
|
||||||
]
|
]
|
||||||
|
|
||||||
def decode(self, zs):
|
def decode(self, zs):
|
||||||
|
with amp.autocast(dtype=self.dtype):
|
||||||
return [
|
return [
|
||||||
self.model.decode(u.unsqueeze(0), self.scale).clamp_(-1, 1).squeeze(0)
|
self.model.decode(u.unsqueeze(0),
|
||||||
|
self.scale).float().clamp_(-1, 1).squeeze(0)
|
||||||
for u in zs
|
for u in zs
|
||||||
]
|
]
|
||||||
|
|||||||
@ -20,23 +20,20 @@ try:
|
|||||||
import torch_musa.core.amp as amp
|
import torch_musa.core.amp as amp
|
||||||
from torch_musa.core.memory import empty_cache
|
from torch_musa.core.memory import empty_cache
|
||||||
from torch_musa.core.device import synchronize
|
from torch_musa.core.device import synchronize
|
||||||
torch.backends.mudnn.allow_tf32 = True
|
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
torch_musa = None
|
torch_musa = None
|
||||||
|
|
||||||
from wan.distributed.fsdp import shard_model
|
from .distributed.fsdp import shard_model
|
||||||
from wan.modules.model import WanModel
|
from .modules.model import WanModel
|
||||||
from wan.modules.t5 import T5EncoderModel
|
from .modules.t5 import T5EncoderModel
|
||||||
from wan.modules.vae import WanVAE
|
from .modules.vae import WanVAE
|
||||||
from wan.utils.fm_solvers import (
|
from .utils.fm_solvers import (
|
||||||
FlowDPMSolverMultistepScheduler,
|
FlowDPMSolverMultistepScheduler,
|
||||||
get_sampling_sigmas,
|
get_sampling_sigmas,
|
||||||
retrieve_timesteps,
|
retrieve_timesteps,
|
||||||
)
|
)
|
||||||
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||||
from wan.utils.platform import get_device
|
from .utils.platform import get_device
|
||||||
from wan.utils.memory_format import convert_conv3d_weight_memory_format
|
|
||||||
|
|
||||||
|
|
||||||
class WanT2V:
|
class WanT2V:
|
||||||
|
|
||||||
@ -94,7 +91,6 @@ class WanT2V:
|
|||||||
self.vae = WanVAE(
|
self.vae = WanVAE(
|
||||||
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
|
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
|
||||||
device=self.device)
|
device=self.device)
|
||||||
convert_conv3d_weight_memory_format(self.vae.model, memory_format=torch.channels_last_3d)
|
|
||||||
|
|
||||||
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
||||||
self.model = WanModel.from_pretrained(checkpoint_dir)
|
self.model = WanModel.from_pretrained(checkpoint_dir)
|
||||||
@ -116,8 +112,7 @@ class WanT2V:
|
|||||||
self.sp_size = 1
|
self.sp_size = 1
|
||||||
|
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
# dist.barrier()
|
dist.barrier()
|
||||||
pass
|
|
||||||
if dit_fsdp:
|
if dit_fsdp:
|
||||||
self.model = shard_fn(self.model)
|
self.model = shard_fn(self.model)
|
||||||
else:
|
else:
|
||||||
@ -244,13 +239,13 @@ class WanT2V:
|
|||||||
arg_c = {'context': context, 'seq_len': seq_len}
|
arg_c = {'context': context, 'seq_len': seq_len}
|
||||||
arg_null = {'context': context_null, 'seq_len': seq_len}
|
arg_null = {'context': context_null, 'seq_len': seq_len}
|
||||||
|
|
||||||
self.model.to(self.device)
|
|
||||||
for _, t in enumerate(tqdm(timesteps)):
|
for _, t in enumerate(tqdm(timesteps)):
|
||||||
latent_model_input = latents
|
latent_model_input = latents
|
||||||
timestep = [t]
|
timestep = [t]
|
||||||
|
|
||||||
timestep = torch.stack(timestep)
|
timestep = torch.stack(timestep)
|
||||||
|
|
||||||
|
self.model.to(self.device)
|
||||||
noise_pred_cond = self.model(
|
noise_pred_cond = self.model(
|
||||||
latent_model_input, t=timestep, **arg_c)[0]
|
latent_model_input, t=timestep, **arg_c)[0]
|
||||||
noise_pred_uncond = self.model(
|
noise_pred_uncond = self.model(
|
||||||
@ -280,7 +275,6 @@ class WanT2V:
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
synchronize()
|
synchronize()
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
# dist.barrier()
|
dist.barrier()
|
||||||
pass
|
|
||||||
|
|
||||||
return videos[0] if self.rank == 0 else None
|
return videos[0] if self.rank == 0 else None
|
||||||
|
|||||||
@ -5,13 +5,10 @@ from .fm_solvers import (
|
|||||||
)
|
)
|
||||||
from .fm_solvers_unipc import FlowUniPCMultistepScheduler
|
from .fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||||
from .vace_processor import VaceVideoProcessor
|
from .vace_processor import VaceVideoProcessor
|
||||||
from .platform import get_device, get_device_type, get_torch_distributed_backend
|
from .platform import get_device, get_torch_distributed_backend
|
||||||
from .memory_format import convert_conv3d_weight_memory_format
|
|
||||||
from .chrono_inspector import ChronoInspector
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps',
|
'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps',
|
||||||
'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler',
|
'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler',
|
||||||
'VaceVideoProcessor', 'get_device', 'get_device_type', 'get_torch_distributed_backend',
|
'VaceVideoProcessor', 'get_device', 'get_torch_distributed_backend'
|
||||||
'convert_conv3d_weight_memory_format', 'ChronoInspector'
|
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,15 +0,0 @@
|
|||||||
from time import perf_counter
|
|
||||||
from logging import info
|
|
||||||
|
|
||||||
|
|
||||||
class ChronoInspector(object):
|
|
||||||
def __init__(self, name:str="Block"):
|
|
||||||
self.name = name
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
self.start_time:float = perf_counter()
|
|
||||||
return self # 可选:返回 self 以获取更多信息
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
end_time:float = perf_counter()
|
|
||||||
info(f"[{self.name}] Elapsed time: {end_time - self.start_time:.2f} seconds")
|
|
||||||
@ -1,76 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def convert_conv3d_weight_memory_format(module:torch.nn.Module, memory_format:torch.memory_format):
|
|
||||||
r"""Convert ``memory_format`` of ``nn.Conv3d.weight`` to ``memory_format``
|
|
||||||
The conversion recursively applies to nested ``nn.Module``, including ``module``.
|
|
||||||
Note that it only changes the memory_format, but not the semantics of each dimensions.
|
|
||||||
This function is used to facilitate the computation to adopt NHWC kernels, which
|
|
||||||
provides considerable speed up for fp16 data on CUDA devices with compute capability >= 7.0
|
|
||||||
|
|
||||||
.. note::
|
|
||||||
Calling ``model.to(memory_format=torch.channels_last_3d)`` is more aggressive
|
|
||||||
than the utility function ``convert_conv3d_weight_memory_format``. Any
|
|
||||||
layer with 4d weight will be affected by ``model.to``, which does not
|
|
||||||
necessarily benefit from conversion to specified ``memory_format``.
|
|
||||||
One place we are confident in is that NDHWC(channels_last_3d) conversion for
|
|
||||||
convolution in cuDNN, as it is beneficial to run convolution in NDHWC,
|
|
||||||
even in cases where we have to apply permutation to input tensors.
|
|
||||||
|
|
||||||
Hence our strategy here is to convert only the weight of convolution to
|
|
||||||
channels_last_3d. This ensures that;
|
|
||||||
1. Fast convolution kernels will be used, the benefit of which could
|
|
||||||
outweigh overhead of permutation (if input is not in the same format).
|
|
||||||
2. No unnecessary permutations are applied on layers that do not benefit
|
|
||||||
from memory_format conversion.
|
|
||||||
|
|
||||||
The optimal case is that, layers between convolution layers are channels
|
|
||||||
last compatible. Input tensor would be permuted to channels last when it
|
|
||||||
encounters the first convolution layer and stay in that memory format.
|
|
||||||
Hence following convolutions will not need to permute its input tensor.
|
|
||||||
|
|
||||||
In case where a channels last incompatible layer is between convolution
|
|
||||||
layers, we need to permute the input tensor back to contiguous format
|
|
||||||
for that layer. The input tensor will go through the remaining layers in
|
|
||||||
contiguous format and be permuted to channels last when it encounters
|
|
||||||
another convolution layer. There's no point in propagating that
|
|
||||||
permutation to an earlier layer, as most layers are quite agnostic to
|
|
||||||
``memory_format``.
|
|
||||||
|
|
||||||
This claim might change when PyTorch supports fusion of permutation, as
|
|
||||||
there might have been a better spot to fuse the permutation other than
|
|
||||||
immediately before a convolution.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
module (nn.Module): ``nn.Conv3d`` & ``nn.ConvTranspose3d`` or container
|
|
||||||
``nn.Module``
|
|
||||||
memory_format: user specified ``memory_format``,
|
|
||||||
e.g. ``torch.channels_last`` or ``torch.contiguous_format``
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The original module with updated ``nn.Conv3d``
|
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
|
||||||
>>> # xdoctest: +REQUIRES(env:CUBLAS_WORKSPACE_CONFIG)
|
|
||||||
>>> input = torch.randint(1, 10, (2, 8, 4, 4, 4), dtype=torch.float16, device="cuda")
|
|
||||||
>>> model = nn.Sequential(
|
|
||||||
>>> nn.Conv3d(8, 4, 3)).cuda().half()
|
|
||||||
>>> # This is identical to:
|
|
||||||
>>> # nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d)
|
|
||||||
>>> model = nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d)
|
|
||||||
>>> out = model(input)
|
|
||||||
"""
|
|
||||||
|
|
||||||
# TODO: expand this to `_ConvNd` when channels_last support is extended
|
|
||||||
# beyond only 4d tensors.
|
|
||||||
if isinstance(module, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
|
|
||||||
weight_data = (
|
|
||||||
module.weight.detach().clone().contiguous(memory_format=memory_format)
|
|
||||||
)
|
|
||||||
module.weight.data = weight_data.resize_(
|
|
||||||
weight_data.size(), memory_format=memory_format
|
|
||||||
)
|
|
||||||
for child in module.children():
|
|
||||||
convert_conv3d_weight_memory_format(child, memory_format)
|
|
||||||
return module
|
|
||||||
@ -9,38 +9,22 @@ except ModuleNotFoundError:
|
|||||||
|
|
||||||
|
|
||||||
def _is_musa():
|
def _is_musa():
|
||||||
if torch_musa is None:
|
try:
|
||||||
return False
|
if torch.musa.is_available():
|
||||||
else:
|
|
||||||
return True
|
return True
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_device(local_rank:Optional[int]=None) -> torch.device:
|
def get_device(local_rank:Optional[int]=None) -> torch.device:
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
return (
|
return torch.cuda.current_device() if local_rank is None else torch.device("cuda", local_rank)
|
||||||
torch.cuda.current_device()
|
|
||||||
if local_rank is None
|
|
||||||
else torch.device("cuda", local_rank)
|
|
||||||
)
|
|
||||||
elif _is_musa():
|
elif _is_musa():
|
||||||
return (
|
return torch.musa.current_device() if local_rank is None else torch.device("musa", local_rank)
|
||||||
torch.musa.current_device()
|
|
||||||
if local_rank is None
|
|
||||||
else torch.device("musa", local_rank)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
|
||||||
def get_device_type() -> str:
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
return "cuda"
|
|
||||||
elif _is_musa():
|
|
||||||
return "musa"
|
|
||||||
else:
|
|
||||||
return "cpu"
|
|
||||||
|
|
||||||
|
|
||||||
def get_torch_distributed_backend() -> str:
|
def get_torch_distributed_backend() -> str:
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
return "nccl"
|
return "nccl"
|
||||||
|
|||||||
22
wan/vace.py
22
wan/vace.py
@ -26,12 +26,11 @@ try:
|
|||||||
import torch_musa.core.amp as amp
|
import torch_musa.core.amp as amp
|
||||||
from torch_musa.core.memory import empty_cache
|
from torch_musa.core.memory import empty_cache
|
||||||
from torch_musa.core.device import synchronize
|
from torch_musa.core.device import synchronize
|
||||||
torch.backends.mudnn.allow_tf32 = True
|
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
torch_musa = None
|
torch_musa = None
|
||||||
|
|
||||||
from wan.modules.vace_model import VaceWanModel
|
from .modules.vace_model import VaceWanModel
|
||||||
from wan.text2video import (
|
from .text2video import (
|
||||||
FlowDPMSolverMultistepScheduler,
|
FlowDPMSolverMultistepScheduler,
|
||||||
FlowUniPCMultistepScheduler,
|
FlowUniPCMultistepScheduler,
|
||||||
T5EncoderModel,
|
T5EncoderModel,
|
||||||
@ -41,9 +40,8 @@ from wan.text2video import (
|
|||||||
retrieve_timesteps,
|
retrieve_timesteps,
|
||||||
shard_model,
|
shard_model,
|
||||||
)
|
)
|
||||||
from wan.utils.vace_processor import VaceVideoProcessor
|
from .utils.vace_processor import VaceVideoProcessor
|
||||||
from wan.utils.platform import get_device, get_torch_distributed_backend
|
from .utils.platform import get_device, get_torch_distributed_backend
|
||||||
from wan.utils.memory_format import convert_conv3d_weight_memory_format
|
|
||||||
|
|
||||||
|
|
||||||
class WanVace(WanT2V):
|
class WanVace(WanT2V):
|
||||||
@ -102,7 +100,6 @@ class WanVace(WanT2V):
|
|||||||
self.vae = WanVAE(
|
self.vae = WanVAE(
|
||||||
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
|
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
|
||||||
device=self.device)
|
device=self.device)
|
||||||
convert_conv3d_weight_memory_format(self.vae.model, memory_format=torch.channels_last_3d)
|
|
||||||
|
|
||||||
logging.info(f"Creating VaceWanModel from {checkpoint_dir}")
|
logging.info(f"Creating VaceWanModel from {checkpoint_dir}")
|
||||||
self.model = VaceWanModel.from_pretrained(checkpoint_dir)
|
self.model = VaceWanModel.from_pretrained(checkpoint_dir)
|
||||||
@ -130,8 +127,7 @@ class WanVace(WanT2V):
|
|||||||
self.sp_size = 1
|
self.sp_size = 1
|
||||||
|
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
# dist.barrier()
|
dist.barrier()
|
||||||
pass
|
|
||||||
if dit_fsdp:
|
if dit_fsdp:
|
||||||
self.model = shard_fn(self.model)
|
self.model = shard_fn(self.model)
|
||||||
else:
|
else:
|
||||||
@ -484,8 +480,7 @@ class WanVace(WanT2V):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
synchronize()
|
synchronize()
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
# dist.barrier()
|
dist.barrier()
|
||||||
pass
|
|
||||||
|
|
||||||
return videos[0] if self.rank == 0 else None
|
return videos[0] if self.rank == 0 else None
|
||||||
|
|
||||||
@ -644,7 +639,7 @@ class WanVaceMP(WanVace):
|
|||||||
else:
|
else:
|
||||||
sp_size = 1
|
sp_size = 1
|
||||||
|
|
||||||
# dist.barrier()
|
dist.barrier()
|
||||||
model = shard_fn(model)
|
model = shard_fn(model)
|
||||||
sample_neg_prompt = self.config.sample_neg_prompt
|
sample_neg_prompt = self.config.sample_neg_prompt
|
||||||
|
|
||||||
@ -775,8 +770,7 @@ class WanVaceMP(WanVace):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
synchronize()
|
synchronize()
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
# dist.barrier()
|
dist.barrier()
|
||||||
pass
|
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
out_q.put(videos[0].cpu())
|
out_q.put(videos[0].cpu())
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user