Compare commits

..

2 Commits

Author SHA1 Message Date
Houchen Li
e0facce6e9
Merge 6d7fc288d8 into 7c81b2f27d 2025-08-11 14:23:37 +08:00
Houchen Li
6d7fc288d8 [feature] adapt for Moore Threads GPU family 2025-08-11 14:21:31 +08:00
5 changed files with 19 additions and 29 deletions

View File

@ -282,13 +282,15 @@ def _init_logging(rank):
def _init_profiler(): def _init_profiler():
return torch.profiler.profile( profiler = torch.profiler.profile(
activities=get_torch_profiler_activities(), activities=get_torch_profiler_activities(),
on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs'), on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs'),
record_shapes=True, record_shapes=True,
profile_memory=True, profile_memory=True,
with_stack=True, with_stack=True,
) )
profiler.start()
return profiler
def _finalize_profiler(profiler): def _finalize_profiler(profiler):
@ -297,8 +299,10 @@ def _finalize_profiler(profiler):
sort_by=f"{get_device_type()}_time_total", sort_by=f"{get_device_type()}_time_total",
row_limit=20, row_limit=20,
) )
with open(f"logs/{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}-profiling.txt", "w") as f: file_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
with open(f"logs/profiling-{file_name}.txt", "w") as f:
f.write(table) f.write(table)
del file_name
def generate(args): def generate(args):
@ -670,23 +674,21 @@ def generate(args):
if "t2i" in args.task: if "t2i" in args.task:
logging.info(f"Saving generated image to {args.save_file}") logging.info(f"Saving generated image to {args.save_file}")
with ChronoInspector("Saving generated image"): cache_image(
cache_image( tensor=video.squeeze(1)[None],
tensor=video.squeeze(1)[None], save_file=args.save_file,
save_file=args.save_file, nrow=1,
nrow=1, normalize=True,
normalize=True, value_range=(-1, 1))
value_range=(-1, 1))
else: else:
logging.info(f"Saving generated video to {args.save_file}") logging.info(f"Saving generated video to {args.save_file}")
with ChronoInspector("Saving generated video"): cache_video(
cache_video( tensor=video[None],
tensor=video[None], save_file=args.save_file,
save_file=args.save_file, fps=cfg.sample_fps,
fps=cfg.sample_fps, nrow=1,
nrow=1, normalize=True,
normalize=True, value_range=(-1, 1))
value_range=(-1, 1))
logging.info("Finished.") logging.info("Finished.")

View File

@ -200,9 +200,6 @@ class WanFLF2V:
- H: Frame height (from max_area) - H: Frame height (from max_area)
- W: Frame width from max_area) - W: Frame width from max_area)
""" """
if self.profiler and self.rank == 0:
self.profiler.start()
start_time = 0.0 start_time = 0.0
end_time = 0.0 end_time = 0.0
if self.rank == 0: if self.rank == 0:

View File

@ -195,9 +195,6 @@ class WanI2V:
- H: Frame height (from max_area) - H: Frame height (from max_area)
- W: Frame width from max_area) - W: Frame width from max_area)
""" """
if self.profiler and self.rank == 0:
self.profiler.start()
start_time = 0.0 start_time = 0.0
end_time = 0.0 end_time = 0.0
if self.rank == 0: if self.rank == 0:

View File

@ -172,9 +172,6 @@ class WanT2V:
- H: Frame height (from size) - H: Frame height (from size)
- W: Frame width from size) - W: Frame width from size)
""" """
if self.profiler and self.rank == 0:
self.profiler.start()
start_time = 0.0 start_time = 0.0
end_time = 0.0 end_time = 0.0
if self.rank == 0: if self.rank == 0:

View File

@ -358,9 +358,6 @@ class WanVace(WanT2V):
- H: Frame height (from size) - H: Frame height (from size)
- W: Frame width from size) - W: Frame width from size)
""" """
if self.profiler and self.rank == 0:
self.profiler.start()
start_time = 0.0 start_time = 0.0
end_time = 0.0 end_time = 0.0
if self.rank == 0: if self.rank == 0: