"
- model_name = get_model_name(model_filename)
+# model_name = get_model_name(model_filename)
- header += model_name
- header += " (attention mode: " + (attention_mode if attention_mode!="auto" else "auto/" + get_auto_attention() )
+# header += model_name
+# header += " (attention mode: " + (attention_mode if attention_mode!="auto" else "auto/" + get_auto_attention() )
+# if attention_mode not in attention_modes_installed:
+# header += " -NOT INSTALLED-"
+# elif attention_mode not in attention_modes_supported:
+# header += " -NOT SUPPORTED-"
+
+# if compile:
+# header += ", pytorch compilation ON"
+# header += ")
"
+
+
+# return header
+
+
+def generate_header(compile, attention_mode):
+
+ header = "
Attention mode: " + (attention_mode if attention_mode!="auto" else "auto/" + get_auto_attention() )
if attention_mode not in attention_modes_installed:
header += " -NOT INSTALLED-"
elif attention_mode not in attention_modes_supported:
@@ -1208,14 +1223,12 @@ def generate_header(model_filename, compile, attention_mode):
if compile:
header += ", pytorch compilation ON"
- header += ") "
-
+ header += " "
return header
def apply_changes( state,
- transformer_t2v_choice,
- transformer_i2v_choice,
+ transformer_type_choice,
text_encoder_choice,
save_path_choice,
attention_choice,
@@ -1223,7 +1236,7 @@ def apply_changes( state,
profile_choice,
vae_config_choice,
metadata_choice,
- default_ui_choice ="t2v",
+ quantization_choice,
boost_choice = 1,
clear_file_list = 0,
reload_choice = 1
@@ -1235,15 +1248,14 @@ def apply_changes( state,
return
global offloadobj, wan_model, server_config, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset, loras_presets
server_config = {"attention_mode" : attention_choice,
- "transformer_filename": transformer_choices_t2v[transformer_t2v_choice],
- "transformer_filename_i2v": transformer_choices_i2v[transformer_i2v_choice],
+ "transformer_type": transformer_type_choice,
"text_encoder_filename" : text_encoder_choices[text_encoder_choice],
"save_path" : save_path_choice,
"compile" : compile_choice,
"profile" : profile_choice,
"vae_config" : vae_config_choice,
"metadata_choice": metadata_choice,
- "default_ui" : default_ui_choice,
+ "transformer_quantization" : quantization_choice,
"boost" : boost_choice,
"clear_file_list" : clear_file_list,
"reload_model" : reload_choice,
@@ -1255,7 +1267,6 @@ def apply_changes( state,
old_server_config = json.loads(text)
if lock_ui_transformer:
server_config["transformer_filename"] = old_server_config["transformer_filename"]
- server_config["transformer_filename_i2v"] = old_server_config["transformer_filename_i2v"]
if lock_ui_attention:
server_config["attention_mode"] = old_server_config["attention_mode"]
if lock_ui_compile:
@@ -1270,15 +1281,17 @@ def apply_changes( state,
if v != v_old:
changes.append(k)
- global attention_mode, profile, compile, transformer_filename_t2v, transformer_filename_i2v, text_encoder_filename, vae_config, boost, lora_dir, reload_needed
+ global attention_mode, profile, compile, transformer_filename, text_encoder_filename, vae_config, boost, lora_dir, reload_needed, reload_model, transformer_quantization, transformer_type
attention_mode = server_config["attention_mode"]
profile = server_config["profile"]
compile = server_config["compile"]
- transformer_filename_t2v = server_config["transformer_filename"]
- transformer_filename_i2v = server_config["transformer_filename_i2v"]
text_encoder_filename = server_config["text_encoder_filename"]
vae_config = server_config["vae_config"]
boost = server_config["boost"]
+ reload_model = server_config["reload_model"]
+ transformer_quantization = server_config["transformer_quantization"]
+ transformer_filename = get_model_filename(transformer_type, transformer_quantization)
+
if all(change in ["attention_mode", "vae_config", "default_ui", "boost", "save_path", "metadata_choice", "clear_file_list"] for change in changes ):
pass
else:
@@ -1357,20 +1370,12 @@ def abort_generation(state):
else:
return "", gr.Button(interactive= True)
-def is_gen_location(state):
- gen = get_gen_info(state)
- gen_location = gen.get("location",None)
- if gen_location == None:
- return None
- return state["image2video"] == gen_location
-
def refresh_gallery(state, msg):
gen = get_gen_info(state)
- if is_gen_location(state):
- gen["last_msg"] = msg
+ gen["last_msg"] = msg
file_list = gen.get("file_list", None)
choice = gen.get("selected",0)
in_progress = "in_progress" in gen
@@ -1391,6 +1396,7 @@ def refresh_gallery(state, msg):
start_img_uri = task.get('start_image_data_base64')
start_img_uri = start_img_uri[0] if start_img_uri !=None else None
end_img_uri = task.get('end_image_data_base64')
+ end_img_uri = end_img_uri[0] if end_img_uri !=None else None
thumbnail_size = "100px"
if start_img_uri:
start_img_md = f'

'
@@ -1489,14 +1495,17 @@ def generate_video(
embedded_guidance_scale,
repeat_generation,
multi_images_gen_type,
- tea_cache,
+ tea_cache_setting,
tea_cache_start_step_perc,
- loras_choices,
- loras_mult_choices,
+ activated_loras,
+ loras_multipliers,
image_prompt_type,
- image_source1,
- image_source2,
- image_source3,
+ image_start,
+ image_end,
+ video_prompt_type,
+ image_refs,
+ video_guide,
+ video_mask,
max_frames,
remove_background_image_ref,
temporal_upsampling,
@@ -1504,39 +1513,41 @@ def generate_video(
RIFLEx_setting,
slg_switch,
slg_layers,
- slg_start,
- slg_end,
+ slg_start_perc,
+ slg_end_perc,
cfg_star_switch,
cfg_zero_step,
state,
- image2video
+ model_filename
):
- global wan_model, offloadobj, reload_needed, last_model_type
+ global wan_model, offloadobj, reload_needed
gen = get_gen_info(state)
file_list = gen["file_list"]
prompt_no = gen["prompt_no"]
-
- file_model_needed = model_needed(image2video)
- # queue = gen.get("queue", [])
- # with lock:
- # queue_not_empty = len(queue) > 0
- # if(last_model_type != image2video and (queue_not_empty or server_config.get("reload_model",1) == 2) and (file_model_needed != model_filename or reload_needed)):
- if file_model_needed != model_filename or reload_needed:
- del wan_model
+
+
+ # if wan_model == None:
+ # gr.Info("Unable to generate a Video while a new configuration is being applied.")
+ # return
+
+ if reload_model !=3 :
+ while wan_model == None:
+ time.sleep(1)
+
+ if model_filename != transformer_filename or reload_needed:
+ wan_model = None
if offloadobj is not None:
offloadobj.release()
- del offloadobj
+ offloadobj = None
gc.collect()
- yield f"Loading model {get_model_name(file_model_needed)}..."
- wan_model, offloadobj, trans = load_models(image2video)
+ yield f"Loading model {get_model_name(model_filename)}..."
+ wan_model, offloadobj, trans = load_models(model_filename)
yield f"Model loaded"
reload_needed= False
- if wan_model == None:
- gr.Info("Unable to generate a Video while a new configuration is being applied.")
if attention_mode == "auto":
attn = get_auto_attention()
elif attention_mode in attention_modes_supported:
@@ -1544,11 +1555,10 @@ def generate_video(
else:
gr.Info(f"You have selected attention mode '{attention_mode}'. However it is not installed or supported on your system. You should either install it or switch to the default 'sdpa' attention.")
return
-
- if not image2video:
- width, height = resolution.split("x")
- width, height = int(width), int(height)
+ width, height = resolution.split("x")
+ width, height = int(width), int(height)
+ resolution_reformated = str(height) + "*" + str(width)
if slg_switch == 0:
slg_layers = None
@@ -1589,11 +1599,11 @@ def generate_video(
except ValueError:
return False
list_mult_choices_nums = []
- if len(loras_mult_choices) > 0:
- loras_mult_choices_list = loras_mult_choices.replace("\r", "").split("\n")
+ if len(loras_multipliers) > 0:
+ loras_mult_choices_list = loras_multipliers.replace("\r", "").split("\n")
loras_mult_choices_list = [multi for multi in loras_mult_choices_list if len(multi)>0 and not multi.startswith("#")]
- loras_mult_choices = " ".join(loras_mult_choices_list)
- list_mult_choices_str = loras_mult_choices.split(" ")
+ loras_multipliers = " ".join(loras_mult_choices_list)
+ list_mult_choices_str = loras_multipliers.split(" ")
for i, mult in enumerate(list_mult_choices_str):
mult = mult.strip()
if "," in mult:
@@ -1609,9 +1619,9 @@ def generate_video(
if not is_float(mult):
raise gr.Error(f"Lora Multiplier no {i+1} ({mult}) is invalid")
list_mult_choices_nums.append(float(mult))
- if len(list_mult_choices_nums ) < len(loras_choices):
- list_mult_choices_nums += [1.0] * ( len(loras_choices) - len(list_mult_choices_nums ) )
- loras_selected = [ lora for i, lora in enumerate(loras) if str(i) in loras_choices]
+ if len(list_mult_choices_nums ) < len(activated_loras):
+ list_mult_choices_nums += [1.0] * ( len(activated_loras) - len(list_mult_choices_nums ) )
+ loras_selected = [ lora for lora in loras if os.path.basename(lora) in activated_loras]
pinnedLora = profile !=5 #False # # #
offload.load_loras_into_model(trans, loras_selected, list_mult_choices_nums, activate_all_loras=True, preprocess_sd=preprocess_loras, pinnedLora=pinnedLora, split_linear_modules_map = None)
errors = trans._loras_errors
@@ -1620,43 +1630,42 @@ def generate_video(
raise gr.Error("Error while loading Loras: " + ", ".join(error_files))
seed = None if seed == -1 else seed
# negative_prompt = "" # not applicable in the inference
-
+ image2video = test_class_i2v(model_filename)
enable_RIFLEx = RIFLEx_setting == 0 and video_length > (6* 16) or RIFLEx_setting == 1
# VAE Tiling
device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576
joint_pass = boost ==1 #and profile != 1 and profile != 3
# TeaCache
- trans.enable_teacache = tea_cache > 0
+ trans.enable_teacache = tea_cache_setting > 0
if trans.enable_teacache:
- trans.teacache_multiplier = tea_cache
+ trans.teacache_multiplier = tea_cache_setting
trans.rel_l1_thresh = 0
trans.teacache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
if image2video:
- if '480p' in transformer_filename_i2v:
+ if '480p' in model_filename:
# teacache_thresholds = [0.13, .19, 0.26]
trans.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
- elif '720p' in transformer_filename_i2v:
+ elif '720p' in model_filename:
teacache_thresholds = [0.18, 0.2 , 0.3]
trans.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
else:
raise gr.Error("Teacache not supported for this model")
else:
- if '1.3B' in transformer_filename_t2v:
+ if '1.3B' in model_filename:
# teacache_thresholds= [0.05, 0.07, 0.08]
trans.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
- elif '14B' in transformer_filename_t2v:
+ elif '14B' in model_filename:
# teacache_thresholds = [0.14, 0.15, 0.2]
trans.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
else:
raise gr.Error("Teacache not supported for this model")
if "Vace" in model_filename:
- resolution_reformated = str(height) + "*" + str(width)
- src_video, src_mask, src_ref_images = wan_model.prepare_source([image_source2],
- [image_source3],
- [image_source1],
+ src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide],
+ [video_mask],
+ [image_refs],
video_length, VACE_SIZE_CONFIGS[resolution_reformated], "cpu",
trim_video=max_frames)
else:
@@ -1721,10 +1730,10 @@ def generate_video(
if image2video:
samples = wan_model.generate(
prompt,
- image_source1,
- image_source2 if image_source2 != None else None,
+ image_start,
+ image_end if image_end != None else None,
frame_num=(video_length // 4)* 4 + 1,
- max_area=MAX_AREA_CONFIGS[resolution],
+ max_area=MAX_AREA_CONFIGS[resolution_reformated],
shift=flow_shift,
sampling_steps=num_inference_steps,
guide_scale=guidance_scale,
@@ -1736,11 +1745,11 @@ def generate_video(
VAE_tile_size = VAE_tile_size,
joint_pass = joint_pass,
slg_layers = slg_layers,
- slg_start = slg_start/100,
- slg_end = slg_end/100,
+ slg_start = slg_start_perc/100,
+ slg_end = slg_end_perc/100,
cfg_star_switch = cfg_star_switch,
cfg_zero_step = cfg_zero_step,
- add_frames_for_end_image = not "Fun" in transformer_filename_i2v,
+ add_frames_for_end_image = not "Fun_InP" in model_filename,
)
else:
samples = wan_model.generate(
@@ -1761,8 +1770,8 @@ def generate_video(
VAE_tile_size = VAE_tile_size,
joint_pass = joint_pass,
slg_layers = slg_layers,
- slg_start = slg_start/100,
- slg_end = slg_end/100,
+ slg_start = slg_start_perc/100,
+ slg_end = slg_end_perc/100,
cfg_star_switch = cfg_star_switch,
cfg_zero_step = cfg_zero_step,
)
@@ -1863,6 +1872,8 @@ def generate_video(
h, w = sample.shape[-2:]
h *= scale
w *= scale
+ h = int(h)
+ w = int(w)
new_frames =[]
for i in range( sample.shape[1] ):
frame = sample[:, i]
@@ -1881,9 +1892,10 @@ def generate_video(
nrow=1,
normalize=True,
value_range=(-1, 1))
-
- configs = get_settings_dict(state, image2video, True, prompt, image_prompt_type, max_frames , remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
- loras_mult_choices, tea_cache , tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start, slg_end, cfg_star_switch, cfg_zero_step)
+
+ inputs = get_function_arguments(generate_video, locals())
+ inputs.pop("progress")
+ configs = prepare_inputs_dict("metadata", inputs)
metadata_choice = server_config.get("metadata_choice","metadata")
if metadata_choice == "json":
@@ -1899,8 +1911,6 @@ def generate_video(
file_list.append(video_path)
state['update_gallery'] = True
seed += 1
-
- last_model_type = image2video
if temp_filename!= None and os.path.isfile(temp_filename):
os.remove(temp_filename)
@@ -1913,36 +1923,6 @@ def prepare_generate_video(state):
return gr.Button(visible= False), gr.Button(visible= True), gr.Column(visible= True)
-def wait_tasks_done(state, progress=gr.Progress()):
-
- gen = get_gen_info(state)
- gen_location = is_gen_location(state)
-
- last_msg = gen.get("last_msg", "")
- if len(last_msg) > 0:
- yield last_msg
-
- if gen_location == None or gen_location:
- return gr.Text()
-
-
- while True:
-
- msg = gen.get("last_msg", "")
- if len(msg) > 0 and last_msg != msg:
- yield msg
- last_msg = msg
- progress_args = gen.get("progress_args", None)
- if progress_args != None:
- progress(*progress_args)
-
- in_progress= gen.get("in_progress", False)
- if not in_progress:
- break
- time.sleep(0.5)
-
-
-
def process_tasks(state, progress=gr.Progress()):
gen = get_gen_info(state)
queue = gen.get("queue", [])
@@ -1950,7 +1930,6 @@ def process_tasks(state, progress=gr.Progress()):
if len(queue) == 0:
return
gen = get_gen_info(state)
- gen["location"] = state["image2video"]
clear_file_list = server_config.get("clear_file_list", 0)
file_list = gen.get("file_list", [])
if clear_file_list > 0:
@@ -2108,7 +2087,7 @@ def save_lset(state, lset_name, loras_choices, loras_mult_choices, prompt, save_
lset_name_filename = lset_name + ".lset"
- full_lset_name_filename = os.path.join(get_lora_dir(state["image2video"]), lset_name_filename)
+ full_lset_name_filename = os.path.join(get_lora_dir(state["model_filename"]), lset_name_filename)
with open(full_lset_name_filename, "w", encoding="utf-8") as writer:
writer.write(json.dumps(lset, indent=4))
@@ -2125,7 +2104,7 @@ def save_lset(state, lset_name, loras_choices, loras_mult_choices, prompt, save_
def delete_lset(state, lset_name):
loras_presets = state["loras_presets"]
- lset_name_filename = os.path.join( get_lora_dir(state["image2video"]), sanitize_file_name(lset_name) + ".lset" )
+ lset_name_filename = os.path.join( get_lora_dir(state["model_filename"]), sanitize_file_name(lset_name) + ".lset" )
if len(lset_name) > 0 and lset_name != get_new_preset_msg(True) and lset_name != get_new_preset_msg(False):
if not os.path.isfile(lset_name_filename):
raise gr.Error(f"Preset '{lset_name}' not found ")
@@ -2146,8 +2125,8 @@ def delete_lset(state, lset_name):
def refresh_lora_list(state, lset_name, loras_choices):
loras_names = state["loras_names"]
prev_lora_names_selected = [ loras_names[int(i)] for i in loras_choices]
- image2video= state["image2video"]
- loras, loras_names, loras_presets, _, _, _, _ = setup_loras(image2video, None, get_lora_dir(image2video), lora_preselected_preset, None)
+ model_filename= state["model_filename"]
+ loras, loras_names, loras_presets, _, _, _, _ = setup_loras(model_filename, None, get_lora_dir(model_filename), lora_preselected_preset, None)
state["loras"] = loras
state["loras_names"] = loras_names
state["loras_presets"] = loras_presets
@@ -2187,7 +2166,7 @@ def apply_lset(state, wizard_prompt_activated, lset_name, loras_choices, loras_m
gr.Info("Please choose a preset in the list or create one")
else:
loras = state["loras"]
- loras_choices, loras_mult_choices, preset_prompt, full_prompt, error = extract_preset(state["image2video"], lset_name, loras)
+ loras_choices, loras_mult_choices, preset_prompt, full_prompt, error = extract_preset(state["model_filename"], lset_name, loras)
if len(error) > 0:
gr.Info(error)
else:
@@ -2344,71 +2323,110 @@ def switch_advanced(state, new_advanced, lset_name):
return gr.Row(visible=new_advanced), gr.Row(visible=True), gr.Button(visible=True), gr.Row(visible= False), gr.Dropdown(choices=lset_choices, value= lset_name)
-def get_settings_dict(state, i2v, image_metadata, prompt, image_prompt_type, max_frames, remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
- loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step):
-
+def prepare_inputs_dict(target, inputs ):
+
+ state = inputs.pop("state")
loras = state["loras"]
- activated_loras = [Path( loras[int(no)]).parts[-1] for no in loras_choices ]
+ if "loras_choices" in inputs:
+ loras_choices = inputs.pop("loras_choices")
+ inputs.pop("model_filename", None)
+ activated_loras = [Path( loras[int(no)]).parts[-1] for no in loras_choices ]
+ inputs["activated_loras"] = activated_loras
- ui_settings = {
- "prompts": prompt,
- "resolution": resolution,
- "video_length": video_length,
- "num_inference_steps": num_inference_steps,
- "seed": seed,
- "repeat_generation": repeat_generation,
- "multi_images_gen_type": multi_images_gen_type,
- "guidance_scale": guidance_scale,
- "flow_shift": flow_shift,
- "negative_prompt": negative_prompt,
- "activated_loras": activated_loras,
- "loras_multipliers": loras_mult_choices,
- "tea_cache": tea_cache_setting,
- "tea_cache_start_step_perc": tea_cache_start_step_perc,
- "temporal_upsampling" : temporal_upsampling,
- "spatial_upsampling" : spatial_upsampling,
- "RIFLEx_setting": RIFLEx_setting,
- "slg_switch": slg_switch,
- "slg_layers": slg_layers,
- "slg_start_perc": slg_start_perc,
- "slg_end_perc": slg_end_perc,
- "cfg_star_switch": cfg_star_switch,
- "cfg_zero_step": cfg_zero_step
- }
+ if target == "state":
+ return inputs
- if i2v:
- ui_settings["type"] = "Wan2.1GP by DeepBeepMeep - image2video"
- ui_settings["image_prompt_type"] = image_prompt_type
- else:
- if "Vace" in transformer_filename_t2v or not image_metadata:
- ui_settings["image_prompt_type"] = image_prompt_type
- ui_settings["max_frames"] = max_frames
- ui_settings["remove_background_image_ref"] = remove_background_image_ref
- ui_settings["type"] = "Wan2.1GP by DeepBeepMeep - text2video"
+ unsaved_params = ["image_start", "image_end", "image_refs", "video_guide", "video_mask"]
+ for k in unsaved_params:
+ inputs.pop(k)
- return ui_settings
+ model_filename = state["model_filename"]
+ inputs["type"] = "Wan2.1GP by DeepBeepMeep - " + get_model_name(model_filename)
-def save_settings(state, prompt, image_prompt_type, max_frames, remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
- loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step):
+ if target == "settings":
+ return inputs
+
+ if not any(k in model_filename for k in ["image2video", "Fun_InP"]):
+ inputs.pop("image_prompt_type")
- if state.get("validate_success",0) != 1:
- return
- image2video = state["image2video"]
- ui_defaults = get_settings_dict(state, image2video, False, prompt, image_prompt_type, max_frames, remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
- loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step)
+ if not "Vace" in model_filename:
+ unsaved_params = ["video_prompt_type", "max_frames", "remove_background_image_ref"]
+ for k in unsaved_params:
+ inputs.pop(k)
- defaults_filename = get_settings_file_name(image2video)
+ if target == "metadata":
+ inputs = {k: v for k,v in inputs.items() if v != None }
- with open(defaults_filename, "w", encoding="utf-8") as f:
- json.dump(ui_defaults, f, indent=4)
+ return inputs
- gr.Info("New Default Settings saved")
+def get_function_arguments(func, locals):
+ args_names = list(inspect.signature(func).parameters)
+ kwargs = typing.OrderedDict()
+ for k in args_names:
+ kwargs[k] = locals[k]
+ return kwargs
+
+
+def save_inputs(
+ target,
+ prompt,
+ negative_prompt,
+ resolution,
+ video_length,
+ seed,
+ num_inference_steps,
+ guidance_scale,
+ flow_shift,
+ embedded_guidance_scale,
+ repeat_generation,
+ multi_images_gen_type,
+ tea_cache_setting,
+ tea_cache_start_step_perc,
+ loras_choices,
+ loras_multipliers,
+ image_prompt_type,
+ image_start,
+ image_end,
+ video_prompt_type,
+ image_refs,
+ video_guide,
+ video_mask,
+ max_frames,
+ remove_background_image_ref,
+ temporal_upsampling,
+ spatial_upsampling,
+ RIFLEx_setting,
+ slg_switch,
+ slg_layers,
+ slg_start_perc,
+ slg_end_perc,
+ cfg_star_switch,
+ cfg_zero_step,
+ state,
+):
+
+
+ # if state.get("validate_success",0) != 1:
+ # return
+ model_filename = state["model_filename"]
+ inputs = get_function_arguments(save_inputs, locals())
+ inputs.pop("target")
+ cleaned_inputs = prepare_inputs_dict(target, inputs)
+ if target == "settings":
+ defaults_filename = get_settings_file_name(model_filename)
+
+ with open(defaults_filename, "w", encoding="utf-8") as f:
+ json.dump(cleaned_inputs, f, indent=4)
+
+ gr.Info("New Default Settings saved")
+ elif target == "state":
+ state[get_model_type(model_filename)] = cleaned_inputs
def download_loras():
from huggingface_hub import snapshot_download
yield gr.Row(visible=True), "
Please wait while the Loras are being downloaded", *[gr.Column(visible=False)] * 2
- lora_dir = get_lora_dir(True)
+ lora_dir = get_lora_dir(get_model_filename("i2v"), quantizeTransformer)
log_path = os.path.join(lora_dir, "log.txt")
if not os.path.isfile(log_path):
import shutil
@@ -2433,42 +2451,130 @@ def download_loras():
writer.write(f"Loras downloaded on the {dt} at {time.time()} on the {time.time()}")
return
-def refresh_i2v_image_prompt_type_radio(state, image_prompt_type_radio):
+def refresh_image_prompt_type(state, image_prompt_type):
if args.multiple_images:
- return gr.Gallery(visible = (image_prompt_type_radio == 1) )
+ return gr.Gallery(visible = "S" in image_prompt_type ), gr.Gallery(visible = "E" in image_prompt_type )
else:
- return gr.Image(visible = (image_prompt_type_radio == 1) )
+ return gr.Image(visible = "S" in image_prompt_type ), gr.Image(visible = "E" in image_prompt_type )
-def refresh_t2v_image_prompt_type_radio(state, image_prompt_type_radio):
- vace_model = "Vace" in state["image_input_type_model"] and not state["image2video"]
- return gr.Column(visible= vace_model), gr.Radio(value= image_prompt_type_radio), gr.Gallery(visible = "I" in image_prompt_type_radio), gr.Video(visible= "V" in image_prompt_type_radio),gr.Video(visible= "M" in image_prompt_type_radio ), gr.Text(visible= "V" in image_prompt_type_radio) , gr.Checkbox(visible= "I" in image_prompt_type_radio)
+def refresh_video_prompt_type(state, video_prompt_type):
+ return gr.Gallery(visible = "I" in video_prompt_type), gr.Video(visible= "V" in video_prompt_type),gr.Video(visible= "M" in video_prompt_type ), gr.Text(visible= "V" in video_prompt_type) , gr.Checkbox(visible= "I" in video_prompt_type)
-def check_refresh_input_type(state):
- if not state["image2video"]:
- model_file_name = state["image_input_type_model"]
- model_file_needed= model_needed(False)
- if model_file_name != model_file_needed:
- state["image_input_type_model"] = model_file_needed
- return gr.Text(value= str(time.time()))
+
+def handle_celll_selection(state, evt: gr.SelectData):
+ gen = get_gen_info(state)
+ queue = gen.get("queue", [])
+
+ if evt.index is None:
+ return gr.update(), gr.update(), gr.update(visible=False)
+ row_index, col_index = evt.index
+ cell_value = None
+ if col_index in [6, 7, 8]:
+ if col_index == 6: cell_value = "↑"
+ elif col_index == 7: cell_value = "↓"
+ elif col_index == 8: cell_value = "✖"
+ if col_index == 6:
+ new_df_data = move_up(queue, [row_index])
+ return new_df_data, gr.update(), gr.update(visible=False)
+ elif col_index == 7:
+ new_df_data = move_down(queue, [row_index])
+ return new_df_data, gr.update(), gr.update(visible=False)
+ elif col_index == 8:
+ new_df_data = remove_task(queue, [row_index])
+ gen["prompts_max"] = gen.get("prompts_max",0) - 1
+ update_status(state)
+ return new_df_data, gr.update(), gr.update(visible=False)
+ start_img_col_idx = 4
+ end_img_col_idx = 5
+ image_data_to_show = None
+ if col_index == start_img_col_idx:
+ with lock:
+ row_index += 1
+ if row_index < len(queue):
+ image_data_to_show = queue[row_index].get('start_image_data')
+ elif col_index == end_img_col_idx:
+ with lock:
+ row_index += 1
+ if row_index < len(queue):
+ image_data_to_show = queue[row_index].get('end_image_data')
+
+ if image_data_to_show:
+ return gr.update(), gr.update(value=image_data_to_show[0]), gr.update(visible=True)
+ else:
+ return gr.update(), gr.update(), gr.update(visible=False)
+
+
+def change_model(state, model_choice):
+ model_filename = ""
+ for filename in model_list:
+ if get_model_type(filename) == model_choice:
+ model_filename = filename
+ break
+ if len(model_filename) == 0:
+ return
+
+ state["model_filename"] = model_filename
+ header = generate_header(compile=compile, attention_mode=attention_mode)
+ return header
+
+def fill_inputs(state):
+ model_filename = state["model_filename"]
+ prefix = get_model_type(model_filename)
+ ui_defaults = state.get(prefix, None)
+ if ui_defaults == None:
+ ui_defaults = get_default_settings(model_filename)
+
+ return generate_video_tab(update_form = True, state_dict = state, ui_defaults = ui_defaults)
+
+def preload_model(state):
+ global reload_needed, wan_model, offloadobj
+ if reload_model == 1:
+ model_filename = state["model_filename"]
+ if state["model_filename"] != transformer_filename:
+ wan_model = None
+ if offloadobj is not None:
+ offloadobj.release()
+ offloadobj = None
+ gc.collect()
+ yield f"Loading model {get_model_name(model_filename)}..."
+ wan_model, offloadobj, _ = load_models(model_filename)
+ yield f"Model loaded"
+ reload_needed= False
+ return
return gr.Text()
-def generate_video_tab(image2video=False):
- filename = transformer_filename_i2v if image2video else transformer_filename_t2v
- ui_defaults= get_default_settings(filename, image2video)
+def unload_model_if_needed(state):
+ global reload_needed, wan_model, offloadobj
+ if reload_model == 3:
+ if wan_model != None:
+ wan_model = None
+ if offloadobj is not None:
+ offloadobj.release()
+ offloadobj = None
+ gc.collect()
+ reload_needed= True
- state_dict = {}
- state_dict["advanced"] = advanced
- state_dict["loras_model"] = filename
- state_dict["image_input_type_model"] = filename
- state_dict["image2video"] = image2video
- gen = dict()
- gen["queue"] = []
- state_dict["gen"] = gen
+def generate_video_tab(update_form = False, state_dict = None, ui_defaults = None, model_choice = None, header = None):
+ global inputs_names #, advanced
- preset_to_load = lora_preselected_preset if use_image2video == image2video else ""
+ if update_form:
+ model_filename = state_dict["model_filename"]
+ advanced_ui = state_dict["advanced"]
+ else:
+ model_filename = transformer_filename
+ advanced_ui = advanced
+ ui_defaults= get_default_settings(model_filename)
+ state_dict = {}
+ state_dict["model_filename"] = model_filename
+ state_dict["advanced"] = advanced_ui
+ gen = dict()
+ gen["queue"] = []
+ state_dict["gen"] = gen
- loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset = setup_loras(image2video, None, get_lora_dir(image2video), preset_to_load, None)
+ preset_to_load = lora_preselected_preset if lora_preset_model == model_filename else ""
+
+ loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset = setup_loras(model_filename, None, get_lora_dir(model_filename), preset_to_load, None)
state_dict["loras"] = loras
state_dict["loras_presets"] = loras_presets
@@ -2479,18 +2585,19 @@ def generate_video_tab(image2video=False):
launch_loras = []
launch_multis_str = ""
- if len(default_lora_preset) > 0 and image2video == use_image2video:
+ if update_form:
+ pass
+ if len(default_lora_preset) > 0 and lora_preset_model == model_filename:
launch_preset = default_lora_preset
launch_prompt = default_lora_preset_prompt
launch_loras = default_loras_choices
launch_multis_str = default_loras_multis_str
if len(launch_prompt) == 0:
- launch_prompt = ui_defaults["prompts"]
+ launch_prompt = ui_defaults.get("prompt","")
if len(launch_loras) == 0:
- activated_loras = ui_defaults["activated_loras"]
- launch_multis_str = ui_defaults["loras_multipliers"]
-
+ launch_multis_str = ui_defaults.get("loras_multipliers","")
+ activated_loras = ui_defaults.get("activated_loras",[])
if len(activated_loras) > 0:
lora_filenames = [os.path.basename(lora_path) for lora_path in loras]
activated_indices = []
@@ -2502,24 +2609,20 @@ def generate_video_tab(image2video=False):
print(f"Warning: Lora file {lora_file} from config not found in loras directory")
launch_loras = activated_indices
-
- header = gr.Markdown(generate_header(model_filename, compile, attention_mode))
with gr.Row():
with gr.Column():
with gr.Column(visible=False, elem_id="image-modal-container") as modal_container:
with gr.Row(elem_id="image-modal-close-button-row"):
close_modal_button = gr.Button("❌", size="sm")
modal_image_display = gr.Image(label="Full Resolution Image", interactive=False, show_label=False)
- progress_update_trigger = gr.Textbox(value="0", visible=False, label="_progress_trigger")
- gallery_update_trigger = gr.Textbox(value="0", visible=False, label="_gallery_trigger")
- with gr.Row(visible= len(loras)>0) as presets_column:
- lset_choices = [ (preset, preset) for preset in loras_presets ] + [(get_new_preset_msg(advanced), "")]
+ with gr.Row(visible= True): #len(loras)>0) as presets_column:
+ lset_choices = [ (preset, preset) for preset in loras_presets ] + [(get_new_preset_msg(advanced_ui), "")]
with gr.Column(scale=6):
lset_name = gr.Dropdown(show_label=False, allow_custom_value= True, scale=5, filterable=True, choices= lset_choices, value=launch_preset)
with gr.Column(scale=1):
with gr.Row(height=17):
apply_lset_btn = gr.Button("Apply Lora Preset", size="sm", min_width= 1)
- refresh_lora_btn = gr.Button("Refresh", size="sm", min_width= 1, visible=advanced or not only_allow_edit_in_advanced)
+ refresh_lora_btn = gr.Button("Refresh", size="sm", min_width= 1, visible=advanced_ui or not only_allow_edit_in_advanced)
save_lset_prompt_drop= gr.Dropdown(
choices=[
("Save Prompt Comments Only", 0),
@@ -2529,62 +2632,50 @@ def generate_video_tab(image2video=False):
with gr.Row(height=17, visible=False) as refresh2_row:
refresh_lora_btn2 = gr.Button("Refresh", size="sm", min_width= 1)
- with gr.Row(height=17, visible=advanced or not only_allow_edit_in_advanced) as preset_buttons_rows:
+ with gr.Row(height=17, visible=advanced_ui or not only_allow_edit_in_advanced) as preset_buttons_rows:
confirm_save_lset_btn = gr.Button("Go Ahead Save it !", size="sm", min_width= 1, visible=False)
confirm_delete_lset_btn = gr.Button("Go Ahead Delete it !", size="sm", min_width= 1, visible=False)
save_lset_btn = gr.Button("Save", size="sm", min_width= 1)
delete_lset_btn = gr.Button("Delete", size="sm", min_width= 1)
cancel_lset_btn = gr.Button("Don't do it !", size="sm", min_width= 1 , visible=False)
- state = gr.State(state_dict)
- vace_model = "Vace" in filename and not image2video
+ if not update_form:
+ state = gr.State(state_dict)
trigger_refresh_input_type = gr.Text(interactive= False, visible= False)
- with gr.Column(visible= image2video or vace_model) as image_prompt_column:
- if image2video:
- image_source3 = gr.Video(label= "Placeholder", visible= image2video and False)
+ with gr.Column(visible= "image2video" in model_filename or "Fun_InP" in model_filename ) as image_prompt_column:
+ image_prompt_type_value= ui_defaults.get("image_prompt_type","S")
+ image_prompt_type = gr.Radio( [("Use only a Start Image", "S"),("Use both a Start and an End Image", "SE")], value =image_prompt_type_value, label="Location", show_label= False, scale= 3)
- image_prompt_type= ui_defaults.get("image_prompt_type",0)
- image_prompt_type_radio = gr.Radio( [("Use only a Start Image", 0),("Use both a Start and an End Image", 1)], value =image_prompt_type, label="Location", show_label= False, scale= 3)
-
- if args.multiple_images:
- image_source1 = gr.Gallery(
- label="Images as starting points for new videos", type ="pil", #file_types= "image",
- columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True)
- else:
- image_source1 = gr.Image(label= "Image as a starting point for a new video", type ="pil")
-
- if args.multiple_images:
- image_source2 = gr.Gallery(
- label="Images as ending points for new videos", type ="pil", #file_types= "image",
- columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=image_prompt_type==1)
- else:
- image_source2 = gr.Image(label= "Last Image for a new video", type ="pil", visible=image_prompt_type==1)
-
-
- image_prompt_type_radio.change(fn=refresh_i2v_image_prompt_type_radio, inputs=[state, image_prompt_type_radio], outputs=[image_source2])
- max_frames = gr.Slider(1, 100,step=1, visible = False)
- remove_background_image_ref = gr.Text(visible = False)
+ if args.multiple_images:
+ image_start = gr.Gallery(
+ label="Images as starting points for new videos", type ="pil", #file_types= "image",
+ columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value)
else:
- image_prompt_type= ui_defaults.get("image_prompt_type","I")
- if not isinstance(image_prompt_type, str):
- image_prompt_type ="I"
- image_prompt_type_radio = gr.Radio( [("Use Images Ref", "I"),("a Video", "V"), ("Images + a Video", "IV"), ("Video + Video Mask", "VM"), ("Images + Video + Mask", "IVM")], value =image_prompt_type, label="Location", show_label= False, scale= 3, visible = vace_model)
- image_source1 = gr.Gallery(
- label="Reference Images of Faces and / or Object to be found in the Video", type ="pil",
- columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in image_prompt_type )
+ image_start = gr.Image(label= "Image as a starting point for a new video", type ="pil",value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value )
- image_source2 = gr.Video(label= "Reference Video", visible= "V" in image_prompt_type )
- with gr.Row():
- max_frames = gr.Slider(0, 100, value=ui_defaults.get("max_frames",0), step=1, label="Nb of frames in Reference Video to use in Video (0 for as many as possible)", visible= "V" in image_prompt_type, scale = 2 )
- remove_background_image_ref = gr.Checkbox(value=ui_defaults.get("remove_background_image_ref",1), label= "Remove Images Ref. Background", visible= "I" in image_prompt_type, scale =1 )
+ if args.multiple_images:
+ image_end = gr.Gallery(
+ label="Images as ending points for new videos", type ="pil", #file_types= "image",
+ columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None))
+ else:
+ image_end = gr.Image(label= "Last Image for a new video", type ="pil", visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None))
- image_source3 = gr.Video(label= "Video Mask (white pixels = Mask)", visible= "M" in image_prompt_type )
+ with gr.Column(visible= "Vace" in model_filename ) as video_prompt_column:
+ video_prompt_type_value= ui_defaults.get("video_prompt_type","I")
+ video_prompt_type = gr.Radio( [("Use Images Ref", "I"),("a Video", "V"), ("Images + a Video", "IV"), ("Video + Video Mask", "VM"), ("Images + Video + Mask", "IVM")], value =video_prompt_type_value, label="Location", show_label= False, scale= 3)
+ image_refs = gr.Gallery(
+ label="Reference Images of Faces and / or Object to be found in the Video", type ="pil",
+ columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in video_prompt_type_value, value= ui_defaults.get("image_refs", None) )
-
- gr.on(triggers=[image_prompt_type_radio.change, trigger_refresh_input_type.change], fn=refresh_t2v_image_prompt_type_radio, inputs=[state, image_prompt_type_radio], outputs=[image_prompt_column, image_prompt_type_radio, image_source1, image_source2, image_source3, max_frames, remove_background_image_ref])
+ video_guide = gr.Video(label= "Reference Video", visible= "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None) )
+ with gr.Row():
+ max_frames = gr.Slider(0, 100, value=ui_defaults.get("max_frames",0), step=1, label="Nb of frames in Ref. Video (0 = as many as possible)", visible= "V" in video_prompt_type_value, scale = 2 )
+ remove_background_image_ref = gr.Checkbox(value=ui_defaults.get("remove_background_image_ref",1), label= "Remove Images Ref. Background", visible= "I" in video_prompt_type_value, scale =1 )
+
+ video_mask = gr.Video(label= "Video Mask (white pixels = Mask)", visible= "M" in video_prompt_type_value, value= ui_defaults.get("video_mask", None) )
- advanced_prompt = advanced
+ advanced_prompt = advanced_ui
prompt_vars=[]
if advanced_prompt:
@@ -2615,14 +2706,14 @@ def generate_video_tab(image2video=False):
wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False)
wizard_variables_var = gr.Text(wizard_variables, visible = False)
with gr.Row():
- if image2video:
+ if "image2video" in model_filename or "Fun_InP" in model_filename:
resolution = gr.Dropdown(
choices=[
# 720p
("720p", "1280x720"),
("480p", "832x480"),
],
- value=ui_defaults["resolution"],
+ value=ui_defaults.get("resolution","480p"),
label="Resolution (video will have the same height / width ratio than the original image)"
)
else:
@@ -2642,33 +2733,33 @@ def generate_video_tab(image2video=False):
# ("624x832 (3:4, 540p)", "624x832"),
# ("720x720 (1:1, 540p)", "720x720"),
],
- value=ui_defaults["resolution"],
+ value=ui_defaults.get("resolution","832x480"),
label="Resolution"
)
with gr.Row():
with gr.Column():
- video_length = gr.Slider(5, 193, value=ui_defaults["video_length"], step=4, label="Number of frames (16 = 1s)")
+ video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)")
with gr.Column():
- num_inference_steps = gr.Slider(1, 100, value=ui_defaults["num_inference_steps"], step=1, label="Number of Inference Steps")
- show_advanced = gr.Checkbox(label="Advanced Mode", value=advanced)
- with gr.Row(visible=advanced) as advanced_row:
+ num_inference_steps = gr.Slider(1, 100, value=ui_defaults.get("num_inference_steps",30), step=1, label="Number of Inference Steps")
+ show_advanced = gr.Checkbox(label="Advanced Mode", value=advanced_ui)
+ with gr.Row(visible=advanced_ui) as advanced_row:
with gr.Column():
seed = gr.Slider(-1, 999999999, value=ui_defaults["seed"], step=1, label="Seed (-1 for random)")
with gr.Row():
- repeat_generation = gr.Slider(1, 25.0, value=ui_defaults["repeat_generation"], step=1, label="Default Number of Generated Videos per Prompt")
- multi_images_gen_type = gr.Dropdown( value=ui_defaults["multi_images_gen_type"],
+ repeat_generation = gr.Slider(1, 25.0, value=ui_defaults.get("repeat_generation",1), step=1, label="Default Number of Generated Videos per Prompt")
+ multi_images_gen_type = gr.Dropdown( value=ui_defaults.get("multi_images_gen_type",0),
choices=[
("Generate every combination of images and texts", 0),
("Match images and text prompts", 1),
], visible= args.multiple_images, label= "Multiple Images as Texts Prompts"
)
with gr.Row():
- guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults["guidance_scale"], step=0.5, label="Guidance Scale", visible=True)
+ guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance_scale",5), step=0.5, label="Guidance Scale", visible=True)
embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=False)
- flow_shift = gr.Slider(0.0, 25.0, value=ui_defaults["flow_shift"], step=0.1, label="Shift Scale")
+ flow_shift = gr.Slider(0.0, 25.0, value=ui_defaults.get("flow_shift",3), step=0.1, label="Shift Scale")
with gr.Row():
- negative_prompt = gr.Textbox(label="Negative Prompt", value=ui_defaults["negative_prompt"])
- with gr.Column(visible = len(loras)>0) as loras_column:
+ negative_prompt = gr.Textbox(label="Negative Prompt", value=ui_defaults.get("negative_prompt", "") )
+ with gr.Column(visible = True): #as loras_column:
gr.Markdown("
Loras can be used to create special effects on the video by mentioning a trigger word in the Prompt. You can save Loras combinations in presets.")
loras_choices = gr.Dropdown(
choices=[
@@ -2678,7 +2769,7 @@ def generate_video_tab(image2video=False):
multiselect= True,
label="Activated Loras"
)
- loras_mult_choices = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by space characters or carriage returns, line that starts with # are ignored", value=launch_multis_str)
+ loras_multipliers = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by space characters or carriage returns, line that starts with # are ignored", value=launch_multis_str)
with gr.Row():
gr.Markdown("
Tea Cache accelerates by skipping intelligently some steps, the more steps are skipped the lower the quality of the video (Tea Cache consumes also VRAM)")
with gr.Row():
@@ -2691,16 +2782,16 @@ def generate_video_tab(image2video=False):
("around x2.25 speed up", 2.25),
("around x2.5 speed up", 2.5),
],
- value=float(ui_defaults["tea_cache"]),
+ value=float(ui_defaults.get("tea_cache_setting",0)),
visible=True,
label="Tea Cache Global Acceleration"
)
- tea_cache_start_step_perc = gr.Slider(0, 100, value=ui_defaults["tea_cache_start_step_perc"], step=1, label="Tea Cache starting moment in % of generation")
+ tea_cache_start_step_perc = gr.Slider(0, 100, value=ui_defaults.get("tea_cache_start_step_perc",0), step=1, label="Tea Cache starting moment in % of generation")
with gr.Row():
gr.Markdown("
Upsampling - postprocessing that may improve fluidity and the size of the video")
with gr.Row():
- temporal_upsampling_choice = gr.Dropdown(
+ temporal_upsampling = gr.Dropdown(
choices=[
("Disabled", ""),
("Rife x2 (32 frames/s)", "rife2"),
@@ -2711,7 +2802,7 @@ def generate_video_tab(image2video=False):
scale = 1,
label="Temporal Upsampling"
)
- spatial_upsampling_choice = gr.Dropdown(
+ spatial_upsampling = gr.Dropdown(
choices=[
("Disabled", ""),
("Lanczos x1.5", "lanczos1.5"),
@@ -2730,7 +2821,7 @@ def generate_video_tab(image2video=False):
("Always ON", 1),
("Always OFF", 2),
],
- value=ui_defaults["RIFLEx_setting"],
+ value=ui_defaults.get("RIFLEx_setting",0),
label="RIFLEx positional embedding to generate long video"
)
with gr.Row():
@@ -2741,7 +2832,7 @@ def generate_video_tab(image2video=False):
("OFF", 0),
("ON", 1),
],
- value=ui_defaults["slg_switch"],
+ value=ui_defaults.get("slg_switch",0),
visible=True,
scale = 1,
label="Skip Layer guidance"
@@ -2750,14 +2841,14 @@ def generate_video_tab(image2video=False):
choices=[
(str(i), i ) for i in range(40)
],
- value=ui_defaults["slg_layers"],
+ value=ui_defaults.get("slg_layers", ["9"]),
multiselect= True,
label="Skip Layers",
scale= 3
)
with gr.Row():
- slg_start_perc = gr.Slider(0, 100, value=ui_defaults["slg_start_perc"], step=1, label="Denoising Steps % start")
- slg_end_perc = gr.Slider(0, 100, value=ui_defaults["slg_end_perc"], step=1, label="Denoising Steps % end")
+ slg_start_perc = gr.Slider(0, 100, value=ui_defaults.get("slg_start_perc",10), step=1, label="Denoising Steps % start")
+ slg_end_perc = gr.Slider(0, 100, value=ui_defaults.get("slg_end_perc",90), step=1, label="Denoising Steps % end")
with gr.Row():
gr.Markdown("
Experimental: Classifier-Free Guidance Zero Star, better adherence to Text Prompt")
@@ -2777,239 +2868,142 @@ def generate_video_tab(image2video=False):
with gr.Row():
save_settings_btn = gr.Button("Set Settings as Default", visible = not args.lock_config)
+
+ if not update_form:
+ with gr.Column():
+ gen_status = gr.Text(interactive= False, label = "Status")
+ output = gr.Gallery( label="Generated videos", show_label=False, elem_id="gallery" , columns=[3], rows=[1], object_fit="contain", height=450, selected_index=0, interactive= False)
+ generate_btn = gr.Button("Generate")
+ add_to_queue_btn = gr.Button("Add New Prompt To Queue", visible = False)
+
+ with gr.Column(visible= False) as current_gen_column:
+ with gr.Row():
+ gen_info = gr.HTML(visible=False, min_height=1)
+ with gr.Row():
+ onemore_btn = gr.Button("One More Sample Please !")
+ abort_btn = gr.Button("Abort")
+
+ queue_df = gr.DataFrame(
+ headers=["Qty","Prompt", "Length","Steps","", "", "", "", ""],
+ datatype=[ "str","markdown","str", "markdown", "markdown", "markdown", "str", "str", "str"],
+ column_widths= ["50","", "65","55", "60", "60", "30", "30", "35"],
+ interactive=False,
+ col_count=(9, "fixed"),
+ wrap=True,
+ value=[],
+ line_breaks= True,
+ visible= False,
+ elem_id="queue_df"
+ )
+
+ extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column,
+ prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, advanced_row] # show_advanced presets_column,
+ if update_form:
+ locals_dict = locals()
+ gen_inputs = [state_dict if k=="state" else locals_dict[k] for k in inputs_names] + [state_dict] + extra_inputs
+ return gen_inputs
+ else:
+ target_state = gr.Text(value = "state", interactive= False, visible= False)
+ target_settings = gr.Text(value = "settings", interactive= False, visible= False)
+
+ image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start, image_end])
+ video_prompt_type.change(fn=refresh_video_prompt_type, inputs=[state, video_prompt_type], outputs=[image_refs, video_guide, video_mask, max_frames, remove_background_image_ref])
show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars])
- with gr.Column():
- gen_status = gr.Text(interactive= False)
- full_sync = gr.Text(interactive= False, visible= False)
- light_sync = gr.Text(interactive= False, visible= False)
-
- gen_progress_html = gr.HTML(
- label="Status",
- value="Idle",
- elem_id="generation_progress_bar_container", visible= False
+ queue_df.select( fn=handle_celll_selection, inputs=state, outputs=[queue_df, modal_image_display, modal_container])
+ save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
+ confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
+ save_lset, inputs=[state, lset_name, loras_choices, loras_multipliers, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn,refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
+ delete_lset_btn.click(validate_delete_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
+ confirm_delete_lset_btn.click(delete_lset, inputs=[state, lset_name], outputs=[lset_name, apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
+ cancel_lset_btn.click(cancel_lset, inputs=[], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_delete_lset_btn,confirm_save_lset_btn, cancel_lset_btn,save_lset_prompt_drop ])
+ apply_lset_btn.click(apply_lset, inputs=[state, wizard_prompt_activated_var, lset_name,loras_choices, loras_multipliers, prompt], outputs=[wizard_prompt_activated_var, loras_choices, loras_multipliers, prompt]).then(
+ fn = fill_wizard_prompt, inputs = [state, wizard_prompt_activated_var, prompt, wizard_prompt], outputs = [ wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars]
)
- output = gr.Gallery(
- label="Generated videos", show_label=False, elem_id="gallery"
- , columns=[3], rows=[1], object_fit="contain", height=450, selected_index=0, interactive= False)
- generate_btn = gr.Button("Generate")
- add_to_queue_btn = gr.Button("Add New Prompt To Queue", visible = False)
-
- with gr.Column(visible= False) as current_gen_column:
- with gr.Row():
- gen_info = gr.HTML(visible=False, min_height=1)
- with gr.Row():
- onemore_btn = gr.Button("One More Sample Please !")
- abort_btn = gr.Button("Abort")
-
- queue_df = gr.DataFrame(
- headers=["Qty","Prompt", "Length","Steps","", "", "", "", ""],
- datatype=[ "str","markdown","str", "markdown", "markdown", "markdown", "str", "str", "str"],
- column_widths= ["50","", "65","55", "60", "60", "30", "30", "35"],
- interactive=False,
- col_count=(9, "fixed"),
- wrap=True,
- value=[],
- line_breaks= True,
- visible= False,
- # every=1,
- elem_id="queue_df"
- )
- # queue_df = gr.HTML("",
- # visible= False,
- # elem_id="queue_df"
- # )
-
- def handle_selection(state, evt: gr.SelectData):
- gen = get_gen_info(state)
- queue = gen.get("queue", [])
-
- if evt.index is None:
- return gr.update(), gr.update(), gr.update(visible=False)
- row_index, col_index = evt.index
- cell_value = None
- if col_index in [6, 7, 8]:
- if col_index == 6: cell_value = "↑"
- elif col_index == 7: cell_value = "↓"
- elif col_index == 8: cell_value = "✖"
- if col_index == 6:
- new_df_data = move_up(queue, [row_index])
- return new_df_data, gr.update(), gr.update(visible=False)
- elif col_index == 7:
- new_df_data = move_down(queue, [row_index])
- return new_df_data, gr.update(), gr.update(visible=False)
- elif col_index == 8:
- new_df_data = remove_task(queue, [row_index])
- gen["prompts_max"] = gen.get("prompts_max",0) - 1
- update_status(state)
- return new_df_data, gr.update(), gr.update(visible=False)
- start_img_col_idx = 4
- end_img_col_idx = 5
- image_data_to_show = None
- if col_index == start_img_col_idx:
- with lock:
- if row_index < len(queue):
- image_data_to_show = queue[row_index].get('start_image_data')
- elif col_index == end_img_col_idx:
- with lock:
- if row_index < len(queue):
- image_data_to_show = queue[row_index].get('end_image_data')
-
- if image_data_to_show:
- return gr.update(), gr.update(value=image_data_to_show), gr.update(visible=True)
- else:
- return gr.update(), gr.update(), gr.update(visible=False)
- selected_indices = gr.State([])
- queue_df.select(
- fn=handle_selection,
- inputs=state,
- outputs=[queue_df, modal_image_display, modal_container],
- )
- # gallery_update_trigger.change(
- # fn=refresh_gallery_on_trigger,
- # inputs=[state],
- # outputs=[output]
- # )
- # queue_df.change(
- # fn=refresh_gallery,
- # inputs=[state],
- # outputs=[gallery_update_trigger]
- # ).then(
- # fn=refresh_progress,
- # inputs=None,
- # outputs=[progress_update_trigger]
- # )
- progress_update_trigger.change(
- fn=update_generation_status,
- inputs=[progress_update_trigger],
- outputs=[gen_progress_html],
- show_progress="hidden"
- )
- save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
- save_settings, inputs = [state, prompt, image_prompt_type_radio, max_frames, remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt,
- loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling_choice, spatial_upsampling_choice, RIFLEx_setting, slg_switch, slg_layers,
- slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step ], outputs = [])
- save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
- confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
- save_lset, inputs=[state, lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn,refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
- delete_lset_btn.click(validate_delete_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
- confirm_delete_lset_btn.click(delete_lset, inputs=[state, lset_name], outputs=[lset_name, apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
- cancel_lset_btn.click(cancel_lset, inputs=[], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_delete_lset_btn,confirm_save_lset_btn, cancel_lset_btn,save_lset_prompt_drop ])
- apply_lset_btn.click(apply_lset, inputs=[state, wizard_prompt_activated_var, lset_name,loras_choices, loras_mult_choices, prompt], outputs=[wizard_prompt_activated_var, loras_choices, loras_mult_choices, prompt]).then(
- fn = fill_wizard_prompt, inputs = [state, wizard_prompt_activated_var, prompt, wizard_prompt], outputs = [ wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars]
- )
- refresh_lora_btn.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
- refresh_lora_btn2.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
- output.select(select_video, state, None )
-
-
+ refresh_lora_btn.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
+ refresh_lora_btn2.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
+ output.select(select_video, state, None )
+
+ gen_status.change(refresh_gallery,
+ inputs = [state, gen_status],
+ outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn])
- gen_status.change(refresh_gallery,
- inputs = [state, gen_status],
- outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn])
-
- full_sync.change(fn= check_refresh_input_type,
- inputs= [state],
- outputs= [trigger_refresh_input_type]
- ).then(fn=refresh_gallery,
- inputs = [state, gen_status],
- outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn]
- ).then(fn=wait_tasks_done,
- inputs= [state],
- outputs =[gen_status],
- ).then(finalize_generation,
- inputs= [state],
- outputs= [output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info]
- )
- light_sync.change(fn= check_refresh_input_type,
- inputs= [state],
- outputs= [trigger_refresh_input_type]
- ).then(fn=refresh_gallery,
- inputs = [state, gen_status],
- outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn]
- )
+ abort_btn.click(abort_generation, [state], [gen_status, abort_btn] ) #.then(refresh_gallery, inputs = [state, gen_info], outputs = [output, gen_info, queue_df] )
+ onemore_btn.click(fn=one_more_sample,inputs=[state], outputs= [state])
- abort_btn.click(abort_generation, [state], [gen_status, abort_btn] ) #.then(refresh_gallery, inputs = [state, gen_info], outputs = [output, gen_info, queue_df] )
- onemore_btn.click(fn=one_more_sample,inputs=[state], outputs= [state])
+ inputs_names= list(inspect.signature(save_inputs).parameters)[1:-1]
+ locals_dict = locals()
+ gen_inputs = [locals_dict[k] for k in inputs_names] + [state]
+ save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
+ save_inputs, inputs =[target_settings] + gen_inputs, outputs = [])
- gen_inputs=[
- prompt,
- negative_prompt,
- resolution,
- video_length,
- seed,
- num_inference_steps,
- guidance_scale,
- flow_shift,
- embedded_guidance_scale,
- repeat_generation,
- multi_images_gen_type,
- tea_cache_setting,
- tea_cache_start_step_perc,
- loras_choices,
- loras_mult_choices,
- image_prompt_type_radio,
- image_source1,
- image_source2,
- image_source3,
- max_frames,
- remove_background_image_ref,
- temporal_upsampling_choice,
- spatial_upsampling_choice,
- RIFLEx_setting,
- slg_switch,
- slg_layers,
- slg_start_perc,
- slg_end_perc,
- cfg_star_switch,
- cfg_zero_step,
- state,
- gr.State(image2video)
- ]
+ model_choice.change(fn=validate_wizard_prompt,
+ inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] ,
+ outputs= [prompt]
+ ).then(fn=save_inputs,
+ inputs =[target_state] + gen_inputs,
+ outputs= None
+ ).then(fn= change_model,
+ inputs=[state, model_choice],
+ outputs= [header]
+ ).then(fn= fill_inputs,
+ inputs=[state],
+ outputs=gen_inputs + extra_inputs
+ ).then(fn= preload_model,
+ inputs=[state],
+ outputs=[gen_status])
- generate_btn.click(fn=validate_wizard_prompt,
- inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] ,
- outputs= [prompt]
- ).then(fn=process_prompt_and_add_tasks,
- inputs = gen_inputs,
- outputs= queue_df
- ).then(fn=prepare_generate_video,
- inputs= [state],
- outputs= [generate_btn, add_to_queue_btn, current_gen_column],
- ).then(fn=process_tasks,
- inputs= [state],
- outputs= [gen_status],
- ).then(finalize_generation,
- inputs= [state],
- outputs= [output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info]
- )
+ generate_btn.click(fn=validate_wizard_prompt,
+ inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] ,
+ outputs= [prompt]
+ ).then(fn=save_inputs,
+ inputs =[target_state] + gen_inputs,
+ outputs= None
+ ).then(fn=process_prompt_and_add_tasks,
+ inputs = [state, model_choice],
+ outputs= queue_df
+ ).then(fn=prepare_generate_video,
+ inputs= [state],
+ outputs= [generate_btn, add_to_queue_btn, current_gen_column],
+ ).then(fn=process_tasks,
+ inputs= [state],
+ outputs= [gen_status],
+ ).then(finalize_generation,
+ inputs= [state],
+ outputs= [output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info]
+ ).then(unload_model_if_needed,
+ inputs= [state],
+ outputs= []
+ )
- add_to_queue_btn.click(fn=validate_wizard_prompt,
- inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] ,
- outputs= [prompt]
- ).then(
- fn=process_prompt_and_add_tasks,
- inputs = gen_inputs,
- outputs=queue_df
- ).then(
- fn=update_status,
- inputs = [state],
- )
+ add_to_queue_btn.click(fn=validate_wizard_prompt,
+ inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] ,
+ outputs= [prompt]
+ ).then(fn=save_inputs,
+ inputs =[target_state] + gen_inputs,
+ outputs= None
+ ).then(fn=process_prompt_and_add_tasks,
+ inputs = [state, model_choice],
+ outputs=queue_df
+ ).then(
+ fn=update_status,
+ inputs = [state],
+ )
+ close_modal_button.click(
+ lambda: gr.update(visible=False),
+ inputs=[],
+ outputs=[modal_container]
+ )
- close_modal_button.click(
- lambda: gr.update(visible=False),
- inputs=[],
- outputs=[modal_container]
- )
- return loras_column, loras_choices, presets_column, lset_name, header, light_sync, full_sync, state
+ return loras_choices, lset_name, state
-def generate_download_tab(presets_column, loras_column, lset_name,loras_choices, state):
+def generate_download_tab(lset_name,loras_choices, state):
with gr.Row():
with gr.Row(scale =2):
- gr.Markdown("Wan2GP's Lora Festival ! Press the following button to download i2v Remade Loras collection (and bonuses Loras).")
+ gr.Markdown("WanGP's Lora Festival ! Press the following button to download i2v Remade_AI Loras collection (and bonuses Loras).")
with gr.Row(scale =1):
download_loras_btn = gr.Button("---> Let the Lora's Festival Start !", scale =1)
with gr.Row(scale =1):
@@ -3017,7 +3011,7 @@ def generate_download_tab(presets_column, loras_column, lset_name,loras_choices,
with gr.Row() as download_status_row:
download_status = gr.Markdown()
- download_loras_btn.click(fn=download_loras, inputs=[], outputs=[download_status_row, download_status, presets_column, loras_column]).then(fn=refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
+ download_loras_btn.click(fn=download_loras, inputs=[], outputs=[download_status_row, download_status]).then(fn=refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
def generate_configuration_tab():
@@ -3025,37 +3019,50 @@ def generate_configuration_tab():
state = gr.State(state_dict)
gr.Markdown("Please click Apply Changes at the bottom so that the changes are effective. Some choices below may be locked if the app has been launched by specifying a config preset.")
with gr.Column():
- index = transformer_choices_t2v.index(transformer_filename_t2v)
+ index = transformer_choices.index(transformer_filename)
index = 0 if index ==0 else index
- transformer_t2v_choice = gr.Dropdown(
+
+ model_list = []
+ for model_type in model_types:
+ choice = get_model_filename(model_type, transformer_quantization)
+ model_list.append(choice)
+ dropdown_choices = [ ( get_model_name(choice), get_model_type(choice) ) for choice in model_list]
+ transformer_type_choice = gr.Dropdown(
+ choices= dropdown_choices,
+ value= get_model_type(transformer_filename),
+ label= "Default Wan Transformer Model",
+ scale= 2
+ )
+
+ # transformer_choice = gr.Dropdown(
+ # choices=[
+ # ("WAN 2.1 1.3B Text to Video 16 bits (recommended)- the small model for fast generations with low VRAM requirements", 0),
+ # ("WAN 2.1 14B Text to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 1),
+ # ("WAN 2.1 14B Text to Video quantized to 8 bits (recommended) - the default engine but quantized", 2),
+ # ("WAN 2.1 VACE 1.3B Text to Video / Control Net - text generation driven by reference images or videos", 3),
+ # ("WAN 2.1 - 480p 14B Image to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 4),
+ # ("WAN 2.1 - 480p 14B Image to Video quantized to 8 bits (recommended) - the default engine but quantized", 5),
+ # ("WAN 2.1 - 720p 14B Image to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 6),
+ # ("WAN 2.1 - 720p 14B Image to Video quantized to 8 bits - the default engine but quantized", 7),
+ # ("WAN 2.1 - Fun InP 1.3B 16 bits - the small model for fast generations with low VRAM requirements", 8),
+ # ("WAN 2.1 - Fun InP 14B 16 bits - Fun InP version in its original glory, offers a slightly better image quality but slower and requires more RAM", 9),
+ # ("WAN 2.1 - Fun InP 14B quantized to 8 bits - quantized Fun InP version", 10),
+ # ],
+ # value= index,
+ # label="Transformer model for Image to Video",
+ # interactive= not lock_ui_transformer,
+ # visible = True,
+ # )
+
+ quantization_choice = gr.Dropdown(
choices=[
- ("WAN 2.1 1.3B Text to Video 16 bits (recommended)- the small model for fast generations with low VRAM requirements", 0),
- ("WAN 2.1 14B Text to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 1),
- ("WAN 2.1 14B Text to Video quantized to 8 bits (recommended) - the default engine but quantized", 2),
- ("WAN 2.1 VACE 1.3B Text to Video / Control Net - text generation driven by reference images or videos", 3),
+ ("Int8 Quantization (recommended)", "int8"),
+ ("BF16 (no quantization)", "bf16"),
],
- value= index,
- label="Transformer model for Text to Video",
- interactive= not lock_ui_transformer,
- visible=True
- )
- index = transformer_choices_i2v.index(transformer_filename_i2v)
- index = 0 if index ==0 else index
- transformer_i2v_choice = gr.Dropdown(
- choices=[
- ("WAN 2.1 - 480p 14B Image to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 0),
- ("WAN 2.1 - 480p 14B Image to Video quantized to 8 bits (recommended) - the default engine but quantized", 1),
- ("WAN 2.1 - 720p 14B Image to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 2),
- ("WAN 2.1 - 720p 14B Image to Video quantized to 8 bits - the default engine but quantized", 3),
- ("WAN 2.1 - Fun InP 1.3B 16 bits - the small model for fast generations with low VRAM requirements", 4),
- ("WAN 2.1 - Fun InP 14B 16 bits - Fun InP version in its original glory, offers a slightly better image quality but slower and requires more RAM", 5),
- ("WAN 2.1 - Fun InP 14B quantized to 8 bits - quantized Fun InP version", 6),
- ],
- value= index,
- label="Transformer model for Image to Video",
- interactive= not lock_ui_transformer,
- visible = True,
- )
+ value= transformer_quantization,
+ label="Wan Transformer Model Quantization (if available)",
+ )
+
index = text_encoder_choices.index(text_encoder_filename)
index = 0 if index ==0 else index
text_encoder_choice = gr.Dropdown(
@@ -3149,11 +3156,12 @@ def generate_configuration_tab():
)
reload_choice = gr.Dropdown(
choices=[
- ("When changing tabs", 1),
- ("When pressing Generate", 2),
+ ("Load Model When Changing Model", 1),
+ ("Load Model When Pressing Generate", 2),
+ ("Load Model When Pressing Generate and Unload Model when Finished", 3),
],
value=server_config.get("reload_model",2),
- label="Reload model"
+ label="RAM Loading / Unloading Model Policy (in any case VRAM will be freed once the queue has been processed)"
)
clear_file_list_choice = gr.Dropdown(
@@ -3176,8 +3184,7 @@ def generate_configuration_tab():
fn=apply_changes,
inputs=[
state,
- transformer_t2v_choice,
- transformer_i2v_choice,
+ transformer_type_choice,
text_encoder_choice,
save_path_choice,
attention_choice,
@@ -3185,7 +3192,7 @@ def generate_configuration_tab():
profile_choice,
vae_config_choice,
metadata_choice,
- default_ui_choice,
+ quantization_choice,
boost_choice,
clear_file_list_choice,
reload_choice,
@@ -3194,7 +3201,7 @@ def generate_configuration_tab():
)
def generate_about_tab():
- gr.Markdown("Wan2.1GP - Wan 2.1 model for the GPU Poor by DeepBeepMeep (GitHub)
")
+ gr.Markdown("WanGP - Wan 2.1 model for the GPU Poor by DeepBeepMeep (GitHub)
")
gr.Markdown("Original Wan 2.1 Model by Alibaba (GitHub)")
gr.Markdown("Many thanks to:")
gr.Markdown("- Alibaba Wan team for the best open source video generator")
@@ -3202,114 +3209,19 @@ def generate_about_tab():
gr.Markdown("- Tophness : created multi tabs and queuing frameworks")
gr.Markdown("- AmericanPresidentJimmyCarter : added original support for Skip Layer Guidance")
gr.Markdown("- Remade_AI : for creating their awesome Loras collection")
+
+def generate_info_tab():
+ gr.Markdown("Welcome to WanGP a super fast and low VRAM AI Video Generator !")
-
-def on_tab_select(global_state, t2v_state, i2v_state, evt: gr.SelectData):
- t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode)
- i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode)
-
- new_t2v = evt.index == 0
- new_i2v = evt.index == 1
- i2v_light_sync = gr.Text()
- t2v_light_sync = gr.Text()
- i2v_full_sync = gr.Text()
- t2v_full_sync = gr.Text()
-
- last_tab_was_image2video =global_state.get("last_tab_was_image2video", None)
- if last_tab_was_image2video == None or last_tab_was_image2video:
- gen = i2v_state["gen"]
- t2v_state["gen"] = gen
- else:
- gen = t2v_state["gen"]
- i2v_state["gen"] = gen
+ gr.Markdown("The VRAM requirements will depend greatly of the resolution and the duration of the video, for instance :")
+ gr.Markdown("- 848 x 480 with a 14B model: 80 frames (5s) : 8 GB of VRAM")
+ gr.Markdown("- 848 x 480 with the 1.3B model: 80 frames (5s) : 5 GB of VRAM")
+ gr.Markdown("- 1280 x 720 with a 14B model: 80 frames (5s): 11 GB of VRAM")
+ gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifacts may appear")
+ gr.Markdown("Please note that if your turn on compilation, the first denoising step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
- if new_t2v or new_i2v:
- if last_tab_was_image2video != None and new_t2v != new_i2v:
- gen_location = gen.get("location", None)
- if "in_progress" in gen and gen_location !=None and not (gen_location and new_i2v or not gen_location and new_t2v) :
- if new_i2v:
- i2v_full_sync = gr.Text(str(time.time()))
- else:
- t2v_full_sync = gr.Text(str(time.time()))
- else:
- if new_i2v:
- i2v_light_sync = gr.Text(str(time.time()))
- else:
- t2v_light_sync = gr.Text(str(time.time()))
- global_state["last_tab_was_image2video"] = new_i2v
-
- if(server_config.get("reload_model",2) == 1):
- queue = gen.get("queue", [])
-
- queue_empty = len(queue) == 0
- if queue_empty:
- global wan_model, offloadobj
- if wan_model is not None:
- if offloadobj is not None:
- offloadobj.release()
- offloadobj = None
- wan_model = None
- gc.collect()
- torch.cuda.empty_cache()
- wan_model, offloadobj, trans = load_models(new_i2v)
- del trans
-
- if new_t2v or new_i2v:
- state = i2v_state if new_i2v else t2v_state
- lora_model_filename = state["loras_model"]
- model_filename = model_needed(new_i2v)
- if ("1.3B" in model_filename and not "1.3B" in lora_model_filename or "14B" in model_filename and not "14B" in lora_model_filename):
- lora_dir = get_lora_dir(new_i2v)
- loras, loras_names, loras_presets, _, _, _, _ = setup_loras(new_i2v, None, lora_dir, lora_preselected_preset, None)
- state["loras"] = loras
- state["loras_names"] = loras_names
- state["loras_presets"] = loras_presets
- state["loras_model"] = model_filename
-
- advanced = state["advanced"]
- new_loras_choices = [(name, str(i)) for i, name in enumerate(loras_names)]
- lset_choices = [(preset, preset) for preset in loras_presets] + [(get_new_preset_msg(advanced), "")]
- visible = len(loras_names)>0
- if new_t2v:
- return [
- gr.Column(visible= visible),
- gr.Dropdown(choices=new_loras_choices, visible=visible, value=[]),
- gr.Column(visible= visible),
- gr.Dropdown(choices=lset_choices, value=get_new_preset_msg(advanced), visible=visible),
- t2v_header,
- t2v_light_sync,
- t2v_full_sync,
- gr.Column(),
- gr.Dropdown(),
- gr.Column(),
- gr.Dropdown(),
- gr.Markdown(),
- gr.Text(),
- gr.Text(),
- ]
- else:
- return [
- gr.Column(),
- gr.Dropdown(),
- gr.Column(),
- gr.Dropdown(),
- gr.Markdown(),
- gr.Text(),
- gr.Text(),
- gr.Text(),
- gr.Column(visible= visible),
- gr.Dropdown(choices=new_loras_choices, visible=visible, value=[]),
- gr.Column(visible= visible),
- gr.Dropdown(choices=lset_choices, value=get_new_preset_msg(advanced), visible=visible),
- i2v_header,
- i2v_light_sync,
- i2v_full_sync,
- ]
-
- return [gr.Column(), gr.Dropdown(), gr.Column(), gr.Dropdown(), t2v_header, t2v_light_sync, t2v_full_sync,
- gr.Column(), gr.Dropdown(), gr.Column(), gr.Dropdown(), i2v_header, i2v_light_sync, i2v_full_sync]
def create_demo():
@@ -3539,40 +3451,41 @@ def create_demo():
}
"""
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md"), title= "Wan2GP") as demo:
- gr.Markdown("Wan 2.1GP v4.0 by DeepBeepMeep (Updates)
")
- gr.Markdown("Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !")
-
- with gr.Accordion("Click here for some Info on how to use Wan2GP", open = False):
- gr.Markdown("The VRAM requirements will depend greatly of the resolution and the duration of the video, for instance :")
- gr.Markdown("- 848 x 480 with a 14B model: 80 frames (5s) : 8 GB of VRAM")
- gr.Markdown("- 848 x 480 with the 1.3B model: 80 frames (5s) : 5 GB of VRAM")
- gr.Markdown("- 1280 x 720 with a 14B model: 80 frames (5s): 11 GB of VRAM")
- gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifacts may appear")
- gr.Markdown("Please note that if your turn on compilation, the first denoising step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
- global_dict = {}
- global_dict["last_tab_was_image2video"] = use_image2video
- global_state = gr.State(global_dict)
+ gr.Markdown("WanGP v4.0 by DeepBeepMeep ") # (Updates)
")
+ global model_list
- with gr.Tabs(selected="i2v" if use_image2video else "t2v") as main_tabs:
- with gr.Tab("Text To Video", id="t2v") as t2v_tab:
- t2v_loras_column, t2v_loras_choices, t2v_presets_column, t2v_lset_name, t2v_header, t2v_light_sync, t2v_full_sync, t2v_state = generate_video_tab(False)
- with gr.Tab("Image To Video", id="i2v") as i2v_tab:
- i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name, i2v_header, i2v_light_sync, i2v_full_sync, i2v_state = generate_video_tab(True)
+ with gr.Tabs(selected="video_gen", ) as main_tabs:
+ with gr.Tab("Video Generator", id="video_gen") as t2v_tab:
+ with gr.Row():
+ header = gr.Markdown(generate_header(compile, attention_mode), visible= True)
+ with gr.Row():
+ gr.Markdown("")
+
+ model_list = []
+ for model_type in model_types:
+ choice = get_model_filename(model_type, transformer_quantization)
+ model_list.append(choice)
+ dropdown_choices = [ ( get_model_name(choice), get_model_type(choice) ) for choice in model_list]
+ model_choice = gr.Dropdown(
+ choices= dropdown_choices,
+ value= get_model_type(transformer_filename),
+ show_label= False,
+ scale= 2
+ )
+ gr.Markdown("")
+ with gr.Row():
+
+ loras_choices, lset_name, state = generate_video_tab(model_choice = model_choice, header = header)
+ with gr.Tab("Informations"):
+ generate_info_tab()
if not args.lock_config:
with gr.Tab("Downloads", id="downloads") as downloads_tab:
- generate_download_tab(i2v_presets_column, i2v_loras_column, i2v_lset_name, i2v_loras_choices, i2v_state)
+ generate_download_tab(lset_name, loras_choices, state)
with gr.Tab("Configuration"):
generate_configuration_tab()
with gr.Tab("About"):
generate_about_tab()
- main_tabs.select(
- fn=on_tab_select,
- inputs=[global_state, t2v_state, i2v_state],
- outputs=[
- t2v_loras_column, t2v_loras_choices, t2v_presets_column, t2v_lset_name, t2v_header, t2v_light_sync, t2v_full_sync,
- i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name, i2v_header, i2v_light_sync, i2v_full_sync
- ]
- )
+
return demo
if __name__ == "__main__":