Added --batch_size parameter to generate multiple videos per run without reloading the model.

Fixed seed synchronization handling by pre-generating them before the batch.
This commit is contained in:
pftq 2025-02-26 07:55:30 -08:00 committed by GitHub
parent d7b01803de
commit 3d8f5f1a04
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -187,6 +187,13 @@ def _parse_args():
default=5.0, default=5.0,
help="Classifier free guidance scale.") help="Classifier free guidance scale.")
# 20250226 pftq: batch size to do more videos without reloading the model
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="Number of videos per run without reloading the model again.")
args = parser.parse_args() args = parser.parse_args()
_validate_args(args) _validate_args(args)
@ -270,11 +277,34 @@ def generate(args):
dist.broadcast_object_list(base_seed, src=0) dist.broadcast_object_list(base_seed, src=0)
args.base_seed = base_seed[0] args.base_seed = base_seed[0]
# 20250226 pftq: Do more videos in one batch without reloading the model
wan_t2v = None
wan_i2v = None
# 20250226 pftq: Pre-generate and sync seeds for all videos
seeds = []
if dist.is_initialized():
if rank == 0:
seeds = [args.base_seed if args.base_seed >= 0 else random.randint(0, sys.maxsize)]
seeds.extend(random.randint(0, sys.maxsize) for _ in range(args.batch_size - 1))
else:
seeds = [None] * args.batch_size
dist.broadcast_object_list(seeds, src=0)
else:
seeds = [args.base_seed if args.base_seed >= 0 else random.randint(0, sys.maxsize)]
seeds.extend(random.randint(0, sys.maxsize) for _ in range(args.batch_size - 1))
for batch_index in range(args.batch_size):
if batch_index > 0:
args.base_seed = seeds[batch_index]
if args.batch_size > 1:
logging.info("Batch size: "+str(args.batch_size)+" | Video #"+str(batch_index+1))
if "t2v" in args.task or "t2i" in args.task: if "t2v" in args.task or "t2i" in args.task:
if args.prompt is None: if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
logging.info(f"Input prompt: {args.prompt}") logging.info(f"Input prompt: {args.prompt}")
if args.use_prompt_extend: if args.use_prompt_extend and batch_index == 0: # 20250226 pftq: extend once
logging.info("Extending prompt ...") logging.info("Extending prompt ...")
if rank == 0: if rank == 0:
prompt_output = prompt_expander( prompt_output = prompt_expander(
@ -296,6 +326,7 @@ def generate(args):
args.prompt = input_prompt[0] args.prompt = input_prompt[0]
logging.info(f"Extended prompt: {args.prompt}") logging.info(f"Extended prompt: {args.prompt}")
if batch_index == 0: # 20250226 pftq: only load the model once
logging.info("Creating WanT2V pipeline.") logging.info("Creating WanT2V pipeline.")
wan_t2v = wan.WanT2V( wan_t2v = wan.WanT2V(
config=cfg, config=cfg,
@ -329,8 +360,9 @@ def generate(args):
logging.info(f"Input prompt: {args.prompt}") logging.info(f"Input prompt: {args.prompt}")
logging.info(f"Input image: {args.image}") logging.info(f"Input image: {args.image}")
if batch_index == 0: # 20250226 pftq: load once
img = Image.open(args.image).convert("RGB") img = Image.open(args.image).convert("RGB")
if args.use_prompt_extend: if args.use_prompt_extend and batch_index == 0: # 20250226 pftq: extend once
logging.info("Extending prompt ...") logging.info("Extending prompt ...")
if rank == 0: if rank == 0:
prompt_output = prompt_expander( prompt_output = prompt_expander(
@ -353,6 +385,7 @@ def generate(args):
args.prompt = input_prompt[0] args.prompt = input_prompt[0]
logging.info(f"Extended prompt: {args.prompt}") logging.info(f"Extended prompt: {args.prompt}")
if batch_index == 0: # 20250226 pftq: only load the model once
logging.info("Creating WanI2V pipeline.") logging.info("Creating WanI2V pipeline.")
wan_i2v = wan.WanI2V( wan_i2v = wan.WanI2V(
config=cfg, config=cfg,
@ -386,6 +419,11 @@ def generate(args):
suffix = '.png' if "t2i" in args.task else '.mp4' suffix = '.png' if "t2i" in args.task else '.mp4'
args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix
# 20250226 pftq: if multiple videos per batch, change the filename index
if args.batch_size > 1:
if batch_index == 0: args.save_file = args.save_file.replace(".mp4", "__"+str(batch_index+1)+".mp4")
else: args.save_file = args.save_file.replace("__"+str(batch_index)+".mp4", "__"+str(batch_index+1)+".mp4")
if "t2i" in args.task: if "t2i" in args.task:
logging.info(f"Saving generated image to {args.save_file}") logging.info(f"Saving generated image to {args.save_file}")
cache_image( cache_image(