mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-06-07 15:54:52 +00:00
Latency metric storer
This commit is contained in:
parent
204f899b64
commit
f3aaa59335
39
generate.py
39
generate.py
@ -5,6 +5,8 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
|
||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
@ -206,6 +208,11 @@ def _parse_args():
|
|||||||
type=float,
|
type=float,
|
||||||
default=5.0,
|
default=5.0,
|
||||||
help="Classifier free guidance scale.")
|
help="Classifier free guidance scale.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--perf_save_file",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The file to save performance metrics to.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -233,6 +240,9 @@ def generate(args):
|
|||||||
device = local_rank
|
device = local_rank
|
||||||
_init_logging(rank)
|
_init_logging(rank)
|
||||||
|
|
||||||
|
# Start timing total execution
|
||||||
|
total_start_time = time.time()
|
||||||
|
|
||||||
if args.offload_model is None:
|
if args.offload_model is None:
|
||||||
args.offload_model = False if world_size > 1 else True
|
args.offload_model = False if world_size > 1 else True
|
||||||
logging.info(
|
logging.info(
|
||||||
@ -330,6 +340,7 @@ def generate(args):
|
|||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Generating {'image' if 't2i' in args.task else 'video'} ...")
|
f"Generating {'image' if 't2i' in args.task else 'video'} ...")
|
||||||
|
start_time = time.time()
|
||||||
video = wan_t2v.generate(
|
video = wan_t2v.generate(
|
||||||
args.prompt,
|
args.prompt,
|
||||||
size=SIZE_CONFIGS[args.size],
|
size=SIZE_CONFIGS[args.size],
|
||||||
@ -340,6 +351,7 @@ def generate(args):
|
|||||||
guide_scale=args.sample_guide_scale,
|
guide_scale=args.sample_guide_scale,
|
||||||
seed=args.base_seed,
|
seed=args.base_seed,
|
||||||
offload_model=args.offload_model)
|
offload_model=args.offload_model)
|
||||||
|
generation_time = time.time() - start_time
|
||||||
|
|
||||||
elif "i2v" in args.task:
|
elif "i2v" in args.task:
|
||||||
if args.prompt is None:
|
if args.prompt is None:
|
||||||
@ -386,6 +398,7 @@ def generate(args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Generating video ...")
|
logging.info("Generating video ...")
|
||||||
|
start_time = time.time()
|
||||||
video = wan_i2v.generate(
|
video = wan_i2v.generate(
|
||||||
args.prompt,
|
args.prompt,
|
||||||
img,
|
img,
|
||||||
@ -397,6 +410,7 @@ def generate(args):
|
|||||||
guide_scale=args.sample_guide_scale,
|
guide_scale=args.sample_guide_scale,
|
||||||
seed=args.base_seed,
|
seed=args.base_seed,
|
||||||
offload_model=args.offload_model)
|
offload_model=args.offload_model)
|
||||||
|
generation_time = time.time() - start_time
|
||||||
else:
|
else:
|
||||||
if args.prompt is None:
|
if args.prompt is None:
|
||||||
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
||||||
@ -444,6 +458,7 @@ def generate(args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Generating video ...")
|
logging.info("Generating video ...")
|
||||||
|
start_time = time.time()
|
||||||
video = wan_flf2v.generate(
|
video = wan_flf2v.generate(
|
||||||
args.prompt,
|
args.prompt,
|
||||||
first_frame,
|
first_frame,
|
||||||
@ -457,6 +472,7 @@ def generate(args):
|
|||||||
seed=args.base_seed,
|
seed=args.base_seed,
|
||||||
offload_model=args.offload_model
|
offload_model=args.offload_model
|
||||||
)
|
)
|
||||||
|
generation_time = time.time() - start_time
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
if args.save_file is None:
|
if args.save_file is None:
|
||||||
@ -483,6 +499,29 @@ def generate(args):
|
|||||||
nrow=1,
|
nrow=1,
|
||||||
normalize=True,
|
normalize=True,
|
||||||
value_range=(-1, 1))
|
value_range=(-1, 1))
|
||||||
|
|
||||||
|
# Save performance metrics if requested
|
||||||
|
if args.perf_save_file is not None:
|
||||||
|
# Calculate total time including model loading
|
||||||
|
total_time = time.time() - total_start_time
|
||||||
|
|
||||||
|
perf_data = {
|
||||||
|
"task": args.task,
|
||||||
|
"size": args.size,
|
||||||
|
"frame_num": args.frame_num,
|
||||||
|
"sample_steps": args.sample_steps,
|
||||||
|
"ulysses_size": args.ulysses_size,
|
||||||
|
"ring_size": args.ring_size,
|
||||||
|
"generation_time_seconds": generation_time,
|
||||||
|
"total_time_seconds": total_time
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(args.perf_save_file, 'w') as f:
|
||||||
|
json.dump(perf_data, f, indent=2)
|
||||||
|
logging.info(f"Saved performance metrics to {args.perf_save_file}")
|
||||||
|
logging.info(f"Generation time: {generation_time:.2f} seconds")
|
||||||
|
logging.info(f"Total time (including model loading): {total_time:.2f} seconds")
|
||||||
|
|
||||||
logging.info("Finished.")
|
logging.info("Finished.")
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user