diff --git a/generate.py b/generate.py index 4f114d4..284ae6a 100644 --- a/generate.py +++ b/generate.py @@ -58,17 +58,24 @@ def _validate_args(args): assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}" assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}" - # TODO(wangang.wa): need to be confirmed # The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks. if args.sample_steps is None: - args.sample_steps = 40 if "i2v" in args.task else 50 + args.sample_steps = 50 + if "vace" in args.task: + args.sample_steps = 25 + elif "i2v" in args.task: + args.sample_steps = 40 + if args.sample_shift is None: args.sample_shift = 5.0 if "i2v" in args.task and args.size in ["832*480", "480*832"]: args.sample_shift = 3.0 - if "flf2v" in args.task: + elif "flf2v" in args.task: args.sample_shift = 16 + elif "vace" in args.task: + args.sample_shift = 3.0 + # The default number of frames are 1 for text-to-image tasks and 81 for other tasks. if args.frame_num is None: