Added Ollama support for prompt extension

This commit is contained in:
Egar Almeida 2025-03-02 12:28:24 -03:00
parent a326079926
commit 362fbc4ff3
4 changed files with 1548 additions and 1374 deletions

View File

@ -170,6 +170,31 @@ DASH_API_KEY=your_key python generate.py --task t2v-14B --size 1280*720 --ckpt_
```
python generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage" --use_prompt_extend --prompt_extend_method 'local_qwen' --prompt_extend_target_lang 'ch'
```
- Using Ollama for extension (Local API).
- Ensure Ollama is installed and running
- Pull your desired model: `ollama pull qwen2.5`
- The model name is case-sensitive and should match exactly what you pulled in Ollama
- For text-to-video tasks:
```bash
python generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage" \
--use_prompt_extend \
--prompt_extend_method 'ollama' \
--prompt_extend_model 'qwen2.5'
```
- For image-to-video tasks:
```bash
python generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P \
--image examples/i2v_input.JPG \
--prompt "Your prompt here" \
--use_prompt_extend \
--prompt_extend_method 'ollama' \
--prompt_extend_model 'qwen2.5'
```
- Optional: Customize the Ollama API URL if not using default (http://localhost:11434):
```bash
python generate.py [...other args...] --ollama_api_url 'http://your-ollama-server:11434'
```
##### (3) Running local gradio

View File

@ -1,4 +1,7 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
# Changelog:
# 2025-03-01: Added Ollama support for prompt extension
import argparse
from datetime import datetime
import logging
@ -14,7 +17,7 @@ from PIL import Image
import wan
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander, OllamaPromptExpander
from wan.utils.utils import cache_video, cache_image, str2bool
EXAMPLE_PROMPT = {
@ -145,8 +148,13 @@ def _parse_args():
"--prompt_extend_method",
type=str,
default="local_qwen",
choices=["dashscope", "local_qwen"],
choices=["dashscope", "local_qwen", "ollama"],
help="The prompt extend method to use.")
parser.add_argument(
"--ollama_api_url",
type=str,
default="http://localhost:11434",
help="The URL of the Ollama API (only used with ollama method).")
parser.add_argument(
"--prompt_extend_model",
type=str,
@ -254,6 +262,11 @@ def generate(args):
model_name=args.prompt_extend_model,
is_vl="i2v" in args.task,
device=rank)
elif args.prompt_extend_method == "ollama":
prompt_expander = OllamaPromptExpander(
model_name=args.prompt_extend_model,
api_url=args.ollama_api_url,
is_vl="i2v" in args.task)
else:
raise NotImplementedError(
f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
@ -381,8 +394,7 @@ def generate(args):
if rank == 0:
if args.save_file is None:
formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
formatted_prompt = args.prompt.replace(" ", "_").replace("/",
"_")[:50]
formatted_prompt = args.prompt.replace(" ", "_").replace("/", "_")[:50]
suffix = '.png' if "t2i" in args.task else '.mp4'
args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix

View File

@ -1,16 +1,21 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
# Changelog:
# 2025-03-01: Added OllamaPromptExpander class to support prompt extension using Ollama API
import json
import math
import os
import random
import sys
import tempfile
import base64
from dataclasses import dataclass
from http import HTTPStatus
from typing import Optional, Union
from io import BytesIO
import dashscope
import torch
import requests
from PIL import Image
try:
@ -456,6 +461,138 @@ class QwenPromptExpander(PromptExpander):
ensure_ascii=False))
class OllamaPromptExpander(PromptExpander):
def __init__(self,
model_name=None,
api_url="http://localhost:11434",
is_vl=False,
device=0,
**kwargs):
"""
Args:
model_name: The Ollama model to use (e.g., 'llama2', 'mistral', etc.)
api_url: The URL of the Ollama API
is_vl: A flag indicating whether the task involves visual-language processing
device: The device to use (not used for Ollama as it's API-based)
**kwargs: Additional keyword arguments
"""
if model_name is None:
model_name = "llama2" if not is_vl else "llava"
super().__init__(model_name, is_vl, device, **kwargs)
self.api_url = api_url.rstrip('/')
def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
"""
Extend a text prompt using Ollama API.
Args:
prompt: The input prompt to extend
system_prompt: The system prompt to guide the extension
seed: Random seed for reproducibility (not used by Ollama)
Returns:
PromptOutput: The extended prompt and metadata
"""
try:
# Format the message for Ollama API
payload = {
"model": self.model_name,
"prompt": prompt,
"system": system_prompt,
"stream": False
}
# Call Ollama API
response = requests.post(f"{self.api_url}/api/generate", json=payload)
response.raise_for_status()
# Parse the response
result = response.json()
expanded_prompt = result.get("response", "")
return PromptOutput(
status=True,
prompt=expanded_prompt,
seed=seed,
system_prompt=system_prompt,
message=json.dumps(result, ensure_ascii=False)
)
except Exception as e:
return PromptOutput(
status=False,
prompt=prompt,
seed=seed,
system_prompt=system_prompt,
message=str(e)
)
def extend_with_img(self,
prompt,
system_prompt,
image: Union[Image.Image, str] = None,
seed=-1,
*args,
**kwargs):
"""
Extend a prompt with an image using Ollama API.
Args:
prompt: The input prompt to extend
system_prompt: The system prompt to guide the extension
image: The input image (PIL Image or path)
seed: Random seed for reproducibility (not used by Ollama)
Returns:
PromptOutput: The extended prompt and metadata
"""
try:
# Convert image to base64
if isinstance(image, str):
# If image is a file path, read it
with open(image, "rb") as img_file:
image_data = img_file.read()
else:
# If image is a PIL Image, convert to bytes
buffer = BytesIO()
image.save(buffer, format="PNG")
image_data = buffer.getvalue()
# Encode image to base64
base64_image = base64.b64encode(image_data).decode("utf-8")
# Format the message for Ollama API
payload = {
"model": self.model_name,
"prompt": prompt,
"system": system_prompt,
"images": [base64_image],
"stream": False
}
# Call Ollama API
response = requests.post(f"{self.api_url}/api/generate", json=payload)
response.raise_for_status()
# Parse the response
result = response.json()
expanded_prompt = result.get("response", "")
return PromptOutput(
status=True,
prompt=expanded_prompt,
seed=seed,
system_prompt=system_prompt,
message=json.dumps(result, ensure_ascii=False)
)
except Exception as e:
return PromptOutput(
status=False,
prompt=prompt,
seed=seed,
system_prompt=system_prompt,
message=str(e)
)
if __name__ == "__main__":
seed = 100