mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-12-26 00:43:31 +00:00
Compare commits
2 Commits
e0facce6e9
...
a76ccc9d36
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a76ccc9d36 | ||
|
|
21df001d53 |
36
generate.py
36
generate.py
@ -282,15 +282,13 @@ def _init_logging(rank):
|
|||||||
|
|
||||||
|
|
||||||
def _init_profiler():
|
def _init_profiler():
|
||||||
profiler = torch.profiler.profile(
|
return torch.profiler.profile(
|
||||||
activities=get_torch_profiler_activities(),
|
activities=get_torch_profiler_activities(),
|
||||||
on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs'),
|
on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs'),
|
||||||
record_shapes=True,
|
record_shapes=True,
|
||||||
profile_memory=True,
|
profile_memory=True,
|
||||||
with_stack=True,
|
with_stack=True,
|
||||||
)
|
)
|
||||||
profiler.start()
|
|
||||||
return profiler
|
|
||||||
|
|
||||||
|
|
||||||
def _finalize_profiler(profiler):
|
def _finalize_profiler(profiler):
|
||||||
@ -299,10 +297,8 @@ def _finalize_profiler(profiler):
|
|||||||
sort_by=f"{get_device_type()}_time_total",
|
sort_by=f"{get_device_type()}_time_total",
|
||||||
row_limit=20,
|
row_limit=20,
|
||||||
)
|
)
|
||||||
file_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
with open(f"logs/{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}-profiling.txt", "w") as f:
|
||||||
with open(f"logs/profiling-{file_name}.txt", "w") as f:
|
|
||||||
f.write(table)
|
f.write(table)
|
||||||
del file_name
|
|
||||||
|
|
||||||
|
|
||||||
def generate(args):
|
def generate(args):
|
||||||
@ -674,21 +670,23 @@ def generate(args):
|
|||||||
|
|
||||||
if "t2i" in args.task:
|
if "t2i" in args.task:
|
||||||
logging.info(f"Saving generated image to {args.save_file}")
|
logging.info(f"Saving generated image to {args.save_file}")
|
||||||
cache_image(
|
with ChronoInspector("Saving generated image"):
|
||||||
tensor=video.squeeze(1)[None],
|
cache_image(
|
||||||
save_file=args.save_file,
|
tensor=video.squeeze(1)[None],
|
||||||
nrow=1,
|
save_file=args.save_file,
|
||||||
normalize=True,
|
nrow=1,
|
||||||
value_range=(-1, 1))
|
normalize=True,
|
||||||
|
value_range=(-1, 1))
|
||||||
else:
|
else:
|
||||||
logging.info(f"Saving generated video to {args.save_file}")
|
logging.info(f"Saving generated video to {args.save_file}")
|
||||||
cache_video(
|
with ChronoInspector("Saving generated video"):
|
||||||
tensor=video[None],
|
cache_video(
|
||||||
save_file=args.save_file,
|
tensor=video[None],
|
||||||
fps=cfg.sample_fps,
|
save_file=args.save_file,
|
||||||
nrow=1,
|
fps=cfg.sample_fps,
|
||||||
normalize=True,
|
nrow=1,
|
||||||
value_range=(-1, 1))
|
normalize=True,
|
||||||
|
value_range=(-1, 1))
|
||||||
logging.info("Finished.")
|
logging.info("Finished.")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -200,6 +200,9 @@ class WanFLF2V:
|
|||||||
- H: Frame height (from max_area)
|
- H: Frame height (from max_area)
|
||||||
- W: Frame width from max_area)
|
- W: Frame width from max_area)
|
||||||
"""
|
"""
|
||||||
|
if self.profiler and self.rank == 0:
|
||||||
|
self.profiler.start()
|
||||||
|
|
||||||
start_time = 0.0
|
start_time = 0.0
|
||||||
end_time = 0.0
|
end_time = 0.0
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
|
|||||||
@ -195,6 +195,9 @@ class WanI2V:
|
|||||||
- H: Frame height (from max_area)
|
- H: Frame height (from max_area)
|
||||||
- W: Frame width from max_area)
|
- W: Frame width from max_area)
|
||||||
"""
|
"""
|
||||||
|
if self.profiler and self.rank == 0:
|
||||||
|
self.profiler.start()
|
||||||
|
|
||||||
start_time = 0.0
|
start_time = 0.0
|
||||||
end_time = 0.0
|
end_time = 0.0
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
|
|||||||
@ -172,6 +172,9 @@ class WanT2V:
|
|||||||
- H: Frame height (from size)
|
- H: Frame height (from size)
|
||||||
- W: Frame width from size)
|
- W: Frame width from size)
|
||||||
"""
|
"""
|
||||||
|
if self.profiler and self.rank == 0:
|
||||||
|
self.profiler.start()
|
||||||
|
|
||||||
start_time = 0.0
|
start_time = 0.0
|
||||||
end_time = 0.0
|
end_time = 0.0
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
|
|||||||
@ -358,6 +358,9 @@ class WanVace(WanT2V):
|
|||||||
- H: Frame height (from size)
|
- H: Frame height (from size)
|
||||||
- W: Frame width from size)
|
- W: Frame width from size)
|
||||||
"""
|
"""
|
||||||
|
if self.profiler and self.rank == 0:
|
||||||
|
self.profiler.start()
|
||||||
|
|
||||||
start_time = 0.0
|
start_time = 0.0
|
||||||
end_time = 0.0
|
end_time = 0.0
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user