mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-12-25 16:33:31 +00:00
Compare commits
2 Commits
e0facce6e9
...
a76ccc9d36
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a76ccc9d36 | ||
|
|
21df001d53 |
36
generate.py
36
generate.py
@ -282,15 +282,13 @@ def _init_logging(rank):
|
||||
|
||||
|
||||
def _init_profiler():
|
||||
profiler = torch.profiler.profile(
|
||||
return torch.profiler.profile(
|
||||
activities=get_torch_profiler_activities(),
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs'),
|
||||
record_shapes=True,
|
||||
profile_memory=True,
|
||||
with_stack=True,
|
||||
)
|
||||
profiler.start()
|
||||
return profiler
|
||||
|
||||
|
||||
def _finalize_profiler(profiler):
|
||||
@ -299,10 +297,8 @@ def _finalize_profiler(profiler):
|
||||
sort_by=f"{get_device_type()}_time_total",
|
||||
row_limit=20,
|
||||
)
|
||||
file_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
with open(f"logs/profiling-{file_name}.txt", "w") as f:
|
||||
with open(f"logs/{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}-profiling.txt", "w") as f:
|
||||
f.write(table)
|
||||
del file_name
|
||||
|
||||
|
||||
def generate(args):
|
||||
@ -674,21 +670,23 @@ def generate(args):
|
||||
|
||||
if "t2i" in args.task:
|
||||
logging.info(f"Saving generated image to {args.save_file}")
|
||||
cache_image(
|
||||
tensor=video.squeeze(1)[None],
|
||||
save_file=args.save_file,
|
||||
nrow=1,
|
||||
normalize=True,
|
||||
value_range=(-1, 1))
|
||||
with ChronoInspector("Saving generated image"):
|
||||
cache_image(
|
||||
tensor=video.squeeze(1)[None],
|
||||
save_file=args.save_file,
|
||||
nrow=1,
|
||||
normalize=True,
|
||||
value_range=(-1, 1))
|
||||
else:
|
||||
logging.info(f"Saving generated video to {args.save_file}")
|
||||
cache_video(
|
||||
tensor=video[None],
|
||||
save_file=args.save_file,
|
||||
fps=cfg.sample_fps,
|
||||
nrow=1,
|
||||
normalize=True,
|
||||
value_range=(-1, 1))
|
||||
with ChronoInspector("Saving generated video"):
|
||||
cache_video(
|
||||
tensor=video[None],
|
||||
save_file=args.save_file,
|
||||
fps=cfg.sample_fps,
|
||||
nrow=1,
|
||||
normalize=True,
|
||||
value_range=(-1, 1))
|
||||
logging.info("Finished.")
|
||||
|
||||
|
||||
|
||||
@ -200,6 +200,9 @@ class WanFLF2V:
|
||||
- H: Frame height (from max_area)
|
||||
- W: Frame width from max_area)
|
||||
"""
|
||||
if self.profiler and self.rank == 0:
|
||||
self.profiler.start()
|
||||
|
||||
start_time = 0.0
|
||||
end_time = 0.0
|
||||
if self.rank == 0:
|
||||
|
||||
@ -195,6 +195,9 @@ class WanI2V:
|
||||
- H: Frame height (from max_area)
|
||||
- W: Frame width from max_area)
|
||||
"""
|
||||
if self.profiler and self.rank == 0:
|
||||
self.profiler.start()
|
||||
|
||||
start_time = 0.0
|
||||
end_time = 0.0
|
||||
if self.rank == 0:
|
||||
|
||||
@ -172,6 +172,9 @@ class WanT2V:
|
||||
- H: Frame height (from size)
|
||||
- W: Frame width from size)
|
||||
"""
|
||||
if self.profiler and self.rank == 0:
|
||||
self.profiler.start()
|
||||
|
||||
start_time = 0.0
|
||||
end_time = 0.0
|
||||
if self.rank == 0:
|
||||
|
||||
@ -358,6 +358,9 @@ class WanVace(WanT2V):
|
||||
- H: Frame height (from size)
|
||||
- W: Frame width from size)
|
||||
"""
|
||||
if self.profiler and self.rank == 0:
|
||||
self.profiler.start()
|
||||
|
||||
start_time = 0.0
|
||||
end_time = 0.0
|
||||
if self.rank == 0:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user