Compare commits

..

2 Commits

Author SHA1 Message Date
Houchen Li
a76ccc9d36
Merge 21df001d53 into 7c81b2f27d 2025-08-08 17:55:51 +08:00
Houchen Li
21df001d53 [feature] adapt for Moore Threads GPU family 2025-08-08 17:08:19 +08:00
8 changed files with 268 additions and 110 deletions

4
.gitignore vendored
View File

@ -21,7 +21,7 @@
*.html *.html
*.pdf *.pdf
*.whl *.whl
cache *cache/
__pycache__/ __pycache__/
storage/ storage/
samples/ samples/
@ -29,9 +29,11 @@ samples/
!requirements.txt !requirements.txt
.DS_Store .DS_Store
*DS_Store *DS_Store
.vscode
google/ google/
Wan2.1-T2V-14B/ Wan2.1-T2V-14B/
Wan2.1-T2V-1.3B/ Wan2.1-T2V-1.3B/
Wan2.1-I2V-14B-480P/ Wan2.1-I2V-14B-480P/
Wan2.1-I2V-14B-720P/ Wan2.1-I2V-14B-720P/
poetry.lock poetry.lock
logs/

View File

@ -25,8 +25,12 @@ import wan
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_image, cache_video, str2bool from wan.utils.utils import cache_image, cache_video, str2bool
from wan.utils.platform import get_torch_distributed_backend from wan.utils.platform import (
from wan.utils.chrono_inspector import ChronoInspector get_device_type,
get_torch_distributed_backend,
get_torch_profiler_activities,
)
EXAMPLE_PROMPT = { EXAMPLE_PROMPT = {
@ -252,6 +256,11 @@ def _parse_args():
type=float, type=float,
default=5.0, default=5.0,
help="Classifier free guidance scale.") help="Classifier free guidance scale.")
parser.add_argument(
"--profile",
action="store_true",
default=False,
help="profile the generating procedure.")
args = parser.parse_args() args = parser.parse_args()
@ -272,6 +281,26 @@ def _init_logging(rank):
logging.basicConfig(level=logging.ERROR) logging.basicConfig(level=logging.ERROR)
def _init_profiler():
return torch.profiler.profile(
activities=get_torch_profiler_activities(),
on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs'),
record_shapes=True,
profile_memory=True,
with_stack=True,
)
def _finalize_profiler(profiler):
profiler.stop()
table = profiler.key_averages().table(
sort_by=f"{get_device_type()}_time_total",
row_limit=20,
)
with open(f"logs/{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}-profiling.txt", "w") as f:
f.write(table)
def generate(args): def generate(args):
rank = int(os.getenv("RANK", 0)) rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1)) world_size = int(os.getenv("WORLD_SIZE", 1))
@ -339,6 +368,10 @@ def generate(args):
dist.broadcast_object_list(base_seed, src=0) dist.broadcast_object_list(base_seed, src=0)
args.base_seed = base_seed[0] args.base_seed = base_seed[0]
profiler = None
if args.profile and rank == 0:
profiler = _init_profiler()
if "t2v" in args.task or "t2i" in args.task: if "t2v" in args.task or "t2i" in args.task:
if args.prompt is None: if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
@ -366,7 +399,6 @@ def generate(args):
logging.info(f"Extended prompt: {args.prompt}") logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating WanT2V pipeline.") logging.info("Creating WanT2V pipeline.")
with ChronoInspector("Creating WanT2V pipeline"):
wan_t2v = wan.WanT2V( wan_t2v = wan.WanT2V(
config=cfg, config=cfg,
checkpoint_dir=args.ckpt_dir, checkpoint_dir=args.ckpt_dir,
@ -376,10 +408,11 @@ def generate(args):
dit_fsdp=args.dit_fsdp, dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1), use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu, t5_cpu=args.t5_cpu,
profiler=profiler,
) )
logging.info("Warming up WanT2V pipeline ...") logging.info("Warming up WanT2V pipeline ...")
with ChronoInspector("Warming up WanT2V pipeline"): with torch.no_grad():
_ = wan_t2v.generate( _ = wan_t2v.generate(
args.prompt, args.prompt,
size=SIZE_CONFIGS[args.size], size=SIZE_CONFIGS[args.size],
@ -392,7 +425,6 @@ def generate(args):
offload_model=args.offload_model) offload_model=args.offload_model)
logging.info(f"Generating {'image' if 't2i' in args.task else 'video'} ...") logging.info(f"Generating {'image' if 't2i' in args.task else 'video'} ...")
with ChronoInspector(f"Generating {'image' if 't2i' in args.task else 'video'}"):
video = wan_t2v.generate( video = wan_t2v.generate(
args.prompt, args.prompt,
size=SIZE_CONFIGS[args.size], size=SIZE_CONFIGS[args.size],
@ -437,7 +469,6 @@ def generate(args):
logging.info(f"Extended prompt: {args.prompt}") logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating WanI2V pipeline.") logging.info("Creating WanI2V pipeline.")
with ChronoInspector("Creating WanI2V pipeline"):
wan_i2v = wan.WanI2V( wan_i2v = wan.WanI2V(
config=cfg, config=cfg,
checkpoint_dir=args.ckpt_dir, checkpoint_dir=args.ckpt_dir,
@ -447,10 +478,11 @@ def generate(args):
dit_fsdp=args.dit_fsdp, dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1), use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu, t5_cpu=args.t5_cpu,
profiler=profiler,
) )
logging.info("Warming up WanI2V pipeline ...") logging.info("Warming up WanI2V pipeline ...")
with ChronoInspector("Warming up WanI2V pipeline"): with torch.no_grad():
_ = wan_i2v.generate( _ = wan_i2v.generate(
args.prompt, args.prompt,
img, img,
@ -464,7 +496,6 @@ def generate(args):
offload_model=args.offload_model) offload_model=args.offload_model)
logging.info("Generating video ...") logging.info("Generating video ...")
with ChronoInspector("Generating video"):
video = wan_i2v.generate( video = wan_i2v.generate(
args.prompt, args.prompt,
img, img,
@ -511,7 +542,6 @@ def generate(args):
logging.info(f"Extended prompt: {args.prompt}") logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating WanFLF2V pipeline.") logging.info("Creating WanFLF2V pipeline.")
with ChronoInspector("Creating WanFLF2V pipeline"):
wan_flf2v = wan.WanFLF2V( wan_flf2v = wan.WanFLF2V(
config=cfg, config=cfg,
checkpoint_dir=args.ckpt_dir, checkpoint_dir=args.ckpt_dir,
@ -521,10 +551,11 @@ def generate(args):
dit_fsdp=args.dit_fsdp, dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1), use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu, t5_cpu=args.t5_cpu,
profiler=profiler
) )
logging.info("Warming up WanFLF2V pipeline ...") logging.info("Warming up WanFLF2V pipeline ...")
with ChronoInspector("Warming up WanFLF2V pipeline"): with torch.no_grad():
_ = wan_flf2v.generate( _ = wan_flf2v.generate(
args.prompt, args.prompt,
first_frame, first_frame,
@ -539,7 +570,6 @@ def generate(args):
offload_model=args.offload_model) offload_model=args.offload_model)
logging.info("Generating video ...") logging.info("Generating video ...")
with ChronoInspector("Generating video"):
video = wan_flf2v.generate( video = wan_flf2v.generate(
args.prompt, args.prompt,
first_frame, first_frame,
@ -576,7 +606,6 @@ def generate(args):
logging.info(f"Extended prompt: {args.prompt}") logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating VACE pipeline.") logging.info("Creating VACE pipeline.")
with ChronoInspector("Creating VACE pipeline"):
wan_vace = wan.WanVace( wan_vace = wan.WanVace(
config=cfg, config=cfg,
checkpoint_dir=args.ckpt_dir, checkpoint_dir=args.ckpt_dir,
@ -586,6 +615,7 @@ def generate(args):
dit_fsdp=args.dit_fsdp, dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1), use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu, t5_cpu=args.t5_cpu,
profiler=profiler
) )
src_video, src_mask, src_ref_images = wan_vace.prepare_source( src_video, src_mask, src_ref_images = wan_vace.prepare_source(
@ -595,8 +625,8 @@ def generate(args):
], args.frame_num, SIZE_CONFIGS[args.size], device) ], args.frame_num, SIZE_CONFIGS[args.size], device)
logging.info("Warming up VACE pipeline ...") logging.info("Warming up VACE pipeline ...")
with ChronoInspector("Warming up VACE pipeline"): with torch.no_grad():
video = wan_vace.generate( _ = wan_vace.generate(
args.prompt, args.prompt,
src_video, src_video,
src_mask, src_mask,
@ -611,7 +641,6 @@ def generate(args):
offload_model=args.offload_model) offload_model=args.offload_model)
logging.info(f"Generating video...") logging.info(f"Generating video...")
with ChronoInspector("Generating video"):
video = wan_vace.generate( video = wan_vace.generate(
args.prompt, args.prompt,
src_video, src_video,
@ -628,6 +657,9 @@ def generate(args):
else: else:
raise ValueError(f"Unkown task type: {args.task}") raise ValueError(f"Unkown task type: {args.task}")
if args.profile and rank == 0:
_finalize_profiler(profiler)
if rank == 0: if rank == 0:
if args.save_file is None: if args.save_file is None:
formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S") formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")

View File

@ -6,6 +6,7 @@ import os
import random import random
import sys import sys
import types import types
from time import perf_counter
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
@ -54,6 +55,7 @@ class WanFLF2V:
use_usp=False, use_usp=False,
t5_cpu=False, t5_cpu=False,
init_on_cpu=True, init_on_cpu=True,
profiler=None,
): ):
r""" r"""
Initializes the image-to-video generation model components. Initializes the image-to-video generation model components.
@ -143,6 +145,7 @@ class WanFLF2V:
self.model.to(self.device) self.model.to(self.device)
self.sample_neg_prompt = config.sample_neg_prompt self.sample_neg_prompt = config.sample_neg_prompt
self.profiler = profiler
def generate(self, def generate(self,
input_prompt, input_prompt,
@ -197,6 +200,14 @@ class WanFLF2V:
- H: Frame height (from max_area) - H: Frame height (from max_area)
- W: Frame width from max_area) - W: Frame width from max_area)
""" """
if self.profiler and self.rank == 0:
self.profiler.start()
start_time = 0.0
end_time = 0.0
if self.rank == 0:
start_time = perf_counter()
first_frame_size = first_frame.size first_frame_size = first_frame.size
last_frame_size = last_frame.size last_frame_size = last_frame.size
first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).to( first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).to(
@ -289,6 +300,10 @@ class WanFLF2V:
])[0] ])[0]
y = torch.concat([msk, y]) y = torch.concat([msk, y])
if self.rank == 0:
end_time = perf_counter()
logging.info(f"[preprocess and VAE encode] Elapsed time: {end_time - start_time:.2f} seconds")
@contextmanager @contextmanager
def noop_no_sync(): def noop_no_sync():
yield yield
@ -339,8 +354,14 @@ class WanFLF2V:
if offload_model: if offload_model:
empty_cache() empty_cache()
if self.rank == 0:
start_time = perf_counter()
self.model.to(self.device) self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)): for _, t in enumerate(tqdm(timesteps)):
if self.profiler and self.rank == 0:
self.profiler.step()
latent_model_input = [latent.to(self.device)] latent_model_input = [latent.to(self.device)]
timestep = [t] timestep = [t]
@ -370,6 +391,10 @@ class WanFLF2V:
generator=seed_g)[0] generator=seed_g)[0]
latent = temp_x0.squeeze(0) latent = temp_x0.squeeze(0)
if self.rank == 0:
end_time = perf_counter()
logging.info(f"[sampling time steps] Elapsed time: {end_time - start_time:.2f} seconds")
x0 = [latent.to(self.device)] x0 = [latent.to(self.device)]
del latent_model_input, timestep del latent_model_input, timestep
@ -378,7 +403,10 @@ class WanFLF2V:
empty_cache() empty_cache()
if self.rank == 0: if self.rank == 0:
start_time = perf_counter()
videos = self.vae.decode(x0) videos = self.vae.decode(x0)
end_time = perf_counter()
logging.info(f"[VAE decoding] Elapsed time: {end_time - start_time:.2f} seconds")
del noise, latent del noise, latent
del sample_scheduler del sample_scheduler

View File

@ -6,6 +6,7 @@ import os
import random import random
import sys import sys
import types import types
from time import perf_counter
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
@ -54,6 +55,7 @@ class WanI2V:
use_usp=False, use_usp=False,
t5_cpu=False, t5_cpu=False,
init_on_cpu=True, init_on_cpu=True,
profiler=None,
): ):
r""" r"""
Initializes the image-to-video generation model components. Initializes the image-to-video generation model components.
@ -143,6 +145,7 @@ class WanI2V:
self.model.to(self.device) self.model.to(self.device)
self.sample_neg_prompt = config.sample_neg_prompt self.sample_neg_prompt = config.sample_neg_prompt
self.profiler = profiler
def generate(self, def generate(self,
input_prompt, input_prompt,
@ -192,6 +195,14 @@ class WanI2V:
- H: Frame height (from max_area) - H: Frame height (from max_area)
- W: Frame width from max_area) - W: Frame width from max_area)
""" """
if self.profiler and self.rank == 0:
self.profiler.start()
start_time = 0.0
end_time = 0.0
if self.rank == 0:
start_time = perf_counter()
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
F = frame_num F = frame_num
@ -262,6 +273,10 @@ class WanI2V:
])[0] ])[0]
y = torch.concat([msk, y]) y = torch.concat([msk, y])
if self.rank == 0:
end_time = perf_counter()
logging.info(f"[preprocess and VAE encode] Elapsed time: {end_time - start_time:.2f} seconds")
@contextmanager @contextmanager
def noop_no_sync(): def noop_no_sync():
yield yield
@ -312,8 +327,14 @@ class WanI2V:
if offload_model: if offload_model:
empty_cache() empty_cache()
if self.rank == 0:
start_time = perf_counter()
self.model.to(self.device) self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)): for _, t in enumerate(tqdm(timesteps)):
if self.profiler and self.rank == 0:
self.profiler.step()
latent_model_input = [latent.to(self.device)] latent_model_input = [latent.to(self.device)]
timestep = [t] timestep = [t]
@ -343,6 +364,10 @@ class WanI2V:
generator=seed_g)[0] generator=seed_g)[0]
latent = temp_x0.squeeze(0) latent = temp_x0.squeeze(0)
if self.rank == 0:
end_time = perf_counter()
logging.info(f"[sampling time steps] Elapsed time: {end_time - start_time:.2f} seconds")
x0 = [latent.to(self.device)] x0 = [latent.to(self.device)]
del latent_model_input, timestep del latent_model_input, timestep
@ -351,7 +376,10 @@ class WanI2V:
empty_cache() empty_cache()
if self.rank == 0: if self.rank == 0:
start_time = perf_counter()
videos = self.vae.decode(x0) videos = self.vae.decode(x0)
end_time = perf_counter()
logging.info(f"[VAE decoding] Elapsed time: {end_time - start_time:.2f} seconds")
del noise, latent del noise, latent
del sample_scheduler del sample_scheduler

View File

@ -16,7 +16,7 @@ try:
from wan.modules.attention import attention as flash_attention from wan.modules.attention import attention as flash_attention
torch.backends.mudnn.allow_tf32 = True torch.backends.mudnn.allow_tf32 = True
except ModuleNotFoundError: except ModuleNotFoundError:
pass torch_musa = None
__all__ = ['WanModel'] __all__ = ['WanModel']

View File

@ -6,6 +6,7 @@ import os
import random import random
import sys import sys
import types import types
from time import perf_counter
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
@ -50,6 +51,7 @@ class WanT2V:
dit_fsdp=False, dit_fsdp=False,
use_usp=False, use_usp=False,
t5_cpu=False, t5_cpu=False,
profiler=None,
): ):
r""" r"""
Initializes the Wan text-to-video generation model components. Initializes the Wan text-to-video generation model components.
@ -124,6 +126,7 @@ class WanT2V:
self.model.to(self.device) self.model.to(self.device)
self.sample_neg_prompt = config.sample_neg_prompt self.sample_neg_prompt = config.sample_neg_prompt
self.profiler = profiler
def generate(self, def generate(self,
input_prompt, input_prompt,
@ -169,6 +172,14 @@ class WanT2V:
- H: Frame height (from size) - H: Frame height (from size)
- W: Frame width from size) - W: Frame width from size)
""" """
if self.profiler and self.rank == 0:
self.profiler.start()
start_time = 0.0
end_time = 0.0
if self.rank == 0:
start_time = perf_counter()
# preprocess # preprocess
F = frame_num F = frame_num
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
@ -208,6 +219,10 @@ class WanT2V:
generator=seed_g) generator=seed_g)
] ]
if self.rank == 0:
end_time = perf_counter()
logging.info(f"[preprocess] Elapsed time: {end_time - start_time:.2f} seconds")
@contextmanager @contextmanager
def noop_no_sync(): def noop_no_sync():
yield yield
@ -244,8 +259,14 @@ class WanT2V:
arg_c = {'context': context, 'seq_len': seq_len} arg_c = {'context': context, 'seq_len': seq_len}
arg_null = {'context': context_null, 'seq_len': seq_len} arg_null = {'context': context_null, 'seq_len': seq_len}
if self.rank == 0:
start_time = perf_counter()
self.model.to(self.device) self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)): for _, t in enumerate(tqdm(timesteps)):
if self.profiler and self.rank == 0:
self.profiler.step()
latent_model_input = latents latent_model_input = latents
timestep = [t] timestep = [t]
@ -267,12 +288,19 @@ class WanT2V:
generator=seed_g)[0] generator=seed_g)[0]
latents = [temp_x0.squeeze(0)] latents = [temp_x0.squeeze(0)]
if self.rank == 0:
end_time = perf_counter()
logging.info(f"[sampling time steps] Elapsed time: {end_time - start_time:.2f} seconds")
x0 = latents x0 = latents
if offload_model: if offload_model:
self.model.cpu() self.model.cpu()
empty_cache() empty_cache()
if self.rank == 0: if self.rank == 0:
start_time = perf_counter()
videos = self.vae.decode(x0) videos = self.vae.decode(x0)
end_time = perf_counter()
logging.info(f"[VAE decoding] Elapsed time: {end_time - start_time:.2f} seconds")
del noise, latents del noise, latents
del sample_scheduler del sample_scheduler

View File

@ -1,4 +1,4 @@
from typing import Optional from typing import Optional, List
import torch import torch
@ -8,7 +8,7 @@ except ModuleNotFoundError:
torch_musa = None torch_musa = None
def _is_musa(): def _is_musa() -> bool:
if torch_musa is None: if torch_musa is None:
return False return False
else: else:
@ -48,3 +48,14 @@ def get_torch_distributed_backend() -> str:
return "mccl" return "mccl"
else: else:
raise NotImplementedError("No Accelerators(NV/MTT GPU) available") raise NotImplementedError("No Accelerators(NV/MTT GPU) available")
def get_torch_profiler_activities() -> List[torch.profiler.ProfilerActivity]:
activities: List[torch.profiler.ProfilerActivity] = [
torch.profiler.ProfilerActivity.CPU
]
if torch.cuda.is_available():
activities.append(torch.profiler.ProfilerActivity.CUDA)
elif _is_musa():
activities.append(torch.profiler.ProfilerActivity.MUSA)
return activities

View File

@ -8,6 +8,7 @@ import sys
import time import time
import traceback import traceback
import types import types
from time import perf_counter
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
@ -58,6 +59,7 @@ class WanVace(WanT2V):
dit_fsdp=False, dit_fsdp=False,
use_usp=False, use_usp=False,
t5_cpu=False, t5_cpu=False,
profiler=None,
): ):
r""" r"""
Initializes the Wan text-to-video generation model components. Initializes the Wan text-to-video generation model components.
@ -150,6 +152,8 @@ class WanVace(WanT2V):
seq_len=75600, seq_len=75600,
keep_last=True) keep_last=True)
self.profiler = profiler
def vace_encode_frames(self, frames, ref_images, masks=None, vae=None): def vace_encode_frames(self, frames, ref_images, masks=None, vae=None):
vae = self.vae if vae is None else vae vae = self.vae if vae is None else vae
if ref_images is None: if ref_images is None:
@ -354,6 +358,14 @@ class WanVace(WanT2V):
- H: Frame height (from size) - H: Frame height (from size)
- W: Frame width from size) - W: Frame width from size)
""" """
if self.profiler and self.rank == 0:
self.profiler.start()
start_time = 0.0
end_time = 0.0
if self.rank == 0:
start_time = perf_counter()
# preprocess # preprocess
# F = frame_num # F = frame_num
# target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, # target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
@ -404,6 +416,10 @@ class WanVace(WanT2V):
(self.patch_size[1] * self.patch_size[2]) * (self.patch_size[1] * self.patch_size[2]) *
target_shape[1] / self.sp_size) * self.sp_size target_shape[1] / self.sp_size) * self.sp_size
if self.rank == 0:
end_time = perf_counter()
logging.info(f"[preprocess] Elapsed time: {end_time - start_time:.2f} seconds")
@contextmanager @contextmanager
def noop_no_sync(): def noop_no_sync():
yield yield
@ -440,13 +456,19 @@ class WanVace(WanT2V):
arg_c = {'context': context, 'seq_len': seq_len} arg_c = {'context': context, 'seq_len': seq_len}
arg_null = {'context': context_null, 'seq_len': seq_len} arg_null = {'context': context_null, 'seq_len': seq_len}
if self.rank == 0:
start_time = perf_counter()
self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)): for _, t in enumerate(tqdm(timesteps)):
if self.profiler and self.rank == 0:
self.profiler.step()
latent_model_input = latents latent_model_input = latents
timestep = [t] timestep = [t]
timestep = torch.stack(timestep) timestep = torch.stack(timestep)
self.model.to(self.device)
noise_pred_cond = self.model( noise_pred_cond = self.model(
latent_model_input, latent_model_input,
t=timestep, t=timestep,
@ -471,12 +493,19 @@ class WanVace(WanT2V):
generator=seed_g)[0] generator=seed_g)[0]
latents = [temp_x0.squeeze(0)] latents = [temp_x0.squeeze(0)]
if self.rank == 0:
end_time = perf_counter()
logging.info(f"[sampling time steps] Elapsed time: {end_time - start_time:.2f} seconds")
x0 = latents x0 = latents
if offload_model: if offload_model:
self.model.cpu() self.model.cpu()
empty_cache() empty_cache()
if self.rank == 0: if self.rank == 0:
start_time = perf_counter()
videos = self.decode_latent(x0, input_ref_images) videos = self.decode_latent(x0, input_ref_images)
end_time = perf_counter()
logging.info(f"[VAE decoding] Elapsed time: {end_time - start_time:.2f} seconds")
del noise, latents del noise, latents
del sample_scheduler del sample_scheduler