mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-12-25 16:33:31 +00:00
Compare commits
2 Commits
a76ccc9d36
...
e0facce6e9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e0facce6e9 | ||
|
|
6d7fc288d8 |
36
generate.py
36
generate.py
@ -282,13 +282,15 @@ def _init_logging(rank):
|
|||||||
|
|
||||||
|
|
||||||
def _init_profiler():
|
def _init_profiler():
|
||||||
return torch.profiler.profile(
|
profiler = torch.profiler.profile(
|
||||||
activities=get_torch_profiler_activities(),
|
activities=get_torch_profiler_activities(),
|
||||||
on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs'),
|
on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs'),
|
||||||
record_shapes=True,
|
record_shapes=True,
|
||||||
profile_memory=True,
|
profile_memory=True,
|
||||||
with_stack=True,
|
with_stack=True,
|
||||||
)
|
)
|
||||||
|
profiler.start()
|
||||||
|
return profiler
|
||||||
|
|
||||||
|
|
||||||
def _finalize_profiler(profiler):
|
def _finalize_profiler(profiler):
|
||||||
@ -297,8 +299,10 @@ def _finalize_profiler(profiler):
|
|||||||
sort_by=f"{get_device_type()}_time_total",
|
sort_by=f"{get_device_type()}_time_total",
|
||||||
row_limit=20,
|
row_limit=20,
|
||||||
)
|
)
|
||||||
with open(f"logs/{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}-profiling.txt", "w") as f:
|
file_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||||
|
with open(f"logs/profiling-{file_name}.txt", "w") as f:
|
||||||
f.write(table)
|
f.write(table)
|
||||||
|
del file_name
|
||||||
|
|
||||||
|
|
||||||
def generate(args):
|
def generate(args):
|
||||||
@ -670,23 +674,21 @@ def generate(args):
|
|||||||
|
|
||||||
if "t2i" in args.task:
|
if "t2i" in args.task:
|
||||||
logging.info(f"Saving generated image to {args.save_file}")
|
logging.info(f"Saving generated image to {args.save_file}")
|
||||||
with ChronoInspector("Saving generated image"):
|
cache_image(
|
||||||
cache_image(
|
tensor=video.squeeze(1)[None],
|
||||||
tensor=video.squeeze(1)[None],
|
save_file=args.save_file,
|
||||||
save_file=args.save_file,
|
nrow=1,
|
||||||
nrow=1,
|
normalize=True,
|
||||||
normalize=True,
|
value_range=(-1, 1))
|
||||||
value_range=(-1, 1))
|
|
||||||
else:
|
else:
|
||||||
logging.info(f"Saving generated video to {args.save_file}")
|
logging.info(f"Saving generated video to {args.save_file}")
|
||||||
with ChronoInspector("Saving generated video"):
|
cache_video(
|
||||||
cache_video(
|
tensor=video[None],
|
||||||
tensor=video[None],
|
save_file=args.save_file,
|
||||||
save_file=args.save_file,
|
fps=cfg.sample_fps,
|
||||||
fps=cfg.sample_fps,
|
nrow=1,
|
||||||
nrow=1,
|
normalize=True,
|
||||||
normalize=True,
|
value_range=(-1, 1))
|
||||||
value_range=(-1, 1))
|
|
||||||
logging.info("Finished.")
|
logging.info("Finished.")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -200,9 +200,6 @@ class WanFLF2V:
|
|||||||
- H: Frame height (from max_area)
|
- H: Frame height (from max_area)
|
||||||
- W: Frame width from max_area)
|
- W: Frame width from max_area)
|
||||||
"""
|
"""
|
||||||
if self.profiler and self.rank == 0:
|
|
||||||
self.profiler.start()
|
|
||||||
|
|
||||||
start_time = 0.0
|
start_time = 0.0
|
||||||
end_time = 0.0
|
end_time = 0.0
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
|
|||||||
@ -195,9 +195,6 @@ class WanI2V:
|
|||||||
- H: Frame height (from max_area)
|
- H: Frame height (from max_area)
|
||||||
- W: Frame width from max_area)
|
- W: Frame width from max_area)
|
||||||
"""
|
"""
|
||||||
if self.profiler and self.rank == 0:
|
|
||||||
self.profiler.start()
|
|
||||||
|
|
||||||
start_time = 0.0
|
start_time = 0.0
|
||||||
end_time = 0.0
|
end_time = 0.0
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
|
|||||||
@ -172,9 +172,6 @@ class WanT2V:
|
|||||||
- H: Frame height (from size)
|
- H: Frame height (from size)
|
||||||
- W: Frame width from size)
|
- W: Frame width from size)
|
||||||
"""
|
"""
|
||||||
if self.profiler and self.rank == 0:
|
|
||||||
self.profiler.start()
|
|
||||||
|
|
||||||
start_time = 0.0
|
start_time = 0.0
|
||||||
end_time = 0.0
|
end_time = 0.0
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
|
|||||||
@ -358,9 +358,6 @@ class WanVace(WanT2V):
|
|||||||
- H: Frame height (from size)
|
- H: Frame height (from size)
|
||||||
- W: Frame width from size)
|
- W: Frame width from size)
|
||||||
"""
|
"""
|
||||||
if self.profiler and self.rank == 0:
|
|
||||||
self.profiler.start()
|
|
||||||
|
|
||||||
start_time = 0.0
|
start_time = 0.0
|
||||||
end_time = 0.0
|
end_time = 0.0
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user