Compare commits

..

2 Commits

Author SHA1 Message Date
Houchen Li
22f4714dc1
Merge 28919e5663 into 7c81b2f27d 2025-08-06 16:15:07 +03:00
Houchen Li
28919e5663 [feature] adapt for Moore Threads GPU family 2025-08-06 20:05:57 +08:00
8 changed files with 110 additions and 268 deletions

4
.gitignore vendored
View File

@ -21,7 +21,7 @@
*.html *.html
*.pdf *.pdf
*.whl *.whl
*cache/ cache
__pycache__/ __pycache__/
storage/ storage/
samples/ samples/
@ -29,11 +29,9 @@ samples/
!requirements.txt !requirements.txt
.DS_Store .DS_Store
*DS_Store *DS_Store
.vscode
google/ google/
Wan2.1-T2V-14B/ Wan2.1-T2V-14B/
Wan2.1-T2V-1.3B/ Wan2.1-T2V-1.3B/
Wan2.1-I2V-14B-480P/ Wan2.1-I2V-14B-480P/
Wan2.1-I2V-14B-720P/ Wan2.1-I2V-14B-720P/
poetry.lock poetry.lock
logs/

View File

@ -25,12 +25,8 @@ import wan
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_image, cache_video, str2bool from wan.utils.utils import cache_image, cache_video, str2bool
from wan.utils.platform import ( from wan.utils.platform import get_torch_distributed_backend
get_device_type, from wan.utils.chrono_inspector import ChronoInspector
get_torch_distributed_backend,
get_torch_profiler_activities,
)
EXAMPLE_PROMPT = { EXAMPLE_PROMPT = {
@ -256,11 +252,6 @@ def _parse_args():
type=float, type=float,
default=5.0, default=5.0,
help="Classifier free guidance scale.") help="Classifier free guidance scale.")
parser.add_argument(
"--profile",
action="store_true",
default=False,
help="profile the generating procedure.")
args = parser.parse_args() args = parser.parse_args()
@ -281,26 +272,6 @@ def _init_logging(rank):
logging.basicConfig(level=logging.ERROR) logging.basicConfig(level=logging.ERROR)
def _init_profiler():
return torch.profiler.profile(
activities=get_torch_profiler_activities(),
on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs'),
record_shapes=True,
profile_memory=True,
with_stack=True,
)
def _finalize_profiler(profiler):
profiler.stop()
table = profiler.key_averages().table(
sort_by=f"{get_device_type()}_time_total",
row_limit=20,
)
with open(f"logs/{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}-profiling.txt", "w") as f:
f.write(table)
def generate(args): def generate(args):
rank = int(os.getenv("RANK", 0)) rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1)) world_size = int(os.getenv("WORLD_SIZE", 1))
@ -367,10 +338,6 @@ def generate(args):
base_seed = [args.base_seed] if rank == 0 else [None] base_seed = [args.base_seed] if rank == 0 else [None]
dist.broadcast_object_list(base_seed, src=0) dist.broadcast_object_list(base_seed, src=0)
args.base_seed = base_seed[0] args.base_seed = base_seed[0]
profiler = None
if args.profile and rank == 0:
profiler = _init_profiler()
if "t2v" in args.task or "t2i" in args.task: if "t2v" in args.task or "t2i" in args.task:
if args.prompt is None: if args.prompt is None:
@ -399,20 +366,20 @@ def generate(args):
logging.info(f"Extended prompt: {args.prompt}") logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating WanT2V pipeline.") logging.info("Creating WanT2V pipeline.")
wan_t2v = wan.WanT2V( with ChronoInspector("Creating WanT2V pipeline"):
config=cfg, wan_t2v = wan.WanT2V(
checkpoint_dir=args.ckpt_dir, config=cfg,
device_id=device, checkpoint_dir=args.ckpt_dir,
rank=rank, device_id=device,
t5_fsdp=args.t5_fsdp, rank=rank,
dit_fsdp=args.dit_fsdp, t5_fsdp=args.t5_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1), dit_fsdp=args.dit_fsdp,
t5_cpu=args.t5_cpu, use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
profiler=profiler, t5_cpu=args.t5_cpu,
) )
logging.info("Warming up WanT2V pipeline ...") logging.info("Warming up WanT2V pipeline ...")
with torch.no_grad(): with ChronoInspector("Warming up WanT2V pipeline"):
_ = wan_t2v.generate( _ = wan_t2v.generate(
args.prompt, args.prompt,
size=SIZE_CONFIGS[args.size], size=SIZE_CONFIGS[args.size],
@ -425,16 +392,17 @@ def generate(args):
offload_model=args.offload_model) offload_model=args.offload_model)
logging.info(f"Generating {'image' if 't2i' in args.task else 'video'} ...") logging.info(f"Generating {'image' if 't2i' in args.task else 'video'} ...")
video = wan_t2v.generate( with ChronoInspector(f"Generating {'image' if 't2i' in args.task else 'video'}"):
args.prompt, video = wan_t2v.generate(
size=SIZE_CONFIGS[args.size], args.prompt,
frame_num=args.frame_num, size=SIZE_CONFIGS[args.size],
shift=args.sample_shift, frame_num=args.frame_num,
sample_solver=args.sample_solver, shift=args.sample_shift,
sampling_steps=args.sample_steps, sample_solver=args.sample_solver,
guide_scale=args.sample_guide_scale, sampling_steps=args.sample_steps,
seed=args.base_seed, guide_scale=args.sample_guide_scale,
offload_model=args.offload_model) seed=args.base_seed,
offload_model=args.offload_model)
elif "i2v" in args.task: elif "i2v" in args.task:
if args.prompt is None: if args.prompt is None:
@ -469,20 +437,20 @@ def generate(args):
logging.info(f"Extended prompt: {args.prompt}") logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating WanI2V pipeline.") logging.info("Creating WanI2V pipeline.")
wan_i2v = wan.WanI2V( with ChronoInspector("Creating WanI2V pipeline"):
config=cfg, wan_i2v = wan.WanI2V(
checkpoint_dir=args.ckpt_dir, config=cfg,
device_id=device, checkpoint_dir=args.ckpt_dir,
rank=rank, device_id=device,
t5_fsdp=args.t5_fsdp, rank=rank,
dit_fsdp=args.dit_fsdp, t5_fsdp=args.t5_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1), dit_fsdp=args.dit_fsdp,
t5_cpu=args.t5_cpu, use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
profiler=profiler, t5_cpu=args.t5_cpu,
) )
logging.info("Warming up WanI2V pipeline ...") logging.info("Warming up WanI2V pipeline ...")
with torch.no_grad(): with ChronoInspector("Warming up WanI2V pipeline"):
_ = wan_i2v.generate( _ = wan_i2v.generate(
args.prompt, args.prompt,
img, img,
@ -496,17 +464,18 @@ def generate(args):
offload_model=args.offload_model) offload_model=args.offload_model)
logging.info("Generating video ...") logging.info("Generating video ...")
video = wan_i2v.generate( with ChronoInspector("Generating video"):
args.prompt, video = wan_i2v.generate(
img, args.prompt,
max_area=MAX_AREA_CONFIGS[args.size], img,
frame_num=args.frame_num, max_area=MAX_AREA_CONFIGS[args.size],
shift=args.sample_shift, frame_num=args.frame_num,
sample_solver=args.sample_solver, shift=args.sample_shift,
sampling_steps=args.sample_steps, sample_solver=args.sample_solver,
guide_scale=args.sample_guide_scale, sampling_steps=args.sample_steps,
seed=args.base_seed, guide_scale=args.sample_guide_scale,
offload_model=args.offload_model) seed=args.base_seed,
offload_model=args.offload_model)
elif "flf2v" in args.task: elif "flf2v" in args.task:
if args.prompt is None: if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
@ -542,20 +511,20 @@ def generate(args):
logging.info(f"Extended prompt: {args.prompt}") logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating WanFLF2V pipeline.") logging.info("Creating WanFLF2V pipeline.")
wan_flf2v = wan.WanFLF2V( with ChronoInspector("Creating WanFLF2V pipeline"):
config=cfg, wan_flf2v = wan.WanFLF2V(
checkpoint_dir=args.ckpt_dir, config=cfg,
device_id=device, checkpoint_dir=args.ckpt_dir,
rank=rank, device_id=device,
t5_fsdp=args.t5_fsdp, rank=rank,
dit_fsdp=args.dit_fsdp, t5_fsdp=args.t5_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1), dit_fsdp=args.dit_fsdp,
t5_cpu=args.t5_cpu, use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
profiler=profiler t5_cpu=args.t5_cpu,
) )
logging.info("Warming up WanFLF2V pipeline ...") logging.info("Warming up WanFLF2V pipeline ...")
with torch.no_grad(): with ChronoInspector("Warming up WanFLF2V pipeline"):
_ = wan_flf2v.generate( _ = wan_flf2v.generate(
args.prompt, args.prompt,
first_frame, first_frame,
@ -570,18 +539,19 @@ def generate(args):
offload_model=args.offload_model) offload_model=args.offload_model)
logging.info("Generating video ...") logging.info("Generating video ...")
video = wan_flf2v.generate( with ChronoInspector("Generating video"):
args.prompt, video = wan_flf2v.generate(
first_frame, args.prompt,
last_frame, first_frame,
max_area=MAX_AREA_CONFIGS[args.size], last_frame,
frame_num=args.frame_num, max_area=MAX_AREA_CONFIGS[args.size],
shift=args.sample_shift, frame_num=args.frame_num,
sample_solver=args.sample_solver, shift=args.sample_shift,
sampling_steps=args.sample_steps, sample_solver=args.sample_solver,
guide_scale=args.sample_guide_scale, sampling_steps=args.sample_steps,
seed=args.base_seed, guide_scale=args.sample_guide_scale,
offload_model=args.offload_model) seed=args.base_seed,
offload_model=args.offload_model)
elif "vace" in args.task: elif "vace" in args.task:
if args.prompt is None: if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
@ -606,17 +576,17 @@ def generate(args):
logging.info(f"Extended prompt: {args.prompt}") logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating VACE pipeline.") logging.info("Creating VACE pipeline.")
wan_vace = wan.WanVace( with ChronoInspector("Creating VACE pipeline"):
config=cfg, wan_vace = wan.WanVace(
checkpoint_dir=args.ckpt_dir, config=cfg,
device_id=device, checkpoint_dir=args.ckpt_dir,
rank=rank, device_id=device,
t5_fsdp=args.t5_fsdp, rank=rank,
dit_fsdp=args.dit_fsdp, t5_fsdp=args.t5_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1), dit_fsdp=args.dit_fsdp,
t5_cpu=args.t5_cpu, use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
profiler=profiler t5_cpu=args.t5_cpu,
) )
src_video, src_mask, src_ref_images = wan_vace.prepare_source( src_video, src_mask, src_ref_images = wan_vace.prepare_source(
[args.src_video], [args.src_mask], [ [args.src_video], [args.src_mask], [
@ -625,8 +595,8 @@ def generate(args):
], args.frame_num, SIZE_CONFIGS[args.size], device) ], args.frame_num, SIZE_CONFIGS[args.size], device)
logging.info("Warming up VACE pipeline ...") logging.info("Warming up VACE pipeline ...")
with torch.no_grad(): with ChronoInspector("Warming up VACE pipeline"):
_ = wan_vace.generate( video = wan_vace.generate(
args.prompt, args.prompt,
src_video, src_video,
src_mask, src_mask,
@ -641,25 +611,23 @@ def generate(args):
offload_model=args.offload_model) offload_model=args.offload_model)
logging.info(f"Generating video...") logging.info(f"Generating video...")
video = wan_vace.generate( with ChronoInspector("Generating video"):
args.prompt, video = wan_vace.generate(
src_video, args.prompt,
src_mask, src_video,
src_ref_images, src_mask,
size=SIZE_CONFIGS[args.size], src_ref_images,
frame_num=args.frame_num, size=SIZE_CONFIGS[args.size],
shift=args.sample_shift, frame_num=args.frame_num,
sample_solver=args.sample_solver, shift=args.sample_shift,
sampling_steps=args.sample_steps, sample_solver=args.sample_solver,
guide_scale=args.sample_guide_scale, sampling_steps=args.sample_steps,
seed=args.base_seed, guide_scale=args.sample_guide_scale,
offload_model=args.offload_model) seed=args.base_seed,
offload_model=args.offload_model)
else: else:
raise ValueError(f"Unkown task type: {args.task}") raise ValueError(f"Unkown task type: {args.task}")
if args.profile and rank == 0:
_finalize_profiler(profiler)
if rank == 0: if rank == 0:
if args.save_file is None: if args.save_file is None:
formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S") formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")

View File

@ -6,7 +6,6 @@ import os
import random import random
import sys import sys
import types import types
from time import perf_counter
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
@ -55,7 +54,6 @@ class WanFLF2V:
use_usp=False, use_usp=False,
t5_cpu=False, t5_cpu=False,
init_on_cpu=True, init_on_cpu=True,
profiler=None,
): ):
r""" r"""
Initializes the image-to-video generation model components. Initializes the image-to-video generation model components.
@ -145,7 +143,6 @@ class WanFLF2V:
self.model.to(self.device) self.model.to(self.device)
self.sample_neg_prompt = config.sample_neg_prompt self.sample_neg_prompt = config.sample_neg_prompt
self.profiler = profiler
def generate(self, def generate(self,
input_prompt, input_prompt,
@ -200,14 +197,6 @@ class WanFLF2V:
- H: Frame height (from max_area) - H: Frame height (from max_area)
- W: Frame width from max_area) - W: Frame width from max_area)
""" """
if self.profiler and self.rank == 0:
self.profiler.start()
start_time = 0.0
end_time = 0.0
if self.rank == 0:
start_time = perf_counter()
first_frame_size = first_frame.size first_frame_size = first_frame.size
last_frame_size = last_frame.size last_frame_size = last_frame.size
first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).to( first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).to(
@ -300,10 +289,6 @@ class WanFLF2V:
])[0] ])[0]
y = torch.concat([msk, y]) y = torch.concat([msk, y])
if self.rank == 0:
end_time = perf_counter()
logging.info(f"[preprocess and VAE encode] Elapsed time: {end_time - start_time:.2f} seconds")
@contextmanager @contextmanager
def noop_no_sync(): def noop_no_sync():
yield yield
@ -354,14 +339,8 @@ class WanFLF2V:
if offload_model: if offload_model:
empty_cache() empty_cache()
if self.rank == 0:
start_time = perf_counter()
self.model.to(self.device) self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)): for _, t in enumerate(tqdm(timesteps)):
if self.profiler and self.rank == 0:
self.profiler.step()
latent_model_input = [latent.to(self.device)] latent_model_input = [latent.to(self.device)]
timestep = [t] timestep = [t]
@ -391,22 +370,15 @@ class WanFLF2V:
generator=seed_g)[0] generator=seed_g)[0]
latent = temp_x0.squeeze(0) latent = temp_x0.squeeze(0)
if self.rank == 0: x0 = [latent.to(self.device)]
end_time = perf_counter() del latent_model_input, timestep
logging.info(f"[sampling time steps] Elapsed time: {end_time - start_time:.2f} seconds")
x0 = [latent.to(self.device)]
del latent_model_input, timestep
if offload_model: if offload_model:
self.model.cpu() self.model.cpu()
empty_cache() empty_cache()
if self.rank == 0: if self.rank == 0:
start_time = perf_counter()
videos = self.vae.decode(x0) videos = self.vae.decode(x0)
end_time = perf_counter()
logging.info(f"[VAE decoding] Elapsed time: {end_time - start_time:.2f} seconds")
del noise, latent del noise, latent
del sample_scheduler del sample_scheduler

View File

@ -6,7 +6,6 @@ import os
import random import random
import sys import sys
import types import types
from time import perf_counter
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
@ -55,7 +54,6 @@ class WanI2V:
use_usp=False, use_usp=False,
t5_cpu=False, t5_cpu=False,
init_on_cpu=True, init_on_cpu=True,
profiler=None,
): ):
r""" r"""
Initializes the image-to-video generation model components. Initializes the image-to-video generation model components.
@ -145,7 +143,6 @@ class WanI2V:
self.model.to(self.device) self.model.to(self.device)
self.sample_neg_prompt = config.sample_neg_prompt self.sample_neg_prompt = config.sample_neg_prompt
self.profiler = profiler
def generate(self, def generate(self,
input_prompt, input_prompt,
@ -195,14 +192,6 @@ class WanI2V:
- H: Frame height (from max_area) - H: Frame height (from max_area)
- W: Frame width from max_area) - W: Frame width from max_area)
""" """
if self.profiler and self.rank == 0:
self.profiler.start()
start_time = 0.0
end_time = 0.0
if self.rank == 0:
start_time = perf_counter()
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
F = frame_num F = frame_num
@ -273,10 +262,6 @@ class WanI2V:
])[0] ])[0]
y = torch.concat([msk, y]) y = torch.concat([msk, y])
if self.rank == 0:
end_time = perf_counter()
logging.info(f"[preprocess and VAE encode] Elapsed time: {end_time - start_time:.2f} seconds")
@contextmanager @contextmanager
def noop_no_sync(): def noop_no_sync():
yield yield
@ -327,14 +312,8 @@ class WanI2V:
if offload_model: if offload_model:
empty_cache() empty_cache()
if self.rank == 0:
start_time = perf_counter()
self.model.to(self.device) self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)): for _, t in enumerate(tqdm(timesteps)):
if self.profiler and self.rank == 0:
self.profiler.step()
latent_model_input = [latent.to(self.device)] latent_model_input = [latent.to(self.device)]
timestep = [t] timestep = [t]
@ -364,22 +343,15 @@ class WanI2V:
generator=seed_g)[0] generator=seed_g)[0]
latent = temp_x0.squeeze(0) latent = temp_x0.squeeze(0)
if self.rank == 0: x0 = [latent.to(self.device)]
end_time = perf_counter() del latent_model_input, timestep
logging.info(f"[sampling time steps] Elapsed time: {end_time - start_time:.2f} seconds")
x0 = [latent.to(self.device)]
del latent_model_input, timestep
if offload_model: if offload_model:
self.model.cpu() self.model.cpu()
empty_cache() empty_cache()
if self.rank == 0: if self.rank == 0:
start_time = perf_counter()
videos = self.vae.decode(x0) videos = self.vae.decode(x0)
end_time = perf_counter()
logging.info(f"[VAE decoding] Elapsed time: {end_time - start_time:.2f} seconds")
del noise, latent del noise, latent
del sample_scheduler del sample_scheduler

View File

@ -16,7 +16,7 @@ try:
from wan.modules.attention import attention as flash_attention from wan.modules.attention import attention as flash_attention
torch.backends.mudnn.allow_tf32 = True torch.backends.mudnn.allow_tf32 = True
except ModuleNotFoundError: except ModuleNotFoundError:
torch_musa = None pass
__all__ = ['WanModel'] __all__ = ['WanModel']

View File

@ -6,7 +6,6 @@ import os
import random import random
import sys import sys
import types import types
from time import perf_counter
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
@ -51,7 +50,6 @@ class WanT2V:
dit_fsdp=False, dit_fsdp=False,
use_usp=False, use_usp=False,
t5_cpu=False, t5_cpu=False,
profiler=None,
): ):
r""" r"""
Initializes the Wan text-to-video generation model components. Initializes the Wan text-to-video generation model components.
@ -126,7 +124,6 @@ class WanT2V:
self.model.to(self.device) self.model.to(self.device)
self.sample_neg_prompt = config.sample_neg_prompt self.sample_neg_prompt = config.sample_neg_prompt
self.profiler = profiler
def generate(self, def generate(self,
input_prompt, input_prompt,
@ -172,14 +169,6 @@ class WanT2V:
- H: Frame height (from size) - H: Frame height (from size)
- W: Frame width from size) - W: Frame width from size)
""" """
if self.profiler and self.rank == 0:
self.profiler.start()
start_time = 0.0
end_time = 0.0
if self.rank == 0:
start_time = perf_counter()
# preprocess # preprocess
F = frame_num F = frame_num
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
@ -219,10 +208,6 @@ class WanT2V:
generator=seed_g) generator=seed_g)
] ]
if self.rank == 0:
end_time = perf_counter()
logging.info(f"[preprocess] Elapsed time: {end_time - start_time:.2f} seconds")
@contextmanager @contextmanager
def noop_no_sync(): def noop_no_sync():
yield yield
@ -259,14 +244,8 @@ class WanT2V:
arg_c = {'context': context, 'seq_len': seq_len} arg_c = {'context': context, 'seq_len': seq_len}
arg_null = {'context': context_null, 'seq_len': seq_len} arg_null = {'context': context_null, 'seq_len': seq_len}
if self.rank == 0:
start_time = perf_counter()
self.model.to(self.device) self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)): for _, t in enumerate(tqdm(timesteps)):
if self.profiler and self.rank == 0:
self.profiler.step()
latent_model_input = latents latent_model_input = latents
timestep = [t] timestep = [t]
@ -288,19 +267,12 @@ class WanT2V:
generator=seed_g)[0] generator=seed_g)[0]
latents = [temp_x0.squeeze(0)] latents = [temp_x0.squeeze(0)]
if self.rank == 0:
end_time = perf_counter()
logging.info(f"[sampling time steps] Elapsed time: {end_time - start_time:.2f} seconds")
x0 = latents x0 = latents
if offload_model: if offload_model:
self.model.cpu() self.model.cpu()
empty_cache() empty_cache()
if self.rank == 0: if self.rank == 0:
start_time = perf_counter()
videos = self.vae.decode(x0) videos = self.vae.decode(x0)
end_time = perf_counter()
logging.info(f"[VAE decoding] Elapsed time: {end_time - start_time:.2f} seconds")
del noise, latents del noise, latents
del sample_scheduler del sample_scheduler

View File

@ -1,4 +1,4 @@
from typing import Optional, List from typing import Optional
import torch import torch
@ -8,7 +8,7 @@ except ModuleNotFoundError:
torch_musa = None torch_musa = None
def _is_musa() -> bool: def _is_musa():
if torch_musa is None: if torch_musa is None:
return False return False
else: else:
@ -48,14 +48,3 @@ def get_torch_distributed_backend() -> str:
return "mccl" return "mccl"
else: else:
raise NotImplementedError("No Accelerators(NV/MTT GPU) available") raise NotImplementedError("No Accelerators(NV/MTT GPU) available")
def get_torch_profiler_activities() -> List[torch.profiler.ProfilerActivity]:
activities: List[torch.profiler.ProfilerActivity] = [
torch.profiler.ProfilerActivity.CPU
]
if torch.cuda.is_available():
activities.append(torch.profiler.ProfilerActivity.CUDA)
elif _is_musa():
activities.append(torch.profiler.ProfilerActivity.MUSA)
return activities

View File

@ -8,7 +8,6 @@ import sys
import time import time
import traceback import traceback
import types import types
from time import perf_counter
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
@ -59,7 +58,6 @@ class WanVace(WanT2V):
dit_fsdp=False, dit_fsdp=False,
use_usp=False, use_usp=False,
t5_cpu=False, t5_cpu=False,
profiler=None,
): ):
r""" r"""
Initializes the Wan text-to-video generation model components. Initializes the Wan text-to-video generation model components.
@ -152,8 +150,6 @@ class WanVace(WanT2V):
seq_len=75600, seq_len=75600,
keep_last=True) keep_last=True)
self.profiler = profiler
def vace_encode_frames(self, frames, ref_images, masks=None, vae=None): def vace_encode_frames(self, frames, ref_images, masks=None, vae=None):
vae = self.vae if vae is None else vae vae = self.vae if vae is None else vae
if ref_images is None: if ref_images is None:
@ -358,14 +354,6 @@ class WanVace(WanT2V):
- H: Frame height (from size) - H: Frame height (from size)
- W: Frame width from size) - W: Frame width from size)
""" """
if self.profiler and self.rank == 0:
self.profiler.start()
start_time = 0.0
end_time = 0.0
if self.rank == 0:
start_time = perf_counter()
# preprocess # preprocess
# F = frame_num # F = frame_num
# target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, # target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
@ -416,10 +404,6 @@ class WanVace(WanT2V):
(self.patch_size[1] * self.patch_size[2]) * (self.patch_size[1] * self.patch_size[2]) *
target_shape[1] / self.sp_size) * self.sp_size target_shape[1] / self.sp_size) * self.sp_size
if self.rank == 0:
end_time = perf_counter()
logging.info(f"[preprocess] Elapsed time: {end_time - start_time:.2f} seconds")
@contextmanager @contextmanager
def noop_no_sync(): def noop_no_sync():
yield yield
@ -456,19 +440,13 @@ class WanVace(WanT2V):
arg_c = {'context': context, 'seq_len': seq_len} arg_c = {'context': context, 'seq_len': seq_len}
arg_null = {'context': context_null, 'seq_len': seq_len} arg_null = {'context': context_null, 'seq_len': seq_len}
if self.rank == 0:
start_time = perf_counter()
self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)): for _, t in enumerate(tqdm(timesteps)):
if self.profiler and self.rank == 0:
self.profiler.step()
latent_model_input = latents latent_model_input = latents
timestep = [t] timestep = [t]
timestep = torch.stack(timestep) timestep = torch.stack(timestep)
self.model.to(self.device)
noise_pred_cond = self.model( noise_pred_cond = self.model(
latent_model_input, latent_model_input,
t=timestep, t=timestep,
@ -493,19 +471,12 @@ class WanVace(WanT2V):
generator=seed_g)[0] generator=seed_g)[0]
latents = [temp_x0.squeeze(0)] latents = [temp_x0.squeeze(0)]
if self.rank == 0:
end_time = perf_counter()
logging.info(f"[sampling time steps] Elapsed time: {end_time - start_time:.2f} seconds")
x0 = latents x0 = latents
if offload_model: if offload_model:
self.model.cpu() self.model.cpu()
empty_cache() empty_cache()
if self.rank == 0: if self.rank == 0:
start_time = perf_counter()
videos = self.decode_latent(x0, input_ref_images) videos = self.decode_latent(x0, input_ref_images)
end_time = perf_counter()
logging.info(f"[VAE decoding] Elapsed time: {end_time - start_time:.2f} seconds")
del noise, latents del noise, latents
del sample_scheduler del sample_scheduler