diff --git a/generate.py b/generate.py index f27bb98..2ac4f8f 100644 --- a/generate.py +++ b/generate.py @@ -187,6 +187,13 @@ def _parse_args(): default=5.0, help="Classifier free guidance scale.") + # 20250226 pftq: batch size to do more videos without reloading the model + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="Number of videos per run without reloading the model again.") + args = parser.parse_args() _validate_args(args) @@ -270,142 +277,173 @@ def generate(args): dist.broadcast_object_list(base_seed, src=0) args.base_seed = base_seed[0] - if "t2v" in args.task or "t2i" in args.task: - if args.prompt is None: - args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] - logging.info(f"Input prompt: {args.prompt}") - if args.use_prompt_extend: - logging.info("Extending prompt ...") - if rank == 0: - prompt_output = prompt_expander( - args.prompt, - tar_lang=args.prompt_extend_target_lang, - seed=args.base_seed) - if prompt_output.status == False: - logging.info( - f"Extending prompt failed: {prompt_output.message}") - logging.info("Falling back to original prompt.") - input_prompt = args.prompt - else: - input_prompt = prompt_output.prompt - input_prompt = [input_prompt] - else: - input_prompt = [None] - if dist.is_initialized(): - dist.broadcast_object_list(input_prompt, src=0) - args.prompt = input_prompt[0] - logging.info(f"Extended prompt: {args.prompt}") - - logging.info("Creating WanT2V pipeline.") - wan_t2v = wan.WanT2V( - config=cfg, - checkpoint_dir=args.ckpt_dir, - device_id=device, - rank=rank, - t5_fsdp=args.t5_fsdp, - dit_fsdp=args.dit_fsdp, - use_usp=(args.ulysses_size > 1 or args.ring_size > 1), - t5_cpu=args.t5_cpu, - ) - - logging.info( - f"Generating {'image' if 't2i' in args.task else 'video'} ...") - video = wan_t2v.generate( - args.prompt, - size=SIZE_CONFIGS[args.size], - frame_num=args.frame_num, - shift=args.sample_shift, - sample_solver=args.sample_solver, - sampling_steps=args.sample_steps, - guide_scale=args.sample_guide_scale, - seed=args.base_seed, - offload_model=args.offload_model) - - else: - if args.prompt is None: - args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] - if args.image is None: - args.image = EXAMPLE_PROMPT[args.task]["image"] - logging.info(f"Input prompt: {args.prompt}") - logging.info(f"Input image: {args.image}") - - img = Image.open(args.image).convert("RGB") - if args.use_prompt_extend: - logging.info("Extending prompt ...") - if rank == 0: - prompt_output = prompt_expander( - args.prompt, - tar_lang=args.prompt_extend_target_lang, - image=img, - seed=args.base_seed) - if prompt_output.status == False: - logging.info( - f"Extending prompt failed: {prompt_output.message}") - logging.info("Falling back to original prompt.") - input_prompt = args.prompt - else: - input_prompt = prompt_output.prompt - input_prompt = [input_prompt] - else: - input_prompt = [None] - if dist.is_initialized(): - dist.broadcast_object_list(input_prompt, src=0) - args.prompt = input_prompt[0] - logging.info(f"Extended prompt: {args.prompt}") - - logging.info("Creating WanI2V pipeline.") - wan_i2v = wan.WanI2V( - config=cfg, - checkpoint_dir=args.ckpt_dir, - device_id=device, - rank=rank, - t5_fsdp=args.t5_fsdp, - dit_fsdp=args.dit_fsdp, - use_usp=(args.ulysses_size > 1 or args.ring_size > 1), - t5_cpu=args.t5_cpu, - ) - - logging.info("Generating video ...") - video = wan_i2v.generate( - args.prompt, - img, - max_area=MAX_AREA_CONFIGS[args.size], - frame_num=args.frame_num, - shift=args.sample_shift, - sample_solver=args.sample_solver, - sampling_steps=args.sample_steps, - guide_scale=args.sample_guide_scale, - seed=args.base_seed, - offload_model=args.offload_model) - - if rank == 0: - if args.save_file is None: - formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S") - formatted_prompt = args.prompt.replace(" ", "_").replace("/", - "_")[:50] - suffix = '.png' if "t2i" in args.task else '.mp4' - args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix - - if "t2i" in args.task: - logging.info(f"Saving generated image to {args.save_file}") - cache_image( - tensor=video.squeeze(1)[None], - save_file=args.save_file, - nrow=1, - normalize=True, - value_range=(-1, 1)) + # 20250226 pftq: Do more videos in one batch without reloading the model + wan_t2v = None + wan_i2v = None + + # 20250226 pftq: Pre-generate and sync seeds for all videos + seeds = [] + if dist.is_initialized(): + if rank == 0: + seeds = [args.base_seed if args.base_seed >= 0 else random.randint(0, sys.maxsize)] + seeds.extend(random.randint(0, sys.maxsize) for _ in range(args.batch_size - 1)) else: - logging.info(f"Saving generated video to {args.save_file}") - cache_video( - tensor=video[None], - save_file=args.save_file, - fps=cfg.sample_fps, - nrow=1, - normalize=True, - value_range=(-1, 1)) - logging.info("Finished.") + seeds = [None] * args.batch_size + dist.broadcast_object_list(seeds, src=0) + else: + seeds = [args.base_seed if args.base_seed >= 0 else random.randint(0, sys.maxsize)] + seeds.extend(random.randint(0, sys.maxsize) for _ in range(args.batch_size - 1)) + + for batch_index in range(args.batch_size): + if batch_index > 0: + args.base_seed = seeds[batch_index] + if args.batch_size > 1: + logging.info("Batch size: "+str(args.batch_size)+" | Video #"+str(batch_index+1)) + + if "t2v" in args.task or "t2i" in args.task: + if args.prompt is None: + args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] + logging.info(f"Input prompt: {args.prompt}") + if args.use_prompt_extend and batch_index == 0: # 20250226 pftq: extend once + logging.info("Extending prompt ...") + if rank == 0: + prompt_output = prompt_expander( + args.prompt, + tar_lang=args.prompt_extend_target_lang, + seed=args.base_seed) + if prompt_output.status == False: + logging.info( + f"Extending prompt failed: {prompt_output.message}") + logging.info("Falling back to original prompt.") + input_prompt = args.prompt + else: + input_prompt = prompt_output.prompt + input_prompt = [input_prompt] + else: + input_prompt = [None] + if dist.is_initialized(): + dist.broadcast_object_list(input_prompt, src=0) + args.prompt = input_prompt[0] + logging.info(f"Extended prompt: {args.prompt}") + if batch_index == 0: # 20250226 pftq: only load the model once + logging.info("Creating WanT2V pipeline.") + wan_t2v = wan.WanT2V( + config=cfg, + checkpoint_dir=args.ckpt_dir, + device_id=device, + rank=rank, + t5_fsdp=args.t5_fsdp, + dit_fsdp=args.dit_fsdp, + use_usp=(args.ulysses_size > 1 or args.ring_size > 1), + t5_cpu=args.t5_cpu, + ) + + logging.info( + f"Generating {'image' if 't2i' in args.task else 'video'} ...") + video = wan_t2v.generate( + args.prompt, + size=SIZE_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=args.sample_steps, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model) + + else: + if args.prompt is None: + args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] + if args.image is None: + args.image = EXAMPLE_PROMPT[args.task]["image"] + logging.info(f"Input prompt: {args.prompt}") + logging.info(f"Input image: {args.image}") + if batch_index == 0: # 20250226 pftq: load once + img = Image.open(args.image).convert("RGB") + if args.use_prompt_extend and batch_index == 0: # 20250226 pftq: extend once + logging.info("Extending prompt ...") + if rank == 0: + prompt_output = prompt_expander( + args.prompt, + tar_lang=args.prompt_extend_target_lang, + image=img, + seed=args.base_seed) + if prompt_output.status == False: + logging.info( + f"Extending prompt failed: {prompt_output.message}") + logging.info("Falling back to original prompt.") + input_prompt = args.prompt + else: + input_prompt = prompt_output.prompt + input_prompt = [input_prompt] + else: + input_prompt = [None] + if dist.is_initialized(): + dist.broadcast_object_list(input_prompt, src=0) + args.prompt = input_prompt[0] + logging.info(f"Extended prompt: {args.prompt}") + + if batch_index == 0: # 20250226 pftq: only load the model once + logging.info("Creating WanI2V pipeline.") + wan_i2v = wan.WanI2V( + config=cfg, + checkpoint_dir=args.ckpt_dir, + device_id=device, + rank=rank, + t5_fsdp=args.t5_fsdp, + dit_fsdp=args.dit_fsdp, + use_usp=(args.ulysses_size > 1 or args.ring_size > 1), + t5_cpu=args.t5_cpu, + ) + + logging.info("Generating video...") + video = wan_i2v.generate( + args.prompt, + img, + max_area=MAX_AREA_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=args.sample_steps, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model) + + if rank == 0: + if args.save_file is None: + formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S") + formatted_prompt = args.prompt.replace(" ", "_").replace("/", + "_")[:50] + suffix = '.png' if "t2i" in args.task else '.mp4' + args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix + + # 20250226 pftq: if multiple videos per batch, change the filename index + if args.batch_size > 1: + if batch_index == 0: args.save_file = args.save_file.replace(".mp4", "__"+str(batch_index+1)+".mp4") + else: args.save_file = args.save_file.replace("__"+str(batch_index)+".mp4", "__"+str(batch_index+1)+".mp4") + + if "t2i" in args.task: + logging.info(f"Saving generated image to {args.save_file}") + cache_image( + tensor=video.squeeze(1)[None], + save_file=args.save_file, + nrow=1, + normalize=True, + value_range=(-1, 1)) + else: + logging.info(f"Saving generated video to {args.save_file}") + cache_video( + tensor=video[None], + save_file=args.save_file, + fps=cfg.sample_fps, + nrow=1, + normalize=True, + value_range=(-1, 1)) + logging.info("Finished.") + + if __name__ == "__main__": args = _parse_args() generate(args)