mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	Added --batch_size parameter to generate multiple videos per run without reloading the model.
Fixed seed synchronization handling by pre-generating them before the batch.
This commit is contained in:
		
							parent
							
								
									d7b01803de
								
							
						
					
					
						commit
						3d8f5f1a04
					
				
							
								
								
									
										304
									
								
								generate.py
									
									
									
									
									
								
							
							
						
						
									
										304
									
								
								generate.py
									
									
									
									
									
								
							@ -187,6 +187,13 @@ def _parse_args():
 | 
			
		||||
        default=5.0,
 | 
			
		||||
        help="Classifier free guidance scale.")
 | 
			
		||||
 | 
			
		||||
    # 20250226 pftq: batch size to do more videos without reloading the model
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--batch_size",
 | 
			
		||||
        type=int,
 | 
			
		||||
        default=1,
 | 
			
		||||
        help="Number of videos per run without reloading the model again.")
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    _validate_args(args)
 | 
			
		||||
@ -270,142 +277,173 @@ def generate(args):
 | 
			
		||||
        dist.broadcast_object_list(base_seed, src=0)
 | 
			
		||||
        args.base_seed = base_seed[0]
 | 
			
		||||
 | 
			
		||||
    if "t2v" in args.task or "t2i" in args.task:
 | 
			
		||||
        if args.prompt is None:
 | 
			
		||||
            args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
 | 
			
		||||
        logging.info(f"Input prompt: {args.prompt}")
 | 
			
		||||
        if args.use_prompt_extend:
 | 
			
		||||
            logging.info("Extending prompt ...")
 | 
			
		||||
            if rank == 0:
 | 
			
		||||
                prompt_output = prompt_expander(
 | 
			
		||||
                    args.prompt,
 | 
			
		||||
                    tar_lang=args.prompt_extend_target_lang,
 | 
			
		||||
                    seed=args.base_seed)
 | 
			
		||||
                if prompt_output.status == False:
 | 
			
		||||
                    logging.info(
 | 
			
		||||
                        f"Extending prompt failed: {prompt_output.message}")
 | 
			
		||||
                    logging.info("Falling back to original prompt.")
 | 
			
		||||
                    input_prompt = args.prompt
 | 
			
		||||
                else:
 | 
			
		||||
                    input_prompt = prompt_output.prompt
 | 
			
		||||
                input_prompt = [input_prompt]
 | 
			
		||||
            else:
 | 
			
		||||
                input_prompt = [None]
 | 
			
		||||
            if dist.is_initialized():
 | 
			
		||||
                dist.broadcast_object_list(input_prompt, src=0)
 | 
			
		||||
            args.prompt = input_prompt[0]
 | 
			
		||||
            logging.info(f"Extended prompt: {args.prompt}")
 | 
			
		||||
 | 
			
		||||
        logging.info("Creating WanT2V pipeline.")
 | 
			
		||||
        wan_t2v = wan.WanT2V(
 | 
			
		||||
            config=cfg,
 | 
			
		||||
            checkpoint_dir=args.ckpt_dir,
 | 
			
		||||
            device_id=device,
 | 
			
		||||
            rank=rank,
 | 
			
		||||
            t5_fsdp=args.t5_fsdp,
 | 
			
		||||
            dit_fsdp=args.dit_fsdp,
 | 
			
		||||
            use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
 | 
			
		||||
            t5_cpu=args.t5_cpu,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        logging.info(
 | 
			
		||||
            f"Generating {'image' if 't2i' in args.task else 'video'} ...")
 | 
			
		||||
        video = wan_t2v.generate(
 | 
			
		||||
            args.prompt,
 | 
			
		||||
            size=SIZE_CONFIGS[args.size],
 | 
			
		||||
            frame_num=args.frame_num,
 | 
			
		||||
            shift=args.sample_shift,
 | 
			
		||||
            sample_solver=args.sample_solver,
 | 
			
		||||
            sampling_steps=args.sample_steps,
 | 
			
		||||
            guide_scale=args.sample_guide_scale,
 | 
			
		||||
            seed=args.base_seed,
 | 
			
		||||
            offload_model=args.offload_model)
 | 
			
		||||
 | 
			
		||||
    else:
 | 
			
		||||
        if args.prompt is None:
 | 
			
		||||
            args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
 | 
			
		||||
        if args.image is None:
 | 
			
		||||
            args.image = EXAMPLE_PROMPT[args.task]["image"]
 | 
			
		||||
        logging.info(f"Input prompt: {args.prompt}")
 | 
			
		||||
        logging.info(f"Input image: {args.image}")
 | 
			
		||||
 | 
			
		||||
        img = Image.open(args.image).convert("RGB")
 | 
			
		||||
        if args.use_prompt_extend:
 | 
			
		||||
            logging.info("Extending prompt ...")
 | 
			
		||||
            if rank == 0:
 | 
			
		||||
                prompt_output = prompt_expander(
 | 
			
		||||
                    args.prompt,
 | 
			
		||||
                    tar_lang=args.prompt_extend_target_lang,
 | 
			
		||||
                    image=img,
 | 
			
		||||
                    seed=args.base_seed)
 | 
			
		||||
                if prompt_output.status == False:
 | 
			
		||||
                    logging.info(
 | 
			
		||||
                        f"Extending prompt failed: {prompt_output.message}")
 | 
			
		||||
                    logging.info("Falling back to original prompt.")
 | 
			
		||||
                    input_prompt = args.prompt
 | 
			
		||||
                else:
 | 
			
		||||
                    input_prompt = prompt_output.prompt
 | 
			
		||||
                input_prompt = [input_prompt]
 | 
			
		||||
            else:
 | 
			
		||||
                input_prompt = [None]
 | 
			
		||||
            if dist.is_initialized():
 | 
			
		||||
                dist.broadcast_object_list(input_prompt, src=0)
 | 
			
		||||
            args.prompt = input_prompt[0]
 | 
			
		||||
            logging.info(f"Extended prompt: {args.prompt}")
 | 
			
		||||
 | 
			
		||||
        logging.info("Creating WanI2V pipeline.")
 | 
			
		||||
        wan_i2v = wan.WanI2V(
 | 
			
		||||
            config=cfg,
 | 
			
		||||
            checkpoint_dir=args.ckpt_dir,
 | 
			
		||||
            device_id=device,
 | 
			
		||||
            rank=rank,
 | 
			
		||||
            t5_fsdp=args.t5_fsdp,
 | 
			
		||||
            dit_fsdp=args.dit_fsdp,
 | 
			
		||||
            use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
 | 
			
		||||
            t5_cpu=args.t5_cpu,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        logging.info("Generating video ...")
 | 
			
		||||
        video = wan_i2v.generate(
 | 
			
		||||
            args.prompt,
 | 
			
		||||
            img,
 | 
			
		||||
            max_area=MAX_AREA_CONFIGS[args.size],
 | 
			
		||||
            frame_num=args.frame_num,
 | 
			
		||||
            shift=args.sample_shift,
 | 
			
		||||
            sample_solver=args.sample_solver,
 | 
			
		||||
            sampling_steps=args.sample_steps,
 | 
			
		||||
            guide_scale=args.sample_guide_scale,
 | 
			
		||||
            seed=args.base_seed,
 | 
			
		||||
            offload_model=args.offload_model)
 | 
			
		||||
 | 
			
		||||
    if rank == 0:
 | 
			
		||||
        if args.save_file is None:
 | 
			
		||||
            formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
 | 
			
		||||
            formatted_prompt = args.prompt.replace(" ", "_").replace("/",
 | 
			
		||||
                                                                     "_")[:50]
 | 
			
		||||
            suffix = '.png' if "t2i" in args.task else '.mp4'
 | 
			
		||||
            args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix
 | 
			
		||||
 | 
			
		||||
        if "t2i" in args.task:
 | 
			
		||||
            logging.info(f"Saving generated image to {args.save_file}")
 | 
			
		||||
            cache_image(
 | 
			
		||||
                tensor=video.squeeze(1)[None],
 | 
			
		||||
                save_file=args.save_file,
 | 
			
		||||
                nrow=1,
 | 
			
		||||
                normalize=True,
 | 
			
		||||
                value_range=(-1, 1))
 | 
			
		||||
    # 20250226 pftq: Do more videos in one batch without reloading the model
 | 
			
		||||
    wan_t2v = None
 | 
			
		||||
    wan_i2v = None
 | 
			
		||||
    
 | 
			
		||||
    # 20250226 pftq: Pre-generate and sync seeds for all videos
 | 
			
		||||
    seeds = []
 | 
			
		||||
    if dist.is_initialized():
 | 
			
		||||
        if rank == 0:
 | 
			
		||||
            seeds = [args.base_seed if args.base_seed >= 0 else random.randint(0, sys.maxsize)]
 | 
			
		||||
            seeds.extend(random.randint(0, sys.maxsize) for _ in range(args.batch_size - 1))
 | 
			
		||||
        else:
 | 
			
		||||
            logging.info(f"Saving generated video to {args.save_file}")
 | 
			
		||||
            cache_video(
 | 
			
		||||
                tensor=video[None],
 | 
			
		||||
                save_file=args.save_file,
 | 
			
		||||
                fps=cfg.sample_fps,
 | 
			
		||||
                nrow=1,
 | 
			
		||||
                normalize=True,
 | 
			
		||||
                value_range=(-1, 1))
 | 
			
		||||
    logging.info("Finished.")
 | 
			
		||||
            seeds = [None] * args.batch_size
 | 
			
		||||
        dist.broadcast_object_list(seeds, src=0)
 | 
			
		||||
    else:
 | 
			
		||||
        seeds = [args.base_seed if args.base_seed >= 0 else random.randint(0, sys.maxsize)]
 | 
			
		||||
        seeds.extend(random.randint(0, sys.maxsize) for _ in range(args.batch_size - 1))
 | 
			
		||||
        
 | 
			
		||||
    for batch_index in range(args.batch_size):
 | 
			
		||||
        if batch_index > 0:
 | 
			
		||||
            args.base_seed = seeds[batch_index] 
 | 
			
		||||
        if args.batch_size > 1:
 | 
			
		||||
            logging.info("Batch size: "+str(args.batch_size)+" | Video #"+str(batch_index+1))
 | 
			
		||||
        
 | 
			
		||||
        if "t2v" in args.task or "t2i" in args.task:
 | 
			
		||||
            if args.prompt is None:
 | 
			
		||||
                args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
 | 
			
		||||
            logging.info(f"Input prompt: {args.prompt}")
 | 
			
		||||
            if args.use_prompt_extend and batch_index == 0: # 20250226 pftq: extend once
 | 
			
		||||
                logging.info("Extending prompt ...")
 | 
			
		||||
                if rank == 0:
 | 
			
		||||
                    prompt_output = prompt_expander(
 | 
			
		||||
                        args.prompt,
 | 
			
		||||
                        tar_lang=args.prompt_extend_target_lang,
 | 
			
		||||
                        seed=args.base_seed)
 | 
			
		||||
                    if prompt_output.status == False:
 | 
			
		||||
                        logging.info(
 | 
			
		||||
                            f"Extending prompt failed: {prompt_output.message}")
 | 
			
		||||
                        logging.info("Falling back to original prompt.")
 | 
			
		||||
                        input_prompt = args.prompt
 | 
			
		||||
                    else:
 | 
			
		||||
                        input_prompt = prompt_output.prompt
 | 
			
		||||
                    input_prompt = [input_prompt]
 | 
			
		||||
                else:
 | 
			
		||||
                    input_prompt = [None]
 | 
			
		||||
                if dist.is_initialized():
 | 
			
		||||
                    dist.broadcast_object_list(input_prompt, src=0)
 | 
			
		||||
                args.prompt = input_prompt[0]
 | 
			
		||||
                logging.info(f"Extended prompt: {args.prompt}")
 | 
			
		||||
 | 
			
		||||
            if batch_index == 0: # 20250226 pftq: only load the model once
 | 
			
		||||
                logging.info("Creating WanT2V pipeline.")
 | 
			
		||||
                wan_t2v = wan.WanT2V(
 | 
			
		||||
                    config=cfg,
 | 
			
		||||
                    checkpoint_dir=args.ckpt_dir,
 | 
			
		||||
                    device_id=device,
 | 
			
		||||
                    rank=rank,
 | 
			
		||||
                    t5_fsdp=args.t5_fsdp,
 | 
			
		||||
                    dit_fsdp=args.dit_fsdp,
 | 
			
		||||
                    use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
 | 
			
		||||
                    t5_cpu=args.t5_cpu,
 | 
			
		||||
                )
 | 
			
		||||
    
 | 
			
		||||
            logging.info(
 | 
			
		||||
                f"Generating {'image' if 't2i' in args.task else 'video'} ...")
 | 
			
		||||
            video = wan_t2v.generate(
 | 
			
		||||
                args.prompt,
 | 
			
		||||
                size=SIZE_CONFIGS[args.size],
 | 
			
		||||
                frame_num=args.frame_num,
 | 
			
		||||
                shift=args.sample_shift,
 | 
			
		||||
                sample_solver=args.sample_solver,
 | 
			
		||||
                sampling_steps=args.sample_steps,
 | 
			
		||||
                guide_scale=args.sample_guide_scale,
 | 
			
		||||
                seed=args.base_seed,
 | 
			
		||||
                offload_model=args.offload_model)
 | 
			
		||||
    
 | 
			
		||||
        else:
 | 
			
		||||
            if args.prompt is None:
 | 
			
		||||
                args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
 | 
			
		||||
            if args.image is None:
 | 
			
		||||
                args.image = EXAMPLE_PROMPT[args.task]["image"]
 | 
			
		||||
            logging.info(f"Input prompt: {args.prompt}")
 | 
			
		||||
            logging.info(f"Input image: {args.image}")
 | 
			
		||||
 | 
			
		||||
            if batch_index == 0: # 20250226 pftq: load once
 | 
			
		||||
                img = Image.open(args.image).convert("RGB")
 | 
			
		||||
            if args.use_prompt_extend and batch_index == 0: # 20250226 pftq: extend once
 | 
			
		||||
                logging.info("Extending prompt ...")
 | 
			
		||||
                if rank == 0:
 | 
			
		||||
                    prompt_output = prompt_expander(
 | 
			
		||||
                        args.prompt,
 | 
			
		||||
                        tar_lang=args.prompt_extend_target_lang,
 | 
			
		||||
                        image=img,
 | 
			
		||||
                        seed=args.base_seed)
 | 
			
		||||
                    if prompt_output.status == False:
 | 
			
		||||
                        logging.info(
 | 
			
		||||
                            f"Extending prompt failed: {prompt_output.message}")
 | 
			
		||||
                        logging.info("Falling back to original prompt.")
 | 
			
		||||
                        input_prompt = args.prompt
 | 
			
		||||
                    else:
 | 
			
		||||
                        input_prompt = prompt_output.prompt
 | 
			
		||||
                    input_prompt = [input_prompt]
 | 
			
		||||
                else:
 | 
			
		||||
                    input_prompt = [None]
 | 
			
		||||
                if dist.is_initialized():
 | 
			
		||||
                    dist.broadcast_object_list(input_prompt, src=0)
 | 
			
		||||
                args.prompt = input_prompt[0]
 | 
			
		||||
                logging.info(f"Extended prompt: {args.prompt}")
 | 
			
		||||
 | 
			
		||||
            if batch_index == 0: # 20250226 pftq: only load the model once
 | 
			
		||||
                logging.info("Creating WanI2V pipeline.")
 | 
			
		||||
                wan_i2v = wan.WanI2V(
 | 
			
		||||
                    config=cfg,
 | 
			
		||||
                    checkpoint_dir=args.ckpt_dir,
 | 
			
		||||
                    device_id=device,
 | 
			
		||||
                    rank=rank,
 | 
			
		||||
                    t5_fsdp=args.t5_fsdp,
 | 
			
		||||
                    dit_fsdp=args.dit_fsdp,
 | 
			
		||||
                    use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
 | 
			
		||||
                    t5_cpu=args.t5_cpu,
 | 
			
		||||
                )
 | 
			
		||||
                
 | 
			
		||||
            logging.info("Generating video...")
 | 
			
		||||
            video = wan_i2v.generate(
 | 
			
		||||
                args.prompt,
 | 
			
		||||
                img,
 | 
			
		||||
                max_area=MAX_AREA_CONFIGS[args.size],
 | 
			
		||||
                frame_num=args.frame_num,
 | 
			
		||||
                shift=args.sample_shift,
 | 
			
		||||
                sample_solver=args.sample_solver,
 | 
			
		||||
                sampling_steps=args.sample_steps,
 | 
			
		||||
                guide_scale=args.sample_guide_scale,
 | 
			
		||||
                seed=args.base_seed,
 | 
			
		||||
                offload_model=args.offload_model)
 | 
			
		||||
    
 | 
			
		||||
        if rank == 0:
 | 
			
		||||
            if args.save_file is None:
 | 
			
		||||
                formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
 | 
			
		||||
                formatted_prompt = args.prompt.replace(" ", "_").replace("/",
 | 
			
		||||
                                                                         "_")[:50]
 | 
			
		||||
                suffix = '.png' if "t2i" in args.task else '.mp4'
 | 
			
		||||
                args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix
 | 
			
		||||
    
 | 
			
		||||
            # 20250226 pftq: if multiple videos per batch, change the filename index
 | 
			
		||||
            if args.batch_size > 1:
 | 
			
		||||
                if batch_index == 0: args.save_file = args.save_file.replace(".mp4", "__"+str(batch_index+1)+".mp4")
 | 
			
		||||
                else: args.save_file = args.save_file.replace("__"+str(batch_index)+".mp4", "__"+str(batch_index+1)+".mp4")
 | 
			
		||||
            
 | 
			
		||||
            if "t2i" in args.task:
 | 
			
		||||
                logging.info(f"Saving generated image to {args.save_file}")
 | 
			
		||||
                cache_image(
 | 
			
		||||
                    tensor=video.squeeze(1)[None],
 | 
			
		||||
                    save_file=args.save_file,
 | 
			
		||||
                    nrow=1,
 | 
			
		||||
                    normalize=True,
 | 
			
		||||
                    value_range=(-1, 1))
 | 
			
		||||
            else:
 | 
			
		||||
                logging.info(f"Saving generated video to {args.save_file}")
 | 
			
		||||
                cache_video(
 | 
			
		||||
                    tensor=video[None],
 | 
			
		||||
                    save_file=args.save_file,
 | 
			
		||||
                    fps=cfg.sample_fps,
 | 
			
		||||
                    nrow=1,
 | 
			
		||||
                    normalize=True,
 | 
			
		||||
                    value_range=(-1, 1))
 | 
			
		||||
        logging.info("Finished.")
 | 
			
		||||
    
 | 
			
		||||
    
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    args = _parse_args()
 | 
			
		||||
    generate(args)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user