Added support for multiple input images

This commit is contained in:
DeepBeepMeep 2025-03-04 14:22:06 +01:00
parent c02c84961f
commit 24d8beb490
4 changed files with 89 additions and 46 deletions

View File

@ -19,8 +19,8 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
## 🔥 Latest News!!
* Mar 03, 2025: 👋 Wan2.1GP v1.2: Implementented tiling on VAE encoding and decoding. No more VRAM peaks at the beginning and at the end
* Mar 03, 2025: 👋 Wan2.1GP v1.3: Support for Image to Video with multiples images for different images / prompts combinations (requires *--multiple-images* switch), and added command line *--preload x* to preload in VRAM x MB of the main diffusion model if you find there is too much unused VRAM and you want to (slightly) accelerate the generation process
* Mar 03, 2025: 👋 Wan2.1GP v1.2: Implemented tiling on VAE encoding and decoding. No more VRAM peaks at the beginning and at the end
* Mar 03, 2025: 👋 Wan2.1GP v1.1: added Tea Cache support for faster generations: optimization of kijai's implementation (https://github.com/kijai/ComfyUI-WanVideoWrapper/) of teacache (https://github.com/ali-vilab/TeaCache)
* Mar 02, 2025: 👋 Wan2.1GP by DeepBeepMeep v1 brings:
- Support for all Wan including the Image to Video model
@ -121,6 +121,11 @@ To run the image to video generator (in Low VRAM mode):
python gradio_server.py --i2v
```
To be able to input multiple images with the image to video generator:
```bash
python gradio_server.py --i2v --multiple-images
```
Within the application you can configure which video generator will be launched without specifying a command line switch.
To run the application while loading entirely the diffusion model in VRAM (slightly faster but requires 24 GB of VRAM for a 8 bits quantized 14B model )
@ -155,18 +160,22 @@ You will find prebuilt Loras on https://civitai.com/ or you will be able to buil
--lora-preset preset : name of preset gile (without the extension) to preload
--verbose level : default (1) : level of information between 0 and 2\
--server-port portno : default (7860) : Gradio port no\
--server-name name : default (0.0.0.0) : Gradio server name\
--server-name name : default (localhost) : Gradio server name\
--open-browser : open automatically Browser when launching Gradio Server\
--lock-config : prevent modifying the video engine configuration from the interface\
--share : create a shareable URL on huggingface so that your server can be accessed remotely\
--multiple-images : Images as a starting point for new videos\
--compile : turn on pytorch compilation\
--attention mode: force attention mode among, sdpa, flash, sage, sage2\
--profile no : default (4) : no of profile between 1 and 5
--profile no : default (4) : no of profile between 1 and 5\
--preload no : number in Megabytes to preload partially the diffusion model in VRAM , may offer speed gains especially on
### Profiles (for power users only)
You can choose between 5 profiles, but two are really relevant here :
- LowRAM_HighVRAM (3): loads entirely the model in VRAM, slightly faster, but less VRAM
- LowRAM_LowVRAM (4): load only the part of the models that is needed, low VRAM and low RAM requirement but slightly slower
- LowRAM_HighVRAM (3): loads entirely the model in VRAM if possible, slightly faster, but less VRAM available for the video data after that
- LowRAM_LowVRAM (4): loads only the part of the model that is needed, low VRAM and low RAM requirement but slightly slower
You can adjust the number of megabytes to preload a model, with --preload nnn (nnn is the number of megabytes to preload)
### Other Models for the GPU Poor
- HuanyuanVideoGP: https://github.com/deepbeepmeep/HunyuanVideoGP :\

View File

@ -13,38 +13,17 @@ import random
import json
import wan
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_video
from wan.modules.attention import get_attention_modes
import torch
import gc
import traceback
import math
import asyncio
def _parse_args():
parser = argparse.ArgumentParser(
description="Generate a video from a text prompt or image using Gradio")
parser.add_argument(
"--ckpt_dir_720p",
type=str,
default=None,
help="The path to the checkpoint directory.")
parser.add_argument(
"--ckpt_dir_480p",
type=str,
default=None,
help="The path to the checkpoint directory.")
parser.add_argument(
"--prompt_extend_method",
type=str,
default="local_qwen",
choices=["dashscope", "local_qwen"],
help="The prompt extend method to use.")
parser.add_argument(
"--prompt_extend_model",
type=str,
default=None,
help="The prompt extend model to use.")
parser.add_argument(
"--quantize-transformer",
@ -52,6 +31,31 @@ def _parse_args():
help="On the fly 'transformer' quantization"
)
parser.add_argument(
"--share",
action="store_true",
help="Create a shared URL to access webserver remotely"
)
parser.add_argument(
"--lock-config",
action="store_true",
help="Prevent modifying the configuration from the web interface"
)
parser.add_argument(
"--preload",
type=str,
default="0",
help="Megabytes of the diffusion model to preload in VRAM"
)
parser.add_argument(
"--multiple-images",
action="store_true",
help="Allow inputting multiple images with image to video"
)
parser.add_argument(
"--lora-dir-i2v",
@ -163,9 +167,6 @@ def _parse_args():
args = parser.parse_args()
args.ckpt_dir_720p = "../ckpts" # os.path.join("ckpt")
args.ckpt_dir_480p = "../ckpts" # os.path.join("ckpt")
assert args.ckpt_dir_720p is not None or args.ckpt_dir_480p is not None, "Please specify at least one checkpoint directory."
return args
@ -179,7 +180,7 @@ lock_ui_attention = False
lock_ui_transformer = False
lock_ui_compile = False
preload =int(args.preload)
force_profile_no = int(args.profile)
verbose_level = int(args.verbose)
quantizeTransformer = args.quantize_transformer
@ -433,7 +434,10 @@ def load_models(i2v, lora_dir, lora_preselected_preset ):
kwargs = { "extraModelsToQuantize": None}
if profile == 2 or profile == 4:
kwargs["budgets"] = { "transformer" : 100, "text_encoder" : 100, "*" : 1000 }
kwargs["budgets"] = { "transformer" : 100 if preload == 0 else preload, "text_encoder" : 100, "*" : 1000 }
elif profile == 3:
kwargs["budgets"] = { "*" : "70%" }
loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets = setup_loras(pipe, lora_dir, lora_preselected_preset, None)
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, **kwargs)
@ -484,7 +488,8 @@ def apply_changes( state,
vae_config_choice,
default_ui_choice ="t2v",
):
if args.lock_config:
return
if gen_in_progress:
yield "<DIV ALIGN=CENTER>Unable to change config when a generation is in progress</DIV>"
return
@ -719,9 +724,32 @@ def generate_video(
global gen_in_progress
gen_in_progress = True
temp_filename = None
if len(prompt) ==0:
return
prompts = prompt.replace("\r", "").split("\n")
if use_image2video:
if image_to_continue is not None:
pass
if isinstance(image_to_continue, list):
image_to_continue = [ tup[0] for tup in image_to_continue ]
else:
image_to_continue = [image_to_continue]
if len(prompts) >= len(image_to_continue):
if len(prompts) % len(image_to_continue) !=0:
raise gr.Error("If there are more text prompts than input images the number of text prompts should be dividable by the number of images")
rep = len(prompts) // len(image_to_continue)
new_image_to_continue = []
for i, _ in enumerate(prompts):
new_image_to_continue.append(image_to_continue[i//rep] )
image_to_continue = new_image_to_continue
else:
if len(image_to_continue) % len(prompts) !=0:
raise gr.Error("If there are more input images than text prompts the number of images should be dividable by the number of text prompts")
rep = len(image_to_continue) // len(prompts)
new_prompts = []
for i, _ in enumerate(image_to_continue):
new_prompts.append( prompts[ i//rep] )
prompts = new_prompts
elif video_to_continue != None and len(video_to_continue) >0 :
input_image_or_video_path = video_to_continue
@ -791,7 +819,6 @@ def generate_video(
from einops import rearrange
save_path = os.path.join(os.getcwd(), "gradio_outputs")
os.makedirs(save_path, exist_ok=True)
prompts = prompt.replace("\r", "").split("\n")
video_no = 0
total_video = repeat_generation * len(prompts)
abort = False
@ -826,7 +853,7 @@ def generate_video(
if use_image2video:
samples = wan_model.generate(
prompt,
image_to_continue,
image_to_continue[video_no-1],
frame_num=(video_length // 4)* 4 + 1,
max_area=MAX_AREA_CONFIGS[resolution],
shift=flow_shift,
@ -1018,7 +1045,7 @@ def create_demo():
header = gr.Markdown(generate_header(transformer_filename_i2v if use_image2video else transformer_filename_t2v, compile, attention_mode) )
with gr.Accordion("Video Engine Configuration - click here to change it", open = False):
with gr.Accordion("Video Engine Configuration - click here to change it", open = False, visible= not args.lock_config):
gr.Markdown("For the changes to be effective you will need to restart the gradio_server. Some choices below may be locked if the app has been launched by specifying a config preset.")
with gr.Column():
@ -1100,7 +1127,7 @@ def create_demo():
("128 x 128 : If at least 6 GB of VRAM", 3),
],
value= vae_config,
label="VAE optimisations - reduce the VRAM requirements for VAE decoding and VAE encoding"
label="VAE Tiling - reduce the high VRAM requirements for VAE decoding and VAE encoding (if enabled it will be slower)"
)
profile_choice = gr.Dropdown(
@ -1132,6 +1159,11 @@ def create_demo():
with gr.Row():
with gr.Column():
video_to_continue = gr.Video(label= "Video to continue", visible= use_image2video and False) #######
if args.multiple_images:
image_to_continue = gr.Gallery(
label="Images as a starting point for new videos", type ="pil", #file_types= "image",
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=use_image2video)
else:
image_to_continue = gr.Image(label= "Image as a starting point for a new video", visible=use_image2video)
if use_image2video:
@ -1317,6 +1349,8 @@ if __name__ == "__main__":
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
server_port = int(args.server_port)
if os.name == "nt":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
if server_port == 0:
server_port = int(os.getenv("SERVER_PORT", "7860"))
@ -1334,6 +1368,6 @@ if __name__ == "__main__":
url = "http://" + server_name
webbrowser.open(url + ":" + str(server_port), new = 0, autoraise = True)
demo.launch(server_name=server_name, server_port=server_port)
demo.launch(server_name=server_name, server_port=server_port, share=args.share)

View File

@ -16,4 +16,4 @@ gradio>=5.0.0
numpy>=1.23.5,<2
einops
moviepy==1.0.3
mmgp==3.2.1
mmgp==3.2.2

View File

@ -541,6 +541,6 @@ class CLIPModel:
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
# forward
with torch.cuda.amp.autocast(dtype=self.dtype):
out = self.model.visual(videos, use_31_block=True)
with torch.amp.autocast(dtype=self.dtype, device_type="cuda"):
out = self.model.visual(videos.to(torch.bfloat16), use_31_block=True)
return out