Latency metric storer

This commit is contained in:
Ayaan Momin 2025-04-25 09:28:17 -07:00
parent 204f899b64
commit f3aaa59335

View File

@ -5,6 +5,8 @@ import logging
import os
import sys
import warnings
import time
import json
warnings.filterwarnings('ignore')
@ -206,6 +208,11 @@ def _parse_args():
type=float,
default=5.0,
help="Classifier free guidance scale.")
parser.add_argument(
"--perf_save_file",
type=str,
default=None,
help="The file to save performance metrics to.")
args = parser.parse_args()
@ -233,6 +240,9 @@ def generate(args):
device = local_rank
_init_logging(rank)
# Start timing total execution
total_start_time = time.time()
if args.offload_model is None:
args.offload_model = False if world_size > 1 else True
logging.info(
@ -330,6 +340,7 @@ def generate(args):
logging.info(
f"Generating {'image' if 't2i' in args.task else 'video'} ...")
start_time = time.time()
video = wan_t2v.generate(
args.prompt,
size=SIZE_CONFIGS[args.size],
@ -340,6 +351,7 @@ def generate(args):
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
generation_time = time.time() - start_time
elif "i2v" in args.task:
if args.prompt is None:
@ -386,6 +398,7 @@ def generate(args):
)
logging.info("Generating video ...")
start_time = time.time()
video = wan_i2v.generate(
args.prompt,
img,
@ -397,6 +410,7 @@ def generate(args):
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
generation_time = time.time() - start_time
else:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
@ -444,6 +458,7 @@ def generate(args):
)
logging.info("Generating video ...")
start_time = time.time()
video = wan_flf2v.generate(
args.prompt,
first_frame,
@ -457,6 +472,7 @@ def generate(args):
seed=args.base_seed,
offload_model=args.offload_model
)
generation_time = time.time() - start_time
if rank == 0:
if args.save_file is None:
@ -483,6 +499,29 @@ def generate(args):
nrow=1,
normalize=True,
value_range=(-1, 1))
# Save performance metrics if requested
if args.perf_save_file is not None:
# Calculate total time including model loading
total_time = time.time() - total_start_time
perf_data = {
"task": args.task,
"size": args.size,
"frame_num": args.frame_num,
"sample_steps": args.sample_steps,
"ulysses_size": args.ulysses_size,
"ring_size": args.ring_size,
"generation_time_seconds": generation_time,
"total_time_seconds": total_time
}
with open(args.perf_save_file, 'w') as f:
json.dump(perf_data, f, indent=2)
logging.info(f"Saved performance metrics to {args.perf_save_file}")
logging.info(f"Generation time: {generation_time:.2f} seconds")
logging.info(f"Total time (including model loading): {total_time:.2f} seconds")
logging.info("Finished.")