mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-06-07 07:44:53 +00:00
Added --batch_size parameter to generate multiple videos per run without reloading the model.
Saves a few minutes on reloading the model and having to re-enter the command each time.
This commit is contained in:
parent
190e928681
commit
b15339baa3
289
generate.py
289
generate.py
@ -187,6 +187,13 @@ def _parse_args():
|
|||||||
default=5.0,
|
default=5.0,
|
||||||
help="Classifier free guidance scale.")
|
help="Classifier free guidance scale.")
|
||||||
|
|
||||||
|
# 20250226 pftq: batch size to do more videos without reloading the model
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of videos per run without reloading the model again.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
_validate_args(args)
|
_validate_args(args)
|
||||||
@ -270,142 +277,156 @@ def generate(args):
|
|||||||
dist.broadcast_object_list(base_seed, src=0)
|
dist.broadcast_object_list(base_seed, src=0)
|
||||||
args.base_seed = base_seed[0]
|
args.base_seed = base_seed[0]
|
||||||
|
|
||||||
if "t2v" in args.task or "t2i" in args.task:
|
# 20250226 pftq: Do more videos in one batch without reloading the model
|
||||||
if args.prompt is None:
|
for batch_index in range(args.batch_size):
|
||||||
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
if batch_index > 0:
|
||||||
logging.info(f"Input prompt: {args.prompt}")
|
args.base_seed = random.randint(0, sys.maxsize)
|
||||||
if args.use_prompt_extend:
|
if args.batch_size > 1:
|
||||||
logging.info("Extending prompt ...")
|
logging.info("Batch size: "+str(args.batch_size)+" | Video #"+str(batch_index+1))
|
||||||
if rank == 0:
|
|
||||||
prompt_output = prompt_expander(
|
if "t2v" in args.task or "t2i" in args.task:
|
||||||
args.prompt,
|
if args.prompt is None:
|
||||||
tar_lang=args.prompt_extend_target_lang,
|
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
||||||
seed=args.base_seed)
|
logging.info(f"Input prompt: {args.prompt}")
|
||||||
if prompt_output.status == False:
|
if args.use_prompt_extend:
|
||||||
logging.info(
|
logging.info("Extending prompt ...")
|
||||||
f"Extending prompt failed: {prompt_output.message}")
|
if rank == 0:
|
||||||
logging.info("Falling back to original prompt.")
|
prompt_output = prompt_expander(
|
||||||
input_prompt = args.prompt
|
args.prompt,
|
||||||
|
tar_lang=args.prompt_extend_target_lang,
|
||||||
|
seed=args.base_seed)
|
||||||
|
if prompt_output.status == False:
|
||||||
|
logging.info(
|
||||||
|
f"Extending prompt failed: {prompt_output.message}")
|
||||||
|
logging.info("Falling back to original prompt.")
|
||||||
|
input_prompt = args.prompt
|
||||||
|
else:
|
||||||
|
input_prompt = prompt_output.prompt
|
||||||
|
input_prompt = [input_prompt]
|
||||||
else:
|
else:
|
||||||
input_prompt = prompt_output.prompt
|
input_prompt = [None]
|
||||||
input_prompt = [input_prompt]
|
if dist.is_initialized():
|
||||||
else:
|
dist.broadcast_object_list(input_prompt, src=0)
|
||||||
input_prompt = [None]
|
args.prompt = input_prompt[0]
|
||||||
if dist.is_initialized():
|
logging.info(f"Extended prompt: {args.prompt}")
|
||||||
dist.broadcast_object_list(input_prompt, src=0)
|
|
||||||
args.prompt = input_prompt[0]
|
if batch_index == 0: # 20250226 pftq: only load the model once
|
||||||
logging.info(f"Extended prompt: {args.prompt}")
|
logging.info("Creating WanT2V pipeline.")
|
||||||
|
wan_t2v = wan.WanT2V(
|
||||||
logging.info("Creating WanT2V pipeline.")
|
config=cfg,
|
||||||
wan_t2v = wan.WanT2V(
|
checkpoint_dir=args.ckpt_dir,
|
||||||
config=cfg,
|
device_id=device,
|
||||||
checkpoint_dir=args.ckpt_dir,
|
rank=rank,
|
||||||
device_id=device,
|
t5_fsdp=args.t5_fsdp,
|
||||||
rank=rank,
|
dit_fsdp=args.dit_fsdp,
|
||||||
t5_fsdp=args.t5_fsdp,
|
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||||
dit_fsdp=args.dit_fsdp,
|
t5_cpu=args.t5_cpu,
|
||||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
)
|
||||||
t5_cpu=args.t5_cpu,
|
|
||||||
)
|
logging.info(
|
||||||
|
f"Generating {'image' if 't2i' in args.task else 'video'} ...")
|
||||||
logging.info(
|
video = wan_t2v.generate(
|
||||||
f"Generating {'image' if 't2i' in args.task else 'video'} ...")
|
args.prompt,
|
||||||
video = wan_t2v.generate(
|
size=SIZE_CONFIGS[args.size],
|
||||||
args.prompt,
|
frame_num=args.frame_num,
|
||||||
size=SIZE_CONFIGS[args.size],
|
shift=args.sample_shift,
|
||||||
frame_num=args.frame_num,
|
sample_solver=args.sample_solver,
|
||||||
shift=args.sample_shift,
|
sampling_steps=args.sample_steps,
|
||||||
sample_solver=args.sample_solver,
|
guide_scale=args.sample_guide_scale,
|
||||||
sampling_steps=args.sample_steps,
|
seed=args.base_seed,
|
||||||
guide_scale=args.sample_guide_scale,
|
offload_model=args.offload_model)
|
||||||
seed=args.base_seed,
|
|
||||||
offload_model=args.offload_model)
|
|
||||||
|
|
||||||
else:
|
|
||||||
if args.prompt is None:
|
|
||||||
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
|
||||||
if args.image is None:
|
|
||||||
args.image = EXAMPLE_PROMPT[args.task]["image"]
|
|
||||||
logging.info(f"Input prompt: {args.prompt}")
|
|
||||||
logging.info(f"Input image: {args.image}")
|
|
||||||
|
|
||||||
img = Image.open(args.image).convert("RGB")
|
|
||||||
if args.use_prompt_extend:
|
|
||||||
logging.info("Extending prompt ...")
|
|
||||||
if rank == 0:
|
|
||||||
prompt_output = prompt_expander(
|
|
||||||
args.prompt,
|
|
||||||
tar_lang=args.prompt_extend_target_lang,
|
|
||||||
image=img,
|
|
||||||
seed=args.base_seed)
|
|
||||||
if prompt_output.status == False:
|
|
||||||
logging.info(
|
|
||||||
f"Extending prompt failed: {prompt_output.message}")
|
|
||||||
logging.info("Falling back to original prompt.")
|
|
||||||
input_prompt = args.prompt
|
|
||||||
else:
|
|
||||||
input_prompt = prompt_output.prompt
|
|
||||||
input_prompt = [input_prompt]
|
|
||||||
else:
|
|
||||||
input_prompt = [None]
|
|
||||||
if dist.is_initialized():
|
|
||||||
dist.broadcast_object_list(input_prompt, src=0)
|
|
||||||
args.prompt = input_prompt[0]
|
|
||||||
logging.info(f"Extended prompt: {args.prompt}")
|
|
||||||
|
|
||||||
logging.info("Creating WanI2V pipeline.")
|
|
||||||
wan_i2v = wan.WanI2V(
|
|
||||||
config=cfg,
|
|
||||||
checkpoint_dir=args.ckpt_dir,
|
|
||||||
device_id=device,
|
|
||||||
rank=rank,
|
|
||||||
t5_fsdp=args.t5_fsdp,
|
|
||||||
dit_fsdp=args.dit_fsdp,
|
|
||||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
|
||||||
t5_cpu=args.t5_cpu,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info("Generating video ...")
|
|
||||||
video = wan_i2v.generate(
|
|
||||||
args.prompt,
|
|
||||||
img,
|
|
||||||
max_area=MAX_AREA_CONFIGS[args.size],
|
|
||||||
frame_num=args.frame_num,
|
|
||||||
shift=args.sample_shift,
|
|
||||||
sample_solver=args.sample_solver,
|
|
||||||
sampling_steps=args.sample_steps,
|
|
||||||
guide_scale=args.sample_guide_scale,
|
|
||||||
seed=args.base_seed,
|
|
||||||
offload_model=args.offload_model)
|
|
||||||
|
|
||||||
if rank == 0:
|
|
||||||
if args.save_file is None:
|
|
||||||
formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
formatted_prompt = args.prompt.replace(" ", "_").replace("/",
|
|
||||||
"_")[:50]
|
|
||||||
suffix = '.png' if "t2i" in args.task else '.mp4'
|
|
||||||
args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix
|
|
||||||
|
|
||||||
if "t2i" in args.task:
|
|
||||||
logging.info(f"Saving generated image to {args.save_file}")
|
|
||||||
cache_image(
|
|
||||||
tensor=video.squeeze(1)[None],
|
|
||||||
save_file=args.save_file,
|
|
||||||
nrow=1,
|
|
||||||
normalize=True,
|
|
||||||
value_range=(-1, 1))
|
|
||||||
else:
|
else:
|
||||||
logging.info(f"Saving generated video to {args.save_file}")
|
if args.prompt is None:
|
||||||
cache_video(
|
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
||||||
tensor=video[None],
|
if args.image is None:
|
||||||
save_file=args.save_file,
|
args.image = EXAMPLE_PROMPT[args.task]["image"]
|
||||||
fps=cfg.sample_fps,
|
logging.info(f"Input prompt: {args.prompt}")
|
||||||
nrow=1,
|
logging.info(f"Input image: {args.image}")
|
||||||
normalize=True,
|
|
||||||
value_range=(-1, 1))
|
img = Image.open(args.image).convert("RGB")
|
||||||
logging.info("Finished.")
|
if args.use_prompt_extend:
|
||||||
|
logging.info("Extending prompt ...")
|
||||||
|
if rank == 0:
|
||||||
|
prompt_output = prompt_expander(
|
||||||
|
args.prompt,
|
||||||
|
tar_lang=args.prompt_extend_target_lang,
|
||||||
|
image=img,
|
||||||
|
seed=args.base_seed)
|
||||||
|
if prompt_output.status == False:
|
||||||
|
logging.info(
|
||||||
|
f"Extending prompt failed: {prompt_output.message}")
|
||||||
|
logging.info("Falling back to original prompt.")
|
||||||
|
input_prompt = args.prompt
|
||||||
|
else:
|
||||||
|
input_prompt = prompt_output.prompt
|
||||||
|
input_prompt = [input_prompt]
|
||||||
|
else:
|
||||||
|
input_prompt = [None]
|
||||||
|
if dist.is_initialized():
|
||||||
|
dist.broadcast_object_list(input_prompt, src=0)
|
||||||
|
args.prompt = input_prompt[0]
|
||||||
|
logging.info(f"Extended prompt: {args.prompt}")
|
||||||
|
|
||||||
|
if batch_index == 0: # 20250226 pftq: only load the model once
|
||||||
|
logging.info("Creating WanI2V pipeline.")
|
||||||
|
wan_i2v = wan.WanI2V(
|
||||||
|
config=cfg,
|
||||||
|
checkpoint_dir=args.ckpt_dir,
|
||||||
|
device_id=device,
|
||||||
|
rank=rank,
|
||||||
|
t5_fsdp=args.t5_fsdp,
|
||||||
|
dit_fsdp=args.dit_fsdp,
|
||||||
|
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||||
|
t5_cpu=args.t5_cpu,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Generating video...")
|
||||||
|
video = wan_i2v.generate(
|
||||||
|
args.prompt,
|
||||||
|
img,
|
||||||
|
max_area=MAX_AREA_CONFIGS[args.size],
|
||||||
|
frame_num=args.frame_num,
|
||||||
|
shift=args.sample_shift,
|
||||||
|
sample_solver=args.sample_solver,
|
||||||
|
sampling_steps=args.sample_steps,
|
||||||
|
guide_scale=args.sample_guide_scale,
|
||||||
|
seed=args.base_seed,
|
||||||
|
offload_model=args.offload_model)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
if args.save_file is None:
|
||||||
|
formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
formatted_prompt = args.prompt.replace(" ", "_").replace("/",
|
||||||
|
"_")[:50]
|
||||||
|
suffix = '.png' if "t2i" in args.task else '.mp4'
|
||||||
|
args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix
|
||||||
|
|
||||||
|
# 20250226 pftq: if multiple videos per batch, change the filename index
|
||||||
|
if args.batch_size > 1:
|
||||||
|
if batch_index == 0: args.save_file = args.save_file.replace(".mp4", "__"+str(batch_index+1)+".mp4")
|
||||||
|
else: args.save_file = args.save_file.replace("__"+str(batch_index)+".mp4", str(batch_index+1)+".mp4")
|
||||||
|
|
||||||
|
if "t2i" in args.task:
|
||||||
|
logging.info(f"Saving generated image to {args.save_file}")
|
||||||
|
cache_image(
|
||||||
|
tensor=video.squeeze(1)[None],
|
||||||
|
save_file=args.save_file,
|
||||||
|
nrow=1,
|
||||||
|
normalize=True,
|
||||||
|
value_range=(-1, 1))
|
||||||
|
else:
|
||||||
|
logging.info(f"Saving generated video to {args.save_file}")
|
||||||
|
cache_video(
|
||||||
|
tensor=video[None],
|
||||||
|
save_file=args.save_file,
|
||||||
|
fps=cfg.sample_fps,
|
||||||
|
nrow=1,
|
||||||
|
normalize=True,
|
||||||
|
value_range=(-1, 1))
|
||||||
|
logging.info("Finished.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = _parse_args()
|
args = _parse_args()
|
||||||
generate(args)
|
generate(args)
|
||||||
|
Loading…
Reference in New Issue
Block a user