From f3aaa593352511b7f90879207095cfb651eafcc4 Mon Sep 17 00:00:00 2001 From: Ayaan Momin Date: Fri, 25 Apr 2025 09:28:17 -0700 Subject: [PATCH] Latency metric storer --- generate.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/generate.py b/generate.py index 73f273e..45918eb 100644 --- a/generate.py +++ b/generate.py @@ -5,6 +5,8 @@ import logging import os import sys import warnings +import time +import json warnings.filterwarnings('ignore') @@ -206,6 +208,11 @@ def _parse_args(): type=float, default=5.0, help="Classifier free guidance scale.") + parser.add_argument( + "--perf_save_file", + type=str, + default=None, + help="The file to save performance metrics to.") args = parser.parse_args() @@ -233,6 +240,9 @@ def generate(args): device = local_rank _init_logging(rank) + # Start timing total execution + total_start_time = time.time() + if args.offload_model is None: args.offload_model = False if world_size > 1 else True logging.info( @@ -330,6 +340,7 @@ def generate(args): logging.info( f"Generating {'image' if 't2i' in args.task else 'video'} ...") + start_time = time.time() video = wan_t2v.generate( args.prompt, size=SIZE_CONFIGS[args.size], @@ -340,6 +351,7 @@ def generate(args): guide_scale=args.sample_guide_scale, seed=args.base_seed, offload_model=args.offload_model) + generation_time = time.time() - start_time elif "i2v" in args.task: if args.prompt is None: @@ -386,6 +398,7 @@ def generate(args): ) logging.info("Generating video ...") + start_time = time.time() video = wan_i2v.generate( args.prompt, img, @@ -397,6 +410,7 @@ def generate(args): guide_scale=args.sample_guide_scale, seed=args.base_seed, offload_model=args.offload_model) + generation_time = time.time() - start_time else: if args.prompt is None: args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] @@ -444,6 +458,7 @@ def generate(args): ) logging.info("Generating video ...") + start_time = time.time() video = wan_flf2v.generate( args.prompt, first_frame, @@ -457,6 +472,7 @@ def generate(args): seed=args.base_seed, offload_model=args.offload_model ) + generation_time = time.time() - start_time if rank == 0: if args.save_file is None: @@ -483,6 +499,29 @@ def generate(args): nrow=1, normalize=True, value_range=(-1, 1)) + + # Save performance metrics if requested + if args.perf_save_file is not None: + # Calculate total time including model loading + total_time = time.time() - total_start_time + + perf_data = { + "task": args.task, + "size": args.size, + "frame_num": args.frame_num, + "sample_steps": args.sample_steps, + "ulysses_size": args.ulysses_size, + "ring_size": args.ring_size, + "generation_time_seconds": generation_time, + "total_time_seconds": total_time + } + + with open(args.perf_save_file, 'w') as f: + json.dump(perf_data, f, indent=2) + logging.info(f"Saved performance metrics to {args.perf_save_file}") + logging.info(f"Generation time: {generation_time:.2f} seconds") + logging.info(f"Total time (including model loading): {total_time:.2f} seconds") + logging.info("Finished.")