mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-06-05 14:54:54 +00:00
Merge f3aaa59335
into 204f899b64
This commit is contained in:
commit
79bdc5628b
39
generate.py
39
generate.py
@ -5,6 +5,8 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
import time
|
||||
import json
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
@ -206,6 +208,11 @@ def _parse_args():
|
||||
type=float,
|
||||
default=5.0,
|
||||
help="Classifier free guidance scale.")
|
||||
parser.add_argument(
|
||||
"--perf_save_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The file to save performance metrics to.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -233,6 +240,9 @@ def generate(args):
|
||||
device = local_rank
|
||||
_init_logging(rank)
|
||||
|
||||
# Start timing total execution
|
||||
total_start_time = time.time()
|
||||
|
||||
if args.offload_model is None:
|
||||
args.offload_model = False if world_size > 1 else True
|
||||
logging.info(
|
||||
@ -330,6 +340,7 @@ def generate(args):
|
||||
|
||||
logging.info(
|
||||
f"Generating {'image' if 't2i' in args.task else 'video'} ...")
|
||||
start_time = time.time()
|
||||
video = wan_t2v.generate(
|
||||
args.prompt,
|
||||
size=SIZE_CONFIGS[args.size],
|
||||
@ -340,6 +351,7 @@ def generate(args):
|
||||
guide_scale=args.sample_guide_scale,
|
||||
seed=args.base_seed,
|
||||
offload_model=args.offload_model)
|
||||
generation_time = time.time() - start_time
|
||||
|
||||
elif "i2v" in args.task:
|
||||
if args.prompt is None:
|
||||
@ -386,6 +398,7 @@ def generate(args):
|
||||
)
|
||||
|
||||
logging.info("Generating video ...")
|
||||
start_time = time.time()
|
||||
video = wan_i2v.generate(
|
||||
args.prompt,
|
||||
img,
|
||||
@ -397,6 +410,7 @@ def generate(args):
|
||||
guide_scale=args.sample_guide_scale,
|
||||
seed=args.base_seed,
|
||||
offload_model=args.offload_model)
|
||||
generation_time = time.time() - start_time
|
||||
else:
|
||||
if args.prompt is None:
|
||||
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
||||
@ -444,6 +458,7 @@ def generate(args):
|
||||
)
|
||||
|
||||
logging.info("Generating video ...")
|
||||
start_time = time.time()
|
||||
video = wan_flf2v.generate(
|
||||
args.prompt,
|
||||
first_frame,
|
||||
@ -457,6 +472,7 @@ def generate(args):
|
||||
seed=args.base_seed,
|
||||
offload_model=args.offload_model
|
||||
)
|
||||
generation_time = time.time() - start_time
|
||||
|
||||
if rank == 0:
|
||||
if args.save_file is None:
|
||||
@ -483,6 +499,29 @@ def generate(args):
|
||||
nrow=1,
|
||||
normalize=True,
|
||||
value_range=(-1, 1))
|
||||
|
||||
# Save performance metrics if requested
|
||||
if args.perf_save_file is not None:
|
||||
# Calculate total time including model loading
|
||||
total_time = time.time() - total_start_time
|
||||
|
||||
perf_data = {
|
||||
"task": args.task,
|
||||
"size": args.size,
|
||||
"frame_num": args.frame_num,
|
||||
"sample_steps": args.sample_steps,
|
||||
"ulysses_size": args.ulysses_size,
|
||||
"ring_size": args.ring_size,
|
||||
"generation_time_seconds": generation_time,
|
||||
"total_time_seconds": total_time
|
||||
}
|
||||
|
||||
with open(args.perf_save_file, 'w') as f:
|
||||
json.dump(perf_data, f, indent=2)
|
||||
logging.info(f"Saved performance metrics to {args.perf_save_file}")
|
||||
logging.info(f"Generation time: {generation_time:.2f} seconds")
|
||||
logging.info(f"Total time (including model loading): {total_time:.2f} seconds")
|
||||
|
||||
logging.info("Finished.")
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user