mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	Added --batch_size parameter to generate multiple videos per run without reloading the model.
Saves a few minutes on reloading the model and having to re-enter the command each time.
This commit is contained in:
		
							parent
							
								
									190e928681
								
							
						
					
					
						commit
						b15339baa3
					
				
							
								
								
									
										23
									
								
								generate.py
									
									
									
									
									
								
							
							
						
						
									
										23
									
								
								generate.py
									
									
									
									
									
								
							@ -187,6 +187,13 @@ def _parse_args():
 | 
			
		||||
        default=5.0,
 | 
			
		||||
        help="Classifier free guidance scale.")
 | 
			
		||||
 | 
			
		||||
    # 20250226 pftq: batch size to do more videos without reloading the model
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--batch_size",
 | 
			
		||||
        type=int,
 | 
			
		||||
        default=1,
 | 
			
		||||
        help="Number of videos per run without reloading the model again.")
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    _validate_args(args)
 | 
			
		||||
@ -270,6 +277,13 @@ def generate(args):
 | 
			
		||||
        dist.broadcast_object_list(base_seed, src=0)
 | 
			
		||||
        args.base_seed = base_seed[0]
 | 
			
		||||
 | 
			
		||||
    # 20250226 pftq: Do more videos in one batch without reloading the model
 | 
			
		||||
    for batch_index in range(args.batch_size):
 | 
			
		||||
        if batch_index > 0:
 | 
			
		||||
            args.base_seed = random.randint(0, sys.maxsize)
 | 
			
		||||
        if args.batch_size > 1:
 | 
			
		||||
            logging.info("Batch size: "+str(args.batch_size)+" | Video #"+str(batch_index+1))
 | 
			
		||||
        
 | 
			
		||||
        if "t2v" in args.task or "t2i" in args.task:
 | 
			
		||||
            if args.prompt is None:
 | 
			
		||||
                args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
 | 
			
		||||
@ -296,6 +310,7 @@ def generate(args):
 | 
			
		||||
                args.prompt = input_prompt[0]
 | 
			
		||||
                logging.info(f"Extended prompt: {args.prompt}")
 | 
			
		||||
    
 | 
			
		||||
            if batch_index == 0: # 20250226 pftq: only load the model once
 | 
			
		||||
                logging.info("Creating WanT2V pipeline.")
 | 
			
		||||
                wan_t2v = wan.WanT2V(
 | 
			
		||||
                    config=cfg,
 | 
			
		||||
@ -353,6 +368,7 @@ def generate(args):
 | 
			
		||||
                args.prompt = input_prompt[0]
 | 
			
		||||
                logging.info(f"Extended prompt: {args.prompt}")
 | 
			
		||||
    
 | 
			
		||||
            if batch_index == 0: # 20250226 pftq: only load the model once
 | 
			
		||||
                logging.info("Creating WanI2V pipeline.")
 | 
			
		||||
                wan_i2v = wan.WanI2V(
 | 
			
		||||
                    config=cfg,
 | 
			
		||||
@ -365,7 +381,7 @@ def generate(args):
 | 
			
		||||
                    t5_cpu=args.t5_cpu,
 | 
			
		||||
                )
 | 
			
		||||
                
 | 
			
		||||
        logging.info("Generating video ...")
 | 
			
		||||
            logging.info("Generating video...")
 | 
			
		||||
            video = wan_i2v.generate(
 | 
			
		||||
                args.prompt,
 | 
			
		||||
                img,
 | 
			
		||||
@ -386,6 +402,11 @@ def generate(args):
 | 
			
		||||
                suffix = '.png' if "t2i" in args.task else '.mp4'
 | 
			
		||||
                args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix
 | 
			
		||||
    
 | 
			
		||||
            # 20250226 pftq: if multiple videos per batch, change the filename index
 | 
			
		||||
            if args.batch_size > 1:
 | 
			
		||||
                if batch_index == 0: args.save_file = args.save_file.replace(".mp4", "__"+str(batch_index+1)+".mp4")
 | 
			
		||||
                else: args.save_file = args.save_file.replace("__"+str(batch_index)+".mp4", str(batch_index+1)+".mp4")
 | 
			
		||||
            
 | 
			
		||||
            if "t2i" in args.task:
 | 
			
		||||
                logging.info(f"Saving generated image to {args.save_file}")
 | 
			
		||||
                cache_image(
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user