Compare commits

..

2 Commits

Author SHA1 Message Date
Houchen Li
a76ccc9d36
Merge 21df001d53 into 7c81b2f27d 2025-08-08 17:55:51 +08:00
Houchen Li
21df001d53 [feature] adapt for Moore Threads GPU family 2025-08-08 17:08:19 +08:00
8 changed files with 268 additions and 110 deletions

4
.gitignore vendored
View File

@ -21,7 +21,7 @@
*.html
*.pdf
*.whl
cache
*cache/
__pycache__/
storage/
samples/
@ -29,9 +29,11 @@ samples/
!requirements.txt
.DS_Store
*DS_Store
.vscode
google/
Wan2.1-T2V-14B/
Wan2.1-T2V-1.3B/
Wan2.1-I2V-14B-480P/
Wan2.1-I2V-14B-720P/
poetry.lock
logs/

View File

@ -25,8 +25,12 @@ import wan
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_image, cache_video, str2bool
from wan.utils.platform import get_torch_distributed_backend
from wan.utils.chrono_inspector import ChronoInspector
from wan.utils.platform import (
get_device_type,
get_torch_distributed_backend,
get_torch_profiler_activities,
)
EXAMPLE_PROMPT = {
@ -252,6 +256,11 @@ def _parse_args():
type=float,
default=5.0,
help="Classifier free guidance scale.")
parser.add_argument(
"--profile",
action="store_true",
default=False,
help="profile the generating procedure.")
args = parser.parse_args()
@ -272,6 +281,26 @@ def _init_logging(rank):
logging.basicConfig(level=logging.ERROR)
def _init_profiler():
return torch.profiler.profile(
activities=get_torch_profiler_activities(),
on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs'),
record_shapes=True,
profile_memory=True,
with_stack=True,
)
def _finalize_profiler(profiler):
profiler.stop()
table = profiler.key_averages().table(
sort_by=f"{get_device_type()}_time_total",
row_limit=20,
)
with open(f"logs/{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}-profiling.txt", "w") as f:
f.write(table)
def generate(args):
rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
@ -339,6 +368,10 @@ def generate(args):
dist.broadcast_object_list(base_seed, src=0)
args.base_seed = base_seed[0]
profiler = None
if args.profile and rank == 0:
profiler = _init_profiler()
if "t2v" in args.task or "t2i" in args.task:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
@ -366,7 +399,6 @@ def generate(args):
logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating WanT2V pipeline.")
with ChronoInspector("Creating WanT2V pipeline"):
wan_t2v = wan.WanT2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
@ -376,10 +408,11 @@ def generate(args):
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
profiler=profiler,
)
logging.info("Warming up WanT2V pipeline ...")
with ChronoInspector("Warming up WanT2V pipeline"):
with torch.no_grad():
_ = wan_t2v.generate(
args.prompt,
size=SIZE_CONFIGS[args.size],
@ -392,7 +425,6 @@ def generate(args):
offload_model=args.offload_model)
logging.info(f"Generating {'image' if 't2i' in args.task else 'video'} ...")
with ChronoInspector(f"Generating {'image' if 't2i' in args.task else 'video'}"):
video = wan_t2v.generate(
args.prompt,
size=SIZE_CONFIGS[args.size],
@ -437,7 +469,6 @@ def generate(args):
logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating WanI2V pipeline.")
with ChronoInspector("Creating WanI2V pipeline"):
wan_i2v = wan.WanI2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
@ -447,10 +478,11 @@ def generate(args):
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
profiler=profiler,
)
logging.info("Warming up WanI2V pipeline ...")
with ChronoInspector("Warming up WanI2V pipeline"):
with torch.no_grad():
_ = wan_i2v.generate(
args.prompt,
img,
@ -464,7 +496,6 @@ def generate(args):
offload_model=args.offload_model)
logging.info("Generating video ...")
with ChronoInspector("Generating video"):
video = wan_i2v.generate(
args.prompt,
img,
@ -511,7 +542,6 @@ def generate(args):
logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating WanFLF2V pipeline.")
with ChronoInspector("Creating WanFLF2V pipeline"):
wan_flf2v = wan.WanFLF2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
@ -521,10 +551,11 @@ def generate(args):
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
profiler=profiler
)
logging.info("Warming up WanFLF2V pipeline ...")
with ChronoInspector("Warming up WanFLF2V pipeline"):
with torch.no_grad():
_ = wan_flf2v.generate(
args.prompt,
first_frame,
@ -539,7 +570,6 @@ def generate(args):
offload_model=args.offload_model)
logging.info("Generating video ...")
with ChronoInspector("Generating video"):
video = wan_flf2v.generate(
args.prompt,
first_frame,
@ -576,7 +606,6 @@ def generate(args):
logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating VACE pipeline.")
with ChronoInspector("Creating VACE pipeline"):
wan_vace = wan.WanVace(
config=cfg,
checkpoint_dir=args.ckpt_dir,
@ -586,6 +615,7 @@ def generate(args):
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
profiler=profiler
)
src_video, src_mask, src_ref_images = wan_vace.prepare_source(
@ -595,8 +625,8 @@ def generate(args):
], args.frame_num, SIZE_CONFIGS[args.size], device)
logging.info("Warming up VACE pipeline ...")
with ChronoInspector("Warming up VACE pipeline"):
video = wan_vace.generate(
with torch.no_grad():
_ = wan_vace.generate(
args.prompt,
src_video,
src_mask,
@ -611,7 +641,6 @@ def generate(args):
offload_model=args.offload_model)
logging.info(f"Generating video...")
with ChronoInspector("Generating video"):
video = wan_vace.generate(
args.prompt,
src_video,
@ -628,6 +657,9 @@ def generate(args):
else:
raise ValueError(f"Unkown task type: {args.task}")
if args.profile and rank == 0:
_finalize_profiler(profiler)
if rank == 0:
if args.save_file is None:
formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")

View File

@ -6,6 +6,7 @@ import os
import random
import sys
import types
from time import perf_counter
from contextlib import contextmanager
from functools import partial
@ -54,6 +55,7 @@ class WanFLF2V:
use_usp=False,
t5_cpu=False,
init_on_cpu=True,
profiler=None,
):
r"""
Initializes the image-to-video generation model components.
@ -143,6 +145,7 @@ class WanFLF2V:
self.model.to(self.device)
self.sample_neg_prompt = config.sample_neg_prompt
self.profiler = profiler
def generate(self,
input_prompt,
@ -197,6 +200,14 @@ class WanFLF2V:
- H: Frame height (from max_area)
- W: Frame width from max_area)
"""
if self.profiler and self.rank == 0:
self.profiler.start()
start_time = 0.0
end_time = 0.0
if self.rank == 0:
start_time = perf_counter()
first_frame_size = first_frame.size
last_frame_size = last_frame.size
first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).to(
@ -289,6 +300,10 @@ class WanFLF2V:
])[0]
y = torch.concat([msk, y])
if self.rank == 0:
end_time = perf_counter()
logging.info(f"[preprocess and VAE encode] Elapsed time: {end_time - start_time:.2f} seconds")
@contextmanager
def noop_no_sync():
yield
@ -339,8 +354,14 @@ class WanFLF2V:
if offload_model:
empty_cache()
if self.rank == 0:
start_time = perf_counter()
self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)):
if self.profiler and self.rank == 0:
self.profiler.step()
latent_model_input = [latent.to(self.device)]
timestep = [t]
@ -370,6 +391,10 @@ class WanFLF2V:
generator=seed_g)[0]
latent = temp_x0.squeeze(0)
if self.rank == 0:
end_time = perf_counter()
logging.info(f"[sampling time steps] Elapsed time: {end_time - start_time:.2f} seconds")
x0 = [latent.to(self.device)]
del latent_model_input, timestep
@ -378,7 +403,10 @@ class WanFLF2V:
empty_cache()
if self.rank == 0:
start_time = perf_counter()
videos = self.vae.decode(x0)
end_time = perf_counter()
logging.info(f"[VAE decoding] Elapsed time: {end_time - start_time:.2f} seconds")
del noise, latent
del sample_scheduler

View File

@ -6,6 +6,7 @@ import os
import random
import sys
import types
from time import perf_counter
from contextlib import contextmanager
from functools import partial
@ -54,6 +55,7 @@ class WanI2V:
use_usp=False,
t5_cpu=False,
init_on_cpu=True,
profiler=None,
):
r"""
Initializes the image-to-video generation model components.
@ -143,6 +145,7 @@ class WanI2V:
self.model.to(self.device)
self.sample_neg_prompt = config.sample_neg_prompt
self.profiler = profiler
def generate(self,
input_prompt,
@ -192,6 +195,14 @@ class WanI2V:
- H: Frame height (from max_area)
- W: Frame width from max_area)
"""
if self.profiler and self.rank == 0:
self.profiler.start()
start_time = 0.0
end_time = 0.0
if self.rank == 0:
start_time = perf_counter()
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
F = frame_num
@ -262,6 +273,10 @@ class WanI2V:
])[0]
y = torch.concat([msk, y])
if self.rank == 0:
end_time = perf_counter()
logging.info(f"[preprocess and VAE encode] Elapsed time: {end_time - start_time:.2f} seconds")
@contextmanager
def noop_no_sync():
yield
@ -312,8 +327,14 @@ class WanI2V:
if offload_model:
empty_cache()
if self.rank == 0:
start_time = perf_counter()
self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)):
if self.profiler and self.rank == 0:
self.profiler.step()
latent_model_input = [latent.to(self.device)]
timestep = [t]
@ -343,6 +364,10 @@ class WanI2V:
generator=seed_g)[0]
latent = temp_x0.squeeze(0)
if self.rank == 0:
end_time = perf_counter()
logging.info(f"[sampling time steps] Elapsed time: {end_time - start_time:.2f} seconds")
x0 = [latent.to(self.device)]
del latent_model_input, timestep
@ -351,7 +376,10 @@ class WanI2V:
empty_cache()
if self.rank == 0:
start_time = perf_counter()
videos = self.vae.decode(x0)
end_time = perf_counter()
logging.info(f"[VAE decoding] Elapsed time: {end_time - start_time:.2f} seconds")
del noise, latent
del sample_scheduler

View File

@ -16,7 +16,7 @@ try:
from wan.modules.attention import attention as flash_attention
torch.backends.mudnn.allow_tf32 = True
except ModuleNotFoundError:
pass
torch_musa = None
__all__ = ['WanModel']

View File

@ -6,6 +6,7 @@ import os
import random
import sys
import types
from time import perf_counter
from contextlib import contextmanager
from functools import partial
@ -50,6 +51,7 @@ class WanT2V:
dit_fsdp=False,
use_usp=False,
t5_cpu=False,
profiler=None,
):
r"""
Initializes the Wan text-to-video generation model components.
@ -124,6 +126,7 @@ class WanT2V:
self.model.to(self.device)
self.sample_neg_prompt = config.sample_neg_prompt
self.profiler = profiler
def generate(self,
input_prompt,
@ -169,6 +172,14 @@ class WanT2V:
- H: Frame height (from size)
- W: Frame width from size)
"""
if self.profiler and self.rank == 0:
self.profiler.start()
start_time = 0.0
end_time = 0.0
if self.rank == 0:
start_time = perf_counter()
# preprocess
F = frame_num
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
@ -208,6 +219,10 @@ class WanT2V:
generator=seed_g)
]
if self.rank == 0:
end_time = perf_counter()
logging.info(f"[preprocess] Elapsed time: {end_time - start_time:.2f} seconds")
@contextmanager
def noop_no_sync():
yield
@ -244,8 +259,14 @@ class WanT2V:
arg_c = {'context': context, 'seq_len': seq_len}
arg_null = {'context': context_null, 'seq_len': seq_len}
if self.rank == 0:
start_time = perf_counter()
self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)):
if self.profiler and self.rank == 0:
self.profiler.step()
latent_model_input = latents
timestep = [t]
@ -267,12 +288,19 @@ class WanT2V:
generator=seed_g)[0]
latents = [temp_x0.squeeze(0)]
if self.rank == 0:
end_time = perf_counter()
logging.info(f"[sampling time steps] Elapsed time: {end_time - start_time:.2f} seconds")
x0 = latents
if offload_model:
self.model.cpu()
empty_cache()
if self.rank == 0:
start_time = perf_counter()
videos = self.vae.decode(x0)
end_time = perf_counter()
logging.info(f"[VAE decoding] Elapsed time: {end_time - start_time:.2f} seconds")
del noise, latents
del sample_scheduler

View File

@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, List
import torch
@ -8,7 +8,7 @@ except ModuleNotFoundError:
torch_musa = None
def _is_musa():
def _is_musa() -> bool:
if torch_musa is None:
return False
else:
@ -48,3 +48,14 @@ def get_torch_distributed_backend() -> str:
return "mccl"
else:
raise NotImplementedError("No Accelerators(NV/MTT GPU) available")
def get_torch_profiler_activities() -> List[torch.profiler.ProfilerActivity]:
activities: List[torch.profiler.ProfilerActivity] = [
torch.profiler.ProfilerActivity.CPU
]
if torch.cuda.is_available():
activities.append(torch.profiler.ProfilerActivity.CUDA)
elif _is_musa():
activities.append(torch.profiler.ProfilerActivity.MUSA)
return activities

View File

@ -8,6 +8,7 @@ import sys
import time
import traceback
import types
from time import perf_counter
from contextlib import contextmanager
from functools import partial
@ -58,6 +59,7 @@ class WanVace(WanT2V):
dit_fsdp=False,
use_usp=False,
t5_cpu=False,
profiler=None,
):
r"""
Initializes the Wan text-to-video generation model components.
@ -150,6 +152,8 @@ class WanVace(WanT2V):
seq_len=75600,
keep_last=True)
self.profiler = profiler
def vace_encode_frames(self, frames, ref_images, masks=None, vae=None):
vae = self.vae if vae is None else vae
if ref_images is None:
@ -354,6 +358,14 @@ class WanVace(WanT2V):
- H: Frame height (from size)
- W: Frame width from size)
"""
if self.profiler and self.rank == 0:
self.profiler.start()
start_time = 0.0
end_time = 0.0
if self.rank == 0:
start_time = perf_counter()
# preprocess
# F = frame_num
# target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
@ -404,6 +416,10 @@ class WanVace(WanT2V):
(self.patch_size[1] * self.patch_size[2]) *
target_shape[1] / self.sp_size) * self.sp_size
if self.rank == 0:
end_time = perf_counter()
logging.info(f"[preprocess] Elapsed time: {end_time - start_time:.2f} seconds")
@contextmanager
def noop_no_sync():
yield
@ -440,13 +456,19 @@ class WanVace(WanT2V):
arg_c = {'context': context, 'seq_len': seq_len}
arg_null = {'context': context_null, 'seq_len': seq_len}
if self.rank == 0:
start_time = perf_counter()
self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)):
if self.profiler and self.rank == 0:
self.profiler.step()
latent_model_input = latents
timestep = [t]
timestep = torch.stack(timestep)
self.model.to(self.device)
noise_pred_cond = self.model(
latent_model_input,
t=timestep,
@ -471,12 +493,19 @@ class WanVace(WanT2V):
generator=seed_g)[0]
latents = [temp_x0.squeeze(0)]
if self.rank == 0:
end_time = perf_counter()
logging.info(f"[sampling time steps] Elapsed time: {end_time - start_time:.2f} seconds")
x0 = latents
if offload_model:
self.model.cpu()
empty_cache()
if self.rank == 0:
start_time = perf_counter()
videos = self.decode_latent(x0, input_ref_images)
end_time = perf_counter()
logging.info(f"[VAE decoding] Elapsed time: {end_time - start_time:.2f} seconds")
del noise, latents
del sample_scheduler