mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-06-05 23:04:54 +00:00
Merge a0bf867138
into 679ccc6c68
This commit is contained in:
commit
55f6b2b7db
@ -27,6 +27,9 @@ EXAMPLE_PROMPT = {
|
||||
"t2i-14B": {
|
||||
"prompt": "一个朴素端庄的美人",
|
||||
},
|
||||
"t2i-1.3B": {
|
||||
"prompt": "一个朴素端庄的美人",
|
||||
},
|
||||
"i2v-14B": {
|
||||
"prompt":
|
||||
"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
|
||||
|
205
gradio/t2i_1.3B_singleGPU.py
Normal file
205
gradio/t2i_1.3B_singleGPU.py
Normal file
@ -0,0 +1,205 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
import gradio as gr
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
# Model
|
||||
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
|
||||
import wan
|
||||
from wan.configs import WAN_CONFIGS
|
||||
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
||||
from wan.utils.utils import cache_image
|
||||
|
||||
# Global Var
|
||||
prompt_expander = None
|
||||
wan_t2i = None
|
||||
|
||||
|
||||
# Button Func
|
||||
def prompt_enc(prompt, tar_lang):
|
||||
global prompt_expander
|
||||
prompt_output = prompt_expander(prompt, tar_lang=tar_lang.lower())
|
||||
if prompt_output.status == False:
|
||||
return prompt
|
||||
else:
|
||||
return prompt_output.prompt
|
||||
|
||||
|
||||
def t2i_generation(txt2img_prompt, resolution, sd_steps, guide_scale,
|
||||
shift_scale, seed, n_prompt):
|
||||
global wan_t2i
|
||||
# print(f"{txt2img_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")
|
||||
|
||||
W = int(resolution.split("*")[0])
|
||||
H = int(resolution.split("*")[1])
|
||||
video = wan_t2i.generate(
|
||||
txt2img_prompt,
|
||||
size=(W, H),
|
||||
frame_num=1,
|
||||
shift=shift_scale,
|
||||
sampling_steps=sd_steps,
|
||||
guide_scale=guide_scale,
|
||||
n_prompt=n_prompt,
|
||||
seed=seed,
|
||||
offload_model=True)
|
||||
|
||||
cache_image(
|
||||
tensor=video.squeeze(1)[None],
|
||||
save_file="example.png",
|
||||
nrow=1,
|
||||
normalize=True,
|
||||
value_range=(-1, 1))
|
||||
|
||||
return "example.png"
|
||||
|
||||
|
||||
# Interface
|
||||
def gradio_interface():
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("""
|
||||
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
|
||||
Wan2.1 (T2I-1.3B)
|
||||
</div>
|
||||
<div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
|
||||
Wan: Open and Advanced Large-Scale Video Generative Models.
|
||||
</div>
|
||||
""")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
txt2img_prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
placeholder="Describe the image you want to generate",
|
||||
)
|
||||
tar_lang = gr.Radio(
|
||||
choices=["ZH", "EN"],
|
||||
label="Target language of prompt enhance",
|
||||
value="ZH")
|
||||
run_p_button = gr.Button(value="Prompt Enhance")
|
||||
|
||||
with gr.Accordion("Advanced Options", open=True):
|
||||
resolution = gr.Dropdown(
|
||||
label='Resolution(Width*Height)',
|
||||
choices=[
|
||||
'720*1280', '1280*720', '960*960', '1088*832',
|
||||
'832*1088', '480*832', '832*480', '624*624',
|
||||
'704*544', '544*704'
|
||||
],
|
||||
value='720*1280')
|
||||
|
||||
with gr.Row():
|
||||
sd_steps = gr.Slider(
|
||||
label="Diffusion steps",
|
||||
minimum=1,
|
||||
maximum=1000,
|
||||
value=50,
|
||||
step=1)
|
||||
guide_scale = gr.Slider(
|
||||
label="Guide scale",
|
||||
minimum=0,
|
||||
maximum=20,
|
||||
value=5.0,
|
||||
step=1)
|
||||
with gr.Row():
|
||||
shift_scale = gr.Slider(
|
||||
label="Shift scale",
|
||||
minimum=0,
|
||||
maximum=10,
|
||||
value=5.0,
|
||||
step=1)
|
||||
seed = gr.Slider(
|
||||
label="Seed",
|
||||
minimum=-1,
|
||||
maximum=2147483647,
|
||||
step=1,
|
||||
value=-1)
|
||||
n_prompt = gr.Textbox(
|
||||
label="Negative Prompt",
|
||||
placeholder="Describe the negative prompt you want to add"
|
||||
)
|
||||
|
||||
run_t2i_button = gr.Button("Generate Image")
|
||||
|
||||
with gr.Column():
|
||||
result_gallery = gr.Image(
|
||||
label='Generated Image', interactive=False, height=600)
|
||||
|
||||
run_p_button.click(
|
||||
fn=prompt_enc,
|
||||
inputs=[txt2img_prompt, tar_lang],
|
||||
outputs=[txt2img_prompt])
|
||||
|
||||
run_t2i_button.click(
|
||||
fn=t2i_generation,
|
||||
inputs=[
|
||||
txt2img_prompt, resolution, sd_steps, guide_scale, shift_scale,
|
||||
seed, n_prompt
|
||||
],
|
||||
outputs=[result_gallery],
|
||||
)
|
||||
|
||||
return demo
|
||||
|
||||
|
||||
# Main
|
||||
def _parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate a image from a text prompt or image using Gradio")
|
||||
parser.add_argument(
|
||||
"--ckpt_dir",
|
||||
type=str,
|
||||
default="cache",
|
||||
help="The path to the checkpoint directory.")
|
||||
parser.add_argument(
|
||||
"--prompt_extend_method",
|
||||
type=str,
|
||||
default="local_qwen",
|
||||
choices=["dashscope", "local_qwen"],
|
||||
help="The prompt extend method to use.")
|
||||
parser.add_argument(
|
||||
"--prompt_extend_model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The prompt extend model to use.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = _parse_args()
|
||||
|
||||
print("Step1: Init prompt_expander...", end='', flush=True)
|
||||
if args.prompt_extend_method == "dashscope":
|
||||
prompt_expander = DashScopePromptExpander(
|
||||
model_name=args.prompt_extend_model, is_vl=False)
|
||||
elif args.prompt_extend_method == "local_qwen":
|
||||
prompt_expander = QwenPromptExpander(
|
||||
model_name=args.prompt_extend_model, is_vl=False, device=0)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
|
||||
print("done", flush=True)
|
||||
|
||||
print("Step2: Init 1.3B t2i model...", end='', flush=True)
|
||||
cfg = WAN_CONFIGS['t2i-1.3B']
|
||||
wan_t2i = wan.WanT2V(
|
||||
config=cfg,
|
||||
checkpoint_dir=args.ckpt_dir,
|
||||
device_id=0,
|
||||
rank=0,
|
||||
t5_fsdp=False,
|
||||
dit_fsdp=False,
|
||||
use_usp=False,
|
||||
)
|
||||
print("done", flush=True)
|
||||
|
||||
demo = gradio_interface()
|
||||
demo.launch(server_name="0.0.0.0", share=False, server_port=7860)
|
@ -11,12 +11,15 @@ from .wan_t2v_14B import t2v_14B
|
||||
# the config of t2i_14B is the same as t2v_14B
|
||||
t2i_14B = copy.deepcopy(t2v_14B)
|
||||
t2i_14B.__name__ = 'Config: Wan T2I 14B'
|
||||
t2i_1_3B = copy.deepcopy(t2v_1_3B)
|
||||
t2i_1_3B.__name__ = 'Config: Wan T2I 1.3B'
|
||||
|
||||
WAN_CONFIGS = {
|
||||
't2v-14B': t2v_14B,
|
||||
't2v-1.3B': t2v_1_3B,
|
||||
'i2v-14B': i2v_14B,
|
||||
't2i-14B': t2i_14B,
|
||||
't2i-1.3B': t2i_1_3B,
|
||||
}
|
||||
|
||||
SIZE_CONFIGS = {
|
||||
@ -39,4 +42,5 @@ SUPPORTED_SIZES = {
|
||||
't2v-1.3B': ('480*832', '832*480'),
|
||||
'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
||||
't2i-14B': tuple(SIZE_CONFIGS.keys()),
|
||||
't2i-1.3B': tuple(SIZE_CONFIGS.keys()),
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user