mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-06-07 07:44:53 +00:00
Added --batch_size parameter to generate multiple videos per run without reloading the model.
Fixed seed synchronization handling by pre-generating them before the batch.
This commit is contained in:
parent
d7b01803de
commit
3d8f5f1a04
42
generate.py
42
generate.py
@ -187,6 +187,13 @@ def _parse_args():
|
||||
default=5.0,
|
||||
help="Classifier free guidance scale.")
|
||||
|
||||
# 20250226 pftq: batch size to do more videos without reloading the model
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of videos per run without reloading the model again.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
_validate_args(args)
|
||||
@ -270,11 +277,34 @@ def generate(args):
|
||||
dist.broadcast_object_list(base_seed, src=0)
|
||||
args.base_seed = base_seed[0]
|
||||
|
||||
# 20250226 pftq: Do more videos in one batch without reloading the model
|
||||
wan_t2v = None
|
||||
wan_i2v = None
|
||||
|
||||
# 20250226 pftq: Pre-generate and sync seeds for all videos
|
||||
seeds = []
|
||||
if dist.is_initialized():
|
||||
if rank == 0:
|
||||
seeds = [args.base_seed if args.base_seed >= 0 else random.randint(0, sys.maxsize)]
|
||||
seeds.extend(random.randint(0, sys.maxsize) for _ in range(args.batch_size - 1))
|
||||
else:
|
||||
seeds = [None] * args.batch_size
|
||||
dist.broadcast_object_list(seeds, src=0)
|
||||
else:
|
||||
seeds = [args.base_seed if args.base_seed >= 0 else random.randint(0, sys.maxsize)]
|
||||
seeds.extend(random.randint(0, sys.maxsize) for _ in range(args.batch_size - 1))
|
||||
|
||||
for batch_index in range(args.batch_size):
|
||||
if batch_index > 0:
|
||||
args.base_seed = seeds[batch_index]
|
||||
if args.batch_size > 1:
|
||||
logging.info("Batch size: "+str(args.batch_size)+" | Video #"+str(batch_index+1))
|
||||
|
||||
if "t2v" in args.task or "t2i" in args.task:
|
||||
if args.prompt is None:
|
||||
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
||||
logging.info(f"Input prompt: {args.prompt}")
|
||||
if args.use_prompt_extend:
|
||||
if args.use_prompt_extend and batch_index == 0: # 20250226 pftq: extend once
|
||||
logging.info("Extending prompt ...")
|
||||
if rank == 0:
|
||||
prompt_output = prompt_expander(
|
||||
@ -296,6 +326,7 @@ def generate(args):
|
||||
args.prompt = input_prompt[0]
|
||||
logging.info(f"Extended prompt: {args.prompt}")
|
||||
|
||||
if batch_index == 0: # 20250226 pftq: only load the model once
|
||||
logging.info("Creating WanT2V pipeline.")
|
||||
wan_t2v = wan.WanT2V(
|
||||
config=cfg,
|
||||
@ -329,8 +360,9 @@ def generate(args):
|
||||
logging.info(f"Input prompt: {args.prompt}")
|
||||
logging.info(f"Input image: {args.image}")
|
||||
|
||||
if batch_index == 0: # 20250226 pftq: load once
|
||||
img = Image.open(args.image).convert("RGB")
|
||||
if args.use_prompt_extend:
|
||||
if args.use_prompt_extend and batch_index == 0: # 20250226 pftq: extend once
|
||||
logging.info("Extending prompt ...")
|
||||
if rank == 0:
|
||||
prompt_output = prompt_expander(
|
||||
@ -353,6 +385,7 @@ def generate(args):
|
||||
args.prompt = input_prompt[0]
|
||||
logging.info(f"Extended prompt: {args.prompt}")
|
||||
|
||||
if batch_index == 0: # 20250226 pftq: only load the model once
|
||||
logging.info("Creating WanI2V pipeline.")
|
||||
wan_i2v = wan.WanI2V(
|
||||
config=cfg,
|
||||
@ -386,6 +419,11 @@ def generate(args):
|
||||
suffix = '.png' if "t2i" in args.task else '.mp4'
|
||||
args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix
|
||||
|
||||
# 20250226 pftq: if multiple videos per batch, change the filename index
|
||||
if args.batch_size > 1:
|
||||
if batch_index == 0: args.save_file = args.save_file.replace(".mp4", "__"+str(batch_index+1)+".mp4")
|
||||
else: args.save_file = args.save_file.replace("__"+str(batch_index)+".mp4", "__"+str(batch_index+1)+".mp4")
|
||||
|
||||
if "t2i" in args.task:
|
||||
logging.info(f"Saving generated image to {args.save_file}")
|
||||
cache_image(
|
||||
|
Loading…
Reference in New Issue
Block a user