Merge pull request #8 from songruizz/release_0.0.3

logging the generate time
This commit is contained in:
zc8gerard 2025-06-10 22:07:48 +08:00 committed by GitHub
commit caac4152cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,6 +3,7 @@ import argparse
import logging import logging
import os import os
import sys import sys
import time
import warnings import warnings
from tqdm import tqdm from tqdm import tqdm
from datetime import datetime from datetime import datetime
@ -978,6 +979,7 @@ def generate(args):
logging.info( logging.info(
f"Generating {'image' if 't2i' in args.task else 'video'} ...") f"Generating {'image' if 't2i' in args.task else 'video'} ...")
start = time.time()
video = wan_t2v.generate( video = wan_t2v.generate(
args.prompt, args.prompt,
size=SIZE_CONFIGS[args.size], size=SIZE_CONFIGS[args.size],
@ -988,6 +990,7 @@ def generate(args):
guide_scale=args.sample_guide_scale, guide_scale=args.sample_guide_scale,
seed=args.base_seed, seed=args.base_seed,
offload_model=args.offload_model) offload_model=args.offload_model)
logging.info(f"Generate done, elapsed {round(time.time() - start, 2)} s")
elif "i2v" in args.task: elif "i2v" in args.task:
if args.prompt is None: if args.prompt is None:
@ -1067,6 +1070,7 @@ def generate(args):
wan_i2v.model.__class__.cutoff_steps = args.sample_steps*2 - 2 wan_i2v.model.__class__.cutoff_steps = args.sample_steps*2 - 2
logging.info("Generating video ...") logging.info("Generating video ...")
start = time.time()
video = wan_i2v.generate( video = wan_i2v.generate(
args.prompt, args.prompt,
img, img,
@ -1078,6 +1082,7 @@ def generate(args):
guide_scale=args.sample_guide_scale, guide_scale=args.sample_guide_scale,
seed=args.base_seed, seed=args.base_seed,
offload_model=args.offload_model) offload_model=args.offload_model)
logging.info(f"Generate done, elapsed {round(time.time() - start, 2)} s")
elif "flf2v" in args.task: elif "flf2v" in args.task:
if args.prompt is None: if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
@ -1125,6 +1130,7 @@ def generate(args):
) )
logging.info("Generating video ...") logging.info("Generating video ...")
start = time.time()
video = wan_flf2v.generate( video = wan_flf2v.generate(
args.prompt, args.prompt,
first_frame, first_frame,
@ -1137,6 +1143,7 @@ def generate(args):
guide_scale=args.sample_guide_scale, guide_scale=args.sample_guide_scale,
seed=args.base_seed, seed=args.base_seed,
offload_model=args.offload_model) offload_model=args.offload_model)
logging.info(f"Generate done, elapsed {round(time.time() - start, 2)} s")
elif "vace" in args.task: elif "vace" in args.task:
if args.prompt is None: if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
@ -1179,6 +1186,7 @@ def generate(args):
], args.frame_num, SIZE_CONFIGS[args.size], device) ], args.frame_num, SIZE_CONFIGS[args.size], device)
logging.info(f"Generating video...") logging.info(f"Generating video...")
start = time.time()
video = wan_vace.generate( video = wan_vace.generate(
args.prompt, args.prompt,
src_video, src_video,
@ -1192,6 +1200,7 @@ def generate(args):
guide_scale=args.sample_guide_scale, guide_scale=args.sample_guide_scale,
seed=args.base_seed, seed=args.base_seed,
offload_model=args.offload_model) offload_model=args.offload_model)
logging.info(f"Generate done, elapsed {round(time.time() - start, 2)} s")
else: else:
raise ValueError(f"Unkown task type: {args.task}") raise ValueError(f"Unkown task type: {args.task}")