mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-12-20 14:12:04 +00:00
Compare commits
2 Commits
618d94c564
...
43ac073411
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
43ac073411 | ||
|
|
447aa08620 |
12
generate.py
12
generate.py
@ -12,12 +12,20 @@ import random
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from torch.cuda import set_device
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch_musa
|
||||||
|
from torch_musa.core.device import set_device
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
torch_musa = None
|
||||||
|
|
||||||
import wan
|
import wan
|
||||||
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
|
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
|
||||||
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
||||||
from wan.utils.utils import cache_image, cache_video, str2bool
|
from wan.utils.utils import cache_image, cache_video, str2bool
|
||||||
|
from wan.utils.platform import get_torch_distributed_backend
|
||||||
|
|
||||||
|
|
||||||
EXAMPLE_PROMPT = {
|
EXAMPLE_PROMPT = {
|
||||||
@ -275,9 +283,9 @@ def generate(args):
|
|||||||
logging.info(
|
logging.info(
|
||||||
f"offload_model is not specified, set to {args.offload_model}.")
|
f"offload_model is not specified, set to {args.offload_model}.")
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
torch.cuda.set_device(local_rank)
|
set_device(local_rank)
|
||||||
dist.init_process_group(
|
dist.init_process_group(
|
||||||
backend="nccl",
|
backend=get_torch_distributed_backend(),
|
||||||
init_method="env://",
|
init_method="env://",
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size)
|
world_size=world_size)
|
||||||
|
|||||||
@ -3,11 +3,17 @@ import gc
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch.cuda import empty_cache
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
|
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
|
||||||
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
|
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
|
||||||
from torch.distributed.utils import _free_storage
|
from torch.distributed.utils import _free_storage
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch_musa
|
||||||
|
from torch_musa.core.memory import empty_cache
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
torch_musa = None
|
||||||
|
|
||||||
def shard_model(
|
def shard_model(
|
||||||
model,
|
model,
|
||||||
@ -40,4 +46,4 @@ def free_model(model):
|
|||||||
_free_storage(m._handle.flat_param.data)
|
_free_storage(m._handle.flat_param.data)
|
||||||
del model
|
del model
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
empty_cache()
|
||||||
|
|||||||
@ -6,7 +6,15 @@ from xfuser.core.distributed import (
|
|||||||
get_sequence_parallel_world_size,
|
get_sequence_parallel_world_size,
|
||||||
get_sp_group,
|
get_sp_group,
|
||||||
)
|
)
|
||||||
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
from xfuser.core.long_ctx_attention import xFuserLongContextAttention, AttnType
|
||||||
|
attn_type:AttnType = AttnType.FA
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch_musa
|
||||||
|
import torch_musa.core.amp as amp
|
||||||
|
attn_type = AttnType.TORCH
|
||||||
|
except ImportError:
|
||||||
|
torch_musa = None
|
||||||
|
|
||||||
from ..modules.model import sinusoidal_embedding_1d
|
from ..modules.model import sinusoidal_embedding_1d
|
||||||
|
|
||||||
@ -24,6 +32,19 @@ def pad_freqs(original_tensor, target_len):
|
|||||||
return padded_tensor
|
return padded_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def pad_tensor(original_tensor, target_len, pad_value=0.0):
|
||||||
|
seq_len, s1, s2 = original_tensor.shape
|
||||||
|
pad_size = target_len - seq_len
|
||||||
|
padding_tensor = torch.full(
|
||||||
|
(pad_size, s1, s2),
|
||||||
|
pad_value,
|
||||||
|
dtype=original_tensor.dtype,
|
||||||
|
device=original_tensor.device,
|
||||||
|
)
|
||||||
|
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
|
||||||
|
return padded_tensor
|
||||||
|
|
||||||
|
|
||||||
@amp.autocast(enabled=False)
|
@amp.autocast(enabled=False)
|
||||||
def rope_apply(x, grid_sizes, freqs):
|
def rope_apply(x, grid_sizes, freqs):
|
||||||
"""
|
"""
|
||||||
@ -65,6 +86,69 @@ def rope_apply(x, grid_sizes, freqs):
|
|||||||
return torch.stack(output).float()
|
return torch.stack(output).float()
|
||||||
|
|
||||||
|
|
||||||
|
@amp.autocast(enabled=False)
|
||||||
|
def rope_apply_musa(x, grid_sizes, freqs):
|
||||||
|
"""
|
||||||
|
x: [B, L, N, C].
|
||||||
|
grid_sizes: [B, 3].
|
||||||
|
freqs: [M, C // 2].
|
||||||
|
"""
|
||||||
|
s, n, c = x.size(1), x.size(2), x.size(3) // 2
|
||||||
|
c0 = c - 2 * (c // 3)
|
||||||
|
c1 = c // 3
|
||||||
|
c2 = c // 3
|
||||||
|
|
||||||
|
# split freqs
|
||||||
|
freqs_real = freqs[0].split([c0, c1, c2], dim=1)
|
||||||
|
freqs_imag = freqs[-1].split([c0, c1, c2], dim=1)
|
||||||
|
|
||||||
|
# loop over samples
|
||||||
|
output = []
|
||||||
|
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
||||||
|
seq_len = f * h * w
|
||||||
|
|
||||||
|
# precompute multipliers
|
||||||
|
x_i = x[i, :seq_len].reshape(s, n, -1, 2)
|
||||||
|
x_real = x_i[..., 0]
|
||||||
|
x_imag = x_i[..., 1]
|
||||||
|
freqs_real = torch.cat(
|
||||||
|
[
|
||||||
|
freqs_real[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||||
|
freqs_real[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||||
|
freqs_real[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
).reshape(seq_len, 1, -1)
|
||||||
|
freqs_imag = torch.cat(
|
||||||
|
[
|
||||||
|
freqs_imag[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||||
|
freqs_imag[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||||
|
freqs_imag[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
).reshape(seq_len, 1, -1)
|
||||||
|
|
||||||
|
# apply rotary embedding
|
||||||
|
sp_size = get_sequence_parallel_world_size()
|
||||||
|
sp_rank = get_sequence_parallel_rank()
|
||||||
|
|
||||||
|
freqs_real = pad_tensor(freqs_real, s * sp_size, 1.0)
|
||||||
|
freqs_imag = pad_tensor(freqs_imag, s * sp_size, 0.0)
|
||||||
|
|
||||||
|
freqs_real_rank = freqs_real[(sp_rank * s) : ((sp_rank + 1) * s), :, :]
|
||||||
|
freqs_imag_rank = freqs_imag[(sp_rank * s) : ((sp_rank + 1) * s), :, :]
|
||||||
|
|
||||||
|
out_real = x_real * freqs_real_rank - x_imag * freqs_imag_rank
|
||||||
|
out_imag = x_real * freqs_imag_rank + x_imag * freqs_real_rank
|
||||||
|
|
||||||
|
x_out = torch.stack([out_real, out_imag], dim=-1).flatten(2)
|
||||||
|
x_out = torch.cat([x_out, x[i, seq_len:]], dim=0)
|
||||||
|
|
||||||
|
# append to collection
|
||||||
|
output.append(x_out)
|
||||||
|
return torch.stack(output)
|
||||||
|
|
||||||
|
|
||||||
def usp_dit_forward_vace(self, x, vace_context, seq_len, kwargs):
|
def usp_dit_forward_vace(self, x, vace_context, seq_len, kwargs):
|
||||||
# embeddings
|
# embeddings
|
||||||
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
||||||
@ -109,9 +193,17 @@ def usp_dit_forward(
|
|||||||
if self.model_type == 'i2v':
|
if self.model_type == 'i2v':
|
||||||
assert clip_fea is not None and y is not None
|
assert clip_fea is not None and y is not None
|
||||||
# params
|
# params
|
||||||
|
dtype = self.patch_embedding.weight.dtype
|
||||||
device = self.patch_embedding.weight.device
|
device = self.patch_embedding.weight.device
|
||||||
if self.freqs.device != device:
|
if torch_musa is not None:
|
||||||
self.freqs = self.freqs.to(device)
|
if self.freqs[0].dtype != dtype or self.freqs[0].device != device:
|
||||||
|
self.freqs = (
|
||||||
|
self.freqs[0].to(dtype=dtype, device=device),
|
||||||
|
self.freqs[-1].to(dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if self.freqs.dtype != dtype or self.freqs.device != device:
|
||||||
|
self.freqs = self.freqs.to(dtype=dtype, device=device)
|
||||||
|
|
||||||
if self.model_type != 'vace' and y is not None:
|
if self.model_type != 'vace' and y is not None:
|
||||||
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
||||||
@ -200,8 +292,13 @@ def usp_attn_forward(self,
|
|||||||
return q, k, v
|
return q, k, v
|
||||||
|
|
||||||
q, k, v = qkv_fn(x)
|
q, k, v = qkv_fn(x)
|
||||||
q = rope_apply(q, grid_sizes, freqs)
|
|
||||||
k = rope_apply(k, grid_sizes, freqs)
|
if torch_musa is not None:
|
||||||
|
q = rope_apply_musa(q, grid_sizes, freqs)
|
||||||
|
k = rope_apply_musa(k, grid_sizes, freqs)
|
||||||
|
else:
|
||||||
|
q = rope_apply(q, grid_sizes, freqs)
|
||||||
|
k = rope_apply(k, grid_sizes, freqs)
|
||||||
|
|
||||||
# TODO: We should use unpaded q,k,v for attention.
|
# TODO: We should use unpaded q,k,v for attention.
|
||||||
# k_lens = seq_lens // get_sequence_parallel_world_size()
|
# k_lens = seq_lens // get_sequence_parallel_world_size()
|
||||||
@ -210,7 +307,7 @@ def usp_attn_forward(self,
|
|||||||
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
|
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
|
||||||
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
|
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
|
||||||
|
|
||||||
x = xFuserLongContextAttention()(
|
x = xFuserLongContextAttention(attn_type=attn_type)(
|
||||||
None,
|
None,
|
||||||
query=half(q),
|
query=half(q),
|
||||||
key=half(k),
|
key=half(k),
|
||||||
|
|||||||
@ -12,10 +12,19 @@ from functools import partial
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda.amp as amp
|
import torch.cuda.amp as amp
|
||||||
|
from torch.cuda import empty_cache, synchronize
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torchvision.transforms.functional as TF
|
import torchvision.transforms.functional as TF
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch_musa
|
||||||
|
import torch_musa.core.amp as amp
|
||||||
|
from torch_musa.core.memory import empty_cache
|
||||||
|
from torch_musa.core.device import synchronize
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
torch_musa = None
|
||||||
|
|
||||||
from .distributed.fsdp import shard_model
|
from .distributed.fsdp import shard_model
|
||||||
from .modules.clip import CLIPModel
|
from .modules.clip import CLIPModel
|
||||||
from .modules.model import WanModel
|
from .modules.model import WanModel
|
||||||
@ -27,6 +36,7 @@ from .utils.fm_solvers import (
|
|||||||
retrieve_timesteps,
|
retrieve_timesteps,
|
||||||
)
|
)
|
||||||
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||||
|
from .utils.platform import get_device
|
||||||
|
|
||||||
|
|
||||||
class WanFLF2V:
|
class WanFLF2V:
|
||||||
@ -66,7 +76,7 @@ class WanFLF2V:
|
|||||||
init_on_cpu (`bool`, *optional*, defaults to True):
|
init_on_cpu (`bool`, *optional*, defaults to True):
|
||||||
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
|
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
|
||||||
"""
|
"""
|
||||||
self.device = torch.device(f"cuda:{device_id}")
|
self.device = get_device(device_id)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.use_usp = use_usp
|
self.use_usp = use_usp
|
||||||
@ -323,7 +333,7 @@ class WanFLF2V:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if offload_model:
|
if offload_model:
|
||||||
torch.cuda.empty_cache()
|
empty_cache()
|
||||||
|
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
for _, t in enumerate(tqdm(timesteps)):
|
for _, t in enumerate(tqdm(timesteps)):
|
||||||
@ -336,12 +346,12 @@ class WanFLF2V:
|
|||||||
latent_model_input, t=timestep, **arg_c)[0].to(
|
latent_model_input, t=timestep, **arg_c)[0].to(
|
||||||
torch.device('cpu') if offload_model else self.device)
|
torch.device('cpu') if offload_model else self.device)
|
||||||
if offload_model:
|
if offload_model:
|
||||||
torch.cuda.empty_cache()
|
empty_cache()
|
||||||
noise_pred_uncond = self.model(
|
noise_pred_uncond = self.model(
|
||||||
latent_model_input, t=timestep, **arg_null)[0].to(
|
latent_model_input, t=timestep, **arg_null)[0].to(
|
||||||
torch.device('cpu') if offload_model else self.device)
|
torch.device('cpu') if offload_model else self.device)
|
||||||
if offload_model:
|
if offload_model:
|
||||||
torch.cuda.empty_cache()
|
empty_cache()
|
||||||
noise_pred = noise_pred_uncond + guide_scale * (
|
noise_pred = noise_pred_uncond + guide_scale * (
|
||||||
noise_pred_cond - noise_pred_uncond)
|
noise_pred_cond - noise_pred_uncond)
|
||||||
|
|
||||||
@ -361,7 +371,7 @@ class WanFLF2V:
|
|||||||
|
|
||||||
if offload_model:
|
if offload_model:
|
||||||
self.model.cpu()
|
self.model.cpu()
|
||||||
torch.cuda.empty_cache()
|
empty_cache()
|
||||||
|
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
videos = self.vae.decode(x0)
|
videos = self.vae.decode(x0)
|
||||||
@ -370,7 +380,7 @@ class WanFLF2V:
|
|||||||
del sample_scheduler
|
del sample_scheduler
|
||||||
if offload_model:
|
if offload_model:
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.synchronize()
|
synchronize()
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
|
|||||||
@ -12,10 +12,19 @@ from functools import partial
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda.amp as amp
|
import torch.cuda.amp as amp
|
||||||
|
from torch.cuda import empty_cache, synchronize
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torchvision.transforms.functional as TF
|
import torchvision.transforms.functional as TF
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch_musa
|
||||||
|
import torch_musa.core.amp as amp
|
||||||
|
from torch_musa.core.memory import empty_cache
|
||||||
|
from torch_musa.core.device import synchronize
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
torch_musa = None
|
||||||
|
|
||||||
from .distributed.fsdp import shard_model
|
from .distributed.fsdp import shard_model
|
||||||
from .modules.clip import CLIPModel
|
from .modules.clip import CLIPModel
|
||||||
from .modules.model import WanModel
|
from .modules.model import WanModel
|
||||||
@ -27,6 +36,7 @@ from .utils.fm_solvers import (
|
|||||||
retrieve_timesteps,
|
retrieve_timesteps,
|
||||||
)
|
)
|
||||||
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||||
|
from .utils.platform import get_device
|
||||||
|
|
||||||
|
|
||||||
class WanI2V:
|
class WanI2V:
|
||||||
@ -66,7 +76,7 @@ class WanI2V:
|
|||||||
init_on_cpu (`bool`, *optional*, defaults to True):
|
init_on_cpu (`bool`, *optional*, defaults to True):
|
||||||
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
|
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
|
||||||
"""
|
"""
|
||||||
self.device = torch.device(f"cuda:{device_id}")
|
self.device = get_device(device_id)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.use_usp = use_usp
|
self.use_usp = use_usp
|
||||||
@ -296,7 +306,7 @@ class WanI2V:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if offload_model:
|
if offload_model:
|
||||||
torch.cuda.empty_cache()
|
empty_cache()
|
||||||
|
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
for _, t in enumerate(tqdm(timesteps)):
|
for _, t in enumerate(tqdm(timesteps)):
|
||||||
@ -309,12 +319,12 @@ class WanI2V:
|
|||||||
latent_model_input, t=timestep, **arg_c)[0].to(
|
latent_model_input, t=timestep, **arg_c)[0].to(
|
||||||
torch.device('cpu') if offload_model else self.device)
|
torch.device('cpu') if offload_model else self.device)
|
||||||
if offload_model:
|
if offload_model:
|
||||||
torch.cuda.empty_cache()
|
empty_cache()
|
||||||
noise_pred_uncond = self.model(
|
noise_pred_uncond = self.model(
|
||||||
latent_model_input, t=timestep, **arg_null)[0].to(
|
latent_model_input, t=timestep, **arg_null)[0].to(
|
||||||
torch.device('cpu') if offload_model else self.device)
|
torch.device('cpu') if offload_model else self.device)
|
||||||
if offload_model:
|
if offload_model:
|
||||||
torch.cuda.empty_cache()
|
empty_cache()
|
||||||
noise_pred = noise_pred_uncond + guide_scale * (
|
noise_pred = noise_pred_uncond + guide_scale * (
|
||||||
noise_pred_cond - noise_pred_uncond)
|
noise_pred_cond - noise_pred_uncond)
|
||||||
|
|
||||||
@ -334,7 +344,7 @@ class WanI2V:
|
|||||||
|
|
||||||
if offload_model:
|
if offload_model:
|
||||||
self.model.cpu()
|
self.model.cpu()
|
||||||
torch.cuda.empty_cache()
|
empty_cache()
|
||||||
|
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
videos = self.vae.decode(x0)
|
videos = self.vae.decode(x0)
|
||||||
@ -343,7 +353,7 @@ class WanI2V:
|
|||||||
del sample_scheduler
|
del sample_scheduler
|
||||||
if offload_model:
|
if offload_model:
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.synchronize()
|
synchronize()
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,6 @@
|
|||||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
|
import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -13,7 +15,13 @@ try:
|
|||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
FLASH_ATTN_2_AVAILABLE = False
|
FLASH_ATTN_2_AVAILABLE = False
|
||||||
|
|
||||||
import warnings
|
try:
|
||||||
|
import torch_musa
|
||||||
|
FLASH_ATTN_3_AVAILABLE = False
|
||||||
|
FLASH_ATTN_2_AVAILABLE = False
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
torch_musa = None
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'flash_attention',
|
'flash_attention',
|
||||||
@ -51,7 +59,7 @@ def flash_attention(
|
|||||||
"""
|
"""
|
||||||
half_dtypes = (torch.float16, torch.bfloat16)
|
half_dtypes = (torch.float16, torch.bfloat16)
|
||||||
assert dtype in half_dtypes
|
assert dtype in half_dtypes
|
||||||
assert q.device.type == 'cuda' and q.size(-1) <= 256
|
assert (q.device.type == "cuda" or q.device.type == "musa") and q.size(-1) <= 256
|
||||||
|
|
||||||
# params
|
# params
|
||||||
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
|
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
|
||||||
@ -173,7 +181,7 @@ def attention(
|
|||||||
v = v.transpose(1, 2).to(dtype)
|
v = v.transpose(1, 2).to(dtype)
|
||||||
|
|
||||||
out = torch.nn.functional.scaled_dot_product_attention(
|
out = torch.nn.functional.scaled_dot_product_attention(
|
||||||
q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
|
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=causal, scale=softmax_scale)
|
||||||
|
|
||||||
out = out.transpose(1, 2).contiguous()
|
out = out.transpose(1, 2).contiguous()
|
||||||
return out
|
return out
|
||||||
|
|||||||
@ -6,12 +6,20 @@ import math
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import torch.cuda.amp as amp
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
|
|
||||||
from .attention import flash_attention
|
from .attention import flash_attention
|
||||||
from .tokenizers import HuggingfaceTokenizer
|
from .tokenizers import HuggingfaceTokenizer
|
||||||
from .xlm_roberta import XLMRoberta
|
from .xlm_roberta import XLMRoberta
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch_musa
|
||||||
|
import torch_musa.core.amp as amp
|
||||||
|
from .attention import attention as flash_attention
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
torch_musa = None
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'XLMRobertaCLIP',
|
'XLMRobertaCLIP',
|
||||||
'clip_xlm_roberta_vit_h_14',
|
'clip_xlm_roberta_vit_h_14',
|
||||||
@ -82,7 +90,10 @@ class SelfAttention(nn.Module):
|
|||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
p = self.attn_dropout if self.training else 0.0
|
p = self.attn_dropout if self.training else 0.0
|
||||||
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
|
if torch_musa is not None:
|
||||||
|
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal)
|
||||||
|
else:
|
||||||
|
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
|
||||||
x = x.reshape(b, s, c)
|
x = x.reshape(b, s, c)
|
||||||
|
|
||||||
# output
|
# output
|
||||||
@ -194,7 +205,10 @@ class AttentionPool(nn.Module):
|
|||||||
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
|
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
|
||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
x = flash_attention(q, k, v, version=2)
|
if torch_musa is not None:
|
||||||
|
x = flash_attention(q, k, v)
|
||||||
|
else:
|
||||||
|
x = flash_attention(q, k, v, version=2)
|
||||||
x = x.reshape(b, 1, c)
|
x = x.reshape(b, 1, c)
|
||||||
|
|
||||||
# output
|
# output
|
||||||
@ -537,6 +551,6 @@ class CLIPModel:
|
|||||||
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
|
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
|
||||||
|
|
||||||
# forward
|
# forward
|
||||||
with torch.cuda.amp.autocast(dtype=self.dtype):
|
with amp.autocast(dtype=self.dtype):
|
||||||
out = self.model.visual(videos, use_31_block=True)
|
out = self.model.visual(videos, use_31_block=True)
|
||||||
return out
|
return out
|
||||||
|
|||||||
@ -7,7 +7,14 @@ import torch.nn as nn
|
|||||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||||
from diffusers.models.modeling_utils import ModelMixin
|
from diffusers.models.modeling_utils import ModelMixin
|
||||||
|
|
||||||
from .attention import flash_attention
|
from wan.modules.attention import flash_attention
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch_musa
|
||||||
|
import torch_musa.core.amp as amp
|
||||||
|
from wan.modules.attention import attention as flash_attention
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
torch_musa = None
|
||||||
|
|
||||||
__all__ = ['WanModel']
|
__all__ = ['WanModel']
|
||||||
|
|
||||||
@ -19,7 +26,7 @@ def sinusoidal_embedding_1d(dim, position):
|
|||||||
# preprocess
|
# preprocess
|
||||||
assert dim % 2 == 0
|
assert dim % 2 == 0
|
||||||
half = dim // 2
|
half = dim // 2
|
||||||
position = position.type(torch.float64)
|
position = position.type(torch.float32)
|
||||||
|
|
||||||
# calculation
|
# calculation
|
||||||
sinusoid = torch.outer(
|
sinusoid = torch.outer(
|
||||||
@ -39,6 +46,36 @@ def rope_params(max_seq_len, dim, theta=10000):
|
|||||||
return freqs
|
return freqs
|
||||||
|
|
||||||
|
|
||||||
|
@amp.autocast(enabled=False)
|
||||||
|
def rope_params_real(
|
||||||
|
max_seq_len, dim, theta=10000, dtype=torch.float32, device=torch.device("cpu")
|
||||||
|
):
|
||||||
|
assert dim % 2 == 0
|
||||||
|
freqs_real = torch.outer(
|
||||||
|
torch.arange(max_seq_len, dtype=dtype, device=device),
|
||||||
|
1.0
|
||||||
|
/ torch.pow(
|
||||||
|
theta, torch.arange(0, dim, 2, dtype=dtype, device=device).div(dim)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return torch.cos(freqs_real)
|
||||||
|
|
||||||
|
|
||||||
|
@amp.autocast(enabled=False)
|
||||||
|
def rope_params_imag(
|
||||||
|
max_seq_len, dim, theta=10000, dtype=torch.float32, device=torch.device("cpu")
|
||||||
|
):
|
||||||
|
assert dim % 2 == 0
|
||||||
|
freqs_imag = torch.outer(
|
||||||
|
torch.arange(max_seq_len, dtype=dtype, device=device),
|
||||||
|
1.0
|
||||||
|
/ torch.pow(
|
||||||
|
theta, torch.arange(0, dim, 2, dtype=dtype, device=device).div(dim)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return torch.sin(freqs_imag)
|
||||||
|
|
||||||
|
|
||||||
@amp.autocast(enabled=False)
|
@amp.autocast(enabled=False)
|
||||||
def rope_apply(x, grid_sizes, freqs):
|
def rope_apply(x, grid_sizes, freqs):
|
||||||
n, c = x.size(2), x.size(3) // 2
|
n, c = x.size(2), x.size(3) // 2
|
||||||
@ -70,6 +107,55 @@ def rope_apply(x, grid_sizes, freqs):
|
|||||||
return torch.stack(output).float()
|
return torch.stack(output).float()
|
||||||
|
|
||||||
|
|
||||||
|
@amp.autocast(enabled=False)
|
||||||
|
def rope_apply_musa(x, grid_sizes, freqs):
|
||||||
|
n, c = x.size(2), x.size(3) // 2
|
||||||
|
c0 = c - 2 * (c // 3)
|
||||||
|
c1 = c // 3
|
||||||
|
c2 = c // 3
|
||||||
|
|
||||||
|
# split freqs
|
||||||
|
freqs_real = freqs[0].split([c0, c1, c2], dim=1)
|
||||||
|
freqs_imag = freqs[-1].split([c0, c1, c2], dim=1)
|
||||||
|
|
||||||
|
# loop over samples
|
||||||
|
output = []
|
||||||
|
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
||||||
|
seq_len = f * h * w
|
||||||
|
|
||||||
|
# precompute multipliers
|
||||||
|
x_i = x[i, :seq_len].reshape(seq_len, n, c, 2)
|
||||||
|
x_real = x_i[..., 0]
|
||||||
|
x_imag = x_i[..., 1]
|
||||||
|
freqs_real = torch.cat(
|
||||||
|
[
|
||||||
|
freqs_real[0][:f].view(f, 1, 1, c0).expand(f, h, w, c0),
|
||||||
|
freqs_real[1][:h].view(1, h, 1, c1).expand(f, h, w, c1),
|
||||||
|
freqs_real[2][:w].view(1, 1, w, c2).expand(f, h, w, c2),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
).reshape(seq_len, 1, c)
|
||||||
|
freqs_imag = torch.cat(
|
||||||
|
[
|
||||||
|
freqs_imag[0][:f].view(f, 1, 1, c0).expand(f, h, w, c0),
|
||||||
|
freqs_imag[1][:h].view(1, h, 1, c1).expand(f, h, w, c1),
|
||||||
|
freqs_imag[2][:w].view(1, 1, w, c2).expand(f, h, w, c2),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
).reshape(seq_len, 1, c)
|
||||||
|
|
||||||
|
out_real = x_real * freqs_real - x_imag * freqs_imag
|
||||||
|
out_imag = x_real * freqs_imag + x_imag * freqs_real
|
||||||
|
|
||||||
|
# apply rotary embedding
|
||||||
|
x_out = torch.stack([out_real, out_imag], dim=-1).flatten(2)
|
||||||
|
x_out = torch.cat([x_out, x[i, seq_len:]], dim=0)
|
||||||
|
|
||||||
|
# append to collection
|
||||||
|
output.append(x_out)
|
||||||
|
return torch.stack(output)
|
||||||
|
|
||||||
|
|
||||||
class WanRMSNorm(nn.Module):
|
class WanRMSNorm(nn.Module):
|
||||||
|
|
||||||
def __init__(self, dim, eps=1e-5):
|
def __init__(self, dim, eps=1e-5):
|
||||||
@ -146,12 +232,22 @@ class WanSelfAttention(nn.Module):
|
|||||||
|
|
||||||
q, k, v = qkv_fn(x)
|
q, k, v = qkv_fn(x)
|
||||||
|
|
||||||
x = flash_attention(
|
if torch_musa is not None:
|
||||||
q=rope_apply(q, grid_sizes, freqs),
|
x = flash_attention(
|
||||||
k=rope_apply(k, grid_sizes, freqs),
|
q=rope_apply_musa(q, grid_sizes, freqs),
|
||||||
v=v,
|
k=rope_apply_musa(k, grid_sizes, freqs),
|
||||||
k_lens=seq_lens,
|
v=v,
|
||||||
window_size=self.window_size)
|
k_lens=seq_lens,
|
||||||
|
window_size=self.window_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x = flash_attention(
|
||||||
|
q=rope_apply(q, grid_sizes, freqs),
|
||||||
|
k=rope_apply(k, grid_sizes, freqs),
|
||||||
|
v=v,
|
||||||
|
k_lens=seq_lens,
|
||||||
|
window_size=self.window_size,
|
||||||
|
)
|
||||||
|
|
||||||
# output
|
# output
|
||||||
x = x.flatten(2)
|
x = x.flatten(2)
|
||||||
@ -477,12 +573,33 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
# buffers (don't use register_buffer otherwise dtype will be changed in to())
|
# buffers (don't use register_buffer otherwise dtype will be changed in to())
|
||||||
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
||||||
d = dim // num_heads
|
d = dim // num_heads
|
||||||
self.freqs = torch.cat([
|
if torch_musa is not None:
|
||||||
rope_params(1024, d - 4 * (d // 6)),
|
freqs_real = torch.cat(
|
||||||
rope_params(1024, 2 * (d // 6)),
|
[
|
||||||
rope_params(1024, 2 * (d // 6))
|
rope_params_real(1024, d - 4 * (d // 6)),
|
||||||
],
|
rope_params_real(1024, 2 * (d // 6)),
|
||||||
dim=1)
|
rope_params_real(1024, 2 * (d // 6)),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
freqs_imag = torch.cat(
|
||||||
|
[
|
||||||
|
rope_params_imag(1024, d - 4 * (d // 6)),
|
||||||
|
rope_params_imag(1024, 2 * (d // 6)),
|
||||||
|
rope_params_imag(1024, 2 * (d // 6)),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
self.freqs = (freqs_real, freqs_imag)
|
||||||
|
else:
|
||||||
|
self.freqs = torch.cat(
|
||||||
|
[
|
||||||
|
rope_params(1024, d - 4 * (d // 6)),
|
||||||
|
rope_params(1024, 2 * (d // 6)),
|
||||||
|
rope_params(1024, 2 * (d // 6)),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
if model_type == 'i2v' or model_type == 'flf2v':
|
if model_type == 'i2v' or model_type == 'flf2v':
|
||||||
self.img_emb = MLPProj(1280, dim, flf_pos_emb=model_type == 'flf2v')
|
self.img_emb = MLPProj(1280, dim, flf_pos_emb=model_type == 'flf2v')
|
||||||
@ -523,9 +640,17 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
if self.model_type == 'i2v' or self.model_type == 'flf2v':
|
if self.model_type == 'i2v' or self.model_type == 'flf2v':
|
||||||
assert clip_fea is not None and y is not None
|
assert clip_fea is not None and y is not None
|
||||||
# params
|
# params
|
||||||
|
dtype = self.patch_embedding.weight.dtype
|
||||||
device = self.patch_embedding.weight.device
|
device = self.patch_embedding.weight.device
|
||||||
if self.freqs.device != device:
|
if torch_musa is not None:
|
||||||
self.freqs = self.freqs.to(device)
|
if self.freqs[0].dtype != dtype or self.freqs[0].device != device:
|
||||||
|
self.freqs = (
|
||||||
|
self.freqs[0].to(dtype=dtype, device=device),
|
||||||
|
self.freqs[-1].to(dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if self.freqs.dtype != dtype or self.freqs.device != device:
|
||||||
|
self.freqs = self.freqs.to(dtype=dtype, device=device)
|
||||||
|
|
||||||
if y is not None:
|
if y is not None:
|
||||||
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
||||||
|
|||||||
@ -6,6 +6,13 @@ import math
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch.cuda import current_device
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch_musa
|
||||||
|
from torch_musa.core.device import current_device
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
torch_musa = None
|
||||||
|
|
||||||
from .tokenizers import HuggingfaceTokenizer
|
from .tokenizers import HuggingfaceTokenizer
|
||||||
|
|
||||||
@ -475,7 +482,7 @@ class T5EncoderModel:
|
|||||||
self,
|
self,
|
||||||
text_len,
|
text_len,
|
||||||
dtype=torch.bfloat16,
|
dtype=torch.bfloat16,
|
||||||
device=torch.cuda.current_device(),
|
device=current_device(),
|
||||||
checkpoint_path=None,
|
checkpoint_path=None,
|
||||||
tokenizer_path=None,
|
tokenizer_path=None,
|
||||||
shard_fn=None,
|
shard_fn=None,
|
||||||
|
|||||||
@ -4,6 +4,12 @@ import torch.cuda.amp as amp
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from diffusers.configuration_utils import register_to_config
|
from diffusers.configuration_utils import register_to_config
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch_musa
|
||||||
|
import torch_musa.core.amp as amp
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
torch_musa = None
|
||||||
|
|
||||||
from .model import WanAttentionBlock, WanModel, sinusoidal_embedding_1d
|
from .model import WanAttentionBlock, WanModel, sinusoidal_embedding_1d
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -5,8 +5,17 @@ import torch
|
|||||||
import torch.cuda.amp as amp
|
import torch.cuda.amp as amp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch.nn import Upsample
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch_musa
|
||||||
|
import torch_musa.core.amp as amp
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
torch_musa = None
|
||||||
|
|
||||||
|
from wan.utils.platform import get_device
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'WanVAE',
|
'WanVAE',
|
||||||
]
|
]
|
||||||
@ -622,7 +631,7 @@ class WanVAE:
|
|||||||
z_dim=16,
|
z_dim=16,
|
||||||
vae_pth='cache/vae_step_411000.pth',
|
vae_pth='cache/vae_step_411000.pth',
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
device="cuda"):
|
device=get_device()):
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
|
|||||||
@ -11,9 +11,18 @@ from functools import partial
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda.amp as amp
|
import torch.cuda.amp as amp
|
||||||
|
from torch.cuda import empty_cache, synchronize
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch_musa
|
||||||
|
import torch_musa.core.amp as amp
|
||||||
|
from torch_musa.core.memory import empty_cache
|
||||||
|
from torch_musa.core.device import synchronize
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
torch_musa = None
|
||||||
|
|
||||||
from .distributed.fsdp import shard_model
|
from .distributed.fsdp import shard_model
|
||||||
from .modules.model import WanModel
|
from .modules.model import WanModel
|
||||||
from .modules.t5 import T5EncoderModel
|
from .modules.t5 import T5EncoderModel
|
||||||
@ -24,7 +33,7 @@ from .utils.fm_solvers import (
|
|||||||
retrieve_timesteps,
|
retrieve_timesteps,
|
||||||
)
|
)
|
||||||
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||||
|
from .utils.platform import get_device
|
||||||
|
|
||||||
class WanT2V:
|
class WanT2V:
|
||||||
|
|
||||||
@ -60,7 +69,7 @@ class WanT2V:
|
|||||||
t5_cpu (`bool`, *optional*, defaults to False):
|
t5_cpu (`bool`, *optional*, defaults to False):
|
||||||
Whether to place T5 model on CPU. Only works without t5_fsdp.
|
Whether to place T5 model on CPU. Only works without t5_fsdp.
|
||||||
"""
|
"""
|
||||||
self.device = torch.device(f"cuda:{device_id}")
|
self.device = get_device(device_id)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.t5_cpu = t5_cpu
|
self.t5_cpu = t5_cpu
|
||||||
@ -256,7 +265,7 @@ class WanT2V:
|
|||||||
x0 = latents
|
x0 = latents
|
||||||
if offload_model:
|
if offload_model:
|
||||||
self.model.cpu()
|
self.model.cpu()
|
||||||
torch.cuda.empty_cache()
|
empty_cache()
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
videos = self.vae.decode(x0)
|
videos = self.vae.decode(x0)
|
||||||
|
|
||||||
@ -264,7 +273,7 @@ class WanT2V:
|
|||||||
del sample_scheduler
|
del sample_scheduler
|
||||||
if offload_model:
|
if offload_model:
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.synchronize()
|
synchronize()
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
|
|||||||
@ -5,9 +5,10 @@ from .fm_solvers import (
|
|||||||
)
|
)
|
||||||
from .fm_solvers_unipc import FlowUniPCMultistepScheduler
|
from .fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||||
from .vace_processor import VaceVideoProcessor
|
from .vace_processor import VaceVideoProcessor
|
||||||
|
from .platform import get_device, get_torch_distributed_backend
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps',
|
'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps',
|
||||||
'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler',
|
'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler',
|
||||||
'VaceVideoProcessor'
|
'VaceVideoProcessor', 'get_device', 'get_torch_distributed_backend'
|
||||||
]
|
]
|
||||||
|
|||||||
34
wan/utils/platform.py
Normal file
34
wan/utils/platform.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch_musa
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
torch_musa = None
|
||||||
|
|
||||||
|
|
||||||
|
def _is_musa():
|
||||||
|
try:
|
||||||
|
if torch.musa.is_available():
|
||||||
|
return True
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_device(local_rank:Optional[int]=None) -> torch.device:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return torch.cuda.current_device() if local_rank is None else torch.device("cuda", local_rank)
|
||||||
|
elif _is_musa():
|
||||||
|
return torch.musa.current_device() if local_rank is None else torch.device("musa", local_rank)
|
||||||
|
else:
|
||||||
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
|
||||||
|
def get_torch_distributed_backend() -> str:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return "nccl"
|
||||||
|
elif _is_musa():
|
||||||
|
return "mccl"
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("No Accelerators(NV/MTT GPU) available")
|
||||||
24
wan/vace.py
24
wan/vace.py
@ -13,6 +13,7 @@ from functools import partial
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda.amp as amp
|
import torch.cuda.amp as amp
|
||||||
|
from torch.cuda import empty_cache, synchronize
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -20,6 +21,14 @@ import torchvision.transforms.functional as TF
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch_musa
|
||||||
|
import torch_musa.core.amp as amp
|
||||||
|
from torch_musa.core.memory import empty_cache
|
||||||
|
from torch_musa.core.device import synchronize
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
torch_musa = None
|
||||||
|
|
||||||
from .modules.vace_model import VaceWanModel
|
from .modules.vace_model import VaceWanModel
|
||||||
from .text2video import (
|
from .text2video import (
|
||||||
FlowDPMSolverMultistepScheduler,
|
FlowDPMSolverMultistepScheduler,
|
||||||
@ -32,6 +41,7 @@ from .text2video import (
|
|||||||
shard_model,
|
shard_model,
|
||||||
)
|
)
|
||||||
from .utils.vace_processor import VaceVideoProcessor
|
from .utils.vace_processor import VaceVideoProcessor
|
||||||
|
from .utils.platform import get_device, get_torch_distributed_backend
|
||||||
|
|
||||||
|
|
||||||
class WanVace(WanT2V):
|
class WanVace(WanT2V):
|
||||||
@ -68,7 +78,7 @@ class WanVace(WanT2V):
|
|||||||
t5_cpu (`bool`, *optional*, defaults to False):
|
t5_cpu (`bool`, *optional*, defaults to False):
|
||||||
Whether to place T5 model on CPU. Only works without t5_fsdp.
|
Whether to place T5 model on CPU. Only works without t5_fsdp.
|
||||||
"""
|
"""
|
||||||
self.device = torch.device(f"cuda:{device_id}")
|
self.device = get_device(device_id)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.t5_cpu = t5_cpu
|
self.t5_cpu = t5_cpu
|
||||||
@ -460,7 +470,7 @@ class WanVace(WanT2V):
|
|||||||
x0 = latents
|
x0 = latents
|
||||||
if offload_model:
|
if offload_model:
|
||||||
self.model.cpu()
|
self.model.cpu()
|
||||||
torch.cuda.empty_cache()
|
empty_cache()
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
videos = self.decode_latent(x0, input_ref_images)
|
videos = self.decode_latent(x0, input_ref_images)
|
||||||
|
|
||||||
@ -468,7 +478,7 @@ class WanVace(WanT2V):
|
|||||||
del sample_scheduler
|
del sample_scheduler
|
||||||
if offload_model:
|
if offload_model:
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.synchronize()
|
synchronize()
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
@ -568,7 +578,7 @@ class WanVaceMP(WanVace):
|
|||||||
|
|
||||||
torch.cuda.set_device(gpu)
|
torch.cuda.set_device(gpu)
|
||||||
dist.init_process_group(
|
dist.init_process_group(
|
||||||
backend='nccl',
|
backend=get_torch_distributed_backend(),
|
||||||
init_method='env://',
|
init_method='env://',
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size)
|
world_size=world_size)
|
||||||
@ -633,7 +643,7 @@ class WanVaceMP(WanVace):
|
|||||||
model = shard_fn(model)
|
model = shard_fn(model)
|
||||||
sample_neg_prompt = self.config.sample_neg_prompt
|
sample_neg_prompt = self.config.sample_neg_prompt
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
empty_cache()
|
||||||
event = initialized_events[gpu]
|
event = initialized_events[gpu]
|
||||||
in_q = in_q_list[gpu]
|
in_q = in_q_list[gpu]
|
||||||
event.set()
|
event.set()
|
||||||
@ -748,7 +758,7 @@ class WanVaceMP(WanVace):
|
|||||||
generator=seed_g)[0]
|
generator=seed_g)[0]
|
||||||
latents = [temp_x0.squeeze(0)]
|
latents = [temp_x0.squeeze(0)]
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
empty_cache()
|
||||||
x0 = latents
|
x0 = latents
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
videos = self.decode_latent(
|
videos = self.decode_latent(
|
||||||
@ -758,7 +768,7 @@ class WanVaceMP(WanVace):
|
|||||||
del sample_scheduler
|
del sample_scheduler
|
||||||
if offload_model:
|
if offload_model:
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.synchronize()
|
synchronize()
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user