Compare commits

..

2 Commits

Author SHA1 Message Date
Houchen Li
a76ccc9d36
Merge 21df001d53 into 7c81b2f27d 2025-08-08 17:55:51 +08:00
Houchen Li
21df001d53 [feature] adapt for Moore Threads GPU family 2025-08-08 17:08:19 +08:00
5 changed files with 29 additions and 19 deletions

View File

@ -282,15 +282,13 @@ def _init_logging(rank):
def _init_profiler():
profiler = torch.profiler.profile(
return torch.profiler.profile(
activities=get_torch_profiler_activities(),
on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs'),
record_shapes=True,
profile_memory=True,
with_stack=True,
)
profiler.start()
return profiler
def _finalize_profiler(profiler):
@ -299,10 +297,8 @@ def _finalize_profiler(profiler):
sort_by=f"{get_device_type()}_time_total",
row_limit=20,
)
file_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
with open(f"logs/profiling-{file_name}.txt", "w") as f:
with open(f"logs/{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}-profiling.txt", "w") as f:
f.write(table)
del file_name
def generate(args):
@ -674,21 +670,23 @@ def generate(args):
if "t2i" in args.task:
logging.info(f"Saving generated image to {args.save_file}")
cache_image(
tensor=video.squeeze(1)[None],
save_file=args.save_file,
nrow=1,
normalize=True,
value_range=(-1, 1))
with ChronoInspector("Saving generated image"):
cache_image(
tensor=video.squeeze(1)[None],
save_file=args.save_file,
nrow=1,
normalize=True,
value_range=(-1, 1))
else:
logging.info(f"Saving generated video to {args.save_file}")
cache_video(
tensor=video[None],
save_file=args.save_file,
fps=cfg.sample_fps,
nrow=1,
normalize=True,
value_range=(-1, 1))
with ChronoInspector("Saving generated video"):
cache_video(
tensor=video[None],
save_file=args.save_file,
fps=cfg.sample_fps,
nrow=1,
normalize=True,
value_range=(-1, 1))
logging.info("Finished.")

View File

@ -200,6 +200,9 @@ class WanFLF2V:
- H: Frame height (from max_area)
- W: Frame width from max_area)
"""
if self.profiler and self.rank == 0:
self.profiler.start()
start_time = 0.0
end_time = 0.0
if self.rank == 0:

View File

@ -195,6 +195,9 @@ class WanI2V:
- H: Frame height (from max_area)
- W: Frame width from max_area)
"""
if self.profiler and self.rank == 0:
self.profiler.start()
start_time = 0.0
end_time = 0.0
if self.rank == 0:

View File

@ -172,6 +172,9 @@ class WanT2V:
- H: Frame height (from size)
- W: Frame width from size)
"""
if self.profiler and self.rank == 0:
self.profiler.start()
start_time = 0.0
end_time = 0.0
if self.rank == 0:

View File

@ -358,6 +358,9 @@ class WanVace(WanT2V):
- H: Frame height (from size)
- W: Frame width from size)
"""
if self.profiler and self.rank == 0:
self.profiler.start()
start_time = 0.0
end_time = 0.0
if self.rank == 0: