mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-06-07 15:54:52 +00:00
Added Ollama support for prompt extension
This commit is contained in:
parent
a326079926
commit
362fbc4ff3
25
README.md
25
README.md
@ -170,6 +170,31 @@ DASH_API_KEY=your_key python generate.py --task t2v-14B --size 1280*720 --ckpt_
|
||||
```
|
||||
python generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage" --use_prompt_extend --prompt_extend_method 'local_qwen' --prompt_extend_target_lang 'ch'
|
||||
```
|
||||
- Using Ollama for extension (Local API).
|
||||
- Ensure Ollama is installed and running
|
||||
- Pull your desired model: `ollama pull qwen2.5`
|
||||
- The model name is case-sensitive and should match exactly what you pulled in Ollama
|
||||
- For text-to-video tasks:
|
||||
```bash
|
||||
python generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B \
|
||||
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage" \
|
||||
--use_prompt_extend \
|
||||
--prompt_extend_method 'ollama' \
|
||||
--prompt_extend_model 'qwen2.5'
|
||||
```
|
||||
- For image-to-video tasks:
|
||||
```bash
|
||||
python generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P \
|
||||
--image examples/i2v_input.JPG \
|
||||
--prompt "Your prompt here" \
|
||||
--use_prompt_extend \
|
||||
--prompt_extend_method 'ollama' \
|
||||
--prompt_extend_model 'qwen2.5'
|
||||
```
|
||||
- Optional: Customize the Ollama API URL if not using default (http://localhost:11434):
|
||||
```bash
|
||||
python generate.py [...other args...] --ollama_api_url 'http://your-ollama-server:11434'
|
||||
```
|
||||
|
||||
##### (3) Running local gradio
|
||||
|
||||
|
20
generate.py
20
generate.py
@ -1,4 +1,7 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
# Changelog:
|
||||
# 2025-03-01: Added Ollama support for prompt extension
|
||||
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
import logging
|
||||
@ -14,7 +17,7 @@ from PIL import Image
|
||||
|
||||
import wan
|
||||
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
|
||||
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
||||
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander, OllamaPromptExpander
|
||||
from wan.utils.utils import cache_video, cache_image, str2bool
|
||||
|
||||
EXAMPLE_PROMPT = {
|
||||
@ -145,8 +148,13 @@ def _parse_args():
|
||||
"--prompt_extend_method",
|
||||
type=str,
|
||||
default="local_qwen",
|
||||
choices=["dashscope", "local_qwen"],
|
||||
choices=["dashscope", "local_qwen", "ollama"],
|
||||
help="The prompt extend method to use.")
|
||||
parser.add_argument(
|
||||
"--ollama_api_url",
|
||||
type=str,
|
||||
default="http://localhost:11434",
|
||||
help="The URL of the Ollama API (only used with ollama method).")
|
||||
parser.add_argument(
|
||||
"--prompt_extend_model",
|
||||
type=str,
|
||||
@ -254,6 +262,11 @@ def generate(args):
|
||||
model_name=args.prompt_extend_model,
|
||||
is_vl="i2v" in args.task,
|
||||
device=rank)
|
||||
elif args.prompt_extend_method == "ollama":
|
||||
prompt_expander = OllamaPromptExpander(
|
||||
model_name=args.prompt_extend_model,
|
||||
api_url=args.ollama_api_url,
|
||||
is_vl="i2v" in args.task)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
|
||||
@ -381,8 +394,7 @@ def generate(args):
|
||||
if rank == 0:
|
||||
if args.save_file is None:
|
||||
formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
formatted_prompt = args.prompt.replace(" ", "_").replace("/",
|
||||
"_")[:50]
|
||||
formatted_prompt = args.prompt.replace(" ", "_").replace("/", "_")[:50]
|
||||
suffix = '.png' if "t2i" in args.task else '.mp4'
|
||||
args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix
|
||||
|
||||
|
@ -1,16 +1,21 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
# Changelog:
|
||||
# 2025-03-01: Added OllamaPromptExpander class to support prompt extension using Ollama API
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import tempfile
|
||||
import base64
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import Optional, Union
|
||||
from io import BytesIO
|
||||
|
||||
import dashscope
|
||||
import torch
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
try:
|
||||
@ -456,6 +461,138 @@ class QwenPromptExpander(PromptExpander):
|
||||
ensure_ascii=False))
|
||||
|
||||
|
||||
class OllamaPromptExpander(PromptExpander):
|
||||
def __init__(self,
|
||||
model_name=None,
|
||||
api_url="http://localhost:11434",
|
||||
is_vl=False,
|
||||
device=0,
|
||||
**kwargs):
|
||||
"""
|
||||
Args:
|
||||
model_name: The Ollama model to use (e.g., 'llama2', 'mistral', etc.)
|
||||
api_url: The URL of the Ollama API
|
||||
is_vl: A flag indicating whether the task involves visual-language processing
|
||||
device: The device to use (not used for Ollama as it's API-based)
|
||||
**kwargs: Additional keyword arguments
|
||||
"""
|
||||
if model_name is None:
|
||||
model_name = "llama2" if not is_vl else "llava"
|
||||
super().__init__(model_name, is_vl, device, **kwargs)
|
||||
self.api_url = api_url.rstrip('/')
|
||||
|
||||
def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
|
||||
"""
|
||||
Extend a text prompt using Ollama API.
|
||||
|
||||
Args:
|
||||
prompt: The input prompt to extend
|
||||
system_prompt: The system prompt to guide the extension
|
||||
seed: Random seed for reproducibility (not used by Ollama)
|
||||
|
||||
Returns:
|
||||
PromptOutput: The extended prompt and metadata
|
||||
"""
|
||||
try:
|
||||
# Format the message for Ollama API
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"prompt": prompt,
|
||||
"system": system_prompt,
|
||||
"stream": False
|
||||
}
|
||||
|
||||
# Call Ollama API
|
||||
response = requests.post(f"{self.api_url}/api/generate", json=payload)
|
||||
response.raise_for_status()
|
||||
|
||||
# Parse the response
|
||||
result = response.json()
|
||||
expanded_prompt = result.get("response", "")
|
||||
|
||||
return PromptOutput(
|
||||
status=True,
|
||||
prompt=expanded_prompt,
|
||||
seed=seed,
|
||||
system_prompt=system_prompt,
|
||||
message=json.dumps(result, ensure_ascii=False)
|
||||
)
|
||||
except Exception as e:
|
||||
return PromptOutput(
|
||||
status=False,
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
system_prompt=system_prompt,
|
||||
message=str(e)
|
||||
)
|
||||
|
||||
def extend_with_img(self,
|
||||
prompt,
|
||||
system_prompt,
|
||||
image: Union[Image.Image, str] = None,
|
||||
seed=-1,
|
||||
*args,
|
||||
**kwargs):
|
||||
"""
|
||||
Extend a prompt with an image using Ollama API.
|
||||
|
||||
Args:
|
||||
prompt: The input prompt to extend
|
||||
system_prompt: The system prompt to guide the extension
|
||||
image: The input image (PIL Image or path)
|
||||
seed: Random seed for reproducibility (not used by Ollama)
|
||||
|
||||
Returns:
|
||||
PromptOutput: The extended prompt and metadata
|
||||
"""
|
||||
try:
|
||||
# Convert image to base64
|
||||
if isinstance(image, str):
|
||||
# If image is a file path, read it
|
||||
with open(image, "rb") as img_file:
|
||||
image_data = img_file.read()
|
||||
else:
|
||||
# If image is a PIL Image, convert to bytes
|
||||
buffer = BytesIO()
|
||||
image.save(buffer, format="PNG")
|
||||
image_data = buffer.getvalue()
|
||||
|
||||
# Encode image to base64
|
||||
base64_image = base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
# Format the message for Ollama API
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"prompt": prompt,
|
||||
"system": system_prompt,
|
||||
"images": [base64_image],
|
||||
"stream": False
|
||||
}
|
||||
|
||||
# Call Ollama API
|
||||
response = requests.post(f"{self.api_url}/api/generate", json=payload)
|
||||
response.raise_for_status()
|
||||
|
||||
# Parse the response
|
||||
result = response.json()
|
||||
expanded_prompt = result.get("response", "")
|
||||
|
||||
return PromptOutput(
|
||||
status=True,
|
||||
prompt=expanded_prompt,
|
||||
seed=seed,
|
||||
system_prompt=system_prompt,
|
||||
message=json.dumps(result, ensure_ascii=False)
|
||||
)
|
||||
except Exception as e:
|
||||
return PromptOutput(
|
||||
status=False,
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
system_prompt=system_prompt,
|
||||
message=str(e)
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
seed = 100
|
||||
|
Loading…
Reference in New Issue
Block a user