logging the generate time

This commit is contained in:
songruizz 2025-06-10 20:59:37 +08:00
parent feba2a62c8
commit 0492d77199

View File

@ -3,6 +3,7 @@ import argparse
import logging
import os
import sys
import time
import warnings
from tqdm import tqdm
from datetime import datetime
@ -977,6 +978,7 @@ def generate(args):
logging.info(
f"Generating {'image' if 't2i' in args.task else 'video'} ...")
start = time.time()
video = wan_t2v.generate(
args.prompt,
size=SIZE_CONFIGS[args.size],
@ -987,6 +989,7 @@ def generate(args):
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
logging.info(f"Generate done, elapsed {round(time.time() - start, 2)} s")
elif "i2v" in args.task:
if args.prompt is None:
@ -1066,6 +1069,7 @@ def generate(args):
wan_i2v.model.__class__.cutoff_steps = args.sample_steps*2 - 2
logging.info("Generating video ...")
start = time.time()
video = wan_i2v.generate(
args.prompt,
img,
@ -1077,6 +1081,7 @@ def generate(args):
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
logging.info(f"Generate done, elapsed {round(time.time() - start, 2)} s")
elif "flf2v" in args.task:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
@ -1124,6 +1129,7 @@ def generate(args):
)
logging.info("Generating video ...")
start = time.time()
video = wan_flf2v.generate(
args.prompt,
first_frame,
@ -1136,6 +1142,7 @@ def generate(args):
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
logging.info(f"Generate done, elapsed {round(time.time() - start, 2)} s")
elif "vace" in args.task:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
@ -1178,6 +1185,7 @@ def generate(args):
], args.frame_num, SIZE_CONFIGS[args.size], device)
logging.info(f"Generating video...")
start = time.time()
video = wan_vace.generate(
args.prompt,
src_video,
@ -1191,6 +1199,7 @@ def generate(args):
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
logging.info(f"Generate done, elapsed {round(time.time() - start, 2)} s")
else:
raise ValueError(f"Unkown task type: {args.task}")