diff --git a/wgp.py b/wgp.py index 2199336..707aff8 100644 --- a/wgp.py +++ b/wgp.py @@ -266,7 +266,7 @@ def process_prompt_and_add_tasks(state, model_choice): } inputs.update(extra_inputs) add_video_task(**inputs) - elif "image2video" in model_filename or "Fun_InP" in model_filename : + elif test_class_i2v(model_filename) : image_prompt_type = inputs["image_prompt_type"] image_start = inputs["image_start"] @@ -1366,7 +1366,7 @@ check_loras = args.check_loras ==1 advanced = args.advanced transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", "ckpts/wan2.1_Vace_1.3B_preview_bf16.safetensors", "ckpts/wan2.1_recammaster_1.3B_bf16.safetensors"] -transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", ] +transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_bf16.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_quanto_int8.safetensors"] transformer_choices = transformer_choices_t2v + transformer_choices_i2v text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"] server_config_filename = "wgp_config.json" @@ -1404,9 +1404,10 @@ else: server_config = json.loads(text) -model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "fun_inp", "recam_1.3B"] +model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "fun_inp", "recam_1.3B", "flf2v_720p"] model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B", - "i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B", "recam_1.3B": "recammaster_1.3B" } + "i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B", "recam_1.3B": "recammaster_1.3B", + "flf2v_720p" : "FLF2V_720p" } def get_model_type(model_filename): @@ -1416,7 +1417,29 @@ def get_model_type(model_filename): raise Exception("Unknown model:" + model_filename) def test_class_i2v(model_filename): - return "image2video" in model_filename or "Fun_InP" in model_filename + return "image2video" in model_filename or "Fun_InP" in model_filename or "FLF2V" in model_filename + +def get_model_name(model_filename): + if "Fun" in model_filename: + model_name = "Fun InP image2video" + model_name += " 14B" if "14B" in model_filename else " 1.3B" + elif "Vace" in model_filename: + model_name = "Vace ControlNet" + model_name += " 14B" if "14B" in model_filename else " 1.3B" + elif "image" in model_filename: + model_name = "Wan2.1 image2video" + model_name += " 720p" if "720p" in model_filename else " 480p" + elif "recam" in model_filename: + model_name = "ReCamMaster" + model_name += " 14B" if "14B" in model_filename else " 1.3B" + elif "FLF2V" in model_filename: + model_name = "Wan2.1 FLF2V" + model_name += " 720p" if "720p" in model_filename else " 480p" + else: + model_name = "Wan2.1 text2video" + model_name += " 14B" if "14B" in model_filename else " 1.3B" + + return model_name def get_model_filename(model_type, quantization): @@ -1444,12 +1467,12 @@ def get_default_settings(filename): return "Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field." else: return "A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect." - i2v = "image2video" in filename or "Fun_InP" in filename + i2v = test_class_i2v(filename) defaults_filename = get_settings_file_name(filename) if not Path(defaults_filename).is_file(): ui_defaults = { "prompt": get_default_prompt(i2v), - "resolution": "832x480", + "resolution": "1280x720" if "720p" in filename else "832x480", "video_length": 81, "num_inference_steps": 30, "seed": -1, @@ -1805,6 +1828,7 @@ def load_models(model_filename): major, minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None) default_dtype = torch.float16 if major < 8 else torch.bfloat16 + # default_dtype = torch.bfloat16 if default_dtype == torch.float16 or args.fp16: print("Switching to f16 model as GPU architecture doesn't support bf16") if "quanto" in model_filename: @@ -1851,24 +1875,6 @@ def get_default_flow(filename, i2v): return 7.0 if "480p" in filename and i2v else 5.0 -def get_model_name(model_filename): - if "Fun" in model_filename: - model_name = "Fun InP image2video" - model_name += " 14B" if "14B" in model_filename else " 1.3B" - elif "Vace" in model_filename: - model_name = "Vace ControlNet" - model_name += " 14B" if "14B" in model_filename else " 1.3B" - elif "image" in model_filename: - model_name = "Wan2.1 image2video" - model_name += " 720p" if "720p" in model_filename else " 480p" - elif "recam" in model_filename: - model_name = "ReCamMaster" - model_name += " 14B" if "14B" in model_filename else " 1.3B" - else: - model_name = "Wan2.1 text2video" - model_name += " 14B" if "14B" in model_filename else " 1.3B" - - return model_name @@ -2594,7 +2600,7 @@ def generate_video( slg_end = slg_end_perc/100, cfg_star_switch = cfg_star_switch, cfg_zero_step = cfg_zero_step, - add_frames_for_end_image = not "Fun_InP" in model_filename, + add_frames_for_end_image = "image2video" in model_filename ) else: samples = wan_model.generate( @@ -3221,7 +3227,7 @@ def prepare_inputs_dict(target, inputs ): if target == "settings": return inputs - if not any(k in model_filename for k in ["image2video", "Fun_InP"]): + if not test_class_i2v(model_filename): inputs.pop("image_prompt_type") @@ -3563,7 +3569,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non if not update_form: state = gr.State(state_dict) trigger_refresh_input_type = gr.Text(interactive= False, visible= False) - with gr.Column(visible= "image2video" in model_filename or "Fun_InP" in model_filename ) as image_prompt_column: + with gr.Column(visible= test_class_i2v(model_filename) ) as image_prompt_column: image_prompt_type_value= ui_defaults.get("image_prompt_type","S") image_prompt_type = gr.Radio( [("Use only a Start Image", "S"),("Use both a Start and an End Image", "SE")], value =image_prompt_type_value, label="Location", show_label= False, scale= 3) @@ -3676,7 +3682,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False) wizard_variables_var = gr.Text(wizard_variables, visible = False) with gr.Row(): - if "image2video" in model_filename or "Fun_InP" in model_filename: + if test_class_i2v(model_filename): resolution = gr.Dropdown( choices=[ # 720p @@ -4670,7 +4676,7 @@ def create_demo(): theme = gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md") with gr.Blocks(css=css, theme=theme, title= "Wan2GP") as demo: - gr.Markdown("