mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-06-07 15:54:52 +00:00
Added Ollama support for prompt extension
This commit is contained in:
parent
a326079926
commit
362fbc4ff3
25
README.md
25
README.md
@ -170,6 +170,31 @@ DASH_API_KEY=your_key python generate.py --task t2v-14B --size 1280*720 --ckpt_
|
|||||||
```
|
```
|
||||||
python generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage" --use_prompt_extend --prompt_extend_method 'local_qwen' --prompt_extend_target_lang 'ch'
|
python generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage" --use_prompt_extend --prompt_extend_method 'local_qwen' --prompt_extend_target_lang 'ch'
|
||||||
```
|
```
|
||||||
|
- Using Ollama for extension (Local API).
|
||||||
|
- Ensure Ollama is installed and running
|
||||||
|
- Pull your desired model: `ollama pull qwen2.5`
|
||||||
|
- The model name is case-sensitive and should match exactly what you pulled in Ollama
|
||||||
|
- For text-to-video tasks:
|
||||||
|
```bash
|
||||||
|
python generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B \
|
||||||
|
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage" \
|
||||||
|
--use_prompt_extend \
|
||||||
|
--prompt_extend_method 'ollama' \
|
||||||
|
--prompt_extend_model 'qwen2.5'
|
||||||
|
```
|
||||||
|
- For image-to-video tasks:
|
||||||
|
```bash
|
||||||
|
python generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P \
|
||||||
|
--image examples/i2v_input.JPG \
|
||||||
|
--prompt "Your prompt here" \
|
||||||
|
--use_prompt_extend \
|
||||||
|
--prompt_extend_method 'ollama' \
|
||||||
|
--prompt_extend_model 'qwen2.5'
|
||||||
|
```
|
||||||
|
- Optional: Customize the Ollama API URL if not using default (http://localhost:11434):
|
||||||
|
```bash
|
||||||
|
python generate.py [...other args...] --ollama_api_url 'http://your-ollama-server:11434'
|
||||||
|
```
|
||||||
|
|
||||||
##### (3) Running local gradio
|
##### (3) Running local gradio
|
||||||
|
|
||||||
|
20
generate.py
20
generate.py
@ -1,4 +1,7 @@
|
|||||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
|
# Changelog:
|
||||||
|
# 2025-03-01: Added Ollama support for prompt extension
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import logging
|
import logging
|
||||||
@ -14,7 +17,7 @@ from PIL import Image
|
|||||||
|
|
||||||
import wan
|
import wan
|
||||||
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
|
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
|
||||||
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander, OllamaPromptExpander
|
||||||
from wan.utils.utils import cache_video, cache_image, str2bool
|
from wan.utils.utils import cache_video, cache_image, str2bool
|
||||||
|
|
||||||
EXAMPLE_PROMPT = {
|
EXAMPLE_PROMPT = {
|
||||||
@ -145,8 +148,13 @@ def _parse_args():
|
|||||||
"--prompt_extend_method",
|
"--prompt_extend_method",
|
||||||
type=str,
|
type=str,
|
||||||
default="local_qwen",
|
default="local_qwen",
|
||||||
choices=["dashscope", "local_qwen"],
|
choices=["dashscope", "local_qwen", "ollama"],
|
||||||
help="The prompt extend method to use.")
|
help="The prompt extend method to use.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--ollama_api_url",
|
||||||
|
type=str,
|
||||||
|
default="http://localhost:11434",
|
||||||
|
help="The URL of the Ollama API (only used with ollama method).")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prompt_extend_model",
|
"--prompt_extend_model",
|
||||||
type=str,
|
type=str,
|
||||||
@ -254,6 +262,11 @@ def generate(args):
|
|||||||
model_name=args.prompt_extend_model,
|
model_name=args.prompt_extend_model,
|
||||||
is_vl="i2v" in args.task,
|
is_vl="i2v" in args.task,
|
||||||
device=rank)
|
device=rank)
|
||||||
|
elif args.prompt_extend_method == "ollama":
|
||||||
|
prompt_expander = OllamaPromptExpander(
|
||||||
|
model_name=args.prompt_extend_model,
|
||||||
|
api_url=args.ollama_api_url,
|
||||||
|
is_vl="i2v" in args.task)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
|
f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
|
||||||
@ -381,8 +394,7 @@ def generate(args):
|
|||||||
if rank == 0:
|
if rank == 0:
|
||||||
if args.save_file is None:
|
if args.save_file is None:
|
||||||
formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
formatted_prompt = args.prompt.replace(" ", "_").replace("/",
|
formatted_prompt = args.prompt.replace(" ", "_").replace("/", "_")[:50]
|
||||||
"_")[:50]
|
|
||||||
suffix = '.png' if "t2i" in args.task else '.mp4'
|
suffix = '.png' if "t2i" in args.task else '.mp4'
|
||||||
args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix
|
args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix
|
||||||
|
|
||||||
|
@ -1,16 +1,21 @@
|
|||||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
|
# Changelog:
|
||||||
|
# 2025-03-01: Added OllamaPromptExpander class to support prompt extension using Ollama API
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import base64
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
import dashscope
|
import dashscope
|
||||||
import torch
|
import torch
|
||||||
|
import requests
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -456,6 +461,138 @@ class QwenPromptExpander(PromptExpander):
|
|||||||
ensure_ascii=False))
|
ensure_ascii=False))
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaPromptExpander(PromptExpander):
|
||||||
|
def __init__(self,
|
||||||
|
model_name=None,
|
||||||
|
api_url="http://localhost:11434",
|
||||||
|
is_vl=False,
|
||||||
|
device=0,
|
||||||
|
**kwargs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
model_name: The Ollama model to use (e.g., 'llama2', 'mistral', etc.)
|
||||||
|
api_url: The URL of the Ollama API
|
||||||
|
is_vl: A flag indicating whether the task involves visual-language processing
|
||||||
|
device: The device to use (not used for Ollama as it's API-based)
|
||||||
|
**kwargs: Additional keyword arguments
|
||||||
|
"""
|
||||||
|
if model_name is None:
|
||||||
|
model_name = "llama2" if not is_vl else "llava"
|
||||||
|
super().__init__(model_name, is_vl, device, **kwargs)
|
||||||
|
self.api_url = api_url.rstrip('/')
|
||||||
|
|
||||||
|
def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Extend a text prompt using Ollama API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The input prompt to extend
|
||||||
|
system_prompt: The system prompt to guide the extension
|
||||||
|
seed: Random seed for reproducibility (not used by Ollama)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PromptOutput: The extended prompt and metadata
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Format the message for Ollama API
|
||||||
|
payload = {
|
||||||
|
"model": self.model_name,
|
||||||
|
"prompt": prompt,
|
||||||
|
"system": system_prompt,
|
||||||
|
"stream": False
|
||||||
|
}
|
||||||
|
|
||||||
|
# Call Ollama API
|
||||||
|
response = requests.post(f"{self.api_url}/api/generate", json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# Parse the response
|
||||||
|
result = response.json()
|
||||||
|
expanded_prompt = result.get("response", "")
|
||||||
|
|
||||||
|
return PromptOutput(
|
||||||
|
status=True,
|
||||||
|
prompt=expanded_prompt,
|
||||||
|
seed=seed,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
message=json.dumps(result, ensure_ascii=False)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return PromptOutput(
|
||||||
|
status=False,
|
||||||
|
prompt=prompt,
|
||||||
|
seed=seed,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
message=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
def extend_with_img(self,
|
||||||
|
prompt,
|
||||||
|
system_prompt,
|
||||||
|
image: Union[Image.Image, str] = None,
|
||||||
|
seed=-1,
|
||||||
|
*args,
|
||||||
|
**kwargs):
|
||||||
|
"""
|
||||||
|
Extend a prompt with an image using Ollama API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The input prompt to extend
|
||||||
|
system_prompt: The system prompt to guide the extension
|
||||||
|
image: The input image (PIL Image or path)
|
||||||
|
seed: Random seed for reproducibility (not used by Ollama)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PromptOutput: The extended prompt and metadata
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Convert image to base64
|
||||||
|
if isinstance(image, str):
|
||||||
|
# If image is a file path, read it
|
||||||
|
with open(image, "rb") as img_file:
|
||||||
|
image_data = img_file.read()
|
||||||
|
else:
|
||||||
|
# If image is a PIL Image, convert to bytes
|
||||||
|
buffer = BytesIO()
|
||||||
|
image.save(buffer, format="PNG")
|
||||||
|
image_data = buffer.getvalue()
|
||||||
|
|
||||||
|
# Encode image to base64
|
||||||
|
base64_image = base64.b64encode(image_data).decode("utf-8")
|
||||||
|
|
||||||
|
# Format the message for Ollama API
|
||||||
|
payload = {
|
||||||
|
"model": self.model_name,
|
||||||
|
"prompt": prompt,
|
||||||
|
"system": system_prompt,
|
||||||
|
"images": [base64_image],
|
||||||
|
"stream": False
|
||||||
|
}
|
||||||
|
|
||||||
|
# Call Ollama API
|
||||||
|
response = requests.post(f"{self.api_url}/api/generate", json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# Parse the response
|
||||||
|
result = response.json()
|
||||||
|
expanded_prompt = result.get("response", "")
|
||||||
|
|
||||||
|
return PromptOutput(
|
||||||
|
status=True,
|
||||||
|
prompt=expanded_prompt,
|
||||||
|
seed=seed,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
message=json.dumps(result, ensure_ascii=False)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return PromptOutput(
|
||||||
|
status=False,
|
||||||
|
prompt=prompt,
|
||||||
|
seed=seed,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
message=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
seed = 100
|
seed = 100
|
||||||
|
Loading…
Reference in New Issue
Block a user