diff --git a/generate.py b/generate.py index 2e6b35c..0477247 100644 --- a/generate.py +++ b/generate.py @@ -299,7 +299,7 @@ def generate(args): elif args.prompt_extend_method == "local_qwen": prompt_expander = QwenPromptExpander( model_name=args.prompt_extend_model, - is_vl="i2v" in args.task, + is_vl="i2v" in args.task or "flf2v" in args.task, device=rank) else: raise NotImplementedError(