Added --batch_size parameter to generate multiple videos per run without reloading the model.

Saves a few minutes on reloading the model and having to re-enter the command each time.
This commit is contained in:
pftq 2025-02-26 06:57:27 -08:00 committed by GitHub
parent 190e928681
commit b15339baa3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -187,6 +187,13 @@ def _parse_args():
default=5.0, default=5.0,
help="Classifier free guidance scale.") help="Classifier free guidance scale.")
# 20250226 pftq: batch size to do more videos without reloading the model
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="Number of videos per run without reloading the model again.")
args = parser.parse_args() args = parser.parse_args()
_validate_args(args) _validate_args(args)
@ -270,142 +277,156 @@ def generate(args):
dist.broadcast_object_list(base_seed, src=0) dist.broadcast_object_list(base_seed, src=0)
args.base_seed = base_seed[0] args.base_seed = base_seed[0]
if "t2v" in args.task or "t2i" in args.task: # 20250226 pftq: Do more videos in one batch without reloading the model
if args.prompt is None: for batch_index in range(args.batch_size):
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] if batch_index > 0:
logging.info(f"Input prompt: {args.prompt}") args.base_seed = random.randint(0, sys.maxsize)
if args.use_prompt_extend: if args.batch_size > 1:
logging.info("Extending prompt ...") logging.info("Batch size: "+str(args.batch_size)+" | Video #"+str(batch_index+1))
if rank == 0:
prompt_output = prompt_expander( if "t2v" in args.task or "t2i" in args.task:
args.prompt, if args.prompt is None:
tar_lang=args.prompt_extend_target_lang, args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
seed=args.base_seed) logging.info(f"Input prompt: {args.prompt}")
if prompt_output.status == False: if args.use_prompt_extend:
logging.info( logging.info("Extending prompt ...")
f"Extending prompt failed: {prompt_output.message}") if rank == 0:
logging.info("Falling back to original prompt.") prompt_output = prompt_expander(
input_prompt = args.prompt args.prompt,
tar_lang=args.prompt_extend_target_lang,
seed=args.base_seed)
if prompt_output.status == False:
logging.info(
f"Extending prompt failed: {prompt_output.message}")
logging.info("Falling back to original prompt.")
input_prompt = args.prompt
else:
input_prompt = prompt_output.prompt
input_prompt = [input_prompt]
else: else:
input_prompt = prompt_output.prompt input_prompt = [None]
input_prompt = [input_prompt] if dist.is_initialized():
else: dist.broadcast_object_list(input_prompt, src=0)
input_prompt = [None] args.prompt = input_prompt[0]
if dist.is_initialized(): logging.info(f"Extended prompt: {args.prompt}")
dist.broadcast_object_list(input_prompt, src=0)
args.prompt = input_prompt[0] if batch_index == 0: # 20250226 pftq: only load the model once
logging.info(f"Extended prompt: {args.prompt}") logging.info("Creating WanT2V pipeline.")
wan_t2v = wan.WanT2V(
logging.info("Creating WanT2V pipeline.") config=cfg,
wan_t2v = wan.WanT2V( checkpoint_dir=args.ckpt_dir,
config=cfg, device_id=device,
checkpoint_dir=args.ckpt_dir, rank=rank,
device_id=device, t5_fsdp=args.t5_fsdp,
rank=rank, dit_fsdp=args.dit_fsdp,
t5_fsdp=args.t5_fsdp, use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
dit_fsdp=args.dit_fsdp, t5_cpu=args.t5_cpu,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1), )
t5_cpu=args.t5_cpu,
) logging.info(
f"Generating {'image' if 't2i' in args.task else 'video'} ...")
logging.info( video = wan_t2v.generate(
f"Generating {'image' if 't2i' in args.task else 'video'} ...") args.prompt,
video = wan_t2v.generate( size=SIZE_CONFIGS[args.size],
args.prompt, frame_num=args.frame_num,
size=SIZE_CONFIGS[args.size], shift=args.sample_shift,
frame_num=args.frame_num, sample_solver=args.sample_solver,
shift=args.sample_shift, sampling_steps=args.sample_steps,
sample_solver=args.sample_solver, guide_scale=args.sample_guide_scale,
sampling_steps=args.sample_steps, seed=args.base_seed,
guide_scale=args.sample_guide_scale, offload_model=args.offload_model)
seed=args.base_seed,
offload_model=args.offload_model)
else:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
if args.image is None:
args.image = EXAMPLE_PROMPT[args.task]["image"]
logging.info(f"Input prompt: {args.prompt}")
logging.info(f"Input image: {args.image}")
img = Image.open(args.image).convert("RGB")
if args.use_prompt_extend:
logging.info("Extending prompt ...")
if rank == 0:
prompt_output = prompt_expander(
args.prompt,
tar_lang=args.prompt_extend_target_lang,
image=img,
seed=args.base_seed)
if prompt_output.status == False:
logging.info(
f"Extending prompt failed: {prompt_output.message}")
logging.info("Falling back to original prompt.")
input_prompt = args.prompt
else:
input_prompt = prompt_output.prompt
input_prompt = [input_prompt]
else:
input_prompt = [None]
if dist.is_initialized():
dist.broadcast_object_list(input_prompt, src=0)
args.prompt = input_prompt[0]
logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating WanI2V pipeline.")
wan_i2v = wan.WanI2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=device,
rank=rank,
t5_fsdp=args.t5_fsdp,
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
)
logging.info("Generating video ...")
video = wan_i2v.generate(
args.prompt,
img,
max_area=MAX_AREA_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
if rank == 0:
if args.save_file is None:
formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
formatted_prompt = args.prompt.replace(" ", "_").replace("/",
"_")[:50]
suffix = '.png' if "t2i" in args.task else '.mp4'
args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix
if "t2i" in args.task:
logging.info(f"Saving generated image to {args.save_file}")
cache_image(
tensor=video.squeeze(1)[None],
save_file=args.save_file,
nrow=1,
normalize=True,
value_range=(-1, 1))
else: else:
logging.info(f"Saving generated video to {args.save_file}") if args.prompt is None:
cache_video( args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
tensor=video[None], if args.image is None:
save_file=args.save_file, args.image = EXAMPLE_PROMPT[args.task]["image"]
fps=cfg.sample_fps, logging.info(f"Input prompt: {args.prompt}")
nrow=1, logging.info(f"Input image: {args.image}")
normalize=True,
value_range=(-1, 1)) img = Image.open(args.image).convert("RGB")
logging.info("Finished.") if args.use_prompt_extend:
logging.info("Extending prompt ...")
if rank == 0:
prompt_output = prompt_expander(
args.prompt,
tar_lang=args.prompt_extend_target_lang,
image=img,
seed=args.base_seed)
if prompt_output.status == False:
logging.info(
f"Extending prompt failed: {prompt_output.message}")
logging.info("Falling back to original prompt.")
input_prompt = args.prompt
else:
input_prompt = prompt_output.prompt
input_prompt = [input_prompt]
else:
input_prompt = [None]
if dist.is_initialized():
dist.broadcast_object_list(input_prompt, src=0)
args.prompt = input_prompt[0]
logging.info(f"Extended prompt: {args.prompt}")
if batch_index == 0: # 20250226 pftq: only load the model once
logging.info("Creating WanI2V pipeline.")
wan_i2v = wan.WanI2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=device,
rank=rank,
t5_fsdp=args.t5_fsdp,
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
)
logging.info("Generating video...")
video = wan_i2v.generate(
args.prompt,
img,
max_area=MAX_AREA_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
if rank == 0:
if args.save_file is None:
formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
formatted_prompt = args.prompt.replace(" ", "_").replace("/",
"_")[:50]
suffix = '.png' if "t2i" in args.task else '.mp4'
args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix
# 20250226 pftq: if multiple videos per batch, change the filename index
if args.batch_size > 1:
if batch_index == 0: args.save_file = args.save_file.replace(".mp4", "__"+str(batch_index+1)+".mp4")
else: args.save_file = args.save_file.replace("__"+str(batch_index)+".mp4", str(batch_index+1)+".mp4")
if "t2i" in args.task:
logging.info(f"Saving generated image to {args.save_file}")
cache_image(
tensor=video.squeeze(1)[None],
save_file=args.save_file,
nrow=1,
normalize=True,
value_range=(-1, 1))
else:
logging.info(f"Saving generated video to {args.save_file}")
cache_video(
tensor=video[None],
save_file=args.save_file,
fps=cfg.sample_fps,
nrow=1,
normalize=True,
value_range=(-1, 1))
logging.info("Finished.")
if __name__ == "__main__": if __name__ == "__main__":
args = _parse_args() args = _parse_args()
generate(args) generate(args)