From cf495a4aaf6b1ab8f3bb71c494effe704017b345 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Sat, 5 Apr 2025 02:06:17 +0200 Subject: [PATCH] Added Vac Contronet support --- gradio_server.py | 403 +++++++++++++++++++++------------ requirements.txt | 8 +- wan/configs/__init__.py | 14 ++ wan/modules/model.py | 144 ++++++++++-- wan/text2video.py | 194 ++++++++++++++-- wan/utils/utils.py | 49 ++++ wan/utils/vace_preprocessor.py | 298 ++++++++++++++++++++++++ 7 files changed, 929 insertions(+), 181 deletions(-) create mode 100644 wan/utils/vace_preprocessor.py diff --git a/gradio_server.py b/gradio_server.py index e85f67b..bcc0164 100644 --- a/gradio_server.py +++ b/gradio_server.py @@ -14,7 +14,7 @@ import gradio as gr import random import json import wan -from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES +from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES, VACE_SIZE_CONFIGS from wan.utils.utils import cache_video from wan.modules.attention import get_attention_modes, get_supported_attention_modes import torch @@ -55,6 +55,11 @@ def format_time(seconds): def pil_to_base64_uri(pil_image, format="png", quality=75): if pil_image is None: return None + + if isinstance(pil_image, str): + from wan.utils.utils import get_video_frame + pil_image = get_video_frame(pil_image, 0) + buffer = io.BytesIO() try: img_to_save = pil_image @@ -93,10 +98,11 @@ def process_prompt_and_add_tasks( loras_choices, loras_mult_choices, image_prompt_type, - image_to_continue, - image_to_end, - video_to_continue, + image_source1, + image_source2, + image_source3, max_frames, + remove_background_image_ref, temporal_upsampling, spatial_upsampling, RIFLEx_setting, @@ -127,9 +133,9 @@ def process_prompt_and_add_tasks( return file_model_needed = model_needed(image2video) + width, height = resolution.split("x") + width, height = int(width), int(height) if image2video: - width, height = resolution.split("x") - width, height = int(width), int(height) if "480p" in file_model_needed and not "Fun" in file_model_needed and width * height > 848*480: gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P") @@ -143,74 +149,94 @@ def process_prompt_and_add_tasks( gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P") return - if image2video: - if image_to_continue == None or isinstance(image_to_continue, list) and len(image_to_continue) == 0: + if not image2video: + if "Vace" in file_model_needed and "1.3B" in file_model_needed : + resolution_reformated = str(height) + "*" + str(width) + if not resolution_reformated in VACE_SIZE_CONFIGS: + res = VACE_SIZE_CONFIGS.keys().join(" and ") + gr.Info(f"Video Resolution for Vace model is not supported. Only {res} resolutions are allowed.") + return + + if not "I" in image_prompt_type: + image_source1 = None + if not "V" in image_prompt_type: + image_source2 = None + if not "M" in image_prompt_type: + image_source3 = None + + if isinstance(image_source1, list): + image_source1 = [ convert_image(tup[0]) for tup in image_source1 ] + + from wan.utils.utils import resize_and_remove_background + image_source1 = resize_and_remove_background(image_source1, width, height, remove_background_image_ref ==1) + + image_source1 = [ image_source1 ] * len(prompts) + image_source2 = [ image_source2 ] * len(prompts) + image_source3 = [ image_source3 ] * len(prompts) + + else: + if image_source1 == None or isinstance(image_source1, list) and len(image_source1) == 0: return if image_prompt_type == 0: - image_to_end = None - if isinstance(image_to_continue, list): - image_to_continue = [ convert_image(tup[0]) for tup in image_to_continue ] + image_source2 = None + if isinstance(image_source1, list): + image_source1 = [ convert_image(tup[0]) for tup in image_source1 ] else: - image_to_continue = [convert_image(image_to_continue)] - if image_to_end != None: - if isinstance(image_to_end , list): - image_to_end = [ convert_image(tup[0]) for tup in image_to_end ] + image_source1 = [convert_image(image_source1)] + if image_source2 != None: + if isinstance(image_source2 , list): + image_source2 = [ convert_image(tup[0]) for tup in image_source2 ] else: - image_to_end = [convert_image(image_to_end) ] - if len(image_to_continue) != len(image_to_end): + image_source2 = [convert_image(image_source2) ] + if len(image_source1) != len(image_source2): gr.Info("The number of start and end images should be the same ") return if multi_images_gen_type == 0: new_prompts = [] - new_image_to_continue = [] - new_image_to_end = [] - for i in range(len(prompts) * len(image_to_continue) ): + new_image_source1 = [] + new_image_source2 = [] + for i in range(len(prompts) * len(image_source1) ): new_prompts.append( prompts[ i % len(prompts)] ) - new_image_to_continue.append(image_to_continue[i // len(prompts)] ) - if image_to_end != None: - new_image_to_end.append(image_to_end[i // len(prompts)] ) + new_image_source1.append(image_source1[i // len(prompts)] ) + if image_source2 != None: + new_image_source2.append(image_source2[i // len(prompts)] ) prompts = new_prompts - image_to_continue = new_image_to_continue - if image_to_end != None: - image_to_end = new_image_to_end + image_source1 = new_image_source1 + if image_source2 != None: + image_source2 = new_image_source2 else: - if len(prompts) >= len(image_to_continue): - if len(prompts) % len(image_to_continue) !=0: + if len(prompts) >= len(image_source1): + if len(prompts) % len(image_source1) !=0: raise gr.Error("If there are more text prompts than input images the number of text prompts should be dividable by the number of images") - rep = len(prompts) // len(image_to_continue) - new_image_to_continue = [] - new_image_to_end = [] + rep = len(prompts) // len(image_source1) + new_image_source1 = [] + new_image_source2 = [] for i, _ in enumerate(prompts): - new_image_to_continue.append(image_to_continue[i//rep] ) - if image_to_end != None: - new_image_to_end.append(image_to_end[i//rep] ) - image_to_continue = new_image_to_continue - if image_to_end != None: - image_to_end = new_image_to_end + new_image_source1.append(image_source1[i//rep] ) + if image_source2 != None: + new_image_source2.append(image_source2[i//rep] ) + image_source1 = new_image_source1 + if image_source2 != None: + image_source2 = new_image_source2 else: - if len(image_to_continue) % len(prompts) !=0: + if len(image_source1) % len(prompts) !=0: raise gr.Error("If there are more input images than text prompts the number of images should be dividable by the number of text prompts") - rep = len(image_to_continue) // len(prompts) + rep = len(image_source1) // len(prompts) new_prompts = [] - for i, _ in enumerate(image_to_continue): + for i, _ in enumerate(image_source1): new_prompts.append( prompts[ i//rep] ) prompts = new_prompts - # elif video_to_continue != None and len(video_to_continue) >0 : - # input_image_or_video_path = video_to_continue - # # pipeline.num_input_frames = max_frames - # # pipeline.max_frames = max_frames - # else: - # return - # else: - # input_image_or_video_path = None - if image_to_continue == None: - image_to_continue = [None] * len(prompts) - if image_to_end == None: - image_to_end = [None] * len(prompts) + + if image_source1 == None: + image_source1 = [None] * len(prompts) + if image_source2 == None: + image_source2 = [None] * len(prompts) + if image_source3 == None: + image_source3 = [None] * len(prompts) - for single_prompt, image_start, image_end in zip(prompts, image_to_continue, image_to_end) : + for single_prompt, image_source1, image_source2, image_source3 in zip(prompts, image_source1, image_source2, image_source3) : kwargs = { "prompt" : single_prompt, "negative_prompt" : negative_prompt, @@ -228,10 +254,11 @@ def process_prompt_and_add_tasks( "loras_choices" : loras_choices, "loras_mult_choices" : loras_mult_choices, "image_prompt_type" : image_prompt_type, - "image_to_continue": image_start, - "image_to_end" : image_end, - "video_to_continue" : video_to_continue , + "image_source1": image_source1, + "image_source2" : image_source2, + "image_source3" : image_source3 , "max_frames" : max_frames, + "remove_background_image_ref" : remove_background_image_ref, "temporal_upsampling" : temporal_upsampling, "spatial_upsampling" : spatial_upsampling, "RIFLEx_setting" : RIFLEx_setting, @@ -262,8 +289,9 @@ def add_video_task(**kwargs): queue = gen["queue"] task_id += 1 current_task_id = task_id - start_image_data = kwargs["image_to_continue"] - end_image_data = kwargs["image_to_end"] + start_image_data = kwargs["image_source1"] + start_image_data = [start_image_data] if not isinstance(start_image_data, list) else start_image_data + end_image_data = kwargs["image_source2"] queue.append({ "id": current_task_id, @@ -275,7 +303,7 @@ def add_video_task(**kwargs): "prompt": kwargs["prompt"], "start_image_data": start_image_data, "end_image_data": end_image_data, - "start_image_data_base64": pil_to_base64_uri(start_image_data, format="jpeg", quality=70), + "start_image_data_base64": [pil_to_base64_uri(img, format="jpeg", quality=70) for img in start_image_data], "end_image_data_base64": pil_to_base64_uri(end_image_data, format="jpeg", quality=70) }) return update_queue_data(queue) @@ -342,6 +370,7 @@ def get_queue_table(queue): full_prompt = item['prompt'].replace('"', '"') prompt_cell = f'{truncated_prompt}' start_img_uri =item.get('start_image_data_base64') + start_img_uri = start_img_uri[0] if start_img_uri !=None else None end_img_uri = item.get('end_image_data_base64') thumbnail_size = "50px" num_steps = item.get('steps') @@ -694,6 +723,9 @@ attention_modes_installed = get_attention_modes() attention_modes_supported = get_supported_attention_modes() args = _parse_args() args.flow_reverse = True +processing_device = args.gpu +if len(processing_device) == 0: + processing_device ="cuda" # torch.backends.cuda.matmul.allow_fp16_accumulation = True lock_ui_attention = False lock_ui_transformer = False @@ -706,7 +738,7 @@ quantizeTransformer = args.quantize_transformer check_loras = args.check_loras ==1 advanced = args.advanced -transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors"] +transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", "ckpts/wan2.1_Vace_1.3B_preview_bf16.safetensors"] transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", ] text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"] @@ -750,7 +782,7 @@ def get_default_settings(filename, i2v): "prompts": get_default_prompt(i2v), "resolution": "832x480", "video_length": 81, - "image_prompt_type" : 0, + "image_prompt_type" : 0 if i2v else "", "num_inference_steps": 30, "seed": -1, "repeat_generation": 1, @@ -1149,6 +1181,9 @@ def get_model_name(model_filename): if "Fun" in model_filename: model_name = "Fun InP image2video" model_name += " 14B" if "14B" in model_filename else " 1.3B" + elif "Vace" in model_filename: + model_name = "Vace ControlNet text2video" + model_name += " 14B" if "14B" in model_filename else " 1.3B" elif "image" in model_filename: model_name = "Wan2.1 image2video" model_name += " 720p" if "720p" in model_filename else " 480p" @@ -1353,22 +1388,22 @@ def refresh_gallery(state, msg): end_img_md = "" prompt = task["prompt"] - if task.get('image2video'): - start_img_uri = task.get('start_image_data_base64') - end_img_uri = task.get('end_image_data_base64') - thumbnail_size = "100px" - if start_img_uri: - start_img_md = f'Start' - if end_img_uri: - end_img_md = f'End' + start_img_uri = task.get('start_image_data_base64') + start_img_uri = start_img_uri[0] if start_img_uri !=None else None + end_img_uri = task.get('end_image_data_base64') + thumbnail_size = "100px" + if start_img_uri: + start_img_md = f'Start' + if end_img_uri: + end_img_md = f'End' label = f"Prompt of Video being Generated" html = "" if start_img_md != "": html += "" - if end_img_md != "": - html += "" + if end_img_md != "": + html += "" html += "
" + prompt + "" + start_img_md + "" + end_img_md + "" + end_img_md + "
" html_output = gr.HTML(html, visible= True) @@ -1419,24 +1454,26 @@ def expand_slist(slist, num_inference_steps ): new_slist.append(slist[ int(pos)]) pos += inc return new_slist - def convert_image(image): - from PIL import ExifTags - image = image.convert('RGB') - for orientation in ExifTags.TAGS.keys(): - if ExifTags.TAGS[orientation]=='Orientation': - break - exif = image.getexif() - if not orientation in exif: - return image - if exif[orientation] == 3: - image=image.rotate(180, expand=True) - elif exif[orientation] == 6: - image=image.rotate(270, expand=True) - elif exif[orientation] == 8: - image=image.rotate(90, expand=True) - return image + from PIL import ExifTags, ImageOps + from typing import cast + + return cast(Image, ImageOps.exif_transpose(image)) + # image = image.convert('RGB') + # for orientation in ExifTags.TAGS.keys(): + # if ExifTags.TAGS[orientation]=='Orientation': + # break + # exif = image.getexif() + # return image + # if not orientation in exif: + # if exif[orientation] == 3: + # image=image.rotate(180, expand=True) + # elif exif[orientation] == 6: + # image=image.rotate(270, expand=True) + # elif exif[orientation] == 8: + # image=image.rotate(90, expand=True) + # return image def generate_video( task_id, @@ -1457,10 +1494,11 @@ def generate_video( loras_choices, loras_mult_choices, image_prompt_type, - image_to_continue, - image_to_end, - video_to_continue, + image_source1, + image_source2, + image_source3, max_frames, + remove_background_image_ref, temporal_upsampling, spatial_upsampling, RIFLEx_setting, @@ -1507,7 +1545,6 @@ def generate_video( gr.Info(f"You have selected attention mode '{attention_mode}'. However it is not installed or supported on your system. You should either install it or switch to the default 'sdpa' attention.") return - if not image2video: width, height = resolution.split("x") @@ -1586,7 +1623,7 @@ def generate_video( enable_RIFLEx = RIFLEx_setting == 0 and video_length > (6* 16) or RIFLEx_setting == 1 # VAE Tiling - device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576 + device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576 joint_pass = boost ==1 #and profile != 1 and profile != 3 # TeaCache @@ -1615,6 +1652,17 @@ def generate_video( else: raise gr.Error("Teacache not supported for this model") + if "Vace" in model_filename: + resolution_reformated = str(height) + "*" + str(width) + src_video, src_mask, src_ref_images = wan_model.prepare_source([image_source2], + [image_source3], + [image_source1], + video_length, VACE_SIZE_CONFIGS[resolution_reformated], "cpu", + trim_video=max_frames) + else: + src_video, src_mask, src_ref_images = None, None, None + + import random if seed == None or seed <0: seed = random.randint(0, 999999999) @@ -1673,8 +1721,8 @@ def generate_video( if image2video: samples = wan_model.generate( prompt, - image_to_continue, - image_to_end if image_to_end != None else None, + image_source1, + image_source2 if image_source2 != None else None, frame_num=(video_length // 4)* 4 + 1, max_area=MAX_AREA_CONFIGS[resolution], shift=flow_shift, @@ -1697,6 +1745,9 @@ def generate_video( else: samples = wan_model.generate( prompt, + input_frames = src_video, + input_ref_images= src_ref_images, + input_masks = src_mask, frame_num=(video_length // 4)* 4 + 1, size=(width, height), shift=flow_shift, @@ -1745,7 +1796,7 @@ def generate_video( new_error = "The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames." else: new_error = gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'") - tb = traceback.format_exc().split('\n')[:-2] + tb = traceback.format_exc().split('\n')[:-1] print('\n'.join(tb)) raise gr.Error(new_error, print_exception= False) @@ -1799,7 +1850,7 @@ def generate_video( if exp > 0: from rife.inference import temporal_interpolation - sample = temporal_interpolation( os.path.join("ckpts", "flownet.pkl"), sample, exp, device="cuda") + sample = temporal_interpolation( os.path.join("ckpts", "flownet.pkl"), sample, exp, device=processing_device) fps = fps * 2**exp if len(spatial_upsampling) > 0: @@ -1831,8 +1882,7 @@ def generate_video( normalize=True, value_range=(-1, 1)) - - configs = get_settings_dict(state, image2video, prompt, 0 if image_to_end == None else 1 , video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices, + configs = get_settings_dict(state, image2video, True, prompt, image_prompt_type, max_frames , remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices, loras_mult_choices, tea_cache , tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start, slg_end, cfg_star_switch, cfg_zero_step) metadata_choice = server_config.get("metadata_choice","metadata") @@ -2294,7 +2344,7 @@ def switch_advanced(state, new_advanced, lset_name): return gr.Row(visible=new_advanced), gr.Row(visible=True), gr.Button(visible=True), gr.Row(visible= False), gr.Dropdown(choices=lset_choices, value= lset_name) -def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices, +def get_settings_dict(state, i2v, image_metadata, prompt, image_prompt_type, max_frames, remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step): loras = state["loras"] @@ -2330,18 +2380,22 @@ def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resol ui_settings["type"] = "Wan2.1GP by DeepBeepMeep - image2video" ui_settings["image_prompt_type"] = image_prompt_type else: + if "Vace" in transformer_filename_t2v or not image_metadata: + ui_settings["image_prompt_type"] = image_prompt_type + ui_settings["max_frames"] = max_frames + ui_settings["remove_background_image_ref"] = remove_background_image_ref ui_settings["type"] = "Wan2.1GP by DeepBeepMeep - text2video" return ui_settings -def save_settings(state, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices, +def save_settings(state, prompt, image_prompt_type, max_frames, remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step): if state.get("validate_success",0) != 1: return image2video = state["image2video"] - ui_defaults = get_settings_dict(state, image2video, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices, + ui_defaults = get_settings_dict(state, image2video, False, prompt, image_prompt_type, max_frames, remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step) defaults_filename = get_settings_file_name(image2video) @@ -2379,6 +2433,25 @@ def download_loras(): writer.write(f"Loras downloaded on the {dt} at {time.time()} on the {time.time()}") return +def refresh_i2v_image_prompt_type_radio(state, image_prompt_type_radio): + if args.multiple_images: + return gr.Gallery(visible = (image_prompt_type_radio == 1) ) + else: + return gr.Image(visible = (image_prompt_type_radio == 1) ) + +def refresh_t2v_image_prompt_type_radio(state, image_prompt_type_radio): + vace_model = "Vace" in state["image_input_type_model"] and not state["image2video"] + return gr.Column(visible= vace_model), gr.Radio(value= image_prompt_type_radio), gr.Gallery(visible = "I" in image_prompt_type_radio), gr.Video(visible= "V" in image_prompt_type_radio),gr.Video(visible= "M" in image_prompt_type_radio ), gr.Text(visible= "V" in image_prompt_type_radio) , gr.Checkbox(visible= "I" in image_prompt_type_radio) + +def check_refresh_input_type(state): + if not state["image2video"]: + model_file_name = state["image_input_type_model"] + model_file_needed= model_needed(False) + if model_file_name != model_file_needed: + state["image_input_type_model"] = model_file_needed + return gr.Text(value= str(time.time())) + return gr.Text() + def generate_video_tab(image2video=False): filename = transformer_filename_i2v if image2video else transformer_filename_t2v ui_defaults= get_default_settings(filename, image2video) @@ -2387,6 +2460,7 @@ def generate_video_tab(image2video=False): state_dict["advanced"] = advanced state_dict["loras_model"] = filename + state_dict["image_input_type_model"] = filename state_dict["image2video"] = image2video gen = dict() gen["queue"] = [] @@ -2461,31 +2535,51 @@ def generate_video_tab(image2video=False): save_lset_btn = gr.Button("Save", size="sm", min_width= 1) delete_lset_btn = gr.Button("Delete", size="sm", min_width= 1) cancel_lset_btn = gr.Button("Don't do it !", size="sm", min_width= 1 , visible=False) - video_to_continue = gr.Video(label= "Video to continue", visible= image2video and False) ####### - image_prompt_type= ui_defaults.get("image_prompt_type",0) - image_prompt_type_radio = gr.Radio( [("Use only a Start Image", 0),("Use both a Start and an End Image", 1)], value =image_prompt_type, label="Location", show_label= False, scale= 3, visible=image2video) - if args.multiple_images: - image_to_continue = gr.Gallery( - label="Images as starting points for new videos", type ="pil", #file_types= "image", - columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=image2video) - else: - image_to_continue = gr.Image(label= "Image as a starting point for a new video", type ="pil", visible=image2video) + state = gr.State(state_dict) + vace_model = "Vace" in filename and not image2video + trigger_refresh_input_type = gr.Text(interactive= False, visible= False) + with gr.Column(visible= image2video or vace_model) as image_prompt_column: + if image2video: + image_source3 = gr.Video(label= "Placeholder", visible= image2video and False) - if args.multiple_images: - image_to_end = gr.Gallery( - label="Images as ending points for new videos", type ="pil", #file_types= "image", - columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=image_prompt_type==1) - else: - image_to_end = gr.Image(label= "Last Image for a new video", type ="pil", visible=image_prompt_type==1) + image_prompt_type= ui_defaults.get("image_prompt_type",0) + image_prompt_type_radio = gr.Radio( [("Use only a Start Image", 0),("Use both a Start and an End Image", 1)], value =image_prompt_type, label="Location", show_label= False, scale= 3) - def switch_image_prompt_type_radio(image_prompt_type_radio): - if args.multiple_images: - return gr.Gallery(visible = (image_prompt_type_radio == 1) ) + if args.multiple_images: + image_source1 = gr.Gallery( + label="Images as starting points for new videos", type ="pil", #file_types= "image", + columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True) + else: + image_source1 = gr.Image(label= "Image as a starting point for a new video", type ="pil") + + if args.multiple_images: + image_source2 = gr.Gallery( + label="Images as ending points for new videos", type ="pil", #file_types= "image", + columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=image_prompt_type==1) + else: + image_source2 = gr.Image(label= "Last Image for a new video", type ="pil", visible=image_prompt_type==1) + + + image_prompt_type_radio.change(fn=refresh_i2v_image_prompt_type_radio, inputs=[state, image_prompt_type_radio], outputs=[image_source2]) + max_frames = gr.Slider(1, 100,step=1, visible = False) + remove_background_image_ref = gr.Text(visible = False) else: - return gr.Image(visible = (image_prompt_type_radio == 1) ) + image_prompt_type= ui_defaults.get("image_prompt_type","I") + image_prompt_type_radio = gr.Radio( [("Use Images Ref", "I"),("a Video", "V"), ("Images + a Video", "IV"), ("Video + Video Mask", "VM"), ("Images + Video + Mask", "IVM")], value =image_prompt_type, label="Location", show_label= False, scale= 3, visible = vace_model) + image_source1 = gr.Gallery( + label="Reference Images of Faces and / or Object to be found in the Video", type ="pil", + columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in image_prompt_type ) - image_prompt_type_radio.change(fn=switch_image_prompt_type_radio, inputs=[image_prompt_type_radio], outputs=[image_to_end]) + image_source2 = gr.Video(label= "Reference Video", visible= "V" in image_prompt_type ) + with gr.Row(): + max_frames = gr.Slider(0, 100, value=ui_defaults.get("max_frames",0), step=1, label="Nb of frames in Reference Video to use in Video (0 for as many as possible)", visible= "V" in image_prompt_type, scale = 2 ) + remove_background_image_ref = gr.Checkbox(value=ui_defaults.get("remove_background_image_ref",1), label= "Remove Images Ref. Background", visible= "I" in image_prompt_type, scale =1 ) + + image_source3 = gr.Video(label= "Video Mask (white pixels = Mask)", visible= "M" in image_prompt_type ) + + + gr.on(triggers=[image_prompt_type_radio.change, trigger_refresh_input_type.change], fn=refresh_t2v_image_prompt_type_radio, inputs=[state, image_prompt_type_radio], outputs=[image_prompt_column, image_prompt_type_radio, image_source1, image_source2, image_source3, max_frames, remove_background_image_ref]) advanced_prompt = advanced @@ -2518,7 +2612,6 @@ def generate_video_tab(image2video=False): wizard_prompt = gr.Textbox(visible = not advanced_prompt, label="Prompts (each new line of prompt will generate a new video, # lines = comments)", value=default_wizard_prompt, lines=3) wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False) wizard_variables_var = gr.Text(wizard_variables, visible = False) - state = gr.State(state_dict) with gr.Row(): if image2video: resolution = gr.Dropdown( @@ -2555,8 +2648,6 @@ def generate_video_tab(image2video=False): video_length = gr.Slider(5, 193, value=ui_defaults["video_length"], step=4, label="Number of frames (16 = 1s)") with gr.Column(): num_inference_steps = gr.Slider(1, 100, value=ui_defaults["num_inference_steps"], step=1, label="Number of Inference Steps") - with gr.Row(): - max_frames = gr.Slider(1, 100, value=9, step=1, label="Number of input frames to use for Video2World prediction", visible=image2video and False) ######### show_advanced = gr.Checkbox(label="Advanced Mode", value=advanced) with gr.Row(visible=advanced) as advanced_row: with gr.Column(): @@ -2605,7 +2696,7 @@ def generate_video_tab(image2video=False): tea_cache_start_step_perc = gr.Slider(0, 100, value=ui_defaults["tea_cache_start_step_perc"], step=1, label="Tea Cache starting moment in % of generation") with gr.Row(): - gr.Markdown("Upsampling") + gr.Markdown("Upsampling - postprocessing that may improve fluidity and the size of the video") with gr.Row(): temporal_upsampling_choice = gr.Dropdown( choices=[ @@ -2687,9 +2778,10 @@ def generate_video_tab(image2video=False): show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then( fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars]) with gr.Column(): - gen_status = gr.Text(label="Status", interactive= False) - full_sync = gr.Text(label="Status", interactive= False, visible= False) - light_sync = gr.Text(label="Status", interactive= False, visible= False) + gen_status = gr.Text(interactive= False) + full_sync = gr.Text(interactive= False, visible= False) + light_sync = gr.Text(interactive= False, visible= False) + gen_progress_html = gr.HTML( label="Status", value="Idle", @@ -2709,8 +2801,8 @@ def generate_video_tab(image2video=False): abort_btn = gr.Button("Abort") queue_df = gr.DataFrame( - headers=["Qty","Prompt", "Length","Steps","Start", "End", "", "", ""], - datatype=[ "str","markdown","str", "markdown", "markdown", "markdown", "str", "str", "str"], + headers=["Qty","Prompt", "Length","Steps","", "", "", "", ""], + datatype=[ "str","markdown","str", "markdown", "markdown", "markdown", "str", "str", "str"], column_widths= ["50","", "65","55", "60", "60", "30", "30", "35"], interactive=False, col_count=(9, "fixed"), @@ -2792,7 +2884,7 @@ def generate_video_tab(image2video=False): show_progress="hidden" ) save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then( - save_settings, inputs = [state, prompt, image_prompt_type_radio, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, + save_settings, inputs = [state, prompt, image_prompt_type_radio, max_frames, remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling_choice, spatial_upsampling_choice, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step ], outputs = []) save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop]) @@ -2808,21 +2900,30 @@ def generate_video_tab(image2video=False): refresh_lora_btn2.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices]) output.select(select_video, state, None ) + + gen_status.change(refresh_gallery, inputs = [state, gen_status], outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn]) - full_sync.change(refresh_gallery, + full_sync.change(fn= check_refresh_input_type, + inputs= [state], + outputs= [trigger_refresh_input_type] + ).then(fn=refresh_gallery, inputs = [state, gen_status], outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn] - ).then( fn=wait_tasks_done, + ).then(fn=wait_tasks_done, inputs= [state], outputs =[gen_status], ).then(finalize_generation, inputs= [state], outputs= [output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info] ) - light_sync.change(refresh_gallery, + + light_sync.change(fn= check_refresh_input_type, + inputs= [state], + outputs= [trigger_refresh_input_type] + ).then(fn=refresh_gallery, inputs = [state, gen_status], outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn] ) @@ -2848,10 +2949,11 @@ def generate_video_tab(image2video=False): loras_choices, loras_mult_choices, image_prompt_type_radio, - image_to_continue, - image_to_end, - video_to_continue, + image_source1, + image_source2, + image_source3, max_frames, + remove_background_image_ref, temporal_upsampling_choice, spatial_upsampling_choice, RIFLEx_setting, @@ -2902,7 +3004,7 @@ def generate_video_tab(image2video=False): ) return loras_column, loras_choices, presets_column, lset_name, header, light_sync, full_sync, state -def generate_doxnload_tab(presets_column, loras_column, lset_name,loras_choices, state): +def generate_download_tab(presets_column, loras_column, lset_name,loras_choices, state): with gr.Row(): with gr.Row(scale =2): gr.Markdown("Wan2GP's Lora Festival ! Press the following button to download i2v Remade Loras collection (and bonuses Loras).") @@ -2928,6 +3030,7 @@ def generate_configuration_tab(): ("WAN 2.1 1.3B Text to Video 16 bits (recommended)- the small model for fast generations with low VRAM requirements", 0), ("WAN 2.1 14B Text to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 1), ("WAN 2.1 14B Text to Video quantized to 8 bits (recommended) - the default engine but quantized", 2), + ("WAN 2.1 VACE 1.3B Text to Video / Control Net - text generation driven by reference images or videos", 3), ], value= index, label="Transformer model for Text to Video", @@ -3108,16 +3211,17 @@ def on_tab_select(global_state, t2v_state, i2v_state, evt: gr.SelectData): t2v_light_sync = gr.Text() i2v_full_sync = gr.Text() t2v_full_sync = gr.Text() + + last_tab_was_image2video =global_state.get("last_tab_was_image2video", None) + if last_tab_was_image2video == None or last_tab_was_image2video: + gen = i2v_state["gen"] + t2v_state["gen"] = gen + else: + gen = t2v_state["gen"] + i2v_state["gen"] = gen + + if new_t2v or new_i2v: - last_tab_was_image2video =global_state.get("last_tab_was_image2video", None) - if last_tab_was_image2video == None or last_tab_was_image2video: - gen = i2v_state["gen"] - t2v_state["gen"] = gen - else: - gen = t2v_state["gen"] - i2v_state["gen"] = gen - - if last_tab_was_image2video != None and new_t2v != new_i2v: gen_location = gen.get("location", None) if "in_progress" in gen and gen_location !=None and not (gen_location and new_i2v or not gen_location and new_t2v) : @@ -3131,7 +3235,6 @@ def on_tab_select(global_state, t2v_state, i2v_state, evt: gr.SelectData): else: t2v_light_sync = gr.Text(str(time.time())) - global_state["last_tab_was_image2video"] = new_i2v if(server_config.get("reload_model",2) == 1): @@ -3433,7 +3536,7 @@ def create_demo(): } """ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md"), title= "Wan2GP") as demo: - gr.Markdown("

Wan 2.1GP v3.4 by DeepBeepMeep (Updates)

") + gr.Markdown("

Wan 2.1GP v4.0 by DeepBeepMeep (Updates)

") gr.Markdown("Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !") with gr.Accordion("Click here for some Info on how to use Wan2GP", open = False): @@ -3454,7 +3557,7 @@ def create_demo(): i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name, i2v_header, i2v_light_sync, i2v_full_sync, i2v_state = generate_video_tab(True) if not args.lock_config: with gr.Tab("Downloads", id="downloads") as downloads_tab: - generate_doxnload_tab(i2v_presets_column, i2v_loras_column, i2v_lset_name, i2v_loras_choices, i2v_state) + generate_download_tab(i2v_presets_column, i2v_loras_column, i2v_lset_name, i2v_loras_choices, i2v_state) with gr.Tab("Configuration"): generate_configuration_tab() with gr.Tab("About"): diff --git a/requirements.txt b/requirements.txt index 7576271..a4cd3b3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,11 +11,15 @@ easydict ftfy dashscope imageio-ffmpeg -# flash_attn +# flash_attn gradio>=5.0.0 numpy>=1.23.5,<2 einops moviepy==1.0.3 mmgp==3.3.4 peft==0.14.0 -mutagen \ No newline at end of file +mutagen +decord +onnxruntime-gpu +rembg[gpu]==2.0.65 +# rembg==2.0.65 \ No newline at end of file diff --git a/wan/configs/__init__.py b/wan/configs/__init__.py index c72d2d0..e3f539b 100644 --- a/wan/configs/__init__.py +++ b/wan/configs/__init__.py @@ -40,3 +40,17 @@ SUPPORTED_SIZES = { 'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), 't2i-14B': tuple(SIZE_CONFIGS.keys()), } + +VACE_SIZE_CONFIGS = { + '480*832': (480, 832), + '832*480': (832, 480), +} + +VACE_MAX_AREA_CONFIGS = { + '480*832': 480 * 832, + '832*480': 832 * 480, +} + +VACE_SUPPORTED_SIZES = { + 'vace-1.3B': ('480*832', '832*480'), +} diff --git a/wan/modules/model.py b/wan/modules/model.py index 2daa00c..5eba92b 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -377,6 +377,7 @@ class WanI2VCrossAttention(WanSelfAttention): return x + WAN_CROSSATTENTION_CLASSES = { 't2v_cross_attn': WanT2VCrossAttention, 'i2v_cross_attn': WanI2VCrossAttention, @@ -393,7 +394,9 @@ class WanAttentionBlock(nn.Module): window_size=(-1, -1), qk_norm=True, cross_attn_norm=False, - eps=1e-6): + eps=1e-6, + block_id=None + ): super().__init__() self.dim = dim self.ffn_dim = ffn_dim @@ -422,6 +425,7 @@ class WanAttentionBlock(nn.Module): # modulation self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + self.block_id = block_id def forward( self, @@ -432,6 +436,8 @@ class WanAttentionBlock(nn.Module): freqs, context, context_lens, + hints= None, + context_scale=1.0, ): r""" Args: @@ -480,10 +486,49 @@ class WanAttentionBlock(nn.Module): x.addcmul_(y, e[5]) + if self.block_id is not None and hints != None: + if context_scale == 1: + x.add_(hints[self.block_id]) + else: + x.add_(hints[self.block_id], alpha =context_scale) + return x - return x - +class VaceWanAttentionBlock(WanAttentionBlock): + def __init__( + self, + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + block_id=0 + ): + super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps) + self.block_id = block_id + if block_id == 0: + self.before_proj = nn.Linear(self.dim, self.dim) + nn.init.zeros_(self.before_proj.weight) + nn.init.zeros_(self.before_proj.bias) + self.after_proj = nn.Linear(self.dim, self.dim) + nn.init.zeros_(self.after_proj.weight) + nn.init.zeros_(self.after_proj.bias) + def forward(self, c, x, **kwargs): + # behold dbm magic ! + if self.block_id == 0: + c = self.before_proj(c) + x + all_c = [] + else: + all_c = c + c = all_c.pop(-1) + c = super().forward(c, **kwargs) + c_skip = self.after_proj(c) + all_c += [c_skip, c] + return all_c + class Head(nn.Module): def __init__(self, dim, out_dim, patch_size, eps=1e-6): @@ -544,6 +589,8 @@ class WanModel(ModelMixin, ConfigMixin): @register_to_config def __init__(self, + vace_layers=None, + vace_in_dim=None, model_type='t2v', patch_size=(1, 2, 2), text_len=512, @@ -628,12 +675,13 @@ class WanModel(ModelMixin, ConfigMixin): self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) # blocks - cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' - self.blocks = nn.ModuleList([ - WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, - window_size, qk_norm, cross_attn_norm, eps) - for _ in range(num_layers) - ]) + if vace_layers == None: + cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' + self.blocks = nn.ModuleList([ + WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, + window_size, qk_norm, cross_attn_norm, eps) + for _ in range(num_layers) + ]) # head self.head = Head(dim, out_dim, patch_size, eps) @@ -646,6 +694,33 @@ class WanModel(ModelMixin, ConfigMixin): # initialize weights self.init_weights() + if vace_layers != None: + self.vace_layers = [i for i in range(0, self.num_layers, 2)] if vace_layers is None else vace_layers + self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim + + assert 0 in self.vace_layers + self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)} + + # blocks + self.blocks = nn.ModuleList([ + WanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm, + self.cross_attn_norm, self.eps, + block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None) + for i in range(self.num_layers) + ]) + + # vace blocks + self.vace_blocks = nn.ModuleList([ + VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm, + self.cross_attn_norm, self.eps, block_id=i) + for i in self.vace_layers + ]) + + # vace patch embeddings + self.vace_patch_embedding = nn.Conv3d( + self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size + ) + def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0): rescale_func = np.poly1d(self.coefficients) @@ -688,6 +763,36 @@ class WanModel(ModelMixin, ConfigMixin): self.rel_l1_thresh = best_threshold print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}") return best_threshold + + def forward_vace( + self, + x, + vace_context, + seq_len, + context, + e, + kwargs + ): + # embeddings + c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] + c = [u.flatten(2).transpose(1, 2) for u in c] + if (len(c) == 1 and seq_len == c[0].size(1)): + c = c[0] + else: + c = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], + dim=1) for u in c + ]) + + # arguments + new_kwargs = dict(x=x) + new_kwargs.update(kwargs) + + for block in self.vace_blocks: + c = block(c, context= context, e= e, **new_kwargs) + hints = c[:-1] + + return hints def forward( self, @@ -695,6 +800,8 @@ class WanModel(ModelMixin, ConfigMixin): t, context, seq_len, + vace_context = None, + vace_context_scale=1.0, clip_fea=None, y=None, freqs = None, @@ -829,13 +936,23 @@ class WanModel(ModelMixin, ConfigMixin): self.previous_residual_cond = None ori_hidden_states = x_list[0].clone() # arguments + kwargs = dict( - # e=e0, seq_lens=seq_lens, grid_sizes=grid_sizes, freqs=freqs, - # context=context, context_lens=context_lens) + + if vace_context == None: + hints_list = [None ] *len(x_list) + else: + hints_list = [] + for x, context in zip(x_list, context_list) : + hints_list.append( self.forward_vace(x, vace_context, seq_len, context= context, e= e0, kwargs= kwargs)) + del x, context + kwargs['context_scale'] = vace_context_scale + + for block_idx, block in enumerate(self.blocks): offload.shared_state["layer"] = block_idx if callback != None: @@ -852,9 +969,10 @@ class WanModel(ModelMixin, ConfigMixin): x_list[0] = block(x_list[0], context = context_list[0], e= e0, **kwargs) else: - for i, (x, context) in enumerate(zip(x_list, context_list)): - x_list[i] = block(x, context = context, e= e0, **kwargs) + for i, (x, context, hints) in enumerate(zip(x_list, context_list, hints_list)): + x_list[i] = block(x, context = context, hints= hints, e= e0, **kwargs) del x + del context, hints if self.enable_teacache: if joint_pass: diff --git a/wan/text2video.py b/wan/text2video.py index cdcbd4f..befe139 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -13,7 +13,9 @@ import torch import torch.cuda.amp as amp import torch.distributed as dist from tqdm import tqdm - +from PIL import Image +import torchvision.transforms.functional as TF +import torch.nn.functional as F from .distributed.fsdp import shard_model from .modules.model import WanModel from .modules.t5 import T5EncoderModel @@ -22,6 +24,7 @@ from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps) from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from wan.modules.posemb_layers import get_rotary_pos_embed +from .utils.vace_preprocessor import VaceVideoProcessor def optimized_scale(positive_flat, negative_flat): @@ -105,8 +108,6 @@ class WanT2V: self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel, writable_tensors= False) - - self.model.eval().requires_grad_(False) if use_usp: @@ -132,8 +133,148 @@ class WanT2V: self.sample_neg_prompt = config.sample_neg_prompt + if "Vace" in model_filename: + self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]), + min_area=480*832, + max_area=480*832, + min_fps=config.sample_fps, + max_fps=config.sample_fps, + zero_start=True, + seq_len=32760, + keep_last=True) + + def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0): + if ref_images is None: + ref_images = [None] * len(frames) + else: + assert len(frames) == len(ref_images) + + if masks is None: + latents = self.vae.encode(frames, tile_size = tile_size) + else: + inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)] + reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)] + inactive = self.vae.encode(inactive, tile_size = tile_size) + reactive = self.vae.encode(reactive, tile_size = tile_size) + latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)] + + cat_latents = [] + for latent, refs in zip(latents, ref_images): + if refs is not None: + if masks is None: + ref_latent = self.vae.encode(refs, tile_size = tile_size) + else: + ref_latent = self.vae.encode(refs, tile_size = tile_size) + ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent] + assert all([x.shape[1] == 1 for x in ref_latent]) + latent = torch.cat([*ref_latent, latent], dim=1) + cat_latents.append(latent) + return cat_latents + + def vace_encode_masks(self, masks, ref_images=None): + if ref_images is None: + ref_images = [None] * len(masks) + else: + assert len(masks) == len(ref_images) + + result_masks = [] + for mask, refs in zip(masks, ref_images): + c, depth, height, width = mask.shape + new_depth = int((depth + 3) // self.vae_stride[0]) + height = 2 * (int(height) // (self.vae_stride[1] * 2)) + width = 2 * (int(width) // (self.vae_stride[2] * 2)) + + # reshape + mask = mask[0, :, :, :] + mask = mask.view( + depth, height, self.vae_stride[1], width, self.vae_stride[1] + ) # depth, height, 8, width, 8 + mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width + mask = mask.reshape( + self.vae_stride[1] * self.vae_stride[2], depth, height, width + ) # 8*8, depth, height, width + + # interpolation + mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0) + + if refs is not None: + length = len(refs) + mask_pad = torch.zeros_like(mask[:, :length, :, :]) + mask = torch.cat((mask_pad, mask), dim=1) + result_masks.append(mask) + return result_masks + + def vace_latent(self, z, m): + return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] + + def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device, trim_video= 0): + image_sizes = [] + for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)): + if sub_src_mask is not None and sub_src_video is not None: + src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video) + src_video[i] = src_video[i].to(device) + src_mask[i] = src_mask[i].to(device) + src_video_shape = src_video[i].shape + if src_video_shape[1] != num_frames: + src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) + src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) + + src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1) + image_sizes.append(src_video[i].shape[2:]) + elif sub_src_video is None: + src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device) + src_mask[i] = torch.ones_like(src_video[i], device=device) + image_sizes.append(image_size) + else: + src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video) + src_video[i] = src_video[i].to(device) + src_video_shape = src_video[i].shape + if src_video_shape[1] != num_frames: + src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) + src_mask[i] = torch.ones_like(src_video[i], device=device) + image_sizes.append(src_video[i].shape[2:]) + + for i, ref_images in enumerate(src_ref_images): + if ref_images is not None: + image_size = image_sizes[i] + for j, ref_img in enumerate(ref_images): + if ref_img is not None: + ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) + if ref_img.shape[-2:] != image_size: + canvas_height, canvas_width = image_size + ref_height, ref_width = ref_img.shape[-2:] + white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1] + scale = min(canvas_height / ref_height, canvas_width / ref_width) + new_height = int(ref_height * scale) + new_width = int(ref_width * scale) + resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1) + top = (canvas_height - new_height) // 2 + left = (canvas_width - new_width) // 2 + white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image + ref_img = white_canvas + src_ref_images[i][j] = ref_img.to(device) + return src_video, src_mask, src_ref_images + + def decode_latent(self, zs, ref_images=None, tile_size= 0 ): + if ref_images is None: + ref_images = [None] * len(zs) + else: + assert len(zs) == len(ref_images) + + trimed_zs = [] + for z, refs in zip(zs, ref_images): + if refs is not None: + z = z[:, len(refs):, :, :] + trimed_zs.append(z) + + return self.vae.decode(trimed_zs, tile_size= tile_size) + def generate(self, input_prompt, + input_frames= None, + input_masks = None, + input_ref_images = None, + context_scale=1.0, size=(1280, 720), frame_num=81, shift=5.0, @@ -187,14 +328,6 @@ class WanT2V: - W: Frame width from size) """ # preprocess - F = frame_num - target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, - size[1] // self.vae_stride[1], - size[0] // self.vae_stride[2]) - - seq_len = math.ceil((target_shape[2] * target_shape[3]) / - (self.patch_size[1] * self.patch_size[2]) * - target_shape[1] / self.sp_size) * self.sp_size if n_prompt == "": n_prompt = self.sample_neg_prompt @@ -213,6 +346,29 @@ class WanT2V: context_null = self.text_encoder([n_prompt], torch.device('cpu')) context = [t.to(self.device) for t in context] context_null = [t.to(self.device) for t in context_null] + + if input_frames != None: + # vace context encode + input_frames = [u.to(self.device) for u in input_frames] + input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images] + input_masks = [u.to(self.device) for u in input_masks] + + z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size) + m0 = self.vace_encode_masks(input_masks, input_ref_images) + z = self.vace_latent(z0, m0) + + target_shape = list(z0[0].shape) + target_shape[0] = int(target_shape[0] / 2) + else: + F = frame_num + target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, + size[1] // self.vae_stride[1], + size[0] // self.vae_stride[2]) + + seq_len = math.ceil((target_shape[2] * target_shape[3]) / + (self.patch_size[1] * self.patch_size[2]) * + target_shape[1] / self.sp_size) * self.sp_size + noise = [ torch.randn( @@ -261,10 +417,12 @@ class WanT2V: arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self} arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self} arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self} + if input_frames != None: + vace_dict = {'vace_context' : z, 'vace_context_scale' : context_scale} + arg_c.update(vace_dict) + arg_null.update(vace_dict) + arg_both.update(vace_dict) - # arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps} - # arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps} - # arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps} if self.model.enable_teacache: self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier) if callback != None: @@ -281,7 +439,7 @@ class WanT2V: # self.model.to(self.device) if joint_pass: noise_pred_cond, noise_pred_uncond = self.model( - latent_model_input, t=timestep,current_step=i, slg_layers=slg_layers_local, **arg_both) + latent_model_input, t=timestep, current_step=i, slg_layers=slg_layers_local, **arg_both) if self._interrupt: return None else: @@ -329,7 +487,11 @@ class WanT2V: self.model.cpu() torch.cuda.empty_cache() if self.rank == 0: - videos = self.vae.decode(x0, VAE_tile_size) + + if input_frames == None: + videos = self.vae.decode(x0, VAE_tile_size) + else: + videos = self.decode_latent(x0, input_ref_images, VAE_tile_size) del noise, latents diff --git a/wan/utils/utils.py b/wan/utils/utils.py index e19c298..d4e237d 100644 --- a/wan/utils/utils.py +++ b/wan/utils/utils.py @@ -3,21 +3,70 @@ import argparse import binascii import os import os.path as osp +import torchvision.transforms.functional as TF +import torch.nn.functional as F import imageio import torch +import decord import torchvision from PIL import Image import numpy as np +from rembg import remove, new_session + __all__ = ['cache_video', 'cache_image', 'str2bool'] + + +from PIL import Image + +def get_video_frame(file_name, frame_no): + decord.bridge.set_bridge('torch') + reader = decord.VideoReader(file_name) + + frame = reader.get_batch([frame_no]).squeeze(0) + img = Image.fromarray(frame.numpy().astype(np.uint8)) + return img + def resize_lanczos(img, h, w): img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) img = img.resize((w,h), resample=Image.Resampling.LANCZOS) return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0) +def remove_background(img, session=None): + if session ==None: + session = new_session() + img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) + img = remove(img, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') + return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0) + + + + +def resize_and_remove_background(img_list, canvas_width, canvas_height, rm_background ): + if rm_background: + session = new_session() + + output_list =[] + for img in img_list: + width, height = img.size + white_canvas = np.full( (canvas_height, canvas_width, 3), 255, dtype= np.uint8 ) + scale = min(canvas_height / height, canvas_width / width) + new_height = int(height * scale) + new_width = int(width * scale) + resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS) + if rm_background: + resized_image = remove(resized_image, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') + top = (canvas_height - new_height) // 2 + left = (canvas_width - new_width) // 2 + white_canvas[top:top + new_height, left:left + new_width, :] = np.array(resized_image) + img = Image.fromarray(white_canvas) + output_list.append(img) + return output_list + + def rand_name(length=8, suffix=''): name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') if suffix: diff --git a/wan/utils/vace_preprocessor.py b/wan/utils/vace_preprocessor.py new file mode 100644 index 0000000..7c10719 --- /dev/null +++ b/wan/utils/vace_preprocessor.py @@ -0,0 +1,298 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +import numpy as np +from PIL import Image +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF + + +class VaceImageProcessor(object): + def __init__(self, downsample=None, seq_len=None): + self.downsample = downsample + self.seq_len = seq_len + + def _pillow_convert(self, image, cvt_type='RGB'): + if image.mode != cvt_type: + if image.mode == 'P': + image = image.convert(f'{cvt_type}A') + if image.mode == f'{cvt_type}A': + bg = Image.new(cvt_type, + size=(image.width, image.height), + color=(255, 255, 255)) + bg.paste(image, (0, 0), mask=image) + image = bg + else: + image = image.convert(cvt_type) + return image + + def _load_image(self, img_path): + if img_path is None or img_path == '': + return None + img = Image.open(img_path) + img = self._pillow_convert(img) + return img + + def _resize_crop(self, img, oh, ow, normalize=True): + """ + Resize, center crop, convert to tensor, and normalize. + """ + # resize and crop + iw, ih = img.size + if iw != ow or ih != oh: + # resize + scale = max(ow / iw, oh / ih) + img = img.resize( + (round(scale * iw), round(scale * ih)), + resample=Image.Resampling.LANCZOS + ) + assert img.width >= ow and img.height >= oh + + # center crop + x1 = (img.width - ow) // 2 + y1 = (img.height - oh) // 2 + img = img.crop((x1, y1, x1 + ow, y1 + oh)) + + # normalize + if normalize: + img = TF.to_tensor(img).sub_(0.5).div_(0.5).unsqueeze(1) + return img + + def _image_preprocess(self, img, oh, ow, normalize=True, **kwargs): + return self._resize_crop(img, oh, ow, normalize) + + def load_image(self, data_key, **kwargs): + return self.load_image_batch(data_key, **kwargs) + + def load_image_pair(self, data_key, data_key2, **kwargs): + return self.load_image_batch(data_key, data_key2, **kwargs) + + def load_image_batch(self, *data_key_batch, normalize=True, seq_len=None, **kwargs): + seq_len = self.seq_len if seq_len is None else seq_len + imgs = [] + for data_key in data_key_batch: + img = self._load_image(data_key) + imgs.append(img) + w, h = imgs[0].size + dh, dw = self.downsample[1:] + + # compute output size + scale = min(1., np.sqrt(seq_len / ((h / dh) * (w / dw)))) + oh = int(h * scale) // dh * dh + ow = int(w * scale) // dw * dw + assert (oh // dh) * (ow // dw) <= seq_len + imgs = [self._image_preprocess(img, oh, ow, normalize) for img in imgs] + return *imgs, (oh, ow) + + +class VaceVideoProcessor(object): + def __init__(self, downsample, min_area, max_area, min_fps, max_fps, zero_start, seq_len, keep_last, **kwargs): + self.downsample = downsample + self.min_area = min_area + self.max_area = max_area + self.min_fps = min_fps + self.max_fps = max_fps + self.zero_start = zero_start + self.keep_last = keep_last + self.seq_len = seq_len + assert seq_len >= min_area / (self.downsample[1] * self.downsample[2]) + + @staticmethod + def resize_crop(video: torch.Tensor, oh: int, ow: int): + """ + Resize, center crop and normalize for decord loaded video (torch.Tensor type) + + Parameters: + video - video to process (torch.Tensor): Tensor from `reader.get_batch(frame_ids)`, in shape of (T, H, W, C) + oh - target height (int) + ow - target width (int) + + Returns: + The processed video (torch.Tensor): Normalized tensor range [-1, 1], in shape of (C, T, H, W) + + Raises: + """ + # permute ([t, h, w, c] -> [t, c, h, w]) + video = video.permute(0, 3, 1, 2) + + # resize and crop + ih, iw = video.shape[2:] + if ih != oh or iw != ow: + # resize + scale = max(ow / iw, oh / ih) + video = F.interpolate( + video, + size=(round(scale * ih), round(scale * iw)), + mode='bicubic', + antialias=True + ) + assert video.size(3) >= ow and video.size(2) >= oh + + # center crop + x1 = (video.size(3) - ow) // 2 + y1 = (video.size(2) - oh) // 2 + video = video[:, :, y1:y1 + oh, x1:x1 + ow] + + # permute ([t, c, h, w] -> [c, t, h, w]) and normalize + video = video.transpose(0, 1).float().div_(127.5).sub_(1.) + return video + + def _video_preprocess(self, video, oh, ow): + return self.resize_crop(video, oh, ow) + + def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box, rng): + target_fps = min(fps, self.max_fps) + duration = frame_timestamps[-1].mean() + x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box + h, w = y2 - y1, x2 - x1 + ratio = h / w + df, dh, dw = self.downsample + + # min/max area of the [latent video] + min_area_z = self.min_area / (dh * dw) + max_area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw)) + + # sample a frame number of the [latent video] + rand_area_z = np.square(np.power(2, rng.uniform( + np.log2(np.sqrt(min_area_z)), + np.log2(np.sqrt(max_area_z)) + ))) + of = min( + (int(duration * target_fps) - 1) // df + 1, + int(self.seq_len / rand_area_z) + ) + + # deduce target shape of the [latent video] + target_area_z = min(max_area_z, int(self.seq_len / of)) + oh = round(np.sqrt(target_area_z * ratio)) + ow = int(target_area_z / oh) + of = (of - 1) * df + 1 + oh *= dh + ow *= dw + + # sample frame ids + target_duration = of / target_fps + begin = 0. if self.zero_start else rng.uniform(0, duration - target_duration) + timestamps = np.linspace(begin, begin + target_duration, of) + frame_ids = np.argmax(np.logical_and( + timestamps[:, None] >= frame_timestamps[None, :, 0], + timestamps[:, None] < frame_timestamps[None, :, 1] + ), axis=1).tolist() + return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps + + def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w, crop_box, rng, max_frames= 0): + import math + target_fps = self.max_fps + video_duration = frame_timestamps[-1][1] + video_frame_duration = 1 /fps + target_frame_duration = 1 / target_fps + + cur_time = 0 + target_time = 0 + frame_no = 0 + frame_ids =[] + for i in range(max_frames): + add_frames_count = math.ceil( (target_time -cur_time) / video_frame_duration ) + frame_no += add_frames_count + frame_ids.append(frame_no) + cur_time += add_frames_count * video_frame_duration + target_time += target_frame_duration + if cur_time > video_duration: + break + + x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box + h, w = y2 - y1, x2 - x1 + ratio = h / w + df, dh, dw = self.downsample + seq_len = self.seq_len + # min/max area of the [latent video] + min_area_z = self.min_area / (dh * dw) + # max_area_z = min(seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw)) + max_area_z = min_area_z # workaround bug + # sample a frame number of the [latent video] + rand_area_z = np.square(np.power(2, rng.uniform( + np.log2(np.sqrt(min_area_z)), + np.log2(np.sqrt(max_area_z)) + ))) + + seq_len = max_area_z * ((max_frames- 1) // df +1) + + # of = min( + # (len(frame_ids) - 1) // df + 1, + # int(seq_len / rand_area_z) + # ) + of = (len(frame_ids) - 1) // df + 1 + + + # deduce target shape of the [latent video] + # target_area_z = min(max_area_z, int(seq_len / of)) + target_area_z = max_area_z + oh = round(np.sqrt(target_area_z * ratio)) + ow = int(target_area_z / oh) + of = (of - 1) * df + 1 + oh *= dh + ow *= dw + + return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps + + def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng, max_frames= 0): + if self.keep_last: + return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h, w, crop_box, rng, max_frames= max_frames) + else: + return self._get_frameid_bbox_default(fps, frame_timestamps, h, w, crop_box, rng, max_frames= max_frames) + + def load_video(self, data_key, crop_box=None, seed=2024, **kwargs): + return self.load_video_batch(data_key, crop_box=crop_box, seed=seed, **kwargs) + + def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs): + return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs) + + def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, max_frames= 0, trim_video =0, **kwargs): + rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000) + # read video + import decord + decord.bridge.set_bridge('torch') + readers = [] + for data_k in data_key_batch: + reader = decord.VideoReader(data_k) + readers.append(reader) + + fps = readers[0].get_avg_fps() + length = min([len(r) for r in readers]) + frame_timestamps = [readers[0].get_frame_timestamp(i) for i in range(length)] + frame_timestamps = np.array(frame_timestamps, dtype=np.float32) + # # frame_timestamps = frame_timestamps[ :max_frames] + # if trim_video > 0: + # frame_timestamps = frame_timestamps[ :trim_video] + max_frames = min(max_frames, trim_video) if trim_video > 0 else max_frames + h, w = readers[0].next().shape[:2] + frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, frame_timestamps, h, w, crop_box, rng, max_frames=max_frames) + + # preprocess video + videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers] + videos = [self._video_preprocess(video, oh, ow) for video in videos] + return *videos, frame_ids, (oh, ow), fps + # return videos if len(videos) > 1 else videos[0] + + +def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size, device): + for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)): + if sub_src_video is None and sub_src_mask is None: + src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device) + src_mask[i] = torch.ones((1, num_frames, image_size[0], image_size[1]), device=device) + for i, ref_images in enumerate(src_ref_images): + if ref_images is not None: + for j, ref_img in enumerate(ref_images): + if ref_img is not None and ref_img.shape[-2:] != image_size: + canvas_height, canvas_width = image_size + ref_height, ref_width = ref_img.shape[-2:] + white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1] + scale = min(canvas_height / ref_height, canvas_width / ref_width) + new_height = int(ref_height * scale) + new_width = int(ref_width * scale) + resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1) + top = (canvas_height - new_height) // 2 + left = (canvas_width - new_width) // 2 + white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image + src_ref_images[i][j] = white_canvas + return src_video, src_mask, src_ref_images