mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	Added --batch_size parameter to generate multiple videos per run without reloading the model.
Saves a few minutes on reloading the model and having to re-enter the command each time.
This commit is contained in:
		
							parent
							
								
									190e928681
								
							
						
					
					
						commit
						b15339baa3
					
				
							
								
								
									
										289
									
								
								generate.py
									
									
									
									
									
								
							
							
						
						
									
										289
									
								
								generate.py
									
									
									
									
									
								
							@ -187,6 +187,13 @@ def _parse_args():
 | 
				
			|||||||
        default=5.0,
 | 
					        default=5.0,
 | 
				
			||||||
        help="Classifier free guidance scale.")
 | 
					        help="Classifier free guidance scale.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # 20250226 pftq: batch size to do more videos without reloading the model
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--batch_size",
 | 
				
			||||||
 | 
					        type=int,
 | 
				
			||||||
 | 
					        default=1,
 | 
				
			||||||
 | 
					        help="Number of videos per run without reloading the model again.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    _validate_args(args)
 | 
					    _validate_args(args)
 | 
				
			||||||
@ -270,142 +277,156 @@ def generate(args):
 | 
				
			|||||||
        dist.broadcast_object_list(base_seed, src=0)
 | 
					        dist.broadcast_object_list(base_seed, src=0)
 | 
				
			||||||
        args.base_seed = base_seed[0]
 | 
					        args.base_seed = base_seed[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if "t2v" in args.task or "t2i" in args.task:
 | 
					    # 20250226 pftq: Do more videos in one batch without reloading the model
 | 
				
			||||||
        if args.prompt is None:
 | 
					    for batch_index in range(args.batch_size):
 | 
				
			||||||
            args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
 | 
					        if batch_index > 0:
 | 
				
			||||||
        logging.info(f"Input prompt: {args.prompt}")
 | 
					            args.base_seed = random.randint(0, sys.maxsize)
 | 
				
			||||||
        if args.use_prompt_extend:
 | 
					        if args.batch_size > 1:
 | 
				
			||||||
            logging.info("Extending prompt ...")
 | 
					            logging.info("Batch size: "+str(args.batch_size)+" | Video #"+str(batch_index+1))
 | 
				
			||||||
            if rank == 0:
 | 
					        
 | 
				
			||||||
                prompt_output = prompt_expander(
 | 
					        if "t2v" in args.task or "t2i" in args.task:
 | 
				
			||||||
                    args.prompt,
 | 
					            if args.prompt is None:
 | 
				
			||||||
                    tar_lang=args.prompt_extend_target_lang,
 | 
					                args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
 | 
				
			||||||
                    seed=args.base_seed)
 | 
					            logging.info(f"Input prompt: {args.prompt}")
 | 
				
			||||||
                if prompt_output.status == False:
 | 
					            if args.use_prompt_extend:
 | 
				
			||||||
                    logging.info(
 | 
					                logging.info("Extending prompt ...")
 | 
				
			||||||
                        f"Extending prompt failed: {prompt_output.message}")
 | 
					                if rank == 0:
 | 
				
			||||||
                    logging.info("Falling back to original prompt.")
 | 
					                    prompt_output = prompt_expander(
 | 
				
			||||||
                    input_prompt = args.prompt
 | 
					                        args.prompt,
 | 
				
			||||||
 | 
					                        tar_lang=args.prompt_extend_target_lang,
 | 
				
			||||||
 | 
					                        seed=args.base_seed)
 | 
				
			||||||
 | 
					                    if prompt_output.status == False:
 | 
				
			||||||
 | 
					                        logging.info(
 | 
				
			||||||
 | 
					                            f"Extending prompt failed: {prompt_output.message}")
 | 
				
			||||||
 | 
					                        logging.info("Falling back to original prompt.")
 | 
				
			||||||
 | 
					                        input_prompt = args.prompt
 | 
				
			||||||
 | 
					                    else:
 | 
				
			||||||
 | 
					                        input_prompt = prompt_output.prompt
 | 
				
			||||||
 | 
					                    input_prompt = [input_prompt]
 | 
				
			||||||
                else:
 | 
					                else:
 | 
				
			||||||
                    input_prompt = prompt_output.prompt
 | 
					                    input_prompt = [None]
 | 
				
			||||||
                input_prompt = [input_prompt]
 | 
					                if dist.is_initialized():
 | 
				
			||||||
            else:
 | 
					                    dist.broadcast_object_list(input_prompt, src=0)
 | 
				
			||||||
                input_prompt = [None]
 | 
					                args.prompt = input_prompt[0]
 | 
				
			||||||
            if dist.is_initialized():
 | 
					                logging.info(f"Extended prompt: {args.prompt}")
 | 
				
			||||||
                dist.broadcast_object_list(input_prompt, src=0)
 | 
					    
 | 
				
			||||||
            args.prompt = input_prompt[0]
 | 
					            if batch_index == 0: # 20250226 pftq: only load the model once
 | 
				
			||||||
            logging.info(f"Extended prompt: {args.prompt}")
 | 
					                logging.info("Creating WanT2V pipeline.")
 | 
				
			||||||
 | 
					                wan_t2v = wan.WanT2V(
 | 
				
			||||||
        logging.info("Creating WanT2V pipeline.")
 | 
					                    config=cfg,
 | 
				
			||||||
        wan_t2v = wan.WanT2V(
 | 
					                    checkpoint_dir=args.ckpt_dir,
 | 
				
			||||||
            config=cfg,
 | 
					                    device_id=device,
 | 
				
			||||||
            checkpoint_dir=args.ckpt_dir,
 | 
					                    rank=rank,
 | 
				
			||||||
            device_id=device,
 | 
					                    t5_fsdp=args.t5_fsdp,
 | 
				
			||||||
            rank=rank,
 | 
					                    dit_fsdp=args.dit_fsdp,
 | 
				
			||||||
            t5_fsdp=args.t5_fsdp,
 | 
					                    use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
 | 
				
			||||||
            dit_fsdp=args.dit_fsdp,
 | 
					                    t5_cpu=args.t5_cpu,
 | 
				
			||||||
            use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
 | 
					                )
 | 
				
			||||||
            t5_cpu=args.t5_cpu,
 | 
					    
 | 
				
			||||||
        )
 | 
					            logging.info(
 | 
				
			||||||
 | 
					                f"Generating {'image' if 't2i' in args.task else 'video'} ...")
 | 
				
			||||||
        logging.info(
 | 
					            video = wan_t2v.generate(
 | 
				
			||||||
            f"Generating {'image' if 't2i' in args.task else 'video'} ...")
 | 
					                args.prompt,
 | 
				
			||||||
        video = wan_t2v.generate(
 | 
					                size=SIZE_CONFIGS[args.size],
 | 
				
			||||||
            args.prompt,
 | 
					                frame_num=args.frame_num,
 | 
				
			||||||
            size=SIZE_CONFIGS[args.size],
 | 
					                shift=args.sample_shift,
 | 
				
			||||||
            frame_num=args.frame_num,
 | 
					                sample_solver=args.sample_solver,
 | 
				
			||||||
            shift=args.sample_shift,
 | 
					                sampling_steps=args.sample_steps,
 | 
				
			||||||
            sample_solver=args.sample_solver,
 | 
					                guide_scale=args.sample_guide_scale,
 | 
				
			||||||
            sampling_steps=args.sample_steps,
 | 
					                seed=args.base_seed,
 | 
				
			||||||
            guide_scale=args.sample_guide_scale,
 | 
					                offload_model=args.offload_model)
 | 
				
			||||||
            seed=args.base_seed,
 | 
					    
 | 
				
			||||||
            offload_model=args.offload_model)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        if args.prompt is None:
 | 
					 | 
				
			||||||
            args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
 | 
					 | 
				
			||||||
        if args.image is None:
 | 
					 | 
				
			||||||
            args.image = EXAMPLE_PROMPT[args.task]["image"]
 | 
					 | 
				
			||||||
        logging.info(f"Input prompt: {args.prompt}")
 | 
					 | 
				
			||||||
        logging.info(f"Input image: {args.image}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        img = Image.open(args.image).convert("RGB")
 | 
					 | 
				
			||||||
        if args.use_prompt_extend:
 | 
					 | 
				
			||||||
            logging.info("Extending prompt ...")
 | 
					 | 
				
			||||||
            if rank == 0:
 | 
					 | 
				
			||||||
                prompt_output = prompt_expander(
 | 
					 | 
				
			||||||
                    args.prompt,
 | 
					 | 
				
			||||||
                    tar_lang=args.prompt_extend_target_lang,
 | 
					 | 
				
			||||||
                    image=img,
 | 
					 | 
				
			||||||
                    seed=args.base_seed)
 | 
					 | 
				
			||||||
                if prompt_output.status == False:
 | 
					 | 
				
			||||||
                    logging.info(
 | 
					 | 
				
			||||||
                        f"Extending prompt failed: {prompt_output.message}")
 | 
					 | 
				
			||||||
                    logging.info("Falling back to original prompt.")
 | 
					 | 
				
			||||||
                    input_prompt = args.prompt
 | 
					 | 
				
			||||||
                else:
 | 
					 | 
				
			||||||
                    input_prompt = prompt_output.prompt
 | 
					 | 
				
			||||||
                input_prompt = [input_prompt]
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                input_prompt = [None]
 | 
					 | 
				
			||||||
            if dist.is_initialized():
 | 
					 | 
				
			||||||
                dist.broadcast_object_list(input_prompt, src=0)
 | 
					 | 
				
			||||||
            args.prompt = input_prompt[0]
 | 
					 | 
				
			||||||
            logging.info(f"Extended prompt: {args.prompt}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        logging.info("Creating WanI2V pipeline.")
 | 
					 | 
				
			||||||
        wan_i2v = wan.WanI2V(
 | 
					 | 
				
			||||||
            config=cfg,
 | 
					 | 
				
			||||||
            checkpoint_dir=args.ckpt_dir,
 | 
					 | 
				
			||||||
            device_id=device,
 | 
					 | 
				
			||||||
            rank=rank,
 | 
					 | 
				
			||||||
            t5_fsdp=args.t5_fsdp,
 | 
					 | 
				
			||||||
            dit_fsdp=args.dit_fsdp,
 | 
					 | 
				
			||||||
            use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
 | 
					 | 
				
			||||||
            t5_cpu=args.t5_cpu,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        logging.info("Generating video ...")
 | 
					 | 
				
			||||||
        video = wan_i2v.generate(
 | 
					 | 
				
			||||||
            args.prompt,
 | 
					 | 
				
			||||||
            img,
 | 
					 | 
				
			||||||
            max_area=MAX_AREA_CONFIGS[args.size],
 | 
					 | 
				
			||||||
            frame_num=args.frame_num,
 | 
					 | 
				
			||||||
            shift=args.sample_shift,
 | 
					 | 
				
			||||||
            sample_solver=args.sample_solver,
 | 
					 | 
				
			||||||
            sampling_steps=args.sample_steps,
 | 
					 | 
				
			||||||
            guide_scale=args.sample_guide_scale,
 | 
					 | 
				
			||||||
            seed=args.base_seed,
 | 
					 | 
				
			||||||
            offload_model=args.offload_model)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if rank == 0:
 | 
					 | 
				
			||||||
        if args.save_file is None:
 | 
					 | 
				
			||||||
            formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
 | 
					 | 
				
			||||||
            formatted_prompt = args.prompt.replace(" ", "_").replace("/",
 | 
					 | 
				
			||||||
                                                                     "_")[:50]
 | 
					 | 
				
			||||||
            suffix = '.png' if "t2i" in args.task else '.mp4'
 | 
					 | 
				
			||||||
            args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if "t2i" in args.task:
 | 
					 | 
				
			||||||
            logging.info(f"Saving generated image to {args.save_file}")
 | 
					 | 
				
			||||||
            cache_image(
 | 
					 | 
				
			||||||
                tensor=video.squeeze(1)[None],
 | 
					 | 
				
			||||||
                save_file=args.save_file,
 | 
					 | 
				
			||||||
                nrow=1,
 | 
					 | 
				
			||||||
                normalize=True,
 | 
					 | 
				
			||||||
                value_range=(-1, 1))
 | 
					 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            logging.info(f"Saving generated video to {args.save_file}")
 | 
					            if args.prompt is None:
 | 
				
			||||||
            cache_video(
 | 
					                args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
 | 
				
			||||||
                tensor=video[None],
 | 
					            if args.image is None:
 | 
				
			||||||
                save_file=args.save_file,
 | 
					                args.image = EXAMPLE_PROMPT[args.task]["image"]
 | 
				
			||||||
                fps=cfg.sample_fps,
 | 
					            logging.info(f"Input prompt: {args.prompt}")
 | 
				
			||||||
                nrow=1,
 | 
					            logging.info(f"Input image: {args.image}")
 | 
				
			||||||
                normalize=True,
 | 
					    
 | 
				
			||||||
                value_range=(-1, 1))
 | 
					            img = Image.open(args.image).convert("RGB")
 | 
				
			||||||
    logging.info("Finished.")
 | 
					            if args.use_prompt_extend:
 | 
				
			||||||
 | 
					                logging.info("Extending prompt ...")
 | 
				
			||||||
 | 
					                if rank == 0:
 | 
				
			||||||
 | 
					                    prompt_output = prompt_expander(
 | 
				
			||||||
 | 
					                        args.prompt,
 | 
				
			||||||
 | 
					                        tar_lang=args.prompt_extend_target_lang,
 | 
				
			||||||
 | 
					                        image=img,
 | 
				
			||||||
 | 
					                        seed=args.base_seed)
 | 
				
			||||||
 | 
					                    if prompt_output.status == False:
 | 
				
			||||||
 | 
					                        logging.info(
 | 
				
			||||||
 | 
					                            f"Extending prompt failed: {prompt_output.message}")
 | 
				
			||||||
 | 
					                        logging.info("Falling back to original prompt.")
 | 
				
			||||||
 | 
					                        input_prompt = args.prompt
 | 
				
			||||||
 | 
					                    else:
 | 
				
			||||||
 | 
					                        input_prompt = prompt_output.prompt
 | 
				
			||||||
 | 
					                    input_prompt = [input_prompt]
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    input_prompt = [None]
 | 
				
			||||||
 | 
					                if dist.is_initialized():
 | 
				
			||||||
 | 
					                    dist.broadcast_object_list(input_prompt, src=0)
 | 
				
			||||||
 | 
					                args.prompt = input_prompt[0]
 | 
				
			||||||
 | 
					                logging.info(f"Extended prompt: {args.prompt}")
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					            if batch_index == 0: # 20250226 pftq: only load the model once
 | 
				
			||||||
 | 
					                logging.info("Creating WanI2V pipeline.")
 | 
				
			||||||
 | 
					                wan_i2v = wan.WanI2V(
 | 
				
			||||||
 | 
					                    config=cfg,
 | 
				
			||||||
 | 
					                    checkpoint_dir=args.ckpt_dir,
 | 
				
			||||||
 | 
					                    device_id=device,
 | 
				
			||||||
 | 
					                    rank=rank,
 | 
				
			||||||
 | 
					                    t5_fsdp=args.t5_fsdp,
 | 
				
			||||||
 | 
					                    dit_fsdp=args.dit_fsdp,
 | 
				
			||||||
 | 
					                    use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
 | 
				
			||||||
 | 
					                    t5_cpu=args.t5_cpu,
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                
 | 
				
			||||||
 | 
					            logging.info("Generating video...")
 | 
				
			||||||
 | 
					            video = wan_i2v.generate(
 | 
				
			||||||
 | 
					                args.prompt,
 | 
				
			||||||
 | 
					                img,
 | 
				
			||||||
 | 
					                max_area=MAX_AREA_CONFIGS[args.size],
 | 
				
			||||||
 | 
					                frame_num=args.frame_num,
 | 
				
			||||||
 | 
					                shift=args.sample_shift,
 | 
				
			||||||
 | 
					                sample_solver=args.sample_solver,
 | 
				
			||||||
 | 
					                sampling_steps=args.sample_steps,
 | 
				
			||||||
 | 
					                guide_scale=args.sample_guide_scale,
 | 
				
			||||||
 | 
					                seed=args.base_seed,
 | 
				
			||||||
 | 
					                offload_model=args.offload_model)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					        if rank == 0:
 | 
				
			||||||
 | 
					            if args.save_file is None:
 | 
				
			||||||
 | 
					                formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
 | 
				
			||||||
 | 
					                formatted_prompt = args.prompt.replace(" ", "_").replace("/",
 | 
				
			||||||
 | 
					                                                                         "_")[:50]
 | 
				
			||||||
 | 
					                suffix = '.png' if "t2i" in args.task else '.mp4'
 | 
				
			||||||
 | 
					                args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					            # 20250226 pftq: if multiple videos per batch, change the filename index
 | 
				
			||||||
 | 
					            if args.batch_size > 1:
 | 
				
			||||||
 | 
					                if batch_index == 0: args.save_file = args.save_file.replace(".mp4", "__"+str(batch_index+1)+".mp4")
 | 
				
			||||||
 | 
					                else: args.save_file = args.save_file.replace("__"+str(batch_index)+".mp4", str(batch_index+1)+".mp4")
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            if "t2i" in args.task:
 | 
				
			||||||
 | 
					                logging.info(f"Saving generated image to {args.save_file}")
 | 
				
			||||||
 | 
					                cache_image(
 | 
				
			||||||
 | 
					                    tensor=video.squeeze(1)[None],
 | 
				
			||||||
 | 
					                    save_file=args.save_file,
 | 
				
			||||||
 | 
					                    nrow=1,
 | 
				
			||||||
 | 
					                    normalize=True,
 | 
				
			||||||
 | 
					                    value_range=(-1, 1))
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                logging.info(f"Saving generated video to {args.save_file}")
 | 
				
			||||||
 | 
					                cache_video(
 | 
				
			||||||
 | 
					                    tensor=video[None],
 | 
				
			||||||
 | 
					                    save_file=args.save_file,
 | 
				
			||||||
 | 
					                    fps=cfg.sample_fps,
 | 
				
			||||||
 | 
					                    nrow=1,
 | 
				
			||||||
 | 
					                    normalize=True,
 | 
				
			||||||
 | 
					                    value_range=(-1, 1))
 | 
				
			||||||
 | 
					        logging.info("Finished.")
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    args = _parse_args()
 | 
					    args = _parse_args()
 | 
				
			||||||
    generate(args)
 | 
					    generate(args)
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user