mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-06-06 23:34:53 +00:00
Merge 236dd43edd
into a326079926
This commit is contained in:
commit
ad7fd70722
325
generate.py
325
generate.py
@ -187,6 +187,20 @@ def _parse_args():
|
||||
default=5.0,
|
||||
help="Classifier free guidance scale.")
|
||||
|
||||
# 20250226 pftq: batch size to do more videos without reloading the model
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of videos per run without reloading the model again.")
|
||||
|
||||
# 20250227 pftq: variety batch with different CFG/steps each video, labeled in filename.
|
||||
parser.add_argument(
|
||||
"--variety_batch",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Different CFG/steps each video, labeled in filename.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
_validate_args(args)
|
||||
@ -270,142 +284,187 @@ def generate(args):
|
||||
dist.broadcast_object_list(base_seed, src=0)
|
||||
args.base_seed = base_seed[0]
|
||||
|
||||
if "t2v" in args.task or "t2i" in args.task:
|
||||
if args.prompt is None:
|
||||
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
||||
logging.info(f"Input prompt: {args.prompt}")
|
||||
if args.use_prompt_extend:
|
||||
logging.info("Extending prompt ...")
|
||||
if rank == 0:
|
||||
prompt_output = prompt_expander(
|
||||
args.prompt,
|
||||
tar_lang=args.prompt_extend_target_lang,
|
||||
seed=args.base_seed)
|
||||
if prompt_output.status == False:
|
||||
logging.info(
|
||||
f"Extending prompt failed: {prompt_output.message}")
|
||||
logging.info("Falling back to original prompt.")
|
||||
input_prompt = args.prompt
|
||||
else:
|
||||
input_prompt = prompt_output.prompt
|
||||
input_prompt = [input_prompt]
|
||||
else:
|
||||
input_prompt = [None]
|
||||
if dist.is_initialized():
|
||||
dist.broadcast_object_list(input_prompt, src=0)
|
||||
args.prompt = input_prompt[0]
|
||||
logging.info(f"Extended prompt: {args.prompt}")
|
||||
|
||||
logging.info("Creating WanT2V pipeline.")
|
||||
wan_t2v = wan.WanT2V(
|
||||
config=cfg,
|
||||
checkpoint_dir=args.ckpt_dir,
|
||||
device_id=device,
|
||||
rank=rank,
|
||||
t5_fsdp=args.t5_fsdp,
|
||||
dit_fsdp=args.dit_fsdp,
|
||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||
t5_cpu=args.t5_cpu,
|
||||
)
|
||||
|
||||
logging.info(
|
||||
f"Generating {'image' if 't2i' in args.task else 'video'} ...")
|
||||
video = wan_t2v.generate(
|
||||
args.prompt,
|
||||
size=SIZE_CONFIGS[args.size],
|
||||
frame_num=args.frame_num,
|
||||
shift=args.sample_shift,
|
||||
sample_solver=args.sample_solver,
|
||||
sampling_steps=args.sample_steps,
|
||||
guide_scale=args.sample_guide_scale,
|
||||
seed=args.base_seed,
|
||||
offload_model=args.offload_model)
|
||||
|
||||
else:
|
||||
if args.prompt is None:
|
||||
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
||||
if args.image is None:
|
||||
args.image = EXAMPLE_PROMPT[args.task]["image"]
|
||||
logging.info(f"Input prompt: {args.prompt}")
|
||||
logging.info(f"Input image: {args.image}")
|
||||
|
||||
img = Image.open(args.image).convert("RGB")
|
||||
if args.use_prompt_extend:
|
||||
logging.info("Extending prompt ...")
|
||||
if rank == 0:
|
||||
prompt_output = prompt_expander(
|
||||
args.prompt,
|
||||
tar_lang=args.prompt_extend_target_lang,
|
||||
image=img,
|
||||
seed=args.base_seed)
|
||||
if prompt_output.status == False:
|
||||
logging.info(
|
||||
f"Extending prompt failed: {prompt_output.message}")
|
||||
logging.info("Falling back to original prompt.")
|
||||
input_prompt = args.prompt
|
||||
else:
|
||||
input_prompt = prompt_output.prompt
|
||||
input_prompt = [input_prompt]
|
||||
else:
|
||||
input_prompt = [None]
|
||||
if dist.is_initialized():
|
||||
dist.broadcast_object_list(input_prompt, src=0)
|
||||
args.prompt = input_prompt[0]
|
||||
logging.info(f"Extended prompt: {args.prompt}")
|
||||
|
||||
logging.info("Creating WanI2V pipeline.")
|
||||
wan_i2v = wan.WanI2V(
|
||||
config=cfg,
|
||||
checkpoint_dir=args.ckpt_dir,
|
||||
device_id=device,
|
||||
rank=rank,
|
||||
t5_fsdp=args.t5_fsdp,
|
||||
dit_fsdp=args.dit_fsdp,
|
||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||
t5_cpu=args.t5_cpu,
|
||||
)
|
||||
|
||||
logging.info("Generating video ...")
|
||||
video = wan_i2v.generate(
|
||||
args.prompt,
|
||||
img,
|
||||
max_area=MAX_AREA_CONFIGS[args.size],
|
||||
frame_num=args.frame_num,
|
||||
shift=args.sample_shift,
|
||||
sample_solver=args.sample_solver,
|
||||
sampling_steps=args.sample_steps,
|
||||
guide_scale=args.sample_guide_scale,
|
||||
seed=args.base_seed,
|
||||
offload_model=args.offload_model)
|
||||
|
||||
if rank == 0:
|
||||
if args.save_file is None:
|
||||
formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
formatted_prompt = args.prompt.replace(" ", "_").replace("/",
|
||||
"_")[:50]
|
||||
suffix = '.png' if "t2i" in args.task else '.mp4'
|
||||
args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix
|
||||
|
||||
if "t2i" in args.task:
|
||||
logging.info(f"Saving generated image to {args.save_file}")
|
||||
cache_image(
|
||||
tensor=video.squeeze(1)[None],
|
||||
save_file=args.save_file,
|
||||
nrow=1,
|
||||
normalize=True,
|
||||
value_range=(-1, 1))
|
||||
# 20250226 pftq: Do more videos in one batch without reloading the model
|
||||
wan_t2v = None
|
||||
wan_i2v = None
|
||||
|
||||
# 20250226 pftq: Pre-generate and sync seeds for all videos
|
||||
seeds = []
|
||||
if dist.is_initialized():
|
||||
if rank == 0:
|
||||
seeds = [args.base_seed if args.base_seed >= 0 else random.randint(0, sys.maxsize)]
|
||||
seeds.extend(random.randint(0, sys.maxsize) for _ in range(args.batch_size - 1))
|
||||
else:
|
||||
logging.info(f"Saving generated video to {args.save_file}")
|
||||
cache_video(
|
||||
tensor=video[None],
|
||||
save_file=args.save_file,
|
||||
fps=cfg.sample_fps,
|
||||
nrow=1,
|
||||
normalize=True,
|
||||
value_range=(-1, 1))
|
||||
logging.info("Finished.")
|
||||
seeds = [None] * args.batch_size
|
||||
dist.broadcast_object_list(seeds, src=0)
|
||||
else:
|
||||
seeds = [args.base_seed if args.base_seed >= 0 else random.randint(0, sys.maxsize)]
|
||||
seeds.extend(random.randint(0, sys.maxsize) for _ in range(args.batch_size - 1))
|
||||
|
||||
for batch_index in range(args.batch_size):
|
||||
if batch_index > 0:
|
||||
args.base_seed = seeds[batch_index]
|
||||
if args.batch_size > 1:
|
||||
logging.info("Batch size: "+str(args.batch_size)+" | Video #"+str(batch_index+1))
|
||||
|
||||
if "t2v" in args.task or "t2i" in args.task:
|
||||
if args.prompt is None:
|
||||
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
||||
logging.info(f"Input prompt: {args.prompt}")
|
||||
if args.use_prompt_extend and batch_index == 0: # 20250226 pftq: extend once
|
||||
logging.info("Extending prompt ...")
|
||||
if rank == 0:
|
||||
prompt_output = prompt_expander(
|
||||
args.prompt,
|
||||
tar_lang=args.prompt_extend_target_lang,
|
||||
seed=args.base_seed)
|
||||
if prompt_output.status == False:
|
||||
logging.info(
|
||||
f"Extending prompt failed: {prompt_output.message}")
|
||||
logging.info("Falling back to original prompt.")
|
||||
input_prompt = args.prompt
|
||||
else:
|
||||
input_prompt = prompt_output.prompt
|
||||
input_prompt = [input_prompt]
|
||||
else:
|
||||
input_prompt = [None]
|
||||
if dist.is_initialized():
|
||||
dist.broadcast_object_list(input_prompt, src=0)
|
||||
args.prompt = input_prompt[0]
|
||||
logging.info(f"Extended prompt: {args.prompt}")
|
||||
|
||||
if batch_index == 0: # 20250226 pftq: only load the model once
|
||||
logging.info("Creating WanT2V pipeline.")
|
||||
wan_t2v = wan.WanT2V(
|
||||
config=cfg,
|
||||
checkpoint_dir=args.ckpt_dir,
|
||||
device_id=device,
|
||||
rank=rank,
|
||||
t5_fsdp=args.t5_fsdp,
|
||||
dit_fsdp=args.dit_fsdp,
|
||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||
t5_cpu=args.t5_cpu,
|
||||
)
|
||||
|
||||
logging.info(
|
||||
f"Generating {'image' if 't2i' in args.task else 'video'} ...")
|
||||
video = wan_t2v.generate(
|
||||
args.prompt,
|
||||
size=SIZE_CONFIGS[args.size],
|
||||
frame_num=args.frame_num,
|
||||
shift=args.sample_shift,
|
||||
sample_solver=args.sample_solver,
|
||||
sampling_steps=args.sample_steps,
|
||||
guide_scale=args.sample_guide_scale,
|
||||
seed=args.base_seed,
|
||||
offload_model=args.offload_model)
|
||||
|
||||
else:
|
||||
if args.prompt is None:
|
||||
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
||||
if args.image is None:
|
||||
args.image = EXAMPLE_PROMPT[args.task]["image"]
|
||||
logging.info(f"Input prompt: {args.prompt}")
|
||||
logging.info(f"Input image: {args.image}")
|
||||
|
||||
if batch_index == 0: # 20250226 pftq: load once
|
||||
img = Image.open(args.image).convert("RGB")
|
||||
if args.use_prompt_extend and batch_index == 0: # 20250226 pftq: extend once
|
||||
logging.info("Extending prompt ...")
|
||||
if rank == 0:
|
||||
prompt_output = prompt_expander(
|
||||
args.prompt,
|
||||
tar_lang=args.prompt_extend_target_lang,
|
||||
image=img,
|
||||
seed=args.base_seed)
|
||||
if prompt_output.status == False:
|
||||
logging.info(
|
||||
f"Extending prompt failed: {prompt_output.message}")
|
||||
logging.info("Falling back to original prompt.")
|
||||
input_prompt = args.prompt
|
||||
else:
|
||||
input_prompt = prompt_output.prompt
|
||||
input_prompt = [input_prompt]
|
||||
else:
|
||||
input_prompt = [None]
|
||||
if dist.is_initialized():
|
||||
dist.broadcast_object_list(input_prompt, src=0)
|
||||
args.prompt = input_prompt[0]
|
||||
logging.info(f"Extended prompt: {args.prompt}")
|
||||
|
||||
if batch_index == 0: # 20250226 pftq: only load the model once
|
||||
logging.info("Creating WanI2V pipeline.")
|
||||
wan_i2v = wan.WanI2V(
|
||||
config=cfg,
|
||||
checkpoint_dir=args.ckpt_dir,
|
||||
device_id=device,
|
||||
rank=rank,
|
||||
t5_fsdp=args.t5_fsdp,
|
||||
dit_fsdp=args.dit_fsdp,
|
||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||
t5_cpu=args.t5_cpu,
|
||||
)
|
||||
|
||||
|
||||
# 20250227 pftq: variety batch
|
||||
if args.variety_batch and batch_index > 0:
|
||||
args.sample_guide_scale = args.sample_guide_scale + 0.5
|
||||
if args.sample_guide_scale > 10:
|
||||
args.sample_guide_scale = 3
|
||||
args.sample_steps = args.sample_steps + 10
|
||||
if args.sample_steps > 150:
|
||||
args.sample_stemps = 50
|
||||
args.save_file = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + "_wan-server_CFG-"+str(args.sample_guide_scale)+"_steps-"+str(args.sample_steps)+"_"+str(args.ulysses_size)+"-GPUs-"+str(args.frame_num)+"f.mp4"
|
||||
logging.info("Generating video... "+args.save_file)
|
||||
|
||||
else:
|
||||
logging.info("Generating video...")
|
||||
|
||||
video = wan_i2v.generate(
|
||||
args.prompt,
|
||||
img,
|
||||
max_area=MAX_AREA_CONFIGS[args.size],
|
||||
frame_num=args.frame_num,
|
||||
shift=args.sample_shift,
|
||||
sample_solver=args.sample_solver,
|
||||
sampling_steps=args.sample_steps,
|
||||
guide_scale=args.sample_guide_scale,
|
||||
seed=args.base_seed,
|
||||
offload_model=args.offload_model)
|
||||
|
||||
if rank == 0:
|
||||
if args.save_file is None:
|
||||
formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
formatted_prompt = args.prompt.replace(" ", "_").replace("/",
|
||||
"_")[:50]
|
||||
suffix = '.png' if "t2i" in args.task else '.mp4'
|
||||
args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix
|
||||
|
||||
# 20250226 pftq: if multiple videos per batch, change the filename index
|
||||
if args.batch_size > 1:
|
||||
if batch_index == 0: args.save_file = args.save_file.replace(".mp4", "__"+str(batch_index+1)+".mp4")
|
||||
else: args.save_file = args.save_file.replace("__"+str(batch_index)+".mp4", "__"+str(batch_index+1)+".mp4")
|
||||
|
||||
if "t2i" in args.task:
|
||||
logging.info(f"Saving generated image to {args.save_file}")
|
||||
cache_image(
|
||||
tensor=video.squeeze(1)[None],
|
||||
save_file=args.save_file,
|
||||
nrow=1,
|
||||
normalize=True,
|
||||
value_range=(-1, 1))
|
||||
else:
|
||||
logging.info(f"Saving generated video to {args.save_file}")
|
||||
cache_video(
|
||||
tensor=video[None],
|
||||
save_file=args.save_file,
|
||||
fps=cfg.sample_fps,
|
||||
nrow=1,
|
||||
normalize=True,
|
||||
value_range=(-1, 1))
|
||||
logging.info("Finished.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = _parse_args()
|
||||
generate(args)
|
||||
|
@ -204,13 +204,19 @@ class WanI2V:
|
||||
generator=seed_g,
|
||||
device=self.device)
|
||||
|
||||
msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
|
||||
#20250226 fix noise assuming frames hardcoded at 81 (21 latent frames)
|
||||
latent_frame_num = (F - 1) // self.vae_stride[0] + 1
|
||||
noise = torch.randn(16, latent_frame_num, lat_h, lat_w, dtype=torch.float32, generator=seed_g, device=self.device)
|
||||
|
||||
#msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
|
||||
msk = torch.ones(1, F, lat_h, lat_w, device=self.device) # 20250226 pftq: Fixed frames being hardcoded as 81
|
||||
msk[:, 1:] = 0
|
||||
msk = torch.concat([
|
||||
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
|
||||
],
|
||||
dim=1)
|
||||
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
|
||||
#msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
|
||||
msk = msk.view(1, latent_frame_num, 4, lat_h, lat_w) # 20250226 pftq: align to actual frames, not hardcoded 81 frames
|
||||
msk = msk.transpose(1, 2)[0]
|
||||
|
||||
if n_prompt == "":
|
||||
@ -239,7 +245,8 @@ class WanI2V:
|
||||
torch.nn.functional.interpolate(
|
||||
img[None].cpu(), size=(h, w), mode='bicubic').transpose(
|
||||
0, 1),
|
||||
torch.zeros(3, 80, h, w)
|
||||
#torch.zeros(3, 80, h, w)
|
||||
torch.zeros(3, F-1, h, w) # 20250226 pftq: fixed 80 being hardcoded frame-1
|
||||
],
|
||||
dim=1).to(self.device)
|
||||
])[0]
|
||||
|
Loading…
Reference in New Issue
Block a user