diff --git a/generate.py b/generate.py index 1b1a9d7..c34c1c9 100644 --- a/generate.py +++ b/generate.py @@ -27,6 +27,9 @@ EXAMPLE_PROMPT = { "t2i-14B": { "prompt": "一个朴素端庄的美人", }, + "t2i-1.3B": { + "prompt": "一个朴素端庄的美人", + }, "i2v-14B": { "prompt": "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.", diff --git a/gradio/t2i_1.3B_singleGPU.py b/gradio/t2i_1.3B_singleGPU.py new file mode 100644 index 0000000..22bf874 --- /dev/null +++ b/gradio/t2i_1.3B_singleGPU.py @@ -0,0 +1,205 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import argparse +import os.path as osp +import os +import sys +import warnings + +import gradio as gr + +warnings.filterwarnings('ignore') + +# Model +sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2])) +import wan +from wan.configs import WAN_CONFIGS +from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander +from wan.utils.utils import cache_image + +# Global Var +prompt_expander = None +wan_t2i = None + + +# Button Func +def prompt_enc(prompt, tar_lang): + global prompt_expander + prompt_output = prompt_expander(prompt, tar_lang=tar_lang.lower()) + if prompt_output.status == False: + return prompt + else: + return prompt_output.prompt + + +def t2i_generation(txt2img_prompt, resolution, sd_steps, guide_scale, + shift_scale, seed, n_prompt): + global wan_t2i + # print(f"{txt2img_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}") + + W = int(resolution.split("*")[0]) + H = int(resolution.split("*")[1]) + video = wan_t2i.generate( + txt2img_prompt, + size=(W, H), + frame_num=1, + shift=shift_scale, + sampling_steps=sd_steps, + guide_scale=guide_scale, + n_prompt=n_prompt, + seed=seed, + offload_model=True) + + cache_image( + tensor=video.squeeze(1)[None], + save_file="example.png", + nrow=1, + normalize=True, + value_range=(-1, 1)) + + return "example.png" + + +# Interface +def gradio_interface(): + with gr.Blocks() as demo: + gr.Markdown(""" +
+ Wan2.1 (T2I-1.3B) +
+
+ Wan: Open and Advanced Large-Scale Video Generative Models. +
+ """) + + with gr.Row(): + with gr.Column(): + txt2img_prompt = gr.Textbox( + label="Prompt", + placeholder="Describe the image you want to generate", + ) + tar_lang = gr.Radio( + choices=["ZH", "EN"], + label="Target language of prompt enhance", + value="ZH") + run_p_button = gr.Button(value="Prompt Enhance") + + with gr.Accordion("Advanced Options", open=True): + resolution = gr.Dropdown( + label='Resolution(Width*Height)', + choices=[ + '720*1280', '1280*720', '960*960', '1088*832', + '832*1088', '480*832', '832*480', '624*624', + '704*544', '544*704' + ], + value='720*1280') + + with gr.Row(): + sd_steps = gr.Slider( + label="Diffusion steps", + minimum=1, + maximum=1000, + value=50, + step=1) + guide_scale = gr.Slider( + label="Guide scale", + minimum=0, + maximum=20, + value=5.0, + step=1) + with gr.Row(): + shift_scale = gr.Slider( + label="Shift scale", + minimum=0, + maximum=10, + value=5.0, + step=1) + seed = gr.Slider( + label="Seed", + minimum=-1, + maximum=2147483647, + step=1, + value=-1) + n_prompt = gr.Textbox( + label="Negative Prompt", + placeholder="Describe the negative prompt you want to add" + ) + + run_t2i_button = gr.Button("Generate Image") + + with gr.Column(): + result_gallery = gr.Image( + label='Generated Image', interactive=False, height=600) + + run_p_button.click( + fn=prompt_enc, + inputs=[txt2img_prompt, tar_lang], + outputs=[txt2img_prompt]) + + run_t2i_button.click( + fn=t2i_generation, + inputs=[ + txt2img_prompt, resolution, sd_steps, guide_scale, shift_scale, + seed, n_prompt + ], + outputs=[result_gallery], + ) + + return demo + + +# Main +def _parse_args(): + parser = argparse.ArgumentParser( + description="Generate a image from a text prompt or image using Gradio") + parser.add_argument( + "--ckpt_dir", + type=str, + default="cache", + help="The path to the checkpoint directory.") + parser.add_argument( + "--prompt_extend_method", + type=str, + default="local_qwen", + choices=["dashscope", "local_qwen"], + help="The prompt extend method to use.") + parser.add_argument( + "--prompt_extend_model", + type=str, + default=None, + help="The prompt extend model to use.") + + args = parser.parse_args() + + return args + + +if __name__ == '__main__': + args = _parse_args() + + print("Step1: Init prompt_expander...", end='', flush=True) + if args.prompt_extend_method == "dashscope": + prompt_expander = DashScopePromptExpander( + model_name=args.prompt_extend_model, is_vl=False) + elif args.prompt_extend_method == "local_qwen": + prompt_expander = QwenPromptExpander( + model_name=args.prompt_extend_model, is_vl=False, device=0) + else: + raise NotImplementedError( + f"Unsupport prompt_extend_method: {args.prompt_extend_method}") + print("done", flush=True) + + print("Step2: Init 1.3B t2i model...", end='', flush=True) + cfg = WAN_CONFIGS['t2i-1.3B'] + wan_t2i = wan.WanT2V( + config=cfg, + checkpoint_dir=args.ckpt_dir, + device_id=0, + rank=0, + t5_fsdp=False, + dit_fsdp=False, + use_usp=False, + ) + print("done", flush=True) + + demo = gradio_interface() + demo.launch(server_name="0.0.0.0", share=False, server_port=7860) diff --git a/wan/configs/__init__.py b/wan/configs/__init__.py index c72d2d0..1925bc5 100644 --- a/wan/configs/__init__.py +++ b/wan/configs/__init__.py @@ -11,12 +11,15 @@ from .wan_t2v_14B import t2v_14B # the config of t2i_14B is the same as t2v_14B t2i_14B = copy.deepcopy(t2v_14B) t2i_14B.__name__ = 'Config: Wan T2I 14B' +t2i_1_3B = copy.deepcopy(t2v_1_3B) +t2i_1_3B.__name__ = 'Config: Wan T2I 1.3B' WAN_CONFIGS = { 't2v-14B': t2v_14B, 't2v-1.3B': t2v_1_3B, 'i2v-14B': i2v_14B, 't2i-14B': t2i_14B, + 't2i-1.3B': t2i_1_3B, } SIZE_CONFIGS = { @@ -39,4 +42,5 @@ SUPPORTED_SIZES = { 't2v-1.3B': ('480*832', '832*480'), 'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), 't2i-14B': tuple(SIZE_CONFIGS.keys()), + 't2i-1.3B': tuple(SIZE_CONFIGS.keys()), }