Merge pull request #8 from songruizz/release_0.0.3

logging the generate time
This commit is contained in:
zc8gerard 2025-06-10 22:07:48 +08:00 committed by GitHub
commit caac4152cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,6 +3,7 @@ import argparse
import logging
import os
import sys
import time
import warnings
from tqdm import tqdm
from datetime import datetime
@ -978,6 +979,7 @@ def generate(args):
logging.info(
f"Generating {'image' if 't2i' in args.task else 'video'} ...")
start = time.time()
video = wan_t2v.generate(
args.prompt,
size=SIZE_CONFIGS[args.size],
@ -988,6 +990,7 @@ def generate(args):
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
logging.info(f"Generate done, elapsed {round(time.time() - start, 2)} s")
elif "i2v" in args.task:
if args.prompt is None:
@ -1067,6 +1070,7 @@ def generate(args):
wan_i2v.model.__class__.cutoff_steps = args.sample_steps*2 - 2
logging.info("Generating video ...")
start = time.time()
video = wan_i2v.generate(
args.prompt,
img,
@ -1078,6 +1082,7 @@ def generate(args):
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
logging.info(f"Generate done, elapsed {round(time.time() - start, 2)} s")
elif "flf2v" in args.task:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
@ -1125,6 +1130,7 @@ def generate(args):
)
logging.info("Generating video ...")
start = time.time()
video = wan_flf2v.generate(
args.prompt,
first_frame,
@ -1137,6 +1143,7 @@ def generate(args):
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
logging.info(f"Generate done, elapsed {round(time.time() - start, 2)} s")
elif "vace" in args.task:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
@ -1179,6 +1186,7 @@ def generate(args):
], args.frame_num, SIZE_CONFIGS[args.size], device)
logging.info(f"Generating video...")
start = time.time()
video = wan_vace.generate(
args.prompt,
src_video,
@ -1192,6 +1200,7 @@ def generate(args):
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
logging.info(f"Generate done, elapsed {round(time.time() - start, 2)} s")
else:
raise ValueError(f"Unkown task type: {args.task}")