Added Vac Contronet support

This commit is contained in:
DeepBeepMeep 2025-04-05 02:06:17 +02:00
parent 8f1e6149a8
commit cf495a4aaf
7 changed files with 929 additions and 181 deletions

View File

@ -14,7 +14,7 @@ import gradio as gr
import random
import json
import wan
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES, VACE_SIZE_CONFIGS
from wan.utils.utils import cache_video
from wan.modules.attention import get_attention_modes, get_supported_attention_modes
import torch
@ -55,6 +55,11 @@ def format_time(seconds):
def pil_to_base64_uri(pil_image, format="png", quality=75):
if pil_image is None:
return None
if isinstance(pil_image, str):
from wan.utils.utils import get_video_frame
pil_image = get_video_frame(pil_image, 0)
buffer = io.BytesIO()
try:
img_to_save = pil_image
@ -93,10 +98,11 @@ def process_prompt_and_add_tasks(
loras_choices,
loras_mult_choices,
image_prompt_type,
image_to_continue,
image_to_end,
video_to_continue,
image_source1,
image_source2,
image_source3,
max_frames,
remove_background_image_ref,
temporal_upsampling,
spatial_upsampling,
RIFLEx_setting,
@ -127,9 +133,9 @@ def process_prompt_and_add_tasks(
return
file_model_needed = model_needed(image2video)
width, height = resolution.split("x")
width, height = int(width), int(height)
if image2video:
width, height = resolution.split("x")
width, height = int(width), int(height)
if "480p" in file_model_needed and not "Fun" in file_model_needed and width * height > 848*480:
gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P")
@ -143,74 +149,94 @@ def process_prompt_and_add_tasks(
gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P")
return
if image2video:
if image_to_continue == None or isinstance(image_to_continue, list) and len(image_to_continue) == 0:
if not image2video:
if "Vace" in file_model_needed and "1.3B" in file_model_needed :
resolution_reformated = str(height) + "*" + str(width)
if not resolution_reformated in VACE_SIZE_CONFIGS:
res = VACE_SIZE_CONFIGS.keys().join(" and ")
gr.Info(f"Video Resolution for Vace model is not supported. Only {res} resolutions are allowed.")
return
if not "I" in image_prompt_type:
image_source1 = None
if not "V" in image_prompt_type:
image_source2 = None
if not "M" in image_prompt_type:
image_source3 = None
if isinstance(image_source1, list):
image_source1 = [ convert_image(tup[0]) for tup in image_source1 ]
from wan.utils.utils import resize_and_remove_background
image_source1 = resize_and_remove_background(image_source1, width, height, remove_background_image_ref ==1)
image_source1 = [ image_source1 ] * len(prompts)
image_source2 = [ image_source2 ] * len(prompts)
image_source3 = [ image_source3 ] * len(prompts)
else:
if image_source1 == None or isinstance(image_source1, list) and len(image_source1) == 0:
return
if image_prompt_type == 0:
image_to_end = None
if isinstance(image_to_continue, list):
image_to_continue = [ convert_image(tup[0]) for tup in image_to_continue ]
image_source2 = None
if isinstance(image_source1, list):
image_source1 = [ convert_image(tup[0]) for tup in image_source1 ]
else:
image_to_continue = [convert_image(image_to_continue)]
if image_to_end != None:
if isinstance(image_to_end , list):
image_to_end = [ convert_image(tup[0]) for tup in image_to_end ]
image_source1 = [convert_image(image_source1)]
if image_source2 != None:
if isinstance(image_source2 , list):
image_source2 = [ convert_image(tup[0]) for tup in image_source2 ]
else:
image_to_end = [convert_image(image_to_end) ]
if len(image_to_continue) != len(image_to_end):
image_source2 = [convert_image(image_source2) ]
if len(image_source1) != len(image_source2):
gr.Info("The number of start and end images should be the same ")
return
if multi_images_gen_type == 0:
new_prompts = []
new_image_to_continue = []
new_image_to_end = []
for i in range(len(prompts) * len(image_to_continue) ):
new_image_source1 = []
new_image_source2 = []
for i in range(len(prompts) * len(image_source1) ):
new_prompts.append( prompts[ i % len(prompts)] )
new_image_to_continue.append(image_to_continue[i // len(prompts)] )
if image_to_end != None:
new_image_to_end.append(image_to_end[i // len(prompts)] )
new_image_source1.append(image_source1[i // len(prompts)] )
if image_source2 != None:
new_image_source2.append(image_source2[i // len(prompts)] )
prompts = new_prompts
image_to_continue = new_image_to_continue
if image_to_end != None:
image_to_end = new_image_to_end
image_source1 = new_image_source1
if image_source2 != None:
image_source2 = new_image_source2
else:
if len(prompts) >= len(image_to_continue):
if len(prompts) % len(image_to_continue) !=0:
if len(prompts) >= len(image_source1):
if len(prompts) % len(image_source1) !=0:
raise gr.Error("If there are more text prompts than input images the number of text prompts should be dividable by the number of images")
rep = len(prompts) // len(image_to_continue)
new_image_to_continue = []
new_image_to_end = []
rep = len(prompts) // len(image_source1)
new_image_source1 = []
new_image_source2 = []
for i, _ in enumerate(prompts):
new_image_to_continue.append(image_to_continue[i//rep] )
if image_to_end != None:
new_image_to_end.append(image_to_end[i//rep] )
image_to_continue = new_image_to_continue
if image_to_end != None:
image_to_end = new_image_to_end
new_image_source1.append(image_source1[i//rep] )
if image_source2 != None:
new_image_source2.append(image_source2[i//rep] )
image_source1 = new_image_source1
if image_source2 != None:
image_source2 = new_image_source2
else:
if len(image_to_continue) % len(prompts) !=0:
if len(image_source1) % len(prompts) !=0:
raise gr.Error("If there are more input images than text prompts the number of images should be dividable by the number of text prompts")
rep = len(image_to_continue) // len(prompts)
rep = len(image_source1) // len(prompts)
new_prompts = []
for i, _ in enumerate(image_to_continue):
for i, _ in enumerate(image_source1):
new_prompts.append( prompts[ i//rep] )
prompts = new_prompts
# elif video_to_continue != None and len(video_to_continue) >0 :
# input_image_or_video_path = video_to_continue
# # pipeline.num_input_frames = max_frames
# # pipeline.max_frames = max_frames
# else:
# return
# else:
# input_image_or_video_path = None
if image_to_continue == None:
image_to_continue = [None] * len(prompts)
if image_to_end == None:
image_to_end = [None] * len(prompts)
if image_source1 == None:
image_source1 = [None] * len(prompts)
if image_source2 == None:
image_source2 = [None] * len(prompts)
if image_source3 == None:
image_source3 = [None] * len(prompts)
for single_prompt, image_start, image_end in zip(prompts, image_to_continue, image_to_end) :
for single_prompt, image_source1, image_source2, image_source3 in zip(prompts, image_source1, image_source2, image_source3) :
kwargs = {
"prompt" : single_prompt,
"negative_prompt" : negative_prompt,
@ -228,10 +254,11 @@ def process_prompt_and_add_tasks(
"loras_choices" : loras_choices,
"loras_mult_choices" : loras_mult_choices,
"image_prompt_type" : image_prompt_type,
"image_to_continue": image_start,
"image_to_end" : image_end,
"video_to_continue" : video_to_continue ,
"image_source1": image_source1,
"image_source2" : image_source2,
"image_source3" : image_source3 ,
"max_frames" : max_frames,
"remove_background_image_ref" : remove_background_image_ref,
"temporal_upsampling" : temporal_upsampling,
"spatial_upsampling" : spatial_upsampling,
"RIFLEx_setting" : RIFLEx_setting,
@ -262,8 +289,9 @@ def add_video_task(**kwargs):
queue = gen["queue"]
task_id += 1
current_task_id = task_id
start_image_data = kwargs["image_to_continue"]
end_image_data = kwargs["image_to_end"]
start_image_data = kwargs["image_source1"]
start_image_data = [start_image_data] if not isinstance(start_image_data, list) else start_image_data
end_image_data = kwargs["image_source2"]
queue.append({
"id": current_task_id,
@ -275,7 +303,7 @@ def add_video_task(**kwargs):
"prompt": kwargs["prompt"],
"start_image_data": start_image_data,
"end_image_data": end_image_data,
"start_image_data_base64": pil_to_base64_uri(start_image_data, format="jpeg", quality=70),
"start_image_data_base64": [pil_to_base64_uri(img, format="jpeg", quality=70) for img in start_image_data],
"end_image_data_base64": pil_to_base64_uri(end_image_data, format="jpeg", quality=70)
})
return update_queue_data(queue)
@ -342,6 +370,7 @@ def get_queue_table(queue):
full_prompt = item['prompt'].replace('"', '"')
prompt_cell = f'<span title="{full_prompt}">{truncated_prompt}</span>'
start_img_uri =item.get('start_image_data_base64')
start_img_uri = start_img_uri[0] if start_img_uri !=None else None
end_img_uri = item.get('end_image_data_base64')
thumbnail_size = "50px"
num_steps = item.get('steps')
@ -694,6 +723,9 @@ attention_modes_installed = get_attention_modes()
attention_modes_supported = get_supported_attention_modes()
args = _parse_args()
args.flow_reverse = True
processing_device = args.gpu
if len(processing_device) == 0:
processing_device ="cuda"
# torch.backends.cuda.matmul.allow_fp16_accumulation = True
lock_ui_attention = False
lock_ui_transformer = False
@ -706,7 +738,7 @@ quantizeTransformer = args.quantize_transformer
check_loras = args.check_loras ==1
advanced = args.advanced
transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors"]
transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", "ckpts/wan2.1_Vace_1.3B_preview_bf16.safetensors"]
transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", ]
text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"]
@ -750,7 +782,7 @@ def get_default_settings(filename, i2v):
"prompts": get_default_prompt(i2v),
"resolution": "832x480",
"video_length": 81,
"image_prompt_type" : 0,
"image_prompt_type" : 0 if i2v else "",
"num_inference_steps": 30,
"seed": -1,
"repeat_generation": 1,
@ -1149,6 +1181,9 @@ def get_model_name(model_filename):
if "Fun" in model_filename:
model_name = "Fun InP image2video"
model_name += " 14B" if "14B" in model_filename else " 1.3B"
elif "Vace" in model_filename:
model_name = "Vace ControlNet text2video"
model_name += " 14B" if "14B" in model_filename else " 1.3B"
elif "image" in model_filename:
model_name = "Wan2.1 image2video"
model_name += " 720p" if "720p" in model_filename else " 480p"
@ -1353,22 +1388,22 @@ def refresh_gallery(state, msg):
end_img_md = ""
prompt = task["prompt"]
if task.get('image2video'):
start_img_uri = task.get('start_image_data_base64')
end_img_uri = task.get('end_image_data_base64')
thumbnail_size = "100px"
if start_img_uri:
start_img_md = f'<img src="{start_img_uri}" alt="Start" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
if end_img_uri:
end_img_md = f'<img src="{end_img_uri}" alt="End" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
start_img_uri = task.get('start_image_data_base64')
start_img_uri = start_img_uri[0] if start_img_uri !=None else None
end_img_uri = task.get('end_image_data_base64')
thumbnail_size = "100px"
if start_img_uri:
start_img_md = f'<img src="{start_img_uri}" alt="Start" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
if end_img_uri:
end_img_md = f'<img src="{end_img_uri}" alt="End" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
label = f"Prompt of Video being Generated"
html = "<STYLE> #PINFO, #PINFO th, #PINFO td {border: 1px solid #CCCCCC;background-color:#FFFFFF;}</STYLE><TABLE WIDTH=100% ID=PINFO ><TR><TD width=100%>" + prompt + "</TD>"
if start_img_md != "":
html += "<TD>" + start_img_md + "</TD>"
if end_img_md != "":
html += "<TD>" + end_img_md + "</TD>"
if end_img_md != "":
html += "<TD>" + end_img_md + "</TD>"
html += "</TR></TABLE>"
html_output = gr.HTML(html, visible= True)
@ -1419,24 +1454,26 @@ def expand_slist(slist, num_inference_steps ):
new_slist.append(slist[ int(pos)])
pos += inc
return new_slist
def convert_image(image):
from PIL import ExifTags
image = image.convert('RGB')
for orientation in ExifTags.TAGS.keys():
if ExifTags.TAGS[orientation]=='Orientation':
break
exif = image.getexif()
if not orientation in exif:
return image
if exif[orientation] == 3:
image=image.rotate(180, expand=True)
elif exif[orientation] == 6:
image=image.rotate(270, expand=True)
elif exif[orientation] == 8:
image=image.rotate(90, expand=True)
return image
from PIL import ExifTags, ImageOps
from typing import cast
return cast(Image, ImageOps.exif_transpose(image))
# image = image.convert('RGB')
# for orientation in ExifTags.TAGS.keys():
# if ExifTags.TAGS[orientation]=='Orientation':
# break
# exif = image.getexif()
# return image
# if not orientation in exif:
# if exif[orientation] == 3:
# image=image.rotate(180, expand=True)
# elif exif[orientation] == 6:
# image=image.rotate(270, expand=True)
# elif exif[orientation] == 8:
# image=image.rotate(90, expand=True)
# return image
def generate_video(
task_id,
@ -1457,10 +1494,11 @@ def generate_video(
loras_choices,
loras_mult_choices,
image_prompt_type,
image_to_continue,
image_to_end,
video_to_continue,
image_source1,
image_source2,
image_source3,
max_frames,
remove_background_image_ref,
temporal_upsampling,
spatial_upsampling,
RIFLEx_setting,
@ -1507,7 +1545,6 @@ def generate_video(
gr.Info(f"You have selected attention mode '{attention_mode}'. However it is not installed or supported on your system. You should either install it or switch to the default 'sdpa' attention.")
return
if not image2video:
width, height = resolution.split("x")
@ -1586,7 +1623,7 @@ def generate_video(
enable_RIFLEx = RIFLEx_setting == 0 and video_length > (6* 16) or RIFLEx_setting == 1
# VAE Tiling
device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576
device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576
joint_pass = boost ==1 #and profile != 1 and profile != 3
# TeaCache
@ -1615,6 +1652,17 @@ def generate_video(
else:
raise gr.Error("Teacache not supported for this model")
if "Vace" in model_filename:
resolution_reformated = str(height) + "*" + str(width)
src_video, src_mask, src_ref_images = wan_model.prepare_source([image_source2],
[image_source3],
[image_source1],
video_length, VACE_SIZE_CONFIGS[resolution_reformated], "cpu",
trim_video=max_frames)
else:
src_video, src_mask, src_ref_images = None, None, None
import random
if seed == None or seed <0:
seed = random.randint(0, 999999999)
@ -1673,8 +1721,8 @@ def generate_video(
if image2video:
samples = wan_model.generate(
prompt,
image_to_continue,
image_to_end if image_to_end != None else None,
image_source1,
image_source2 if image_source2 != None else None,
frame_num=(video_length // 4)* 4 + 1,
max_area=MAX_AREA_CONFIGS[resolution],
shift=flow_shift,
@ -1697,6 +1745,9 @@ def generate_video(
else:
samples = wan_model.generate(
prompt,
input_frames = src_video,
input_ref_images= src_ref_images,
input_masks = src_mask,
frame_num=(video_length // 4)* 4 + 1,
size=(width, height),
shift=flow_shift,
@ -1745,7 +1796,7 @@ def generate_video(
new_error = "The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames."
else:
new_error = gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
tb = traceback.format_exc().split('\n')[:-2]
tb = traceback.format_exc().split('\n')[:-1]
print('\n'.join(tb))
raise gr.Error(new_error, print_exception= False)
@ -1799,7 +1850,7 @@ def generate_video(
if exp > 0:
from rife.inference import temporal_interpolation
sample = temporal_interpolation( os.path.join("ckpts", "flownet.pkl"), sample, exp, device="cuda")
sample = temporal_interpolation( os.path.join("ckpts", "flownet.pkl"), sample, exp, device=processing_device)
fps = fps * 2**exp
if len(spatial_upsampling) > 0:
@ -1831,8 +1882,7 @@ def generate_video(
normalize=True,
value_range=(-1, 1))
configs = get_settings_dict(state, image2video, prompt, 0 if image_to_end == None else 1 , video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
configs = get_settings_dict(state, image2video, True, prompt, image_prompt_type, max_frames , remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
loras_mult_choices, tea_cache , tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start, slg_end, cfg_star_switch, cfg_zero_step)
metadata_choice = server_config.get("metadata_choice","metadata")
@ -2294,7 +2344,7 @@ def switch_advanced(state, new_advanced, lset_name):
return gr.Row(visible=new_advanced), gr.Row(visible=True), gr.Button(visible=True), gr.Row(visible= False), gr.Dropdown(choices=lset_choices, value= lset_name)
def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
def get_settings_dict(state, i2v, image_metadata, prompt, image_prompt_type, max_frames, remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step):
loras = state["loras"]
@ -2330,18 +2380,22 @@ def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resol
ui_settings["type"] = "Wan2.1GP by DeepBeepMeep - image2video"
ui_settings["image_prompt_type"] = image_prompt_type
else:
if "Vace" in transformer_filename_t2v or not image_metadata:
ui_settings["image_prompt_type"] = image_prompt_type
ui_settings["max_frames"] = max_frames
ui_settings["remove_background_image_ref"] = remove_background_image_ref
ui_settings["type"] = "Wan2.1GP by DeepBeepMeep - text2video"
return ui_settings
def save_settings(state, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
def save_settings(state, prompt, image_prompt_type, max_frames, remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step):
if state.get("validate_success",0) != 1:
return
image2video = state["image2video"]
ui_defaults = get_settings_dict(state, image2video, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
ui_defaults = get_settings_dict(state, image2video, False, prompt, image_prompt_type, max_frames, remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step)
defaults_filename = get_settings_file_name(image2video)
@ -2379,6 +2433,25 @@ def download_loras():
writer.write(f"Loras downloaded on the {dt} at {time.time()} on the {time.time()}")
return
def refresh_i2v_image_prompt_type_radio(state, image_prompt_type_radio):
if args.multiple_images:
return gr.Gallery(visible = (image_prompt_type_radio == 1) )
else:
return gr.Image(visible = (image_prompt_type_radio == 1) )
def refresh_t2v_image_prompt_type_radio(state, image_prompt_type_radio):
vace_model = "Vace" in state["image_input_type_model"] and not state["image2video"]
return gr.Column(visible= vace_model), gr.Radio(value= image_prompt_type_radio), gr.Gallery(visible = "I" in image_prompt_type_radio), gr.Video(visible= "V" in image_prompt_type_radio),gr.Video(visible= "M" in image_prompt_type_radio ), gr.Text(visible= "V" in image_prompt_type_radio) , gr.Checkbox(visible= "I" in image_prompt_type_radio)
def check_refresh_input_type(state):
if not state["image2video"]:
model_file_name = state["image_input_type_model"]
model_file_needed= model_needed(False)
if model_file_name != model_file_needed:
state["image_input_type_model"] = model_file_needed
return gr.Text(value= str(time.time()))
return gr.Text()
def generate_video_tab(image2video=False):
filename = transformer_filename_i2v if image2video else transformer_filename_t2v
ui_defaults= get_default_settings(filename, image2video)
@ -2387,6 +2460,7 @@ def generate_video_tab(image2video=False):
state_dict["advanced"] = advanced
state_dict["loras_model"] = filename
state_dict["image_input_type_model"] = filename
state_dict["image2video"] = image2video
gen = dict()
gen["queue"] = []
@ -2461,31 +2535,51 @@ def generate_video_tab(image2video=False):
save_lset_btn = gr.Button("Save", size="sm", min_width= 1)
delete_lset_btn = gr.Button("Delete", size="sm", min_width= 1)
cancel_lset_btn = gr.Button("Don't do it !", size="sm", min_width= 1 , visible=False)
video_to_continue = gr.Video(label= "Video to continue", visible= image2video and False) #######
image_prompt_type= ui_defaults.get("image_prompt_type",0)
image_prompt_type_radio = gr.Radio( [("Use only a Start Image", 0),("Use both a Start and an End Image", 1)], value =image_prompt_type, label="Location", show_label= False, scale= 3, visible=image2video)
if args.multiple_images:
image_to_continue = gr.Gallery(
label="Images as starting points for new videos", type ="pil", #file_types= "image",
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=image2video)
else:
image_to_continue = gr.Image(label= "Image as a starting point for a new video", type ="pil", visible=image2video)
state = gr.State(state_dict)
vace_model = "Vace" in filename and not image2video
trigger_refresh_input_type = gr.Text(interactive= False, visible= False)
with gr.Column(visible= image2video or vace_model) as image_prompt_column:
if image2video:
image_source3 = gr.Video(label= "Placeholder", visible= image2video and False)
if args.multiple_images:
image_to_end = gr.Gallery(
label="Images as ending points for new videos", type ="pil", #file_types= "image",
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=image_prompt_type==1)
else:
image_to_end = gr.Image(label= "Last Image for a new video", type ="pil", visible=image_prompt_type==1)
image_prompt_type= ui_defaults.get("image_prompt_type",0)
image_prompt_type_radio = gr.Radio( [("Use only a Start Image", 0),("Use both a Start and an End Image", 1)], value =image_prompt_type, label="Location", show_label= False, scale= 3)
def switch_image_prompt_type_radio(image_prompt_type_radio):
if args.multiple_images:
return gr.Gallery(visible = (image_prompt_type_radio == 1) )
if args.multiple_images:
image_source1 = gr.Gallery(
label="Images as starting points for new videos", type ="pil", #file_types= "image",
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True)
else:
image_source1 = gr.Image(label= "Image as a starting point for a new video", type ="pil")
if args.multiple_images:
image_source2 = gr.Gallery(
label="Images as ending points for new videos", type ="pil", #file_types= "image",
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=image_prompt_type==1)
else:
image_source2 = gr.Image(label= "Last Image for a new video", type ="pil", visible=image_prompt_type==1)
image_prompt_type_radio.change(fn=refresh_i2v_image_prompt_type_radio, inputs=[state, image_prompt_type_radio], outputs=[image_source2])
max_frames = gr.Slider(1, 100,step=1, visible = False)
remove_background_image_ref = gr.Text(visible = False)
else:
return gr.Image(visible = (image_prompt_type_radio == 1) )
image_prompt_type= ui_defaults.get("image_prompt_type","I")
image_prompt_type_radio = gr.Radio( [("Use Images Ref", "I"),("a Video", "V"), ("Images + a Video", "IV"), ("Video + Video Mask", "VM"), ("Images + Video + Mask", "IVM")], value =image_prompt_type, label="Location", show_label= False, scale= 3, visible = vace_model)
image_source1 = gr.Gallery(
label="Reference Images of Faces and / or Object to be found in the Video", type ="pil",
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in image_prompt_type )
image_prompt_type_radio.change(fn=switch_image_prompt_type_radio, inputs=[image_prompt_type_radio], outputs=[image_to_end])
image_source2 = gr.Video(label= "Reference Video", visible= "V" in image_prompt_type )
with gr.Row():
max_frames = gr.Slider(0, 100, value=ui_defaults.get("max_frames",0), step=1, label="Nb of frames in Reference Video to use in Video (0 for as many as possible)", visible= "V" in image_prompt_type, scale = 2 )
remove_background_image_ref = gr.Checkbox(value=ui_defaults.get("remove_background_image_ref",1), label= "Remove Images Ref. Background", visible= "I" in image_prompt_type, scale =1 )
image_source3 = gr.Video(label= "Video Mask (white pixels = Mask)", visible= "M" in image_prompt_type )
gr.on(triggers=[image_prompt_type_radio.change, trigger_refresh_input_type.change], fn=refresh_t2v_image_prompt_type_radio, inputs=[state, image_prompt_type_radio], outputs=[image_prompt_column, image_prompt_type_radio, image_source1, image_source2, image_source3, max_frames, remove_background_image_ref])
advanced_prompt = advanced
@ -2518,7 +2612,6 @@ def generate_video_tab(image2video=False):
wizard_prompt = gr.Textbox(visible = not advanced_prompt, label="Prompts (each new line of prompt will generate a new video, # lines = comments)", value=default_wizard_prompt, lines=3)
wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False)
wizard_variables_var = gr.Text(wizard_variables, visible = False)
state = gr.State(state_dict)
with gr.Row():
if image2video:
resolution = gr.Dropdown(
@ -2555,8 +2648,6 @@ def generate_video_tab(image2video=False):
video_length = gr.Slider(5, 193, value=ui_defaults["video_length"], step=4, label="Number of frames (16 = 1s)")
with gr.Column():
num_inference_steps = gr.Slider(1, 100, value=ui_defaults["num_inference_steps"], step=1, label="Number of Inference Steps")
with gr.Row():
max_frames = gr.Slider(1, 100, value=9, step=1, label="Number of input frames to use for Video2World prediction", visible=image2video and False) #########
show_advanced = gr.Checkbox(label="Advanced Mode", value=advanced)
with gr.Row(visible=advanced) as advanced_row:
with gr.Column():
@ -2605,7 +2696,7 @@ def generate_video_tab(image2video=False):
tea_cache_start_step_perc = gr.Slider(0, 100, value=ui_defaults["tea_cache_start_step_perc"], step=1, label="Tea Cache starting moment in % of generation")
with gr.Row():
gr.Markdown("<B>Upsampling</B>")
gr.Markdown("<B>Upsampling - postprocessing that may improve fluidity and the size of the video</B>")
with gr.Row():
temporal_upsampling_choice = gr.Dropdown(
choices=[
@ -2687,9 +2778,10 @@ def generate_video_tab(image2video=False):
show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars])
with gr.Column():
gen_status = gr.Text(label="Status", interactive= False)
full_sync = gr.Text(label="Status", interactive= False, visible= False)
light_sync = gr.Text(label="Status", interactive= False, visible= False)
gen_status = gr.Text(interactive= False)
full_sync = gr.Text(interactive= False, visible= False)
light_sync = gr.Text(interactive= False, visible= False)
gen_progress_html = gr.HTML(
label="Status",
value="Idle",
@ -2709,8 +2801,8 @@ def generate_video_tab(image2video=False):
abort_btn = gr.Button("Abort")
queue_df = gr.DataFrame(
headers=["Qty","Prompt", "Length","Steps","Start", "End", "", "", ""],
datatype=[ "str","markdown","str", "markdown", "markdown", "markdown", "str", "str", "str"],
headers=["Qty","Prompt", "Length","Steps","", "", "", "", ""],
datatype=[ "str","markdown","str", "markdown", "markdown", "markdown", "str", "str", "str"],
column_widths= ["50","", "65","55", "60", "60", "30", "30", "35"],
interactive=False,
col_count=(9, "fixed"),
@ -2792,7 +2884,7 @@ def generate_video_tab(image2video=False):
show_progress="hidden"
)
save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
save_settings, inputs = [state, prompt, image_prompt_type_radio, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt,
save_settings, inputs = [state, prompt, image_prompt_type_radio, max_frames, remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt,
loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling_choice, spatial_upsampling_choice, RIFLEx_setting, slg_switch, slg_layers,
slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step ], outputs = [])
save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
@ -2808,21 +2900,30 @@ def generate_video_tab(image2video=False):
refresh_lora_btn2.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
output.select(select_video, state, None )
gen_status.change(refresh_gallery,
inputs = [state, gen_status],
outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn])
full_sync.change(refresh_gallery,
full_sync.change(fn= check_refresh_input_type,
inputs= [state],
outputs= [trigger_refresh_input_type]
).then(fn=refresh_gallery,
inputs = [state, gen_status],
outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn]
).then( fn=wait_tasks_done,
).then(fn=wait_tasks_done,
inputs= [state],
outputs =[gen_status],
).then(finalize_generation,
inputs= [state],
outputs= [output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info]
)
light_sync.change(refresh_gallery,
light_sync.change(fn= check_refresh_input_type,
inputs= [state],
outputs= [trigger_refresh_input_type]
).then(fn=refresh_gallery,
inputs = [state, gen_status],
outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn]
)
@ -2848,10 +2949,11 @@ def generate_video_tab(image2video=False):
loras_choices,
loras_mult_choices,
image_prompt_type_radio,
image_to_continue,
image_to_end,
video_to_continue,
image_source1,
image_source2,
image_source3,
max_frames,
remove_background_image_ref,
temporal_upsampling_choice,
spatial_upsampling_choice,
RIFLEx_setting,
@ -2902,7 +3004,7 @@ def generate_video_tab(image2video=False):
)
return loras_column, loras_choices, presets_column, lset_name, header, light_sync, full_sync, state
def generate_doxnload_tab(presets_column, loras_column, lset_name,loras_choices, state):
def generate_download_tab(presets_column, loras_column, lset_name,loras_choices, state):
with gr.Row():
with gr.Row(scale =2):
gr.Markdown("<I>Wan2GP's Lora Festival ! Press the following button to download i2v <B>Remade</B> Loras collection (and bonuses Loras).")
@ -2928,6 +3030,7 @@ def generate_configuration_tab():
("WAN 2.1 1.3B Text to Video 16 bits (recommended)- the small model for fast generations with low VRAM requirements", 0),
("WAN 2.1 14B Text to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 1),
("WAN 2.1 14B Text to Video quantized to 8 bits (recommended) - the default engine but quantized", 2),
("WAN 2.1 VACE 1.3B Text to Video / Control Net - text generation driven by reference images or videos", 3),
],
value= index,
label="Transformer model for Text to Video",
@ -3108,16 +3211,17 @@ def on_tab_select(global_state, t2v_state, i2v_state, evt: gr.SelectData):
t2v_light_sync = gr.Text()
i2v_full_sync = gr.Text()
t2v_full_sync = gr.Text()
last_tab_was_image2video =global_state.get("last_tab_was_image2video", None)
if last_tab_was_image2video == None or last_tab_was_image2video:
gen = i2v_state["gen"]
t2v_state["gen"] = gen
else:
gen = t2v_state["gen"]
i2v_state["gen"] = gen
if new_t2v or new_i2v:
last_tab_was_image2video =global_state.get("last_tab_was_image2video", None)
if last_tab_was_image2video == None or last_tab_was_image2video:
gen = i2v_state["gen"]
t2v_state["gen"] = gen
else:
gen = t2v_state["gen"]
i2v_state["gen"] = gen
if last_tab_was_image2video != None and new_t2v != new_i2v:
gen_location = gen.get("location", None)
if "in_progress" in gen and gen_location !=None and not (gen_location and new_i2v or not gen_location and new_t2v) :
@ -3131,7 +3235,6 @@ def on_tab_select(global_state, t2v_state, i2v_state, evt: gr.SelectData):
else:
t2v_light_sync = gr.Text(str(time.time()))
global_state["last_tab_was_image2video"] = new_i2v
if(server_config.get("reload_model",2) == 1):
@ -3433,7 +3536,7 @@ def create_demo():
}
"""
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md"), title= "Wan2GP") as demo:
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v3.4 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v4.0 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !</FONT>")
with gr.Accordion("Click here for some Info on how to use Wan2GP", open = False):
@ -3454,7 +3557,7 @@ def create_demo():
i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name, i2v_header, i2v_light_sync, i2v_full_sync, i2v_state = generate_video_tab(True)
if not args.lock_config:
with gr.Tab("Downloads", id="downloads") as downloads_tab:
generate_doxnload_tab(i2v_presets_column, i2v_loras_column, i2v_lset_name, i2v_loras_choices, i2v_state)
generate_download_tab(i2v_presets_column, i2v_loras_column, i2v_lset_name, i2v_loras_choices, i2v_state)
with gr.Tab("Configuration"):
generate_configuration_tab()
with gr.Tab("About"):

View File

@ -11,11 +11,15 @@ easydict
ftfy
dashscope
imageio-ffmpeg
# flash_attn
# flash_attn
gradio>=5.0.0
numpy>=1.23.5,<2
einops
moviepy==1.0.3
mmgp==3.3.4
peft==0.14.0
mutagen
mutagen
decord
onnxruntime-gpu
rembg[gpu]==2.0.65
# rembg==2.0.65

View File

@ -40,3 +40,17 @@ SUPPORTED_SIZES = {
'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
't2i-14B': tuple(SIZE_CONFIGS.keys()),
}
VACE_SIZE_CONFIGS = {
'480*832': (480, 832),
'832*480': (832, 480),
}
VACE_MAX_AREA_CONFIGS = {
'480*832': 480 * 832,
'832*480': 832 * 480,
}
VACE_SUPPORTED_SIZES = {
'vace-1.3B': ('480*832', '832*480'),
}

View File

@ -377,6 +377,7 @@ class WanI2VCrossAttention(WanSelfAttention):
return x
WAN_CROSSATTENTION_CLASSES = {
't2v_cross_attn': WanT2VCrossAttention,
'i2v_cross_attn': WanI2VCrossAttention,
@ -393,7 +394,9 @@ class WanAttentionBlock(nn.Module):
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6):
eps=1e-6,
block_id=None
):
super().__init__()
self.dim = dim
self.ffn_dim = ffn_dim
@ -422,6 +425,7 @@ class WanAttentionBlock(nn.Module):
# modulation
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
self.block_id = block_id
def forward(
self,
@ -432,6 +436,8 @@ class WanAttentionBlock(nn.Module):
freqs,
context,
context_lens,
hints= None,
context_scale=1.0,
):
r"""
Args:
@ -480,10 +486,49 @@ class WanAttentionBlock(nn.Module):
x.addcmul_(y, e[5])
if self.block_id is not None and hints != None:
if context_scale == 1:
x.add_(hints[self.block_id])
else:
x.add_(hints[self.block_id], alpha =context_scale)
return x
return x
class VaceWanAttentionBlock(WanAttentionBlock):
def __init__(
self,
cross_attn_type,
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6,
block_id=0
):
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
self.block_id = block_id
if block_id == 0:
self.before_proj = nn.Linear(self.dim, self.dim)
nn.init.zeros_(self.before_proj.weight)
nn.init.zeros_(self.before_proj.bias)
self.after_proj = nn.Linear(self.dim, self.dim)
nn.init.zeros_(self.after_proj.weight)
nn.init.zeros_(self.after_proj.bias)
def forward(self, c, x, **kwargs):
# behold dbm magic !
if self.block_id == 0:
c = self.before_proj(c) + x
all_c = []
else:
all_c = c
c = all_c.pop(-1)
c = super().forward(c, **kwargs)
c_skip = self.after_proj(c)
all_c += [c_skip, c]
return all_c
class Head(nn.Module):
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
@ -544,6 +589,8 @@ class WanModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(self,
vace_layers=None,
vace_in_dim=None,
model_type='t2v',
patch_size=(1, 2, 2),
text_len=512,
@ -628,12 +675,13 @@ class WanModel(ModelMixin, ConfigMixin):
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
# blocks
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
self.blocks = nn.ModuleList([
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
window_size, qk_norm, cross_attn_norm, eps)
for _ in range(num_layers)
])
if vace_layers == None:
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
self.blocks = nn.ModuleList([
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
window_size, qk_norm, cross_attn_norm, eps)
for _ in range(num_layers)
])
# head
self.head = Head(dim, out_dim, patch_size, eps)
@ -646,6 +694,33 @@ class WanModel(ModelMixin, ConfigMixin):
# initialize weights
self.init_weights()
if vace_layers != None:
self.vace_layers = [i for i in range(0, self.num_layers, 2)] if vace_layers is None else vace_layers
self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim
assert 0 in self.vace_layers
self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
# blocks
self.blocks = nn.ModuleList([
WanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
self.cross_attn_norm, self.eps,
block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None)
for i in range(self.num_layers)
])
# vace blocks
self.vace_blocks = nn.ModuleList([
VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
self.cross_attn_norm, self.eps, block_id=i)
for i in self.vace_layers
])
# vace patch embeddings
self.vace_patch_embedding = nn.Conv3d(
self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size
)
def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
rescale_func = np.poly1d(self.coefficients)
@ -688,6 +763,36 @@ class WanModel(ModelMixin, ConfigMixin):
self.rel_l1_thresh = best_threshold
print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}")
return best_threshold
def forward_vace(
self,
x,
vace_context,
seq_len,
context,
e,
kwargs
):
# embeddings
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c]
if (len(c) == 1 and seq_len == c[0].size(1)):
c = c[0]
else:
c = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in c
])
# arguments
new_kwargs = dict(x=x)
new_kwargs.update(kwargs)
for block in self.vace_blocks:
c = block(c, context= context, e= e, **new_kwargs)
hints = c[:-1]
return hints
def forward(
self,
@ -695,6 +800,8 @@ class WanModel(ModelMixin, ConfigMixin):
t,
context,
seq_len,
vace_context = None,
vace_context_scale=1.0,
clip_fea=None,
y=None,
freqs = None,
@ -829,13 +936,23 @@ class WanModel(ModelMixin, ConfigMixin):
self.previous_residual_cond = None
ori_hidden_states = x_list[0].clone()
# arguments
kwargs = dict(
# e=e0,
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=freqs,
# context=context,
context_lens=context_lens)
if vace_context == None:
hints_list = [None ] *len(x_list)
else:
hints_list = []
for x, context in zip(x_list, context_list) :
hints_list.append( self.forward_vace(x, vace_context, seq_len, context= context, e= e0, kwargs= kwargs))
del x, context
kwargs['context_scale'] = vace_context_scale
for block_idx, block in enumerate(self.blocks):
offload.shared_state["layer"] = block_idx
if callback != None:
@ -852,9 +969,10 @@ class WanModel(ModelMixin, ConfigMixin):
x_list[0] = block(x_list[0], context = context_list[0], e= e0, **kwargs)
else:
for i, (x, context) in enumerate(zip(x_list, context_list)):
x_list[i] = block(x, context = context, e= e0, **kwargs)
for i, (x, context, hints) in enumerate(zip(x_list, context_list, hints_list)):
x_list[i] = block(x, context = context, hints= hints, e= e0, **kwargs)
del x
del context, hints
if self.enable_teacache:
if joint_pass:

View File

@ -13,7 +13,9 @@ import torch
import torch.cuda.amp as amp
import torch.distributed as dist
from tqdm import tqdm
from PIL import Image
import torchvision.transforms.functional as TF
import torch.nn.functional as F
from .distributed.fsdp import shard_model
from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
@ -22,6 +24,7 @@ from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from wan.modules.posemb_layers import get_rotary_pos_embed
from .utils.vace_preprocessor import VaceVideoProcessor
def optimized_scale(positive_flat, negative_flat):
@ -105,8 +108,6 @@ class WanT2V:
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel, writable_tensors= False)
self.model.eval().requires_grad_(False)
if use_usp:
@ -132,8 +133,148 @@ class WanT2V:
self.sample_neg_prompt = config.sample_neg_prompt
if "Vace" in model_filename:
self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]),
min_area=480*832,
max_area=480*832,
min_fps=config.sample_fps,
max_fps=config.sample_fps,
zero_start=True,
seq_len=32760,
keep_last=True)
def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0):
if ref_images is None:
ref_images = [None] * len(frames)
else:
assert len(frames) == len(ref_images)
if masks is None:
latents = self.vae.encode(frames, tile_size = tile_size)
else:
inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
inactive = self.vae.encode(inactive, tile_size = tile_size)
reactive = self.vae.encode(reactive, tile_size = tile_size)
latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
cat_latents = []
for latent, refs in zip(latents, ref_images):
if refs is not None:
if masks is None:
ref_latent = self.vae.encode(refs, tile_size = tile_size)
else:
ref_latent = self.vae.encode(refs, tile_size = tile_size)
ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent]
assert all([x.shape[1] == 1 for x in ref_latent])
latent = torch.cat([*ref_latent, latent], dim=1)
cat_latents.append(latent)
return cat_latents
def vace_encode_masks(self, masks, ref_images=None):
if ref_images is None:
ref_images = [None] * len(masks)
else:
assert len(masks) == len(ref_images)
result_masks = []
for mask, refs in zip(masks, ref_images):
c, depth, height, width = mask.shape
new_depth = int((depth + 3) // self.vae_stride[0])
height = 2 * (int(height) // (self.vae_stride[1] * 2))
width = 2 * (int(width) // (self.vae_stride[2] * 2))
# reshape
mask = mask[0, :, :, :]
mask = mask.view(
depth, height, self.vae_stride[1], width, self.vae_stride[1]
) # depth, height, 8, width, 8
mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width
mask = mask.reshape(
self.vae_stride[1] * self.vae_stride[2], depth, height, width
) # 8*8, depth, height, width
# interpolation
mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0)
if refs is not None:
length = len(refs)
mask_pad = torch.zeros_like(mask[:, :length, :, :])
mask = torch.cat((mask_pad, mask), dim=1)
result_masks.append(mask)
return result_masks
def vace_latent(self, z, m):
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device, trim_video= 0):
image_sizes = []
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
if sub_src_mask is not None and sub_src_video is not None:
src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video)
src_video[i] = src_video[i].to(device)
src_mask[i] = src_mask[i].to(device)
src_video_shape = src_video[i].shape
if src_video_shape[1] != num_frames:
src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
image_sizes.append(src_video[i].shape[2:])
elif sub_src_video is None:
src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
src_mask[i] = torch.ones_like(src_video[i], device=device)
image_sizes.append(image_size)
else:
src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video)
src_video[i] = src_video[i].to(device)
src_video_shape = src_video[i].shape
if src_video_shape[1] != num_frames:
src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
src_mask[i] = torch.ones_like(src_video[i], device=device)
image_sizes.append(src_video[i].shape[2:])
for i, ref_images in enumerate(src_ref_images):
if ref_images is not None:
image_size = image_sizes[i]
for j, ref_img in enumerate(ref_images):
if ref_img is not None:
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
if ref_img.shape[-2:] != image_size:
canvas_height, canvas_width = image_size
ref_height, ref_width = ref_img.shape[-2:]
white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1]
scale = min(canvas_height / ref_height, canvas_width / ref_width)
new_height = int(ref_height * scale)
new_width = int(ref_width * scale)
resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1)
top = (canvas_height - new_height) // 2
left = (canvas_width - new_width) // 2
white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image
ref_img = white_canvas
src_ref_images[i][j] = ref_img.to(device)
return src_video, src_mask, src_ref_images
def decode_latent(self, zs, ref_images=None, tile_size= 0 ):
if ref_images is None:
ref_images = [None] * len(zs)
else:
assert len(zs) == len(ref_images)
trimed_zs = []
for z, refs in zip(zs, ref_images):
if refs is not None:
z = z[:, len(refs):, :, :]
trimed_zs.append(z)
return self.vae.decode(trimed_zs, tile_size= tile_size)
def generate(self,
input_prompt,
input_frames= None,
input_masks = None,
input_ref_images = None,
context_scale=1.0,
size=(1280, 720),
frame_num=81,
shift=5.0,
@ -187,14 +328,6 @@ class WanT2V:
- W: Frame width from size)
"""
# preprocess
F = frame_num
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
size[1] // self.vae_stride[1],
size[0] // self.vae_stride[2])
seq_len = math.ceil((target_shape[2] * target_shape[3]) /
(self.patch_size[1] * self.patch_size[2]) *
target_shape[1] / self.sp_size) * self.sp_size
if n_prompt == "":
n_prompt = self.sample_neg_prompt
@ -213,6 +346,29 @@ class WanT2V:
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null]
if input_frames != None:
# vace context encode
input_frames = [u.to(self.device) for u in input_frames]
input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images]
input_masks = [u.to(self.device) for u in input_masks]
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size)
m0 = self.vace_encode_masks(input_masks, input_ref_images)
z = self.vace_latent(z0, m0)
target_shape = list(z0[0].shape)
target_shape[0] = int(target_shape[0] / 2)
else:
F = frame_num
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
size[1] // self.vae_stride[1],
size[0] // self.vae_stride[2])
seq_len = math.ceil((target_shape[2] * target_shape[3]) /
(self.patch_size[1] * self.patch_size[2]) *
target_shape[1] / self.sp_size) * self.sp_size
noise = [
torch.randn(
@ -261,10 +417,12 @@ class WanT2V:
arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
if input_frames != None:
vace_dict = {'vace_context' : z, 'vace_context_scale' : context_scale}
arg_c.update(vace_dict)
arg_null.update(vace_dict)
arg_both.update(vace_dict)
# arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
# arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
# arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
if self.model.enable_teacache:
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
if callback != None:
@ -281,7 +439,7 @@ class WanT2V:
# self.model.to(self.device)
if joint_pass:
noise_pred_cond, noise_pred_uncond = self.model(
latent_model_input, t=timestep,current_step=i, slg_layers=slg_layers_local, **arg_both)
latent_model_input, t=timestep, current_step=i, slg_layers=slg_layers_local, **arg_both)
if self._interrupt:
return None
else:
@ -329,7 +487,11 @@ class WanT2V:
self.model.cpu()
torch.cuda.empty_cache()
if self.rank == 0:
videos = self.vae.decode(x0, VAE_tile_size)
if input_frames == None:
videos = self.vae.decode(x0, VAE_tile_size)
else:
videos = self.decode_latent(x0, input_ref_images, VAE_tile_size)
del noise, latents

View File

@ -3,21 +3,70 @@ import argparse
import binascii
import os
import os.path as osp
import torchvision.transforms.functional as TF
import torch.nn.functional as F
import imageio
import torch
import decord
import torchvision
from PIL import Image
import numpy as np
from rembg import remove, new_session
__all__ = ['cache_video', 'cache_image', 'str2bool']
from PIL import Image
def get_video_frame(file_name, frame_no):
decord.bridge.set_bridge('torch')
reader = decord.VideoReader(file_name)
frame = reader.get_batch([frame_no]).squeeze(0)
img = Image.fromarray(frame.numpy().astype(np.uint8))
return img
def resize_lanczos(img, h, w):
img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8))
img = img.resize((w,h), resample=Image.Resampling.LANCZOS)
return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0)
def remove_background(img, session=None):
if session ==None:
session = new_session()
img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8))
img = remove(img, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0)
def resize_and_remove_background(img_list, canvas_width, canvas_height, rm_background ):
if rm_background:
session = new_session()
output_list =[]
for img in img_list:
width, height = img.size
white_canvas = np.full( (canvas_height, canvas_width, 3), 255, dtype= np.uint8 )
scale = min(canvas_height / height, canvas_width / width)
new_height = int(height * scale)
new_width = int(width * scale)
resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS)
if rm_background:
resized_image = remove(resized_image, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
top = (canvas_height - new_height) // 2
left = (canvas_width - new_width) // 2
white_canvas[top:top + new_height, left:left + new_width, :] = np.array(resized_image)
img = Image.fromarray(white_canvas)
output_list.append(img)
return output_list
def rand_name(length=8, suffix=''):
name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
if suffix:

View File

@ -0,0 +1,298 @@
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
class VaceImageProcessor(object):
def __init__(self, downsample=None, seq_len=None):
self.downsample = downsample
self.seq_len = seq_len
def _pillow_convert(self, image, cvt_type='RGB'):
if image.mode != cvt_type:
if image.mode == 'P':
image = image.convert(f'{cvt_type}A')
if image.mode == f'{cvt_type}A':
bg = Image.new(cvt_type,
size=(image.width, image.height),
color=(255, 255, 255))
bg.paste(image, (0, 0), mask=image)
image = bg
else:
image = image.convert(cvt_type)
return image
def _load_image(self, img_path):
if img_path is None or img_path == '':
return None
img = Image.open(img_path)
img = self._pillow_convert(img)
return img
def _resize_crop(self, img, oh, ow, normalize=True):
"""
Resize, center crop, convert to tensor, and normalize.
"""
# resize and crop
iw, ih = img.size
if iw != ow or ih != oh:
# resize
scale = max(ow / iw, oh / ih)
img = img.resize(
(round(scale * iw), round(scale * ih)),
resample=Image.Resampling.LANCZOS
)
assert img.width >= ow and img.height >= oh
# center crop
x1 = (img.width - ow) // 2
y1 = (img.height - oh) // 2
img = img.crop((x1, y1, x1 + ow, y1 + oh))
# normalize
if normalize:
img = TF.to_tensor(img).sub_(0.5).div_(0.5).unsqueeze(1)
return img
def _image_preprocess(self, img, oh, ow, normalize=True, **kwargs):
return self._resize_crop(img, oh, ow, normalize)
def load_image(self, data_key, **kwargs):
return self.load_image_batch(data_key, **kwargs)
def load_image_pair(self, data_key, data_key2, **kwargs):
return self.load_image_batch(data_key, data_key2, **kwargs)
def load_image_batch(self, *data_key_batch, normalize=True, seq_len=None, **kwargs):
seq_len = self.seq_len if seq_len is None else seq_len
imgs = []
for data_key in data_key_batch:
img = self._load_image(data_key)
imgs.append(img)
w, h = imgs[0].size
dh, dw = self.downsample[1:]
# compute output size
scale = min(1., np.sqrt(seq_len / ((h / dh) * (w / dw))))
oh = int(h * scale) // dh * dh
ow = int(w * scale) // dw * dw
assert (oh // dh) * (ow // dw) <= seq_len
imgs = [self._image_preprocess(img, oh, ow, normalize) for img in imgs]
return *imgs, (oh, ow)
class VaceVideoProcessor(object):
def __init__(self, downsample, min_area, max_area, min_fps, max_fps, zero_start, seq_len, keep_last, **kwargs):
self.downsample = downsample
self.min_area = min_area
self.max_area = max_area
self.min_fps = min_fps
self.max_fps = max_fps
self.zero_start = zero_start
self.keep_last = keep_last
self.seq_len = seq_len
assert seq_len >= min_area / (self.downsample[1] * self.downsample[2])
@staticmethod
def resize_crop(video: torch.Tensor, oh: int, ow: int):
"""
Resize, center crop and normalize for decord loaded video (torch.Tensor type)
Parameters:
video - video to process (torch.Tensor): Tensor from `reader.get_batch(frame_ids)`, in shape of (T, H, W, C)
oh - target height (int)
ow - target width (int)
Returns:
The processed video (torch.Tensor): Normalized tensor range [-1, 1], in shape of (C, T, H, W)
Raises:
"""
# permute ([t, h, w, c] -> [t, c, h, w])
video = video.permute(0, 3, 1, 2)
# resize and crop
ih, iw = video.shape[2:]
if ih != oh or iw != ow:
# resize
scale = max(ow / iw, oh / ih)
video = F.interpolate(
video,
size=(round(scale * ih), round(scale * iw)),
mode='bicubic',
antialias=True
)
assert video.size(3) >= ow and video.size(2) >= oh
# center crop
x1 = (video.size(3) - ow) // 2
y1 = (video.size(2) - oh) // 2
video = video[:, :, y1:y1 + oh, x1:x1 + ow]
# permute ([t, c, h, w] -> [c, t, h, w]) and normalize
video = video.transpose(0, 1).float().div_(127.5).sub_(1.)
return video
def _video_preprocess(self, video, oh, ow):
return self.resize_crop(video, oh, ow)
def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box, rng):
target_fps = min(fps, self.max_fps)
duration = frame_timestamps[-1].mean()
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
h, w = y2 - y1, x2 - x1
ratio = h / w
df, dh, dw = self.downsample
# min/max area of the [latent video]
min_area_z = self.min_area / (dh * dw)
max_area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw))
# sample a frame number of the [latent video]
rand_area_z = np.square(np.power(2, rng.uniform(
np.log2(np.sqrt(min_area_z)),
np.log2(np.sqrt(max_area_z))
)))
of = min(
(int(duration * target_fps) - 1) // df + 1,
int(self.seq_len / rand_area_z)
)
# deduce target shape of the [latent video]
target_area_z = min(max_area_z, int(self.seq_len / of))
oh = round(np.sqrt(target_area_z * ratio))
ow = int(target_area_z / oh)
of = (of - 1) * df + 1
oh *= dh
ow *= dw
# sample frame ids
target_duration = of / target_fps
begin = 0. if self.zero_start else rng.uniform(0, duration - target_duration)
timestamps = np.linspace(begin, begin + target_duration, of)
frame_ids = np.argmax(np.logical_and(
timestamps[:, None] >= frame_timestamps[None, :, 0],
timestamps[:, None] < frame_timestamps[None, :, 1]
), axis=1).tolist()
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w, crop_box, rng, max_frames= 0):
import math
target_fps = self.max_fps
video_duration = frame_timestamps[-1][1]
video_frame_duration = 1 /fps
target_frame_duration = 1 / target_fps
cur_time = 0
target_time = 0
frame_no = 0
frame_ids =[]
for i in range(max_frames):
add_frames_count = math.ceil( (target_time -cur_time) / video_frame_duration )
frame_no += add_frames_count
frame_ids.append(frame_no)
cur_time += add_frames_count * video_frame_duration
target_time += target_frame_duration
if cur_time > video_duration:
break
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
h, w = y2 - y1, x2 - x1
ratio = h / w
df, dh, dw = self.downsample
seq_len = self.seq_len
# min/max area of the [latent video]
min_area_z = self.min_area / (dh * dw)
# max_area_z = min(seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw))
max_area_z = min_area_z # workaround bug
# sample a frame number of the [latent video]
rand_area_z = np.square(np.power(2, rng.uniform(
np.log2(np.sqrt(min_area_z)),
np.log2(np.sqrt(max_area_z))
)))
seq_len = max_area_z * ((max_frames- 1) // df +1)
# of = min(
# (len(frame_ids) - 1) // df + 1,
# int(seq_len / rand_area_z)
# )
of = (len(frame_ids) - 1) // df + 1
# deduce target shape of the [latent video]
# target_area_z = min(max_area_z, int(seq_len / of))
target_area_z = max_area_z
oh = round(np.sqrt(target_area_z * ratio))
ow = int(target_area_z / oh)
of = (of - 1) * df + 1
oh *= dh
ow *= dw
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng, max_frames= 0):
if self.keep_last:
return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h, w, crop_box, rng, max_frames= max_frames)
else:
return self._get_frameid_bbox_default(fps, frame_timestamps, h, w, crop_box, rng, max_frames= max_frames)
def load_video(self, data_key, crop_box=None, seed=2024, **kwargs):
return self.load_video_batch(data_key, crop_box=crop_box, seed=seed, **kwargs)
def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs):
return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs)
def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, max_frames= 0, trim_video =0, **kwargs):
rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000)
# read video
import decord
decord.bridge.set_bridge('torch')
readers = []
for data_k in data_key_batch:
reader = decord.VideoReader(data_k)
readers.append(reader)
fps = readers[0].get_avg_fps()
length = min([len(r) for r in readers])
frame_timestamps = [readers[0].get_frame_timestamp(i) for i in range(length)]
frame_timestamps = np.array(frame_timestamps, dtype=np.float32)
# # frame_timestamps = frame_timestamps[ :max_frames]
# if trim_video > 0:
# frame_timestamps = frame_timestamps[ :trim_video]
max_frames = min(max_frames, trim_video) if trim_video > 0 else max_frames
h, w = readers[0].next().shape[:2]
frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, frame_timestamps, h, w, crop_box, rng, max_frames=max_frames)
# preprocess video
videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers]
videos = [self._video_preprocess(video, oh, ow) for video in videos]
return *videos, frame_ids, (oh, ow), fps
# return videos if len(videos) > 1 else videos[0]
def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size, device):
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
if sub_src_video is None and sub_src_mask is None:
src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
src_mask[i] = torch.ones((1, num_frames, image_size[0], image_size[1]), device=device)
for i, ref_images in enumerate(src_ref_images):
if ref_images is not None:
for j, ref_img in enumerate(ref_images):
if ref_img is not None and ref_img.shape[-2:] != image_size:
canvas_height, canvas_width = image_size
ref_height, ref_width = ref_img.shape[-2:]
white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1]
scale = min(canvas_height / ref_height, canvas_width / ref_width)
new_height = int(ref_height * scale)
new_width = int(ref_width * scale)
resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1)
top = (canvas_height - new_height) // 2
left = (canvas_width - new_width) // 2
white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image
src_ref_images[i][j] = white_canvas
return src_video, src_mask, src_ref_images