Added Vac Contronet support

This commit is contained in:
DeepBeepMeep 2025-04-05 02:06:17 +02:00
parent 8f1e6149a8
commit cf495a4aaf
7 changed files with 929 additions and 181 deletions

View File

@ -14,7 +14,7 @@ import gradio as gr
import random import random
import json import json
import wan import wan
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES, VACE_SIZE_CONFIGS
from wan.utils.utils import cache_video from wan.utils.utils import cache_video
from wan.modules.attention import get_attention_modes, get_supported_attention_modes from wan.modules.attention import get_attention_modes, get_supported_attention_modes
import torch import torch
@ -55,6 +55,11 @@ def format_time(seconds):
def pil_to_base64_uri(pil_image, format="png", quality=75): def pil_to_base64_uri(pil_image, format="png", quality=75):
if pil_image is None: if pil_image is None:
return None return None
if isinstance(pil_image, str):
from wan.utils.utils import get_video_frame
pil_image = get_video_frame(pil_image, 0)
buffer = io.BytesIO() buffer = io.BytesIO()
try: try:
img_to_save = pil_image img_to_save = pil_image
@ -93,10 +98,11 @@ def process_prompt_and_add_tasks(
loras_choices, loras_choices,
loras_mult_choices, loras_mult_choices,
image_prompt_type, image_prompt_type,
image_to_continue, image_source1,
image_to_end, image_source2,
video_to_continue, image_source3,
max_frames, max_frames,
remove_background_image_ref,
temporal_upsampling, temporal_upsampling,
spatial_upsampling, spatial_upsampling,
RIFLEx_setting, RIFLEx_setting,
@ -127,9 +133,9 @@ def process_prompt_and_add_tasks(
return return
file_model_needed = model_needed(image2video) file_model_needed = model_needed(image2video)
width, height = resolution.split("x")
width, height = int(width), int(height)
if image2video: if image2video:
width, height = resolution.split("x")
width, height = int(width), int(height)
if "480p" in file_model_needed and not "Fun" in file_model_needed and width * height > 848*480: if "480p" in file_model_needed and not "Fun" in file_model_needed and width * height > 848*480:
gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P") gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P")
@ -143,74 +149,94 @@ def process_prompt_and_add_tasks(
gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P") gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P")
return return
if image2video: if not image2video:
if image_to_continue == None or isinstance(image_to_continue, list) and len(image_to_continue) == 0: if "Vace" in file_model_needed and "1.3B" in file_model_needed :
resolution_reformated = str(height) + "*" + str(width)
if not resolution_reformated in VACE_SIZE_CONFIGS:
res = VACE_SIZE_CONFIGS.keys().join(" and ")
gr.Info(f"Video Resolution for Vace model is not supported. Only {res} resolutions are allowed.")
return
if not "I" in image_prompt_type:
image_source1 = None
if not "V" in image_prompt_type:
image_source2 = None
if not "M" in image_prompt_type:
image_source3 = None
if isinstance(image_source1, list):
image_source1 = [ convert_image(tup[0]) for tup in image_source1 ]
from wan.utils.utils import resize_and_remove_background
image_source1 = resize_and_remove_background(image_source1, width, height, remove_background_image_ref ==1)
image_source1 = [ image_source1 ] * len(prompts)
image_source2 = [ image_source2 ] * len(prompts)
image_source3 = [ image_source3 ] * len(prompts)
else:
if image_source1 == None or isinstance(image_source1, list) and len(image_source1) == 0:
return return
if image_prompt_type == 0: if image_prompt_type == 0:
image_to_end = None image_source2 = None
if isinstance(image_to_continue, list): if isinstance(image_source1, list):
image_to_continue = [ convert_image(tup[0]) for tup in image_to_continue ] image_source1 = [ convert_image(tup[0]) for tup in image_source1 ]
else: else:
image_to_continue = [convert_image(image_to_continue)] image_source1 = [convert_image(image_source1)]
if image_to_end != None: if image_source2 != None:
if isinstance(image_to_end , list): if isinstance(image_source2 , list):
image_to_end = [ convert_image(tup[0]) for tup in image_to_end ] image_source2 = [ convert_image(tup[0]) for tup in image_source2 ]
else: else:
image_to_end = [convert_image(image_to_end) ] image_source2 = [convert_image(image_source2) ]
if len(image_to_continue) != len(image_to_end): if len(image_source1) != len(image_source2):
gr.Info("The number of start and end images should be the same ") gr.Info("The number of start and end images should be the same ")
return return
if multi_images_gen_type == 0: if multi_images_gen_type == 0:
new_prompts = [] new_prompts = []
new_image_to_continue = [] new_image_source1 = []
new_image_to_end = [] new_image_source2 = []
for i in range(len(prompts) * len(image_to_continue) ): for i in range(len(prompts) * len(image_source1) ):
new_prompts.append( prompts[ i % len(prompts)] ) new_prompts.append( prompts[ i % len(prompts)] )
new_image_to_continue.append(image_to_continue[i // len(prompts)] ) new_image_source1.append(image_source1[i // len(prompts)] )
if image_to_end != None: if image_source2 != None:
new_image_to_end.append(image_to_end[i // len(prompts)] ) new_image_source2.append(image_source2[i // len(prompts)] )
prompts = new_prompts prompts = new_prompts
image_to_continue = new_image_to_continue image_source1 = new_image_source1
if image_to_end != None: if image_source2 != None:
image_to_end = new_image_to_end image_source2 = new_image_source2
else: else:
if len(prompts) >= len(image_to_continue): if len(prompts) >= len(image_source1):
if len(prompts) % len(image_to_continue) !=0: if len(prompts) % len(image_source1) !=0:
raise gr.Error("If there are more text prompts than input images the number of text prompts should be dividable by the number of images") raise gr.Error("If there are more text prompts than input images the number of text prompts should be dividable by the number of images")
rep = len(prompts) // len(image_to_continue) rep = len(prompts) // len(image_source1)
new_image_to_continue = [] new_image_source1 = []
new_image_to_end = [] new_image_source2 = []
for i, _ in enumerate(prompts): for i, _ in enumerate(prompts):
new_image_to_continue.append(image_to_continue[i//rep] ) new_image_source1.append(image_source1[i//rep] )
if image_to_end != None: if image_source2 != None:
new_image_to_end.append(image_to_end[i//rep] ) new_image_source2.append(image_source2[i//rep] )
image_to_continue = new_image_to_continue image_source1 = new_image_source1
if image_to_end != None: if image_source2 != None:
image_to_end = new_image_to_end image_source2 = new_image_source2
else: else:
if len(image_to_continue) % len(prompts) !=0: if len(image_source1) % len(prompts) !=0:
raise gr.Error("If there are more input images than text prompts the number of images should be dividable by the number of text prompts") raise gr.Error("If there are more input images than text prompts the number of images should be dividable by the number of text prompts")
rep = len(image_to_continue) // len(prompts) rep = len(image_source1) // len(prompts)
new_prompts = [] new_prompts = []
for i, _ in enumerate(image_to_continue): for i, _ in enumerate(image_source1):
new_prompts.append( prompts[ i//rep] ) new_prompts.append( prompts[ i//rep] )
prompts = new_prompts prompts = new_prompts
# elif video_to_continue != None and len(video_to_continue) >0 :
# input_image_or_video_path = video_to_continue if image_source1 == None:
# # pipeline.num_input_frames = max_frames image_source1 = [None] * len(prompts)
# # pipeline.max_frames = max_frames if image_source2 == None:
# else: image_source2 = [None] * len(prompts)
# return if image_source3 == None:
# else: image_source3 = [None] * len(prompts)
# input_image_or_video_path = None
if image_to_continue == None:
image_to_continue = [None] * len(prompts)
if image_to_end == None:
image_to_end = [None] * len(prompts)
for single_prompt, image_start, image_end in zip(prompts, image_to_continue, image_to_end) : for single_prompt, image_source1, image_source2, image_source3 in zip(prompts, image_source1, image_source2, image_source3) :
kwargs = { kwargs = {
"prompt" : single_prompt, "prompt" : single_prompt,
"negative_prompt" : negative_prompt, "negative_prompt" : negative_prompt,
@ -228,10 +254,11 @@ def process_prompt_and_add_tasks(
"loras_choices" : loras_choices, "loras_choices" : loras_choices,
"loras_mult_choices" : loras_mult_choices, "loras_mult_choices" : loras_mult_choices,
"image_prompt_type" : image_prompt_type, "image_prompt_type" : image_prompt_type,
"image_to_continue": image_start, "image_source1": image_source1,
"image_to_end" : image_end, "image_source2" : image_source2,
"video_to_continue" : video_to_continue , "image_source3" : image_source3 ,
"max_frames" : max_frames, "max_frames" : max_frames,
"remove_background_image_ref" : remove_background_image_ref,
"temporal_upsampling" : temporal_upsampling, "temporal_upsampling" : temporal_upsampling,
"spatial_upsampling" : spatial_upsampling, "spatial_upsampling" : spatial_upsampling,
"RIFLEx_setting" : RIFLEx_setting, "RIFLEx_setting" : RIFLEx_setting,
@ -262,8 +289,9 @@ def add_video_task(**kwargs):
queue = gen["queue"] queue = gen["queue"]
task_id += 1 task_id += 1
current_task_id = task_id current_task_id = task_id
start_image_data = kwargs["image_to_continue"] start_image_data = kwargs["image_source1"]
end_image_data = kwargs["image_to_end"] start_image_data = [start_image_data] if not isinstance(start_image_data, list) else start_image_data
end_image_data = kwargs["image_source2"]
queue.append({ queue.append({
"id": current_task_id, "id": current_task_id,
@ -275,7 +303,7 @@ def add_video_task(**kwargs):
"prompt": kwargs["prompt"], "prompt": kwargs["prompt"],
"start_image_data": start_image_data, "start_image_data": start_image_data,
"end_image_data": end_image_data, "end_image_data": end_image_data,
"start_image_data_base64": pil_to_base64_uri(start_image_data, format="jpeg", quality=70), "start_image_data_base64": [pil_to_base64_uri(img, format="jpeg", quality=70) for img in start_image_data],
"end_image_data_base64": pil_to_base64_uri(end_image_data, format="jpeg", quality=70) "end_image_data_base64": pil_to_base64_uri(end_image_data, format="jpeg", quality=70)
}) })
return update_queue_data(queue) return update_queue_data(queue)
@ -342,6 +370,7 @@ def get_queue_table(queue):
full_prompt = item['prompt'].replace('"', '"') full_prompt = item['prompt'].replace('"', '"')
prompt_cell = f'<span title="{full_prompt}">{truncated_prompt}</span>' prompt_cell = f'<span title="{full_prompt}">{truncated_prompt}</span>'
start_img_uri =item.get('start_image_data_base64') start_img_uri =item.get('start_image_data_base64')
start_img_uri = start_img_uri[0] if start_img_uri !=None else None
end_img_uri = item.get('end_image_data_base64') end_img_uri = item.get('end_image_data_base64')
thumbnail_size = "50px" thumbnail_size = "50px"
num_steps = item.get('steps') num_steps = item.get('steps')
@ -694,6 +723,9 @@ attention_modes_installed = get_attention_modes()
attention_modes_supported = get_supported_attention_modes() attention_modes_supported = get_supported_attention_modes()
args = _parse_args() args = _parse_args()
args.flow_reverse = True args.flow_reverse = True
processing_device = args.gpu
if len(processing_device) == 0:
processing_device ="cuda"
# torch.backends.cuda.matmul.allow_fp16_accumulation = True # torch.backends.cuda.matmul.allow_fp16_accumulation = True
lock_ui_attention = False lock_ui_attention = False
lock_ui_transformer = False lock_ui_transformer = False
@ -706,7 +738,7 @@ quantizeTransformer = args.quantize_transformer
check_loras = args.check_loras ==1 check_loras = args.check_loras ==1
advanced = args.advanced advanced = args.advanced
transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors"] transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", "ckpts/wan2.1_Vace_1.3B_preview_bf16.safetensors"]
transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", ] transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", ]
text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"] text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"]
@ -750,7 +782,7 @@ def get_default_settings(filename, i2v):
"prompts": get_default_prompt(i2v), "prompts": get_default_prompt(i2v),
"resolution": "832x480", "resolution": "832x480",
"video_length": 81, "video_length": 81,
"image_prompt_type" : 0, "image_prompt_type" : 0 if i2v else "",
"num_inference_steps": 30, "num_inference_steps": 30,
"seed": -1, "seed": -1,
"repeat_generation": 1, "repeat_generation": 1,
@ -1149,6 +1181,9 @@ def get_model_name(model_filename):
if "Fun" in model_filename: if "Fun" in model_filename:
model_name = "Fun InP image2video" model_name = "Fun InP image2video"
model_name += " 14B" if "14B" in model_filename else " 1.3B" model_name += " 14B" if "14B" in model_filename else " 1.3B"
elif "Vace" in model_filename:
model_name = "Vace ControlNet text2video"
model_name += " 14B" if "14B" in model_filename else " 1.3B"
elif "image" in model_filename: elif "image" in model_filename:
model_name = "Wan2.1 image2video" model_name = "Wan2.1 image2video"
model_name += " 720p" if "720p" in model_filename else " 480p" model_name += " 720p" if "720p" in model_filename else " 480p"
@ -1353,22 +1388,22 @@ def refresh_gallery(state, msg):
end_img_md = "" end_img_md = ""
prompt = task["prompt"] prompt = task["prompt"]
if task.get('image2video'): start_img_uri = task.get('start_image_data_base64')
start_img_uri = task.get('start_image_data_base64') start_img_uri = start_img_uri[0] if start_img_uri !=None else None
end_img_uri = task.get('end_image_data_base64') end_img_uri = task.get('end_image_data_base64')
thumbnail_size = "100px" thumbnail_size = "100px"
if start_img_uri: if start_img_uri:
start_img_md = f'<img src="{start_img_uri}" alt="Start" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />' start_img_md = f'<img src="{start_img_uri}" alt="Start" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
if end_img_uri: if end_img_uri:
end_img_md = f'<img src="{end_img_uri}" alt="End" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />' end_img_md = f'<img src="{end_img_uri}" alt="End" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
label = f"Prompt of Video being Generated" label = f"Prompt of Video being Generated"
html = "<STYLE> #PINFO, #PINFO th, #PINFO td {border: 1px solid #CCCCCC;background-color:#FFFFFF;}</STYLE><TABLE WIDTH=100% ID=PINFO ><TR><TD width=100%>" + prompt + "</TD>" html = "<STYLE> #PINFO, #PINFO th, #PINFO td {border: 1px solid #CCCCCC;background-color:#FFFFFF;}</STYLE><TABLE WIDTH=100% ID=PINFO ><TR><TD width=100%>" + prompt + "</TD>"
if start_img_md != "": if start_img_md != "":
html += "<TD>" + start_img_md + "</TD>" html += "<TD>" + start_img_md + "</TD>"
if end_img_md != "": if end_img_md != "":
html += "<TD>" + end_img_md + "</TD>" html += "<TD>" + end_img_md + "</TD>"
html += "</TR></TABLE>" html += "</TR></TABLE>"
html_output = gr.HTML(html, visible= True) html_output = gr.HTML(html, visible= True)
@ -1419,24 +1454,26 @@ def expand_slist(slist, num_inference_steps ):
new_slist.append(slist[ int(pos)]) new_slist.append(slist[ int(pos)])
pos += inc pos += inc
return new_slist return new_slist
def convert_image(image): def convert_image(image):
from PIL import ExifTags
image = image.convert('RGB') from PIL import ExifTags, ImageOps
for orientation in ExifTags.TAGS.keys(): from typing import cast
if ExifTags.TAGS[orientation]=='Orientation':
break return cast(Image, ImageOps.exif_transpose(image))
exif = image.getexif() # image = image.convert('RGB')
if not orientation in exif: # for orientation in ExifTags.TAGS.keys():
return image # if ExifTags.TAGS[orientation]=='Orientation':
if exif[orientation] == 3: # break
image=image.rotate(180, expand=True) # exif = image.getexif()
elif exif[orientation] == 6: # return image
image=image.rotate(270, expand=True) # if not orientation in exif:
elif exif[orientation] == 8: # if exif[orientation] == 3:
image=image.rotate(90, expand=True) # image=image.rotate(180, expand=True)
return image # elif exif[orientation] == 6:
# image=image.rotate(270, expand=True)
# elif exif[orientation] == 8:
# image=image.rotate(90, expand=True)
# return image
def generate_video( def generate_video(
task_id, task_id,
@ -1457,10 +1494,11 @@ def generate_video(
loras_choices, loras_choices,
loras_mult_choices, loras_mult_choices,
image_prompt_type, image_prompt_type,
image_to_continue, image_source1,
image_to_end, image_source2,
video_to_continue, image_source3,
max_frames, max_frames,
remove_background_image_ref,
temporal_upsampling, temporal_upsampling,
spatial_upsampling, spatial_upsampling,
RIFLEx_setting, RIFLEx_setting,
@ -1507,7 +1545,6 @@ def generate_video(
gr.Info(f"You have selected attention mode '{attention_mode}'. However it is not installed or supported on your system. You should either install it or switch to the default 'sdpa' attention.") gr.Info(f"You have selected attention mode '{attention_mode}'. However it is not installed or supported on your system. You should either install it or switch to the default 'sdpa' attention.")
return return
if not image2video: if not image2video:
width, height = resolution.split("x") width, height = resolution.split("x")
@ -1586,7 +1623,7 @@ def generate_video(
enable_RIFLEx = RIFLEx_setting == 0 and video_length > (6* 16) or RIFLEx_setting == 1 enable_RIFLEx = RIFLEx_setting == 0 and video_length > (6* 16) or RIFLEx_setting == 1
# VAE Tiling # VAE Tiling
device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576 device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576
joint_pass = boost ==1 #and profile != 1 and profile != 3 joint_pass = boost ==1 #and profile != 1 and profile != 3
# TeaCache # TeaCache
@ -1615,6 +1652,17 @@ def generate_video(
else: else:
raise gr.Error("Teacache not supported for this model") raise gr.Error("Teacache not supported for this model")
if "Vace" in model_filename:
resolution_reformated = str(height) + "*" + str(width)
src_video, src_mask, src_ref_images = wan_model.prepare_source([image_source2],
[image_source3],
[image_source1],
video_length, VACE_SIZE_CONFIGS[resolution_reformated], "cpu",
trim_video=max_frames)
else:
src_video, src_mask, src_ref_images = None, None, None
import random import random
if seed == None or seed <0: if seed == None or seed <0:
seed = random.randint(0, 999999999) seed = random.randint(0, 999999999)
@ -1673,8 +1721,8 @@ def generate_video(
if image2video: if image2video:
samples = wan_model.generate( samples = wan_model.generate(
prompt, prompt,
image_to_continue, image_source1,
image_to_end if image_to_end != None else None, image_source2 if image_source2 != None else None,
frame_num=(video_length // 4)* 4 + 1, frame_num=(video_length // 4)* 4 + 1,
max_area=MAX_AREA_CONFIGS[resolution], max_area=MAX_AREA_CONFIGS[resolution],
shift=flow_shift, shift=flow_shift,
@ -1697,6 +1745,9 @@ def generate_video(
else: else:
samples = wan_model.generate( samples = wan_model.generate(
prompt, prompt,
input_frames = src_video,
input_ref_images= src_ref_images,
input_masks = src_mask,
frame_num=(video_length // 4)* 4 + 1, frame_num=(video_length // 4)* 4 + 1,
size=(width, height), size=(width, height),
shift=flow_shift, shift=flow_shift,
@ -1745,7 +1796,7 @@ def generate_video(
new_error = "The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames." new_error = "The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames."
else: else:
new_error = gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'") new_error = gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
tb = traceback.format_exc().split('\n')[:-2] tb = traceback.format_exc().split('\n')[:-1]
print('\n'.join(tb)) print('\n'.join(tb))
raise gr.Error(new_error, print_exception= False) raise gr.Error(new_error, print_exception= False)
@ -1799,7 +1850,7 @@ def generate_video(
if exp > 0: if exp > 0:
from rife.inference import temporal_interpolation from rife.inference import temporal_interpolation
sample = temporal_interpolation( os.path.join("ckpts", "flownet.pkl"), sample, exp, device="cuda") sample = temporal_interpolation( os.path.join("ckpts", "flownet.pkl"), sample, exp, device=processing_device)
fps = fps * 2**exp fps = fps * 2**exp
if len(spatial_upsampling) > 0: if len(spatial_upsampling) > 0:
@ -1831,8 +1882,7 @@ def generate_video(
normalize=True, normalize=True,
value_range=(-1, 1)) value_range=(-1, 1))
configs = get_settings_dict(state, image2video, True, prompt, image_prompt_type, max_frames , remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
configs = get_settings_dict(state, image2video, prompt, 0 if image_to_end == None else 1 , video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
loras_mult_choices, tea_cache , tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start, slg_end, cfg_star_switch, cfg_zero_step) loras_mult_choices, tea_cache , tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start, slg_end, cfg_star_switch, cfg_zero_step)
metadata_choice = server_config.get("metadata_choice","metadata") metadata_choice = server_config.get("metadata_choice","metadata")
@ -2294,7 +2344,7 @@ def switch_advanced(state, new_advanced, lset_name):
return gr.Row(visible=new_advanced), gr.Row(visible=True), gr.Button(visible=True), gr.Row(visible= False), gr.Dropdown(choices=lset_choices, value= lset_name) return gr.Row(visible=new_advanced), gr.Row(visible=True), gr.Button(visible=True), gr.Row(visible= False), gr.Dropdown(choices=lset_choices, value= lset_name)
def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices, def get_settings_dict(state, i2v, image_metadata, prompt, image_prompt_type, max_frames, remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step): loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step):
loras = state["loras"] loras = state["loras"]
@ -2330,18 +2380,22 @@ def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resol
ui_settings["type"] = "Wan2.1GP by DeepBeepMeep - image2video" ui_settings["type"] = "Wan2.1GP by DeepBeepMeep - image2video"
ui_settings["image_prompt_type"] = image_prompt_type ui_settings["image_prompt_type"] = image_prompt_type
else: else:
if "Vace" in transformer_filename_t2v or not image_metadata:
ui_settings["image_prompt_type"] = image_prompt_type
ui_settings["max_frames"] = max_frames
ui_settings["remove_background_image_ref"] = remove_background_image_ref
ui_settings["type"] = "Wan2.1GP by DeepBeepMeep - text2video" ui_settings["type"] = "Wan2.1GP by DeepBeepMeep - text2video"
return ui_settings return ui_settings
def save_settings(state, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices, def save_settings(state, prompt, image_prompt_type, max_frames, remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step): loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step):
if state.get("validate_success",0) != 1: if state.get("validate_success",0) != 1:
return return
image2video = state["image2video"] image2video = state["image2video"]
ui_defaults = get_settings_dict(state, image2video, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices, ui_defaults = get_settings_dict(state, image2video, False, prompt, image_prompt_type, max_frames, remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step) loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step)
defaults_filename = get_settings_file_name(image2video) defaults_filename = get_settings_file_name(image2video)
@ -2379,6 +2433,25 @@ def download_loras():
writer.write(f"Loras downloaded on the {dt} at {time.time()} on the {time.time()}") writer.write(f"Loras downloaded on the {dt} at {time.time()} on the {time.time()}")
return return
def refresh_i2v_image_prompt_type_radio(state, image_prompt_type_radio):
if args.multiple_images:
return gr.Gallery(visible = (image_prompt_type_radio == 1) )
else:
return gr.Image(visible = (image_prompt_type_radio == 1) )
def refresh_t2v_image_prompt_type_radio(state, image_prompt_type_radio):
vace_model = "Vace" in state["image_input_type_model"] and not state["image2video"]
return gr.Column(visible= vace_model), gr.Radio(value= image_prompt_type_radio), gr.Gallery(visible = "I" in image_prompt_type_radio), gr.Video(visible= "V" in image_prompt_type_radio),gr.Video(visible= "M" in image_prompt_type_radio ), gr.Text(visible= "V" in image_prompt_type_radio) , gr.Checkbox(visible= "I" in image_prompt_type_radio)
def check_refresh_input_type(state):
if not state["image2video"]:
model_file_name = state["image_input_type_model"]
model_file_needed= model_needed(False)
if model_file_name != model_file_needed:
state["image_input_type_model"] = model_file_needed
return gr.Text(value= str(time.time()))
return gr.Text()
def generate_video_tab(image2video=False): def generate_video_tab(image2video=False):
filename = transformer_filename_i2v if image2video else transformer_filename_t2v filename = transformer_filename_i2v if image2video else transformer_filename_t2v
ui_defaults= get_default_settings(filename, image2video) ui_defaults= get_default_settings(filename, image2video)
@ -2387,6 +2460,7 @@ def generate_video_tab(image2video=False):
state_dict["advanced"] = advanced state_dict["advanced"] = advanced
state_dict["loras_model"] = filename state_dict["loras_model"] = filename
state_dict["image_input_type_model"] = filename
state_dict["image2video"] = image2video state_dict["image2video"] = image2video
gen = dict() gen = dict()
gen["queue"] = [] gen["queue"] = []
@ -2461,31 +2535,51 @@ def generate_video_tab(image2video=False):
save_lset_btn = gr.Button("Save", size="sm", min_width= 1) save_lset_btn = gr.Button("Save", size="sm", min_width= 1)
delete_lset_btn = gr.Button("Delete", size="sm", min_width= 1) delete_lset_btn = gr.Button("Delete", size="sm", min_width= 1)
cancel_lset_btn = gr.Button("Don't do it !", size="sm", min_width= 1 , visible=False) cancel_lset_btn = gr.Button("Don't do it !", size="sm", min_width= 1 , visible=False)
video_to_continue = gr.Video(label= "Video to continue", visible= image2video and False) #######
image_prompt_type= ui_defaults.get("image_prompt_type",0)
image_prompt_type_radio = gr.Radio( [("Use only a Start Image", 0),("Use both a Start and an End Image", 1)], value =image_prompt_type, label="Location", show_label= False, scale= 3, visible=image2video)
if args.multiple_images: state = gr.State(state_dict)
image_to_continue = gr.Gallery( vace_model = "Vace" in filename and not image2video
label="Images as starting points for new videos", type ="pil", #file_types= "image", trigger_refresh_input_type = gr.Text(interactive= False, visible= False)
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=image2video) with gr.Column(visible= image2video or vace_model) as image_prompt_column:
else: if image2video:
image_to_continue = gr.Image(label= "Image as a starting point for a new video", type ="pil", visible=image2video) image_source3 = gr.Video(label= "Placeholder", visible= image2video and False)
if args.multiple_images: image_prompt_type= ui_defaults.get("image_prompt_type",0)
image_to_end = gr.Gallery( image_prompt_type_radio = gr.Radio( [("Use only a Start Image", 0),("Use both a Start and an End Image", 1)], value =image_prompt_type, label="Location", show_label= False, scale= 3)
label="Images as ending points for new videos", type ="pil", #file_types= "image",
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=image_prompt_type==1)
else:
image_to_end = gr.Image(label= "Last Image for a new video", type ="pil", visible=image_prompt_type==1)
def switch_image_prompt_type_radio(image_prompt_type_radio): if args.multiple_images:
if args.multiple_images: image_source1 = gr.Gallery(
return gr.Gallery(visible = (image_prompt_type_radio == 1) ) label="Images as starting points for new videos", type ="pil", #file_types= "image",
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True)
else:
image_source1 = gr.Image(label= "Image as a starting point for a new video", type ="pil")
if args.multiple_images:
image_source2 = gr.Gallery(
label="Images as ending points for new videos", type ="pil", #file_types= "image",
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=image_prompt_type==1)
else:
image_source2 = gr.Image(label= "Last Image for a new video", type ="pil", visible=image_prompt_type==1)
image_prompt_type_radio.change(fn=refresh_i2v_image_prompt_type_radio, inputs=[state, image_prompt_type_radio], outputs=[image_source2])
max_frames = gr.Slider(1, 100,step=1, visible = False)
remove_background_image_ref = gr.Text(visible = False)
else: else:
return gr.Image(visible = (image_prompt_type_radio == 1) ) image_prompt_type= ui_defaults.get("image_prompt_type","I")
image_prompt_type_radio = gr.Radio( [("Use Images Ref", "I"),("a Video", "V"), ("Images + a Video", "IV"), ("Video + Video Mask", "VM"), ("Images + Video + Mask", "IVM")], value =image_prompt_type, label="Location", show_label= False, scale= 3, visible = vace_model)
image_source1 = gr.Gallery(
label="Reference Images of Faces and / or Object to be found in the Video", type ="pil",
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in image_prompt_type )
image_prompt_type_radio.change(fn=switch_image_prompt_type_radio, inputs=[image_prompt_type_radio], outputs=[image_to_end]) image_source2 = gr.Video(label= "Reference Video", visible= "V" in image_prompt_type )
with gr.Row():
max_frames = gr.Slider(0, 100, value=ui_defaults.get("max_frames",0), step=1, label="Nb of frames in Reference Video to use in Video (0 for as many as possible)", visible= "V" in image_prompt_type, scale = 2 )
remove_background_image_ref = gr.Checkbox(value=ui_defaults.get("remove_background_image_ref",1), label= "Remove Images Ref. Background", visible= "I" in image_prompt_type, scale =1 )
image_source3 = gr.Video(label= "Video Mask (white pixels = Mask)", visible= "M" in image_prompt_type )
gr.on(triggers=[image_prompt_type_radio.change, trigger_refresh_input_type.change], fn=refresh_t2v_image_prompt_type_radio, inputs=[state, image_prompt_type_radio], outputs=[image_prompt_column, image_prompt_type_radio, image_source1, image_source2, image_source3, max_frames, remove_background_image_ref])
advanced_prompt = advanced advanced_prompt = advanced
@ -2518,7 +2612,6 @@ def generate_video_tab(image2video=False):
wizard_prompt = gr.Textbox(visible = not advanced_prompt, label="Prompts (each new line of prompt will generate a new video, # lines = comments)", value=default_wizard_prompt, lines=3) wizard_prompt = gr.Textbox(visible = not advanced_prompt, label="Prompts (each new line of prompt will generate a new video, # lines = comments)", value=default_wizard_prompt, lines=3)
wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False) wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False)
wizard_variables_var = gr.Text(wizard_variables, visible = False) wizard_variables_var = gr.Text(wizard_variables, visible = False)
state = gr.State(state_dict)
with gr.Row(): with gr.Row():
if image2video: if image2video:
resolution = gr.Dropdown( resolution = gr.Dropdown(
@ -2555,8 +2648,6 @@ def generate_video_tab(image2video=False):
video_length = gr.Slider(5, 193, value=ui_defaults["video_length"], step=4, label="Number of frames (16 = 1s)") video_length = gr.Slider(5, 193, value=ui_defaults["video_length"], step=4, label="Number of frames (16 = 1s)")
with gr.Column(): with gr.Column():
num_inference_steps = gr.Slider(1, 100, value=ui_defaults["num_inference_steps"], step=1, label="Number of Inference Steps") num_inference_steps = gr.Slider(1, 100, value=ui_defaults["num_inference_steps"], step=1, label="Number of Inference Steps")
with gr.Row():
max_frames = gr.Slider(1, 100, value=9, step=1, label="Number of input frames to use for Video2World prediction", visible=image2video and False) #########
show_advanced = gr.Checkbox(label="Advanced Mode", value=advanced) show_advanced = gr.Checkbox(label="Advanced Mode", value=advanced)
with gr.Row(visible=advanced) as advanced_row: with gr.Row(visible=advanced) as advanced_row:
with gr.Column(): with gr.Column():
@ -2605,7 +2696,7 @@ def generate_video_tab(image2video=False):
tea_cache_start_step_perc = gr.Slider(0, 100, value=ui_defaults["tea_cache_start_step_perc"], step=1, label="Tea Cache starting moment in % of generation") tea_cache_start_step_perc = gr.Slider(0, 100, value=ui_defaults["tea_cache_start_step_perc"], step=1, label="Tea Cache starting moment in % of generation")
with gr.Row(): with gr.Row():
gr.Markdown("<B>Upsampling</B>") gr.Markdown("<B>Upsampling - postprocessing that may improve fluidity and the size of the video</B>")
with gr.Row(): with gr.Row():
temporal_upsampling_choice = gr.Dropdown( temporal_upsampling_choice = gr.Dropdown(
choices=[ choices=[
@ -2687,9 +2778,10 @@ def generate_video_tab(image2video=False):
show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then( show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars]) fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars])
with gr.Column(): with gr.Column():
gen_status = gr.Text(label="Status", interactive= False) gen_status = gr.Text(interactive= False)
full_sync = gr.Text(label="Status", interactive= False, visible= False) full_sync = gr.Text(interactive= False, visible= False)
light_sync = gr.Text(label="Status", interactive= False, visible= False) light_sync = gr.Text(interactive= False, visible= False)
gen_progress_html = gr.HTML( gen_progress_html = gr.HTML(
label="Status", label="Status",
value="Idle", value="Idle",
@ -2709,8 +2801,8 @@ def generate_video_tab(image2video=False):
abort_btn = gr.Button("Abort") abort_btn = gr.Button("Abort")
queue_df = gr.DataFrame( queue_df = gr.DataFrame(
headers=["Qty","Prompt", "Length","Steps","Start", "End", "", "", ""], headers=["Qty","Prompt", "Length","Steps","", "", "", "", ""],
datatype=[ "str","markdown","str", "markdown", "markdown", "markdown", "str", "str", "str"], datatype=[ "str","markdown","str", "markdown", "markdown", "markdown", "str", "str", "str"],
column_widths= ["50","", "65","55", "60", "60", "30", "30", "35"], column_widths= ["50","", "65","55", "60", "60", "30", "30", "35"],
interactive=False, interactive=False,
col_count=(9, "fixed"), col_count=(9, "fixed"),
@ -2792,7 +2884,7 @@ def generate_video_tab(image2video=False):
show_progress="hidden" show_progress="hidden"
) )
save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then( save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
save_settings, inputs = [state, prompt, image_prompt_type_radio, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, save_settings, inputs = [state, prompt, image_prompt_type_radio, max_frames, remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt,
loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling_choice, spatial_upsampling_choice, RIFLEx_setting, slg_switch, slg_layers, loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling_choice, spatial_upsampling_choice, RIFLEx_setting, slg_switch, slg_layers,
slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step ], outputs = []) slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step ], outputs = [])
save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop]) save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
@ -2808,21 +2900,30 @@ def generate_video_tab(image2video=False):
refresh_lora_btn2.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices]) refresh_lora_btn2.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
output.select(select_video, state, None ) output.select(select_video, state, None )
gen_status.change(refresh_gallery, gen_status.change(refresh_gallery,
inputs = [state, gen_status], inputs = [state, gen_status],
outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn]) outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn])
full_sync.change(refresh_gallery, full_sync.change(fn= check_refresh_input_type,
inputs= [state],
outputs= [trigger_refresh_input_type]
).then(fn=refresh_gallery,
inputs = [state, gen_status], inputs = [state, gen_status],
outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn] outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn]
).then( fn=wait_tasks_done, ).then(fn=wait_tasks_done,
inputs= [state], inputs= [state],
outputs =[gen_status], outputs =[gen_status],
).then(finalize_generation, ).then(finalize_generation,
inputs= [state], inputs= [state],
outputs= [output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info] outputs= [output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info]
) )
light_sync.change(refresh_gallery,
light_sync.change(fn= check_refresh_input_type,
inputs= [state],
outputs= [trigger_refresh_input_type]
).then(fn=refresh_gallery,
inputs = [state, gen_status], inputs = [state, gen_status],
outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn] outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn]
) )
@ -2848,10 +2949,11 @@ def generate_video_tab(image2video=False):
loras_choices, loras_choices,
loras_mult_choices, loras_mult_choices,
image_prompt_type_radio, image_prompt_type_radio,
image_to_continue, image_source1,
image_to_end, image_source2,
video_to_continue, image_source3,
max_frames, max_frames,
remove_background_image_ref,
temporal_upsampling_choice, temporal_upsampling_choice,
spatial_upsampling_choice, spatial_upsampling_choice,
RIFLEx_setting, RIFLEx_setting,
@ -2902,7 +3004,7 @@ def generate_video_tab(image2video=False):
) )
return loras_column, loras_choices, presets_column, lset_name, header, light_sync, full_sync, state return loras_column, loras_choices, presets_column, lset_name, header, light_sync, full_sync, state
def generate_doxnload_tab(presets_column, loras_column, lset_name,loras_choices, state): def generate_download_tab(presets_column, loras_column, lset_name,loras_choices, state):
with gr.Row(): with gr.Row():
with gr.Row(scale =2): with gr.Row(scale =2):
gr.Markdown("<I>Wan2GP's Lora Festival ! Press the following button to download i2v <B>Remade</B> Loras collection (and bonuses Loras).") gr.Markdown("<I>Wan2GP's Lora Festival ! Press the following button to download i2v <B>Remade</B> Loras collection (and bonuses Loras).")
@ -2928,6 +3030,7 @@ def generate_configuration_tab():
("WAN 2.1 1.3B Text to Video 16 bits (recommended)- the small model for fast generations with low VRAM requirements", 0), ("WAN 2.1 1.3B Text to Video 16 bits (recommended)- the small model for fast generations with low VRAM requirements", 0),
("WAN 2.1 14B Text to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 1), ("WAN 2.1 14B Text to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 1),
("WAN 2.1 14B Text to Video quantized to 8 bits (recommended) - the default engine but quantized", 2), ("WAN 2.1 14B Text to Video quantized to 8 bits (recommended) - the default engine but quantized", 2),
("WAN 2.1 VACE 1.3B Text to Video / Control Net - text generation driven by reference images or videos", 3),
], ],
value= index, value= index,
label="Transformer model for Text to Video", label="Transformer model for Text to Video",
@ -3108,16 +3211,17 @@ def on_tab_select(global_state, t2v_state, i2v_state, evt: gr.SelectData):
t2v_light_sync = gr.Text() t2v_light_sync = gr.Text()
i2v_full_sync = gr.Text() i2v_full_sync = gr.Text()
t2v_full_sync = gr.Text() t2v_full_sync = gr.Text()
last_tab_was_image2video =global_state.get("last_tab_was_image2video", None)
if last_tab_was_image2video == None or last_tab_was_image2video:
gen = i2v_state["gen"]
t2v_state["gen"] = gen
else:
gen = t2v_state["gen"]
i2v_state["gen"] = gen
if new_t2v or new_i2v: if new_t2v or new_i2v:
last_tab_was_image2video =global_state.get("last_tab_was_image2video", None)
if last_tab_was_image2video == None or last_tab_was_image2video:
gen = i2v_state["gen"]
t2v_state["gen"] = gen
else:
gen = t2v_state["gen"]
i2v_state["gen"] = gen
if last_tab_was_image2video != None and new_t2v != new_i2v: if last_tab_was_image2video != None and new_t2v != new_i2v:
gen_location = gen.get("location", None) gen_location = gen.get("location", None)
if "in_progress" in gen and gen_location !=None and not (gen_location and new_i2v or not gen_location and new_t2v) : if "in_progress" in gen and gen_location !=None and not (gen_location and new_i2v or not gen_location and new_t2v) :
@ -3131,7 +3235,6 @@ def on_tab_select(global_state, t2v_state, i2v_state, evt: gr.SelectData):
else: else:
t2v_light_sync = gr.Text(str(time.time())) t2v_light_sync = gr.Text(str(time.time()))
global_state["last_tab_was_image2video"] = new_i2v global_state["last_tab_was_image2video"] = new_i2v
if(server_config.get("reload_model",2) == 1): if(server_config.get("reload_model",2) == 1):
@ -3433,7 +3536,7 @@ def create_demo():
} }
""" """
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md"), title= "Wan2GP") as demo: with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md"), title= "Wan2GP") as demo:
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v3.4 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>") gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v4.0 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !</FONT>") gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !</FONT>")
with gr.Accordion("Click here for some Info on how to use Wan2GP", open = False): with gr.Accordion("Click here for some Info on how to use Wan2GP", open = False):
@ -3454,7 +3557,7 @@ def create_demo():
i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name, i2v_header, i2v_light_sync, i2v_full_sync, i2v_state = generate_video_tab(True) i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name, i2v_header, i2v_light_sync, i2v_full_sync, i2v_state = generate_video_tab(True)
if not args.lock_config: if not args.lock_config:
with gr.Tab("Downloads", id="downloads") as downloads_tab: with gr.Tab("Downloads", id="downloads") as downloads_tab:
generate_doxnload_tab(i2v_presets_column, i2v_loras_column, i2v_lset_name, i2v_loras_choices, i2v_state) generate_download_tab(i2v_presets_column, i2v_loras_column, i2v_lset_name, i2v_loras_choices, i2v_state)
with gr.Tab("Configuration"): with gr.Tab("Configuration"):
generate_configuration_tab() generate_configuration_tab()
with gr.Tab("About"): with gr.Tab("About"):

View File

@ -11,11 +11,15 @@ easydict
ftfy ftfy
dashscope dashscope
imageio-ffmpeg imageio-ffmpeg
# flash_attn # flash_attn
gradio>=5.0.0 gradio>=5.0.0
numpy>=1.23.5,<2 numpy>=1.23.5,<2
einops einops
moviepy==1.0.3 moviepy==1.0.3
mmgp==3.3.4 mmgp==3.3.4
peft==0.14.0 peft==0.14.0
mutagen mutagen
decord
onnxruntime-gpu
rembg[gpu]==2.0.65
# rembg==2.0.65

View File

@ -40,3 +40,17 @@ SUPPORTED_SIZES = {
'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), 'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
't2i-14B': tuple(SIZE_CONFIGS.keys()), 't2i-14B': tuple(SIZE_CONFIGS.keys()),
} }
VACE_SIZE_CONFIGS = {
'480*832': (480, 832),
'832*480': (832, 480),
}
VACE_MAX_AREA_CONFIGS = {
'480*832': 480 * 832,
'832*480': 832 * 480,
}
VACE_SUPPORTED_SIZES = {
'vace-1.3B': ('480*832', '832*480'),
}

View File

@ -377,6 +377,7 @@ class WanI2VCrossAttention(WanSelfAttention):
return x return x
WAN_CROSSATTENTION_CLASSES = { WAN_CROSSATTENTION_CLASSES = {
't2v_cross_attn': WanT2VCrossAttention, 't2v_cross_attn': WanT2VCrossAttention,
'i2v_cross_attn': WanI2VCrossAttention, 'i2v_cross_attn': WanI2VCrossAttention,
@ -393,7 +394,9 @@ class WanAttentionBlock(nn.Module):
window_size=(-1, -1), window_size=(-1, -1),
qk_norm=True, qk_norm=True,
cross_attn_norm=False, cross_attn_norm=False,
eps=1e-6): eps=1e-6,
block_id=None
):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.ffn_dim = ffn_dim self.ffn_dim = ffn_dim
@ -422,6 +425,7 @@ class WanAttentionBlock(nn.Module):
# modulation # modulation
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
self.block_id = block_id
def forward( def forward(
self, self,
@ -432,6 +436,8 @@ class WanAttentionBlock(nn.Module):
freqs, freqs,
context, context,
context_lens, context_lens,
hints= None,
context_scale=1.0,
): ):
r""" r"""
Args: Args:
@ -480,10 +486,49 @@ class WanAttentionBlock(nn.Module):
x.addcmul_(y, e[5]) x.addcmul_(y, e[5])
if self.block_id is not None and hints != None:
if context_scale == 1:
x.add_(hints[self.block_id])
else:
x.add_(hints[self.block_id], alpha =context_scale)
return x
return x class VaceWanAttentionBlock(WanAttentionBlock):
def __init__(
self,
cross_attn_type,
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6,
block_id=0
):
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
self.block_id = block_id
if block_id == 0:
self.before_proj = nn.Linear(self.dim, self.dim)
nn.init.zeros_(self.before_proj.weight)
nn.init.zeros_(self.before_proj.bias)
self.after_proj = nn.Linear(self.dim, self.dim)
nn.init.zeros_(self.after_proj.weight)
nn.init.zeros_(self.after_proj.bias)
def forward(self, c, x, **kwargs):
# behold dbm magic !
if self.block_id == 0:
c = self.before_proj(c) + x
all_c = []
else:
all_c = c
c = all_c.pop(-1)
c = super().forward(c, **kwargs)
c_skip = self.after_proj(c)
all_c += [c_skip, c]
return all_c
class Head(nn.Module): class Head(nn.Module):
def __init__(self, dim, out_dim, patch_size, eps=1e-6): def __init__(self, dim, out_dim, patch_size, eps=1e-6):
@ -544,6 +589,8 @@ class WanModel(ModelMixin, ConfigMixin):
@register_to_config @register_to_config
def __init__(self, def __init__(self,
vace_layers=None,
vace_in_dim=None,
model_type='t2v', model_type='t2v',
patch_size=(1, 2, 2), patch_size=(1, 2, 2),
text_len=512, text_len=512,
@ -628,12 +675,13 @@ class WanModel(ModelMixin, ConfigMixin):
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
# blocks # blocks
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' if vace_layers == None:
self.blocks = nn.ModuleList([ cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, self.blocks = nn.ModuleList([
window_size, qk_norm, cross_attn_norm, eps) WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
for _ in range(num_layers) window_size, qk_norm, cross_attn_norm, eps)
]) for _ in range(num_layers)
])
# head # head
self.head = Head(dim, out_dim, patch_size, eps) self.head = Head(dim, out_dim, patch_size, eps)
@ -646,6 +694,33 @@ class WanModel(ModelMixin, ConfigMixin):
# initialize weights # initialize weights
self.init_weights() self.init_weights()
if vace_layers != None:
self.vace_layers = [i for i in range(0, self.num_layers, 2)] if vace_layers is None else vace_layers
self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim
assert 0 in self.vace_layers
self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
# blocks
self.blocks = nn.ModuleList([
WanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
self.cross_attn_norm, self.eps,
block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None)
for i in range(self.num_layers)
])
# vace blocks
self.vace_blocks = nn.ModuleList([
VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
self.cross_attn_norm, self.eps, block_id=i)
for i in self.vace_layers
])
# vace patch embeddings
self.vace_patch_embedding = nn.Conv3d(
self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size
)
def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0): def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
rescale_func = np.poly1d(self.coefficients) rescale_func = np.poly1d(self.coefficients)
@ -688,6 +763,36 @@ class WanModel(ModelMixin, ConfigMixin):
self.rel_l1_thresh = best_threshold self.rel_l1_thresh = best_threshold
print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}") print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}")
return best_threshold return best_threshold
def forward_vace(
self,
x,
vace_context,
seq_len,
context,
e,
kwargs
):
# embeddings
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c]
if (len(c) == 1 and seq_len == c[0].size(1)):
c = c[0]
else:
c = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in c
])
# arguments
new_kwargs = dict(x=x)
new_kwargs.update(kwargs)
for block in self.vace_blocks:
c = block(c, context= context, e= e, **new_kwargs)
hints = c[:-1]
return hints
def forward( def forward(
self, self,
@ -695,6 +800,8 @@ class WanModel(ModelMixin, ConfigMixin):
t, t,
context, context,
seq_len, seq_len,
vace_context = None,
vace_context_scale=1.0,
clip_fea=None, clip_fea=None,
y=None, y=None,
freqs = None, freqs = None,
@ -829,13 +936,23 @@ class WanModel(ModelMixin, ConfigMixin):
self.previous_residual_cond = None self.previous_residual_cond = None
ori_hidden_states = x_list[0].clone() ori_hidden_states = x_list[0].clone()
# arguments # arguments
kwargs = dict( kwargs = dict(
# e=e0,
seq_lens=seq_lens, seq_lens=seq_lens,
grid_sizes=grid_sizes, grid_sizes=grid_sizes,
freqs=freqs, freqs=freqs,
# context=context,
context_lens=context_lens) context_lens=context_lens)
if vace_context == None:
hints_list = [None ] *len(x_list)
else:
hints_list = []
for x, context in zip(x_list, context_list) :
hints_list.append( self.forward_vace(x, vace_context, seq_len, context= context, e= e0, kwargs= kwargs))
del x, context
kwargs['context_scale'] = vace_context_scale
for block_idx, block in enumerate(self.blocks): for block_idx, block in enumerate(self.blocks):
offload.shared_state["layer"] = block_idx offload.shared_state["layer"] = block_idx
if callback != None: if callback != None:
@ -852,9 +969,10 @@ class WanModel(ModelMixin, ConfigMixin):
x_list[0] = block(x_list[0], context = context_list[0], e= e0, **kwargs) x_list[0] = block(x_list[0], context = context_list[0], e= e0, **kwargs)
else: else:
for i, (x, context) in enumerate(zip(x_list, context_list)): for i, (x, context, hints) in enumerate(zip(x_list, context_list, hints_list)):
x_list[i] = block(x, context = context, e= e0, **kwargs) x_list[i] = block(x, context = context, hints= hints, e= e0, **kwargs)
del x del x
del context, hints
if self.enable_teacache: if self.enable_teacache:
if joint_pass: if joint_pass:

View File

@ -13,7 +13,9 @@ import torch
import torch.cuda.amp as amp import torch.cuda.amp as amp
import torch.distributed as dist import torch.distributed as dist
from tqdm import tqdm from tqdm import tqdm
from PIL import Image
import torchvision.transforms.functional as TF
import torch.nn.functional as F
from .distributed.fsdp import shard_model from .distributed.fsdp import shard_model
from .modules.model import WanModel from .modules.model import WanModel
from .modules.t5 import T5EncoderModel from .modules.t5 import T5EncoderModel
@ -22,6 +24,7 @@ from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps) get_sampling_sigmas, retrieve_timesteps)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from wan.modules.posemb_layers import get_rotary_pos_embed from wan.modules.posemb_layers import get_rotary_pos_embed
from .utils.vace_preprocessor import VaceVideoProcessor
def optimized_scale(positive_flat, negative_flat): def optimized_scale(positive_flat, negative_flat):
@ -105,8 +108,6 @@ class WanT2V:
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel, writable_tensors= False) self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel, writable_tensors= False)
self.model.eval().requires_grad_(False) self.model.eval().requires_grad_(False)
if use_usp: if use_usp:
@ -132,8 +133,148 @@ class WanT2V:
self.sample_neg_prompt = config.sample_neg_prompt self.sample_neg_prompt = config.sample_neg_prompt
if "Vace" in model_filename:
self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]),
min_area=480*832,
max_area=480*832,
min_fps=config.sample_fps,
max_fps=config.sample_fps,
zero_start=True,
seq_len=32760,
keep_last=True)
def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0):
if ref_images is None:
ref_images = [None] * len(frames)
else:
assert len(frames) == len(ref_images)
if masks is None:
latents = self.vae.encode(frames, tile_size = tile_size)
else:
inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
inactive = self.vae.encode(inactive, tile_size = tile_size)
reactive = self.vae.encode(reactive, tile_size = tile_size)
latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
cat_latents = []
for latent, refs in zip(latents, ref_images):
if refs is not None:
if masks is None:
ref_latent = self.vae.encode(refs, tile_size = tile_size)
else:
ref_latent = self.vae.encode(refs, tile_size = tile_size)
ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent]
assert all([x.shape[1] == 1 for x in ref_latent])
latent = torch.cat([*ref_latent, latent], dim=1)
cat_latents.append(latent)
return cat_latents
def vace_encode_masks(self, masks, ref_images=None):
if ref_images is None:
ref_images = [None] * len(masks)
else:
assert len(masks) == len(ref_images)
result_masks = []
for mask, refs in zip(masks, ref_images):
c, depth, height, width = mask.shape
new_depth = int((depth + 3) // self.vae_stride[0])
height = 2 * (int(height) // (self.vae_stride[1] * 2))
width = 2 * (int(width) // (self.vae_stride[2] * 2))
# reshape
mask = mask[0, :, :, :]
mask = mask.view(
depth, height, self.vae_stride[1], width, self.vae_stride[1]
) # depth, height, 8, width, 8
mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width
mask = mask.reshape(
self.vae_stride[1] * self.vae_stride[2], depth, height, width
) # 8*8, depth, height, width
# interpolation
mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0)
if refs is not None:
length = len(refs)
mask_pad = torch.zeros_like(mask[:, :length, :, :])
mask = torch.cat((mask_pad, mask), dim=1)
result_masks.append(mask)
return result_masks
def vace_latent(self, z, m):
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device, trim_video= 0):
image_sizes = []
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
if sub_src_mask is not None and sub_src_video is not None:
src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video)
src_video[i] = src_video[i].to(device)
src_mask[i] = src_mask[i].to(device)
src_video_shape = src_video[i].shape
if src_video_shape[1] != num_frames:
src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
image_sizes.append(src_video[i].shape[2:])
elif sub_src_video is None:
src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
src_mask[i] = torch.ones_like(src_video[i], device=device)
image_sizes.append(image_size)
else:
src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video)
src_video[i] = src_video[i].to(device)
src_video_shape = src_video[i].shape
if src_video_shape[1] != num_frames:
src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
src_mask[i] = torch.ones_like(src_video[i], device=device)
image_sizes.append(src_video[i].shape[2:])
for i, ref_images in enumerate(src_ref_images):
if ref_images is not None:
image_size = image_sizes[i]
for j, ref_img in enumerate(ref_images):
if ref_img is not None:
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
if ref_img.shape[-2:] != image_size:
canvas_height, canvas_width = image_size
ref_height, ref_width = ref_img.shape[-2:]
white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1]
scale = min(canvas_height / ref_height, canvas_width / ref_width)
new_height = int(ref_height * scale)
new_width = int(ref_width * scale)
resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1)
top = (canvas_height - new_height) // 2
left = (canvas_width - new_width) // 2
white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image
ref_img = white_canvas
src_ref_images[i][j] = ref_img.to(device)
return src_video, src_mask, src_ref_images
def decode_latent(self, zs, ref_images=None, tile_size= 0 ):
if ref_images is None:
ref_images = [None] * len(zs)
else:
assert len(zs) == len(ref_images)
trimed_zs = []
for z, refs in zip(zs, ref_images):
if refs is not None:
z = z[:, len(refs):, :, :]
trimed_zs.append(z)
return self.vae.decode(trimed_zs, tile_size= tile_size)
def generate(self, def generate(self,
input_prompt, input_prompt,
input_frames= None,
input_masks = None,
input_ref_images = None,
context_scale=1.0,
size=(1280, 720), size=(1280, 720),
frame_num=81, frame_num=81,
shift=5.0, shift=5.0,
@ -187,14 +328,6 @@ class WanT2V:
- W: Frame width from size) - W: Frame width from size)
""" """
# preprocess # preprocess
F = frame_num
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
size[1] // self.vae_stride[1],
size[0] // self.vae_stride[2])
seq_len = math.ceil((target_shape[2] * target_shape[3]) /
(self.patch_size[1] * self.patch_size[2]) *
target_shape[1] / self.sp_size) * self.sp_size
if n_prompt == "": if n_prompt == "":
n_prompt = self.sample_neg_prompt n_prompt = self.sample_neg_prompt
@ -213,6 +346,29 @@ class WanT2V:
context_null = self.text_encoder([n_prompt], torch.device('cpu')) context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context] context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null] context_null = [t.to(self.device) for t in context_null]
if input_frames != None:
# vace context encode
input_frames = [u.to(self.device) for u in input_frames]
input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images]
input_masks = [u.to(self.device) for u in input_masks]
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size)
m0 = self.vace_encode_masks(input_masks, input_ref_images)
z = self.vace_latent(z0, m0)
target_shape = list(z0[0].shape)
target_shape[0] = int(target_shape[0] / 2)
else:
F = frame_num
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
size[1] // self.vae_stride[1],
size[0] // self.vae_stride[2])
seq_len = math.ceil((target_shape[2] * target_shape[3]) /
(self.patch_size[1] * self.patch_size[2]) *
target_shape[1] / self.sp_size) * self.sp_size
noise = [ noise = [
torch.randn( torch.randn(
@ -261,10 +417,12 @@ class WanT2V:
arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self} arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self} arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self} arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
if input_frames != None:
vace_dict = {'vace_context' : z, 'vace_context_scale' : context_scale}
arg_c.update(vace_dict)
arg_null.update(vace_dict)
arg_both.update(vace_dict)
# arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
# arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
# arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
if self.model.enable_teacache: if self.model.enable_teacache:
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier) self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
if callback != None: if callback != None:
@ -281,7 +439,7 @@ class WanT2V:
# self.model.to(self.device) # self.model.to(self.device)
if joint_pass: if joint_pass:
noise_pred_cond, noise_pred_uncond = self.model( noise_pred_cond, noise_pred_uncond = self.model(
latent_model_input, t=timestep,current_step=i, slg_layers=slg_layers_local, **arg_both) latent_model_input, t=timestep, current_step=i, slg_layers=slg_layers_local, **arg_both)
if self._interrupt: if self._interrupt:
return None return None
else: else:
@ -329,7 +487,11 @@ class WanT2V:
self.model.cpu() self.model.cpu()
torch.cuda.empty_cache() torch.cuda.empty_cache()
if self.rank == 0: if self.rank == 0:
videos = self.vae.decode(x0, VAE_tile_size)
if input_frames == None:
videos = self.vae.decode(x0, VAE_tile_size)
else:
videos = self.decode_latent(x0, input_ref_images, VAE_tile_size)
del noise, latents del noise, latents

View File

@ -3,21 +3,70 @@ import argparse
import binascii import binascii
import os import os
import os.path as osp import os.path as osp
import torchvision.transforms.functional as TF
import torch.nn.functional as F
import imageio import imageio
import torch import torch
import decord
import torchvision import torchvision
from PIL import Image from PIL import Image
import numpy as np import numpy as np
from rembg import remove, new_session
__all__ = ['cache_video', 'cache_image', 'str2bool'] __all__ = ['cache_video', 'cache_image', 'str2bool']
from PIL import Image
def get_video_frame(file_name, frame_no):
decord.bridge.set_bridge('torch')
reader = decord.VideoReader(file_name)
frame = reader.get_batch([frame_no]).squeeze(0)
img = Image.fromarray(frame.numpy().astype(np.uint8))
return img
def resize_lanczos(img, h, w): def resize_lanczos(img, h, w):
img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8))
img = img.resize((w,h), resample=Image.Resampling.LANCZOS) img = img.resize((w,h), resample=Image.Resampling.LANCZOS)
return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0) return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0)
def remove_background(img, session=None):
if session ==None:
session = new_session()
img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8))
img = remove(img, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0)
def resize_and_remove_background(img_list, canvas_width, canvas_height, rm_background ):
if rm_background:
session = new_session()
output_list =[]
for img in img_list:
width, height = img.size
white_canvas = np.full( (canvas_height, canvas_width, 3), 255, dtype= np.uint8 )
scale = min(canvas_height / height, canvas_width / width)
new_height = int(height * scale)
new_width = int(width * scale)
resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS)
if rm_background:
resized_image = remove(resized_image, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
top = (canvas_height - new_height) // 2
left = (canvas_width - new_width) // 2
white_canvas[top:top + new_height, left:left + new_width, :] = np.array(resized_image)
img = Image.fromarray(white_canvas)
output_list.append(img)
return output_list
def rand_name(length=8, suffix=''): def rand_name(length=8, suffix=''):
name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
if suffix: if suffix:

View File

@ -0,0 +1,298 @@
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
class VaceImageProcessor(object):
def __init__(self, downsample=None, seq_len=None):
self.downsample = downsample
self.seq_len = seq_len
def _pillow_convert(self, image, cvt_type='RGB'):
if image.mode != cvt_type:
if image.mode == 'P':
image = image.convert(f'{cvt_type}A')
if image.mode == f'{cvt_type}A':
bg = Image.new(cvt_type,
size=(image.width, image.height),
color=(255, 255, 255))
bg.paste(image, (0, 0), mask=image)
image = bg
else:
image = image.convert(cvt_type)
return image
def _load_image(self, img_path):
if img_path is None or img_path == '':
return None
img = Image.open(img_path)
img = self._pillow_convert(img)
return img
def _resize_crop(self, img, oh, ow, normalize=True):
"""
Resize, center crop, convert to tensor, and normalize.
"""
# resize and crop
iw, ih = img.size
if iw != ow or ih != oh:
# resize
scale = max(ow / iw, oh / ih)
img = img.resize(
(round(scale * iw), round(scale * ih)),
resample=Image.Resampling.LANCZOS
)
assert img.width >= ow and img.height >= oh
# center crop
x1 = (img.width - ow) // 2
y1 = (img.height - oh) // 2
img = img.crop((x1, y1, x1 + ow, y1 + oh))
# normalize
if normalize:
img = TF.to_tensor(img).sub_(0.5).div_(0.5).unsqueeze(1)
return img
def _image_preprocess(self, img, oh, ow, normalize=True, **kwargs):
return self._resize_crop(img, oh, ow, normalize)
def load_image(self, data_key, **kwargs):
return self.load_image_batch(data_key, **kwargs)
def load_image_pair(self, data_key, data_key2, **kwargs):
return self.load_image_batch(data_key, data_key2, **kwargs)
def load_image_batch(self, *data_key_batch, normalize=True, seq_len=None, **kwargs):
seq_len = self.seq_len if seq_len is None else seq_len
imgs = []
for data_key in data_key_batch:
img = self._load_image(data_key)
imgs.append(img)
w, h = imgs[0].size
dh, dw = self.downsample[1:]
# compute output size
scale = min(1., np.sqrt(seq_len / ((h / dh) * (w / dw))))
oh = int(h * scale) // dh * dh
ow = int(w * scale) // dw * dw
assert (oh // dh) * (ow // dw) <= seq_len
imgs = [self._image_preprocess(img, oh, ow, normalize) for img in imgs]
return *imgs, (oh, ow)
class VaceVideoProcessor(object):
def __init__(self, downsample, min_area, max_area, min_fps, max_fps, zero_start, seq_len, keep_last, **kwargs):
self.downsample = downsample
self.min_area = min_area
self.max_area = max_area
self.min_fps = min_fps
self.max_fps = max_fps
self.zero_start = zero_start
self.keep_last = keep_last
self.seq_len = seq_len
assert seq_len >= min_area / (self.downsample[1] * self.downsample[2])
@staticmethod
def resize_crop(video: torch.Tensor, oh: int, ow: int):
"""
Resize, center crop and normalize for decord loaded video (torch.Tensor type)
Parameters:
video - video to process (torch.Tensor): Tensor from `reader.get_batch(frame_ids)`, in shape of (T, H, W, C)
oh - target height (int)
ow - target width (int)
Returns:
The processed video (torch.Tensor): Normalized tensor range [-1, 1], in shape of (C, T, H, W)
Raises:
"""
# permute ([t, h, w, c] -> [t, c, h, w])
video = video.permute(0, 3, 1, 2)
# resize and crop
ih, iw = video.shape[2:]
if ih != oh or iw != ow:
# resize
scale = max(ow / iw, oh / ih)
video = F.interpolate(
video,
size=(round(scale * ih), round(scale * iw)),
mode='bicubic',
antialias=True
)
assert video.size(3) >= ow and video.size(2) >= oh
# center crop
x1 = (video.size(3) - ow) // 2
y1 = (video.size(2) - oh) // 2
video = video[:, :, y1:y1 + oh, x1:x1 + ow]
# permute ([t, c, h, w] -> [c, t, h, w]) and normalize
video = video.transpose(0, 1).float().div_(127.5).sub_(1.)
return video
def _video_preprocess(self, video, oh, ow):
return self.resize_crop(video, oh, ow)
def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box, rng):
target_fps = min(fps, self.max_fps)
duration = frame_timestamps[-1].mean()
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
h, w = y2 - y1, x2 - x1
ratio = h / w
df, dh, dw = self.downsample
# min/max area of the [latent video]
min_area_z = self.min_area / (dh * dw)
max_area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw))
# sample a frame number of the [latent video]
rand_area_z = np.square(np.power(2, rng.uniform(
np.log2(np.sqrt(min_area_z)),
np.log2(np.sqrt(max_area_z))
)))
of = min(
(int(duration * target_fps) - 1) // df + 1,
int(self.seq_len / rand_area_z)
)
# deduce target shape of the [latent video]
target_area_z = min(max_area_z, int(self.seq_len / of))
oh = round(np.sqrt(target_area_z * ratio))
ow = int(target_area_z / oh)
of = (of - 1) * df + 1
oh *= dh
ow *= dw
# sample frame ids
target_duration = of / target_fps
begin = 0. if self.zero_start else rng.uniform(0, duration - target_duration)
timestamps = np.linspace(begin, begin + target_duration, of)
frame_ids = np.argmax(np.logical_and(
timestamps[:, None] >= frame_timestamps[None, :, 0],
timestamps[:, None] < frame_timestamps[None, :, 1]
), axis=1).tolist()
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w, crop_box, rng, max_frames= 0):
import math
target_fps = self.max_fps
video_duration = frame_timestamps[-1][1]
video_frame_duration = 1 /fps
target_frame_duration = 1 / target_fps
cur_time = 0
target_time = 0
frame_no = 0
frame_ids =[]
for i in range(max_frames):
add_frames_count = math.ceil( (target_time -cur_time) / video_frame_duration )
frame_no += add_frames_count
frame_ids.append(frame_no)
cur_time += add_frames_count * video_frame_duration
target_time += target_frame_duration
if cur_time > video_duration:
break
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
h, w = y2 - y1, x2 - x1
ratio = h / w
df, dh, dw = self.downsample
seq_len = self.seq_len
# min/max area of the [latent video]
min_area_z = self.min_area / (dh * dw)
# max_area_z = min(seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw))
max_area_z = min_area_z # workaround bug
# sample a frame number of the [latent video]
rand_area_z = np.square(np.power(2, rng.uniform(
np.log2(np.sqrt(min_area_z)),
np.log2(np.sqrt(max_area_z))
)))
seq_len = max_area_z * ((max_frames- 1) // df +1)
# of = min(
# (len(frame_ids) - 1) // df + 1,
# int(seq_len / rand_area_z)
# )
of = (len(frame_ids) - 1) // df + 1
# deduce target shape of the [latent video]
# target_area_z = min(max_area_z, int(seq_len / of))
target_area_z = max_area_z
oh = round(np.sqrt(target_area_z * ratio))
ow = int(target_area_z / oh)
of = (of - 1) * df + 1
oh *= dh
ow *= dw
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng, max_frames= 0):
if self.keep_last:
return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h, w, crop_box, rng, max_frames= max_frames)
else:
return self._get_frameid_bbox_default(fps, frame_timestamps, h, w, crop_box, rng, max_frames= max_frames)
def load_video(self, data_key, crop_box=None, seed=2024, **kwargs):
return self.load_video_batch(data_key, crop_box=crop_box, seed=seed, **kwargs)
def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs):
return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs)
def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, max_frames= 0, trim_video =0, **kwargs):
rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000)
# read video
import decord
decord.bridge.set_bridge('torch')
readers = []
for data_k in data_key_batch:
reader = decord.VideoReader(data_k)
readers.append(reader)
fps = readers[0].get_avg_fps()
length = min([len(r) for r in readers])
frame_timestamps = [readers[0].get_frame_timestamp(i) for i in range(length)]
frame_timestamps = np.array(frame_timestamps, dtype=np.float32)
# # frame_timestamps = frame_timestamps[ :max_frames]
# if trim_video > 0:
# frame_timestamps = frame_timestamps[ :trim_video]
max_frames = min(max_frames, trim_video) if trim_video > 0 else max_frames
h, w = readers[0].next().shape[:2]
frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, frame_timestamps, h, w, crop_box, rng, max_frames=max_frames)
# preprocess video
videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers]
videos = [self._video_preprocess(video, oh, ow) for video in videos]
return *videos, frame_ids, (oh, ow), fps
# return videos if len(videos) > 1 else videos[0]
def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size, device):
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
if sub_src_video is None and sub_src_mask is None:
src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
src_mask[i] = torch.ones((1, num_frames, image_size[0], image_size[1]), device=device)
for i, ref_images in enumerate(src_ref_images):
if ref_images is not None:
for j, ref_img in enumerate(ref_images):
if ref_img is not None and ref_img.shape[-2:] != image_size:
canvas_height, canvas_width = image_size
ref_height, ref_width = ref_img.shape[-2:]
white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1]
scale = min(canvas_height / ref_height, canvas_width / ref_width)
new_height = int(ref_height * scale)
new_width = int(ref_width * scale)
resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1)
top = (canvas_height - new_height) // 2
left = (canvas_width - new_width) // 2
white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image
src_ref_images[i][j] = white_canvas
return src_video, src_mask, src_ref_images