Added support for multiple input images

This commit is contained in:
DeepBeepMeep 2025-03-04 14:22:06 +01:00
parent c02c84961f
commit 24d8beb490
4 changed files with 89 additions and 46 deletions

View File

@ -19,8 +19,8 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
## 🔥 Latest News!!
* Mar 03, 2025: 👋 Wan2.1GP v1.2: Implementented tiling on VAE encoding and decoding. No more VRAM peaks at the beginning and at the end
* Mar 03, 2025: 👋 Wan2.1GP v1.3: Support for Image to Video with multiples images for different images / prompts combinations (requires *--multiple-images* switch), and added command line *--preload x* to preload in VRAM x MB of the main diffusion model if you find there is too much unused VRAM and you want to (slightly) accelerate the generation process
* Mar 03, 2025: 👋 Wan2.1GP v1.2: Implemented tiling on VAE encoding and decoding. No more VRAM peaks at the beginning and at the end
* Mar 03, 2025: 👋 Wan2.1GP v1.1: added Tea Cache support for faster generations: optimization of kijai's implementation (https://github.com/kijai/ComfyUI-WanVideoWrapper/) of teacache (https://github.com/ali-vilab/TeaCache)
* Mar 02, 2025: 👋 Wan2.1GP by DeepBeepMeep v1 brings:
- Support for all Wan including the Image to Video model
@ -121,6 +121,11 @@ To run the image to video generator (in Low VRAM mode):
python gradio_server.py --i2v
```
To be able to input multiple images with the image to video generator:
```bash
python gradio_server.py --i2v --multiple-images
```
Within the application you can configure which video generator will be launched without specifying a command line switch.
To run the application while loading entirely the diffusion model in VRAM (slightly faster but requires 24 GB of VRAM for a 8 bits quantized 14B model )
@ -155,18 +160,22 @@ You will find prebuilt Loras on https://civitai.com/ or you will be able to buil
--lora-preset preset : name of preset gile (without the extension) to preload
--verbose level : default (1) : level of information between 0 and 2\
--server-port portno : default (7860) : Gradio port no\
--server-name name : default (0.0.0.0) : Gradio server name\
--server-name name : default (localhost) : Gradio server name\
--open-browser : open automatically Browser when launching Gradio Server\
--lock-config : prevent modifying the video engine configuration from the interface\
--share : create a shareable URL on huggingface so that your server can be accessed remotely\
--multiple-images : Images as a starting point for new videos\
--compile : turn on pytorch compilation\
--attention mode: force attention mode among, sdpa, flash, sage, sage2\
--profile no : default (4) : no of profile between 1 and 5
--profile no : default (4) : no of profile between 1 and 5\
--preload no : number in Megabytes to preload partially the diffusion model in VRAM , may offer speed gains especially on
### Profiles (for power users only)
You can choose between 5 profiles, but two are really relevant here :
- LowRAM_HighVRAM (3): loads entirely the model in VRAM, slightly faster, but less VRAM
- LowRAM_LowVRAM (4): load only the part of the models that is needed, low VRAM and low RAM requirement but slightly slower
- LowRAM_HighVRAM (3): loads entirely the model in VRAM if possible, slightly faster, but less VRAM available for the video data after that
- LowRAM_LowVRAM (4): loads only the part of the model that is needed, low VRAM and low RAM requirement but slightly slower
You can adjust the number of megabytes to preload a model, with --preload nnn (nnn is the number of megabytes to preload)
### Other Models for the GPU Poor
- HuanyuanVideoGP: https://github.com/deepbeepmeep/HunyuanVideoGP :\

View File

@ -13,38 +13,17 @@ import random
import json
import wan
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_video
from wan.modules.attention import get_attention_modes
import torch
import gc
import traceback
import math
import asyncio
def _parse_args():
parser = argparse.ArgumentParser(
description="Generate a video from a text prompt or image using Gradio")
parser.add_argument(
"--ckpt_dir_720p",
type=str,
default=None,
help="The path to the checkpoint directory.")
parser.add_argument(
"--ckpt_dir_480p",
type=str,
default=None,
help="The path to the checkpoint directory.")
parser.add_argument(
"--prompt_extend_method",
type=str,
default="local_qwen",
choices=["dashscope", "local_qwen"],
help="The prompt extend method to use.")
parser.add_argument(
"--prompt_extend_model",
type=str,
default=None,
help="The prompt extend model to use.")
parser.add_argument(
"--quantize-transformer",
@ -52,6 +31,31 @@ def _parse_args():
help="On the fly 'transformer' quantization"
)
parser.add_argument(
"--share",
action="store_true",
help="Create a shared URL to access webserver remotely"
)
parser.add_argument(
"--lock-config",
action="store_true",
help="Prevent modifying the configuration from the web interface"
)
parser.add_argument(
"--preload",
type=str,
default="0",
help="Megabytes of the diffusion model to preload in VRAM"
)
parser.add_argument(
"--multiple-images",
action="store_true",
help="Allow inputting multiple images with image to video"
)
parser.add_argument(
"--lora-dir-i2v",
@ -163,9 +167,6 @@ def _parse_args():
args = parser.parse_args()
args.ckpt_dir_720p = "../ckpts" # os.path.join("ckpt")
args.ckpt_dir_480p = "../ckpts" # os.path.join("ckpt")
assert args.ckpt_dir_720p is not None or args.ckpt_dir_480p is not None, "Please specify at least one checkpoint directory."
return args
@ -179,7 +180,7 @@ lock_ui_attention = False
lock_ui_transformer = False
lock_ui_compile = False
preload =int(args.preload)
force_profile_no = int(args.profile)
verbose_level = int(args.verbose)
quantizeTransformer = args.quantize_transformer
@ -433,7 +434,10 @@ def load_models(i2v, lora_dir, lora_preselected_preset ):
kwargs = { "extraModelsToQuantize": None}
if profile == 2 or profile == 4:
kwargs["budgets"] = { "transformer" : 100, "text_encoder" : 100, "*" : 1000 }
kwargs["budgets"] = { "transformer" : 100 if preload == 0 else preload, "text_encoder" : 100, "*" : 1000 }
elif profile == 3:
kwargs["budgets"] = { "*" : "70%" }
loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets = setup_loras(pipe, lora_dir, lora_preselected_preset, None)
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, **kwargs)
@ -484,7 +488,8 @@ def apply_changes( state,
vae_config_choice,
default_ui_choice ="t2v",
):
if args.lock_config:
return
if gen_in_progress:
yield "<DIV ALIGN=CENTER>Unable to change config when a generation is in progress</DIV>"
return
@ -719,9 +724,32 @@ def generate_video(
global gen_in_progress
gen_in_progress = True
temp_filename = None
if len(prompt) ==0:
return
prompts = prompt.replace("\r", "").split("\n")
if use_image2video:
if image_to_continue is not None:
pass
if isinstance(image_to_continue, list):
image_to_continue = [ tup[0] for tup in image_to_continue ]
else:
image_to_continue = [image_to_continue]
if len(prompts) >= len(image_to_continue):
if len(prompts) % len(image_to_continue) !=0:
raise gr.Error("If there are more text prompts than input images the number of text prompts should be dividable by the number of images")
rep = len(prompts) // len(image_to_continue)
new_image_to_continue = []
for i, _ in enumerate(prompts):
new_image_to_continue.append(image_to_continue[i//rep] )
image_to_continue = new_image_to_continue
else:
if len(image_to_continue) % len(prompts) !=0:
raise gr.Error("If there are more input images than text prompts the number of images should be dividable by the number of text prompts")
rep = len(image_to_continue) // len(prompts)
new_prompts = []
for i, _ in enumerate(image_to_continue):
new_prompts.append( prompts[ i//rep] )
prompts = new_prompts
elif video_to_continue != None and len(video_to_continue) >0 :
input_image_or_video_path = video_to_continue
@ -791,7 +819,6 @@ def generate_video(
from einops import rearrange
save_path = os.path.join(os.getcwd(), "gradio_outputs")
os.makedirs(save_path, exist_ok=True)
prompts = prompt.replace("\r", "").split("\n")
video_no = 0
total_video = repeat_generation * len(prompts)
abort = False
@ -826,7 +853,7 @@ def generate_video(
if use_image2video:
samples = wan_model.generate(
prompt,
image_to_continue,
image_to_continue[video_no-1],
frame_num=(video_length // 4)* 4 + 1,
max_area=MAX_AREA_CONFIGS[resolution],
shift=flow_shift,
@ -1018,7 +1045,7 @@ def create_demo():
header = gr.Markdown(generate_header(transformer_filename_i2v if use_image2video else transformer_filename_t2v, compile, attention_mode) )
with gr.Accordion("Video Engine Configuration - click here to change it", open = False):
with gr.Accordion("Video Engine Configuration - click here to change it", open = False, visible= not args.lock_config):
gr.Markdown("For the changes to be effective you will need to restart the gradio_server. Some choices below may be locked if the app has been launched by specifying a config preset.")
with gr.Column():
@ -1100,7 +1127,7 @@ def create_demo():
("128 x 128 : If at least 6 GB of VRAM", 3),
],
value= vae_config,
label="VAE optimisations - reduce the VRAM requirements for VAE decoding and VAE encoding"
label="VAE Tiling - reduce the high VRAM requirements for VAE decoding and VAE encoding (if enabled it will be slower)"
)
profile_choice = gr.Dropdown(
@ -1131,8 +1158,13 @@ def create_demo():
with gr.Row():
with gr.Column():
video_to_continue = gr.Video(label= "Video to continue", visible= use_image2video and False) #######
image_to_continue = gr.Image(label= "Image as a starting point for a new video", visible=use_image2video)
video_to_continue = gr.Video(label= "Video to continue", visible= use_image2video and False) #######
if args.multiple_images:
image_to_continue = gr.Gallery(
label="Images as a starting point for new videos", type ="pil", #file_types= "image",
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=use_image2video)
else:
image_to_continue = gr.Image(label= "Image as a starting point for a new video", visible=use_image2video)
if use_image2video:
prompt = gr.Textbox(label="Prompts (multiple prompts separated by carriage returns will generate multiple videos)", value="Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field.", lines=3)
@ -1317,6 +1349,8 @@ if __name__ == "__main__":
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
server_port = int(args.server_port)
if os.name == "nt":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
if server_port == 0:
server_port = int(os.getenv("SERVER_PORT", "7860"))
@ -1334,6 +1368,6 @@ if __name__ == "__main__":
url = "http://" + server_name
webbrowser.open(url + ":" + str(server_port), new = 0, autoraise = True)
demo.launch(server_name=server_name, server_port=server_port)
demo.launch(server_name=server_name, server_port=server_port, share=args.share)

View File

@ -16,4 +16,4 @@ gradio>=5.0.0
numpy>=1.23.5,<2
einops
moviepy==1.0.3
mmgp==3.2.1
mmgp==3.2.2

View File

@ -541,6 +541,6 @@ class CLIPModel:
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
# forward
with torch.cuda.amp.autocast(dtype=self.dtype):
out = self.model.visual(videos, use_31_block=True)
with torch.amp.autocast(dtype=self.dtype, device_type="cuda"):
out = self.model.visual(videos.to(torch.bfloat16), use_31_block=True)
return out