diff --git a/generate.py b/generate.py index c841c19..c33b3b3 100644 --- a/generate.py +++ b/generate.py @@ -312,7 +312,7 @@ def generate(args): elif args.prompt_extend_method == "local_qwen": prompt_expander = QwenPromptExpander( model_name=args.prompt_extend_model, - is_vl="i2v" in args.task, + is_vl="i2v" in args.task or "flf2v" in args.task, device=rank) else: raise NotImplementedError(