This commit is contained in:
pftq 2025-02-28 09:46:33 +08:00 committed by GitHub
commit ad7fd70722
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 202 additions and 136 deletions

View File

@ -187,6 +187,20 @@ def _parse_args():
default=5.0, default=5.0,
help="Classifier free guidance scale.") help="Classifier free guidance scale.")
# 20250226 pftq: batch size to do more videos without reloading the model
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="Number of videos per run without reloading the model again.")
# 20250227 pftq: variety batch with different CFG/steps each video, labeled in filename.
parser.add_argument(
"--variety_batch",
action="store_true",
default=False,
help="Different CFG/steps each video, labeled in filename.")
args = parser.parse_args() args = parser.parse_args()
_validate_args(args) _validate_args(args)
@ -270,11 +284,34 @@ def generate(args):
dist.broadcast_object_list(base_seed, src=0) dist.broadcast_object_list(base_seed, src=0)
args.base_seed = base_seed[0] args.base_seed = base_seed[0]
# 20250226 pftq: Do more videos in one batch without reloading the model
wan_t2v = None
wan_i2v = None
# 20250226 pftq: Pre-generate and sync seeds for all videos
seeds = []
if dist.is_initialized():
if rank == 0:
seeds = [args.base_seed if args.base_seed >= 0 else random.randint(0, sys.maxsize)]
seeds.extend(random.randint(0, sys.maxsize) for _ in range(args.batch_size - 1))
else:
seeds = [None] * args.batch_size
dist.broadcast_object_list(seeds, src=0)
else:
seeds = [args.base_seed if args.base_seed >= 0 else random.randint(0, sys.maxsize)]
seeds.extend(random.randint(0, sys.maxsize) for _ in range(args.batch_size - 1))
for batch_index in range(args.batch_size):
if batch_index > 0:
args.base_seed = seeds[batch_index]
if args.batch_size > 1:
logging.info("Batch size: "+str(args.batch_size)+" | Video #"+str(batch_index+1))
if "t2v" in args.task or "t2i" in args.task: if "t2v" in args.task or "t2i" in args.task:
if args.prompt is None: if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
logging.info(f"Input prompt: {args.prompt}") logging.info(f"Input prompt: {args.prompt}")
if args.use_prompt_extend: if args.use_prompt_extend and batch_index == 0: # 20250226 pftq: extend once
logging.info("Extending prompt ...") logging.info("Extending prompt ...")
if rank == 0: if rank == 0:
prompt_output = prompt_expander( prompt_output = prompt_expander(
@ -296,6 +333,7 @@ def generate(args):
args.prompt = input_prompt[0] args.prompt = input_prompt[0]
logging.info(f"Extended prompt: {args.prompt}") logging.info(f"Extended prompt: {args.prompt}")
if batch_index == 0: # 20250226 pftq: only load the model once
logging.info("Creating WanT2V pipeline.") logging.info("Creating WanT2V pipeline.")
wan_t2v = wan.WanT2V( wan_t2v = wan.WanT2V(
config=cfg, config=cfg,
@ -329,8 +367,9 @@ def generate(args):
logging.info(f"Input prompt: {args.prompt}") logging.info(f"Input prompt: {args.prompt}")
logging.info(f"Input image: {args.image}") logging.info(f"Input image: {args.image}")
if batch_index == 0: # 20250226 pftq: load once
img = Image.open(args.image).convert("RGB") img = Image.open(args.image).convert("RGB")
if args.use_prompt_extend: if args.use_prompt_extend and batch_index == 0: # 20250226 pftq: extend once
logging.info("Extending prompt ...") logging.info("Extending prompt ...")
if rank == 0: if rank == 0:
prompt_output = prompt_expander( prompt_output = prompt_expander(
@ -353,6 +392,7 @@ def generate(args):
args.prompt = input_prompt[0] args.prompt = input_prompt[0]
logging.info(f"Extended prompt: {args.prompt}") logging.info(f"Extended prompt: {args.prompt}")
if batch_index == 0: # 20250226 pftq: only load the model once
logging.info("Creating WanI2V pipeline.") logging.info("Creating WanI2V pipeline.")
wan_i2v = wan.WanI2V( wan_i2v = wan.WanI2V(
config=cfg, config=cfg,
@ -365,7 +405,21 @@ def generate(args):
t5_cpu=args.t5_cpu, t5_cpu=args.t5_cpu,
) )
# 20250227 pftq: variety batch
if args.variety_batch and batch_index > 0:
args.sample_guide_scale = args.sample_guide_scale + 0.5
if args.sample_guide_scale > 10:
args.sample_guide_scale = 3
args.sample_steps = args.sample_steps + 10
if args.sample_steps > 150:
args.sample_stemps = 50
args.save_file = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + "_wan-server_CFG-"+str(args.sample_guide_scale)+"_steps-"+str(args.sample_steps)+"_"+str(args.ulysses_size)+"-GPUs-"+str(args.frame_num)+"f.mp4"
logging.info("Generating video... "+args.save_file)
else:
logging.info("Generating video...") logging.info("Generating video...")
video = wan_i2v.generate( video = wan_i2v.generate(
args.prompt, args.prompt,
img, img,
@ -386,6 +440,11 @@ def generate(args):
suffix = '.png' if "t2i" in args.task else '.mp4' suffix = '.png' if "t2i" in args.task else '.mp4'
args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix
# 20250226 pftq: if multiple videos per batch, change the filename index
if args.batch_size > 1:
if batch_index == 0: args.save_file = args.save_file.replace(".mp4", "__"+str(batch_index+1)+".mp4")
else: args.save_file = args.save_file.replace("__"+str(batch_index)+".mp4", "__"+str(batch_index+1)+".mp4")
if "t2i" in args.task: if "t2i" in args.task:
logging.info(f"Saving generated image to {args.save_file}") logging.info(f"Saving generated image to {args.save_file}")
cache_image( cache_image(

View File

@ -204,13 +204,19 @@ class WanI2V:
generator=seed_g, generator=seed_g,
device=self.device) device=self.device)
msk = torch.ones(1, 81, lat_h, lat_w, device=self.device) #20250226 fix noise assuming frames hardcoded at 81 (21 latent frames)
latent_frame_num = (F - 1) // self.vae_stride[0] + 1
noise = torch.randn(16, latent_frame_num, lat_h, lat_w, dtype=torch.float32, generator=seed_g, device=self.device)
#msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
msk = torch.ones(1, F, lat_h, lat_w, device=self.device) # 20250226 pftq: Fixed frames being hardcoded as 81
msk[:, 1:] = 0 msk[:, 1:] = 0
msk = torch.concat([ msk = torch.concat([
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
], ],
dim=1) dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) #msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.view(1, latent_frame_num, 4, lat_h, lat_w) # 20250226 pftq: align to actual frames, not hardcoded 81 frames
msk = msk.transpose(1, 2)[0] msk = msk.transpose(1, 2)[0]
if n_prompt == "": if n_prompt == "":
@ -239,7 +245,8 @@ class WanI2V:
torch.nn.functional.interpolate( torch.nn.functional.interpolate(
img[None].cpu(), size=(h, w), mode='bicubic').transpose( img[None].cpu(), size=(h, w), mode='bicubic').transpose(
0, 1), 0, 1),
torch.zeros(3, 80, h, w) #torch.zeros(3, 80, h, w)
torch.zeros(3, F-1, h, w) # 20250226 pftq: fixed 80 being hardcoded frame-1
], ],
dim=1).to(self.device) dim=1).to(self.device)
])[0] ])[0]