Compare commits

..

2 Commits

Author SHA1 Message Date
Houchen Li
a76ccc9d36
Merge 21df001d53 into 7c81b2f27d 2025-08-08 17:55:51 +08:00
Houchen Li
21df001d53 [feature] adapt for Moore Threads GPU family 2025-08-08 17:08:19 +08:00
5 changed files with 29 additions and 19 deletions

View File

@ -282,15 +282,13 @@ def _init_logging(rank):
def _init_profiler(): def _init_profiler():
profiler = torch.profiler.profile( return torch.profiler.profile(
activities=get_torch_profiler_activities(), activities=get_torch_profiler_activities(),
on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs'), on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs'),
record_shapes=True, record_shapes=True,
profile_memory=True, profile_memory=True,
with_stack=True, with_stack=True,
) )
profiler.start()
return profiler
def _finalize_profiler(profiler): def _finalize_profiler(profiler):
@ -299,10 +297,8 @@ def _finalize_profiler(profiler):
sort_by=f"{get_device_type()}_time_total", sort_by=f"{get_device_type()}_time_total",
row_limit=20, row_limit=20,
) )
file_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") with open(f"logs/{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}-profiling.txt", "w") as f:
with open(f"logs/profiling-{file_name}.txt", "w") as f:
f.write(table) f.write(table)
del file_name
def generate(args): def generate(args):
@ -674,6 +670,7 @@ def generate(args):
if "t2i" in args.task: if "t2i" in args.task:
logging.info(f"Saving generated image to {args.save_file}") logging.info(f"Saving generated image to {args.save_file}")
with ChronoInspector("Saving generated image"):
cache_image( cache_image(
tensor=video.squeeze(1)[None], tensor=video.squeeze(1)[None],
save_file=args.save_file, save_file=args.save_file,
@ -682,6 +679,7 @@ def generate(args):
value_range=(-1, 1)) value_range=(-1, 1))
else: else:
logging.info(f"Saving generated video to {args.save_file}") logging.info(f"Saving generated video to {args.save_file}")
with ChronoInspector("Saving generated video"):
cache_video( cache_video(
tensor=video[None], tensor=video[None],
save_file=args.save_file, save_file=args.save_file,

View File

@ -200,6 +200,9 @@ class WanFLF2V:
- H: Frame height (from max_area) - H: Frame height (from max_area)
- W: Frame width from max_area) - W: Frame width from max_area)
""" """
if self.profiler and self.rank == 0:
self.profiler.start()
start_time = 0.0 start_time = 0.0
end_time = 0.0 end_time = 0.0
if self.rank == 0: if self.rank == 0:

View File

@ -195,6 +195,9 @@ class WanI2V:
- H: Frame height (from max_area) - H: Frame height (from max_area)
- W: Frame width from max_area) - W: Frame width from max_area)
""" """
if self.profiler and self.rank == 0:
self.profiler.start()
start_time = 0.0 start_time = 0.0
end_time = 0.0 end_time = 0.0
if self.rank == 0: if self.rank == 0:

View File

@ -172,6 +172,9 @@ class WanT2V:
- H: Frame height (from size) - H: Frame height (from size)
- W: Frame width from size) - W: Frame width from size)
""" """
if self.profiler and self.rank == 0:
self.profiler.start()
start_time = 0.0 start_time = 0.0
end_time = 0.0 end_time = 0.0
if self.rank == 0: if self.rank == 0:

View File

@ -358,6 +358,9 @@ class WanVace(WanT2V):
- H: Frame height (from size) - H: Frame height (from size)
- W: Frame width from size) - W: Frame width from size)
""" """
if self.profiler and self.rank == 0:
self.profiler.start()
start_time = 0.0 start_time = 0.0
end_time = 0.0 end_time = 0.0
if self.rank == 0: if self.rank == 0: