mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-05 06:29:14 +00:00
Added support for multiple input images
This commit is contained in:
parent
c02c84961f
commit
24d8beb490
23
README.md
23
README.md
@ -19,8 +19,8 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
|
||||
|
||||
|
||||
## 🔥 Latest News!!
|
||||
|
||||
* Mar 03, 2025: 👋 Wan2.1GP v1.2: Implementented tiling on VAE encoding and decoding. No more VRAM peaks at the beginning and at the end
|
||||
* Mar 03, 2025: 👋 Wan2.1GP v1.3: Support for Image to Video with multiples images for different images / prompts combinations (requires *--multiple-images* switch), and added command line *--preload x* to preload in VRAM x MB of the main diffusion model if you find there is too much unused VRAM and you want to (slightly) accelerate the generation process
|
||||
* Mar 03, 2025: 👋 Wan2.1GP v1.2: Implemented tiling on VAE encoding and decoding. No more VRAM peaks at the beginning and at the end
|
||||
* Mar 03, 2025: 👋 Wan2.1GP v1.1: added Tea Cache support for faster generations: optimization of kijai's implementation (https://github.com/kijai/ComfyUI-WanVideoWrapper/) of teacache (https://github.com/ali-vilab/TeaCache)
|
||||
* Mar 02, 2025: 👋 Wan2.1GP by DeepBeepMeep v1 brings:
|
||||
- Support for all Wan including the Image to Video model
|
||||
@ -121,6 +121,11 @@ To run the image to video generator (in Low VRAM mode):
|
||||
python gradio_server.py --i2v
|
||||
```
|
||||
|
||||
To be able to input multiple images with the image to video generator:
|
||||
```bash
|
||||
python gradio_server.py --i2v --multiple-images
|
||||
```
|
||||
|
||||
Within the application you can configure which video generator will be launched without specifying a command line switch.
|
||||
|
||||
To run the application while loading entirely the diffusion model in VRAM (slightly faster but requires 24 GB of VRAM for a 8 bits quantized 14B model )
|
||||
@ -155,18 +160,22 @@ You will find prebuilt Loras on https://civitai.com/ or you will be able to buil
|
||||
--lora-preset preset : name of preset gile (without the extension) to preload
|
||||
--verbose level : default (1) : level of information between 0 and 2\
|
||||
--server-port portno : default (7860) : Gradio port no\
|
||||
--server-name name : default (0.0.0.0) : Gradio server name\
|
||||
--server-name name : default (localhost) : Gradio server name\
|
||||
--open-browser : open automatically Browser when launching Gradio Server\
|
||||
--lock-config : prevent modifying the video engine configuration from the interface\
|
||||
--share : create a shareable URL on huggingface so that your server can be accessed remotely\
|
||||
--multiple-images : Images as a starting point for new videos\
|
||||
--compile : turn on pytorch compilation\
|
||||
--attention mode: force attention mode among, sdpa, flash, sage, sage2\
|
||||
--profile no : default (4) : no of profile between 1 and 5
|
||||
--profile no : default (4) : no of profile between 1 and 5\
|
||||
--preload no : number in Megabytes to preload partially the diffusion model in VRAM , may offer speed gains especially on
|
||||
|
||||
### Profiles (for power users only)
|
||||
You can choose between 5 profiles, but two are really relevant here :
|
||||
- LowRAM_HighVRAM (3): loads entirely the model in VRAM, slightly faster, but less VRAM
|
||||
- LowRAM_LowVRAM (4): load only the part of the models that is needed, low VRAM and low RAM requirement but slightly slower
|
||||
|
||||
- LowRAM_HighVRAM (3): loads entirely the model in VRAM if possible, slightly faster, but less VRAM available for the video data after that
|
||||
- LowRAM_LowVRAM (4): loads only the part of the model that is needed, low VRAM and low RAM requirement but slightly slower
|
||||
|
||||
You can adjust the number of megabytes to preload a model, with --preload nnn (nnn is the number of megabytes to preload)
|
||||
### Other Models for the GPU Poor
|
||||
|
||||
- HuanyuanVideoGP: https://github.com/deepbeepmeep/HunyuanVideoGP :\
|
||||
|
||||
102
gradio_server.py
102
gradio_server.py
@ -13,38 +13,17 @@ import random
|
||||
import json
|
||||
import wan
|
||||
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES
|
||||
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
||||
from wan.utils.utils import cache_video
|
||||
from wan.modules.attention import get_attention_modes
|
||||
import torch
|
||||
import gc
|
||||
import traceback
|
||||
import math
|
||||
import asyncio
|
||||
|
||||
def _parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate a video from a text prompt or image using Gradio")
|
||||
parser.add_argument(
|
||||
"--ckpt_dir_720p",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The path to the checkpoint directory.")
|
||||
parser.add_argument(
|
||||
"--ckpt_dir_480p",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The path to the checkpoint directory.")
|
||||
parser.add_argument(
|
||||
"--prompt_extend_method",
|
||||
type=str,
|
||||
default="local_qwen",
|
||||
choices=["dashscope", "local_qwen"],
|
||||
help="The prompt extend method to use.")
|
||||
parser.add_argument(
|
||||
"--prompt_extend_model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The prompt extend model to use.")
|
||||
|
||||
parser.add_argument(
|
||||
"--quantize-transformer",
|
||||
@ -52,6 +31,31 @@ def _parse_args():
|
||||
help="On the fly 'transformer' quantization"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--share",
|
||||
action="store_true",
|
||||
help="Create a shared URL to access webserver remotely"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lock-config",
|
||||
action="store_true",
|
||||
help="Prevent modifying the configuration from the web interface"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--preload",
|
||||
type=str,
|
||||
default="0",
|
||||
help="Megabytes of the diffusion model to preload in VRAM"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--multiple-images",
|
||||
action="store_true",
|
||||
help="Allow inputting multiple images with image to video"
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--lora-dir-i2v",
|
||||
@ -163,9 +167,6 @@ def _parse_args():
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
args.ckpt_dir_720p = "../ckpts" # os.path.join("ckpt")
|
||||
args.ckpt_dir_480p = "../ckpts" # os.path.join("ckpt")
|
||||
assert args.ckpt_dir_720p is not None or args.ckpt_dir_480p is not None, "Please specify at least one checkpoint directory."
|
||||
|
||||
return args
|
||||
|
||||
@ -179,7 +180,7 @@ lock_ui_attention = False
|
||||
lock_ui_transformer = False
|
||||
lock_ui_compile = False
|
||||
|
||||
|
||||
preload =int(args.preload)
|
||||
force_profile_no = int(args.profile)
|
||||
verbose_level = int(args.verbose)
|
||||
quantizeTransformer = args.quantize_transformer
|
||||
@ -433,7 +434,10 @@ def load_models(i2v, lora_dir, lora_preselected_preset ):
|
||||
|
||||
kwargs = { "extraModelsToQuantize": None}
|
||||
if profile == 2 or profile == 4:
|
||||
kwargs["budgets"] = { "transformer" : 100, "text_encoder" : 100, "*" : 1000 }
|
||||
kwargs["budgets"] = { "transformer" : 100 if preload == 0 else preload, "text_encoder" : 100, "*" : 1000 }
|
||||
elif profile == 3:
|
||||
kwargs["budgets"] = { "*" : "70%" }
|
||||
|
||||
|
||||
loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets = setup_loras(pipe, lora_dir, lora_preselected_preset, None)
|
||||
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, **kwargs)
|
||||
@ -484,7 +488,8 @@ def apply_changes( state,
|
||||
vae_config_choice,
|
||||
default_ui_choice ="t2v",
|
||||
):
|
||||
|
||||
if args.lock_config:
|
||||
return
|
||||
if gen_in_progress:
|
||||
yield "<DIV ALIGN=CENTER>Unable to change config when a generation is in progress</DIV>"
|
||||
return
|
||||
@ -719,9 +724,32 @@ def generate_video(
|
||||
global gen_in_progress
|
||||
gen_in_progress = True
|
||||
temp_filename = None
|
||||
if len(prompt) ==0:
|
||||
return
|
||||
prompts = prompt.replace("\r", "").split("\n")
|
||||
|
||||
if use_image2video:
|
||||
if image_to_continue is not None:
|
||||
pass
|
||||
if isinstance(image_to_continue, list):
|
||||
image_to_continue = [ tup[0] for tup in image_to_continue ]
|
||||
else:
|
||||
image_to_continue = [image_to_continue]
|
||||
if len(prompts) >= len(image_to_continue):
|
||||
if len(prompts) % len(image_to_continue) !=0:
|
||||
raise gr.Error("If there are more text prompts than input images the number of text prompts should be dividable by the number of images")
|
||||
rep = len(prompts) // len(image_to_continue)
|
||||
new_image_to_continue = []
|
||||
for i, _ in enumerate(prompts):
|
||||
new_image_to_continue.append(image_to_continue[i//rep] )
|
||||
image_to_continue = new_image_to_continue
|
||||
else:
|
||||
if len(image_to_continue) % len(prompts) !=0:
|
||||
raise gr.Error("If there are more input images than text prompts the number of images should be dividable by the number of text prompts")
|
||||
rep = len(image_to_continue) // len(prompts)
|
||||
new_prompts = []
|
||||
for i, _ in enumerate(image_to_continue):
|
||||
new_prompts.append( prompts[ i//rep] )
|
||||
prompts = new_prompts
|
||||
|
||||
elif video_to_continue != None and len(video_to_continue) >0 :
|
||||
input_image_or_video_path = video_to_continue
|
||||
@ -791,7 +819,6 @@ def generate_video(
|
||||
from einops import rearrange
|
||||
save_path = os.path.join(os.getcwd(), "gradio_outputs")
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
prompts = prompt.replace("\r", "").split("\n")
|
||||
video_no = 0
|
||||
total_video = repeat_generation * len(prompts)
|
||||
abort = False
|
||||
@ -826,7 +853,7 @@ def generate_video(
|
||||
if use_image2video:
|
||||
samples = wan_model.generate(
|
||||
prompt,
|
||||
image_to_continue,
|
||||
image_to_continue[video_no-1],
|
||||
frame_num=(video_length // 4)* 4 + 1,
|
||||
max_area=MAX_AREA_CONFIGS[resolution],
|
||||
shift=flow_shift,
|
||||
@ -1018,7 +1045,7 @@ def create_demo():
|
||||
|
||||
header = gr.Markdown(generate_header(transformer_filename_i2v if use_image2video else transformer_filename_t2v, compile, attention_mode) )
|
||||
|
||||
with gr.Accordion("Video Engine Configuration - click here to change it", open = False):
|
||||
with gr.Accordion("Video Engine Configuration - click here to change it", open = False, visible= not args.lock_config):
|
||||
gr.Markdown("For the changes to be effective you will need to restart the gradio_server. Some choices below may be locked if the app has been launched by specifying a config preset.")
|
||||
|
||||
with gr.Column():
|
||||
@ -1100,7 +1127,7 @@ def create_demo():
|
||||
("128 x 128 : If at least 6 GB of VRAM", 3),
|
||||
],
|
||||
value= vae_config,
|
||||
label="VAE optimisations - reduce the VRAM requirements for VAE decoding and VAE encoding"
|
||||
label="VAE Tiling - reduce the high VRAM requirements for VAE decoding and VAE encoding (if enabled it will be slower)"
|
||||
)
|
||||
|
||||
profile_choice = gr.Dropdown(
|
||||
@ -1132,6 +1159,11 @@ def create_demo():
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
video_to_continue = gr.Video(label= "Video to continue", visible= use_image2video and False) #######
|
||||
if args.multiple_images:
|
||||
image_to_continue = gr.Gallery(
|
||||
label="Images as a starting point for new videos", type ="pil", #file_types= "image",
|
||||
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=use_image2video)
|
||||
else:
|
||||
image_to_continue = gr.Image(label= "Image as a starting point for a new video", visible=use_image2video)
|
||||
|
||||
if use_image2video:
|
||||
@ -1317,6 +1349,8 @@ if __name__ == "__main__":
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
server_port = int(args.server_port)
|
||||
|
||||
if os.name == "nt":
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
if server_port == 0:
|
||||
server_port = int(os.getenv("SERVER_PORT", "7860"))
|
||||
|
||||
@ -1334,6 +1368,6 @@ if __name__ == "__main__":
|
||||
url = "http://" + server_name
|
||||
webbrowser.open(url + ":" + str(server_port), new = 0, autoraise = True)
|
||||
|
||||
demo.launch(server_name=server_name, server_port=server_port)
|
||||
demo.launch(server_name=server_name, server_port=server_port, share=args.share)
|
||||
|
||||
|
||||
|
||||
@ -16,4 +16,4 @@ gradio>=5.0.0
|
||||
numpy>=1.23.5,<2
|
||||
einops
|
||||
moviepy==1.0.3
|
||||
mmgp==3.2.1
|
||||
mmgp==3.2.2
|
||||
@ -541,6 +541,6 @@ class CLIPModel:
|
||||
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
|
||||
|
||||
# forward
|
||||
with torch.cuda.amp.autocast(dtype=self.dtype):
|
||||
out = self.model.visual(videos, use_31_block=True)
|
||||
with torch.amp.autocast(dtype=self.dtype, device_type="cuda"):
|
||||
out = self.model.visual(videos.to(torch.bfloat16), use_31_block=True)
|
||||
return out
|
||||
|
||||
Loading…
Reference in New Issue
Block a user