diff --git a/gradio_server.py b/gradio_server.py
index 6f7f691..e21c5cb 100644
--- a/gradio_server.py
+++ b/gradio_server.py
@@ -1,5 +1,6 @@
import os
import time
+import sys
import threading
import argparse
from mmgp import offload, safetensors2, profile_type
@@ -33,15 +34,12 @@ mmgp_version = version("mmgp")
if mmgp_version != target_mmgp_version:
print(f"Incorrect version of mmgp ({mmgp_version}), version {target_mmgp_version} is needed. Please upgrade with the command 'pip install -r requirements.txt'")
exit()
-queue = []
lock = threading.Lock()
current_task_id = None
task_id = 0
-progress_tracker = {}
-tracker_lock = threading.Lock()
-file_list = []
+# progress_tracker = {}
+# tracker_lock = threading.Lock()
last_model_type = None
-last_status_string = ""
def format_time(seconds):
if seconds < 60:
@@ -77,37 +75,6 @@ def pil_to_base64_uri(pil_image, format="png", quality=75):
print(f"Error converting PIL to base64: {e}")
return None
-def runner():
- global current_task_id
- while True:
- with lock:
- for item in queue:
- task_id_runner = item['id']
- with tracker_lock:
- progress = progress_tracker.get(task_id_runner, {})
-
- if item['state'] == "Processing":
- current_step = progress.get('current_step', 0)
- total_steps = progress.get('total_steps', 0)
- elapsed = time.time() - progress.get('start_time', time.time())
- status = progress.get('status', "")
- repeats = progress.get("repeats", "0/0")
- item.update({
- 'progress': f"{((current_step/total_steps)*100 if total_steps > 0 else 0):.1f}%",
- 'steps': f"{current_step}/{total_steps}",
- 'time': format_time(elapsed),
- 'repeats': f"{repeats}",
- 'status': f"{status}"
- })
- if not any(item['state'] == "Processing" for item in queue):
- for item in queue:
- if item['state'] == "Queued":
- item['status'] = "Processing"
- item['state'] = "Processing"
- current_task_id = item['id']
- threading.Thread(target=process_task, args=(item,)).start()
- break
- time.sleep(1)
def process_prompt_and_add_tasks(
prompt,
@@ -137,161 +104,290 @@ def process_prompt_and_add_tasks(
slg_end,
cfg_star_switch,
cfg_zero_step,
- state_arg,
+ state,
image2video
):
-
- if state_arg.get("validate_success",0) != 1:
- print("Validation failed, not adding tasks.")
+
+ if state.get("validate_success",0) != 1:
+ gr.Info("Validation failed, not adding tasks.")
return
+
+ state["validate_success"] = 0
if len(prompt) ==0:
return
prompt, errors = prompt_parser.process_template(prompt)
if len(errors) > 0:
- print("Error processing prompt template: " + errors)
+ gr.Info("Error processing prompt template: " + errors)
return
prompts = prompt.replace("\r", "").split("\n")
prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")]
if len(prompts) ==0:
return
- for single_prompt in prompts:
- task_params = (
- single_prompt,
- negative_prompt,
- resolution,
- video_length,
- seed,
- num_inference_steps,
- guidance_scale,
- flow_shift,
- embedded_guidance_scale,
- repeat_generation,
- multi_images_gen_type,
- tea_cache,
- tea_cache_start_step_perc,
- loras_choices,
- loras_mult_choices,
- image_prompt_type,
- image_to_continue,
- image_to_end,
- video_to_continue,
- max_frames,
- RIFLEx_setting,
- slg_switch,
- slg_layers,
- slg_start,
- slg_end,
- cfg_star_switch,
- cfg_zero_step,
- state_arg,
- image2video
- )
- add_video_task(*task_params)
- return update_queue_data()
+ file_model_needed = model_needed(image2video)
+ if image2video:
+ width, height = resolution.split("x")
+ width, height = int(width), int(height)
-def process_task(task):
- try:
- task_id, *params = task['params']
- generate_video(task_id, *params)
- finally:
- with lock:
- queue[:] = [item for item in queue if item['id'] != task['id']]
- with tracker_lock:
- if task['id'] in progress_tracker:
- del progress_tracker[task['id']]
+ if "480p" in file_model_needed and not "Fun" in file_model_needed and width * height > 848*480:
+ gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P")
+ return
+ resolution = str(width) + "*" + str(height)
+ if resolution not in ['720*1280', '1280*720', '480*832', '832*480']:
+ gr.Info(f"Resolution {resolution} not supported by image 2 video")
+ return
-def add_video_task(*params):
+ if "1.3B" in file_model_needed and width * height > 848*480:
+ gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P")
+ return
+
+ if image2video:
+ if image_to_continue == None or isinstance(image_to_continue, list) and len(image_to_continue) == 0:
+ return
+ if image_prompt_type == 0:
+ image_to_end = None
+ if isinstance(image_to_continue, list):
+ image_to_continue = [ convert_image(tup[0]) for tup in image_to_continue ]
+ else:
+ image_to_continue = [convert_image(image_to_continue)]
+ if image_to_end != None:
+ if isinstance(image_to_end , list):
+ image_to_end = [ convert_image(tup[0]) for tup in image_to_end ]
+ else:
+ image_to_end = [convert_image(image_to_end) ]
+ if len(image_to_continue) != len(image_to_end):
+ gr.Info("The number of start and end images should be the same ")
+ return
+
+ if multi_images_gen_type == 0:
+ new_prompts = []
+ new_image_to_continue = []
+ new_image_to_end = []
+ for i in range(len(prompts) * len(image_to_continue) ):
+ new_prompts.append( prompts[ i % len(prompts)] )
+ new_image_to_continue.append(image_to_continue[i // len(prompts)] )
+ if image_to_end != None:
+ new_image_to_end.append(image_to_end[i // len(prompts)] )
+ prompts = new_prompts
+ image_to_continue = new_image_to_continue
+ if image_to_end != None:
+ image_to_end = new_image_to_end
+ else:
+ if len(prompts) >= len(image_to_continue):
+ if len(prompts) % len(image_to_continue) !=0:
+ raise gr.Error("If there are more text prompts than input images the number of text prompts should be dividable by the number of images")
+ rep = len(prompts) // len(image_to_continue)
+ new_image_to_continue = []
+ new_image_to_end = []
+ for i, _ in enumerate(prompts):
+ new_image_to_continue.append(image_to_continue[i//rep] )
+ if image_to_end != None:
+ new_image_to_end.append(image_to_end[i//rep] )
+ image_to_continue = new_image_to_continue
+ if image_to_end != None:
+ image_to_end = new_image_to_end
+ else:
+ if len(image_to_continue) % len(prompts) !=0:
+ raise gr.Error("If there are more input images than text prompts the number of images should be dividable by the number of text prompts")
+ rep = len(image_to_continue) // len(prompts)
+ new_prompts = []
+ for i, _ in enumerate(image_to_continue):
+ new_prompts.append( prompts[ i//rep] )
+ prompts = new_prompts
+
+ # elif video_to_continue != None and len(video_to_continue) >0 :
+ # input_image_or_video_path = video_to_continue
+ # # pipeline.num_input_frames = max_frames
+ # # pipeline.max_frames = max_frames
+ # else:
+ # return
+ # else:
+ # input_image_or_video_path = None
+ if image_to_continue == None:
+ image_to_continue = [None] * len(prompts)
+ if image_to_end == None:
+ image_to_end = [None] * len(prompts)
+
+ for single_prompt, image_start, image_end in zip(prompts, image_to_continue, image_to_end) :
+ kwargs = {
+ "prompt" : single_prompt,
+ "negative_prompt" : negative_prompt,
+ "resolution" : resolution,
+ "video_length" : video_length,
+ "seed" : seed,
+ "num_inference_steps" : num_inference_steps,
+ "guidance_scale" : guidance_scale,
+ "flow_shift" : flow_shift,
+ "embedded_guidance_scale" : embedded_guidance_scale,
+ "repeat_generation" : repeat_generation,
+ "multi_images_gen_type" : multi_images_gen_type,
+ "tea_cache" : tea_cache,
+ "tea_cache_start_step_perc" : tea_cache_start_step_perc,
+ "loras_choices" : loras_choices,
+ "loras_mult_choices" : loras_mult_choices,
+ "image_prompt_type" : image_prompt_type,
+ "image_to_continue": image_start,
+ "image_to_end" : image_end,
+ "video_to_continue" : video_to_continue ,
+ "max_frames" : max_frames,
+ "RIFLEx_setting" : RIFLEx_setting,
+ "slg_switch" : slg_switch,
+ "slg_layers" : slg_layers,
+ "slg_start" : slg_start,
+ "slg_end" : slg_end,
+ "cfg_star_switch" : cfg_star_switch,
+ "cfg_zero_step" : cfg_zero_step,
+ "state" : state,
+ "image2video" : image2video
+ }
+ add_video_task(**kwargs)
+
+ gen = get_gen_info(state)
+ gen["prompts_max"] = len(prompts) + gen.get("prompts_max",0)
+ state["validate_success"] = 1
+ queue= gen.get("queue", [])
+ return update_queue_data(queue)
+
+
+
+
+def add_video_task(**kwargs):
global task_id
- with lock:
- task_id += 1
- current_task_id = task_id
- start_image_data = params[16] if len(params) > 16 else None
- end_image_data = params[17] if len(params) > 17 else None
+ state = kwargs["state"]
+ gen = get_gen_info(state)
+ queue = gen["queue"]
+ task_id += 1
+ current_task_id = task_id
+ start_image_data = kwargs["image_to_continue"]
+ end_image_data = kwargs["image_to_end"]
- queue.append({
- "id": current_task_id,
- "params": (current_task_id,) + params,
- "state": "Queued",
- "status": "Queued",
- "repeats": "0/0",
- "progress": "0.0%",
- "steps": f"0/{params[5]}",
- "time": "--",
- "prompt": params[0],
- "start_image_data": start_image_data,
- "end_image_data": end_image_data
- })
- return update_queue_data()
+ queue.append({
+ "id": current_task_id,
+ "image2video": kwargs["image2video"],
+ "params": kwargs.copy(),
+ "repeats": kwargs["repeat_generation"],
+ "length": kwargs["video_length"],
+ "steps": kwargs["num_inference_steps"],
+ "prompt": kwargs["prompt"],
+ "start_image_data": start_image_data,
+ "end_image_data": end_image_data,
+ "start_image_data_base64": pil_to_base64_uri(start_image_data, format="jpeg", quality=70),
+ "end_image_data_base64": pil_to_base64_uri(end_image_data, format="jpeg", quality=70)
+ })
+ return update_queue_data(queue)
-def move_up(selected_indices):
+def move_up(queue, selected_indices):
if not selected_indices or len(selected_indices) == 0:
- return update_queue_data()
+ return update_queue_data(queue)
idx = selected_indices[0]
if isinstance(idx, list):
idx = idx[0]
idx = int(idx)
with lock:
if idx > 0:
+ idx += 1
queue[idx], queue[idx-1] = queue[idx-1], queue[idx]
- return update_queue_data()
+ return update_queue_data(queue)
-def move_down(selected_indices):
+def move_down(queue, selected_indices):
if not selected_indices or len(selected_indices) == 0:
- return update_queue_data()
+ return update_queue_data(queue)
idx = selected_indices[0]
if isinstance(idx, list):
idx = idx[0]
idx = int(idx)
with lock:
+ idx += 1
if idx < len(queue)-1:
queue[idx], queue[idx+1] = queue[idx+1], queue[idx]
- return update_queue_data()
+ return update_queue_data(queue)
-def remove_task(selected_indices):
+def remove_task(queue, selected_indices):
if not selected_indices or len(selected_indices) == 0:
- return update_queue_data()
+ return update_queue_data(queue)
idx = selected_indices[0]
if isinstance(idx, list):
idx = idx[0]
- idx = int(idx)
+ idx = int(idx) + 1
with lock:
if idx < len(queue):
if idx == 0:
wan_model._interrupt = True
del queue[idx]
- return update_queue_data()
+ return update_queue_data(queue)
-def update_queue_data():
- with lock:
- data = []
- for item in queue:
- truncated_prompt = (item['prompt'][:97] + '...') if len(item['prompt']) > 100 else item['prompt']
- full_prompt = item['prompt'].replace('"', '"')
- prompt_cell = f'{truncated_prompt}'
- start_img_uri = pil_to_base64_uri(item.get('start_image_data'), format="jpeg", quality=70)
- end_img_uri = pil_to_base64_uri(item.get('end_image_data'), format="jpeg", quality=70)
- thumbnail_size = "50px"
- start_img_md = ""
- end_img_md = ""
- if start_img_uri:
- start_img_md = f'
'
- if end_img_uri:
- end_img_md = f'
'
- data.append([
- item.get('status', "Starting"),
- item.get('repeats', "0/0"),
- item.get('progress', "0.0%"),
- item.get('steps', ''),
- item.get('time', '--'),
- prompt_cell,
- start_img_md,
- end_img_md,
- "↑",
- "↓",
- "✖"
- ])
- return data
+
+
+def get_queue_table(queue):
+ data = []
+ if len(queue) == 1:
+ return data
+
+ # def td(l, content, width =None):
+ # if width !=None:
+ # l.append("
" + content + " | ")
+ # else:
+ # l.append("" + content + " | ")
+
+ # data.append("| Qty | Prompt | Steps | | | | | |
")
+
+ for i, item in enumerate(queue):
+ if i==0:
+ continue
+ truncated_prompt = (item['prompt'][:97] + '...') if len(item['prompt']) > 100 else item['prompt']
+ full_prompt = item['prompt'].replace('"', '"')
+ prompt_cell = f'{truncated_prompt}'
+ start_img_uri =item.get('start_image_data_base64')
+ end_img_uri = item.get('end_image_data_base64')
+ thumbnail_size = "50px"
+ num_steps = item.get('steps')
+ length = item.get('length')
+ start_img_md = ""
+ end_img_md = ""
+ if start_img_uri:
+ start_img_md = f'
'
+ if end_img_uri:
+ end_img_md = f'
'
+ # if i % 2 == 1:
+ # data.append("")
+ # else:
+ # data.append("
")
+
+ # td(data,str(item.get('repeats', "1")) )
+ # td(data, prompt_cell, "100%")
+ # td(data, num_steps, "100%")
+ # td(data, start_img_md)
+ # td(data, end_img_md)
+ # td(data, "↑")
+ # td(data, "↓")
+ # td(data, "✖")
+ # data.append("
")
+ # data.append("
")
+ # return ''.join(data)
+
+ data.append([item.get('repeats', "1"),
+ prompt_cell,
+ length,
+ num_steps,
+ start_img_md,
+ end_img_md,
+ "↑",
+ "↓",
+ "✖"
+ ])
+ return data
+def update_queue_data(queue):
+
+ data = get_queue_table(queue)
+
+ # if len(data) == 0:
+ # return gr.HTML(visible=False)
+ # else:
+ # return gr.HTML(value=data, visible= True)
+ if len(data) == 0:
+ return gr.DataFrame(visible=False)
+ else:
+ return gr.DataFrame(value=data, visible= True)
def create_html_progress_bar(percentage=0.0, text="Idle", is_idle=True):
bar_class = "progress-bar-custom idle" if is_idle else "progress-bar-custom"
@@ -306,35 +402,35 @@ def create_html_progress_bar(percentage=0.0, text="Idle", is_idle=True):
"""
return html
-def refresh_progress():
- global current_task_id, progress_tracker, last_status_string
- task_id_to_check = current_task_id
- is_idle = True
- status_string = "Starting..."
- progress_percent = 0.0
- html_content = ""
+# def refresh_progress():
+# global current_task_id, progress_tracker, last_status_string
+# task_id_to_check = current_task_id
+# is_idle = True
+# status_string = "Starting..."
+# progress_percent = 0.0
+# html_content = ""
- with tracker_lock:
- with lock:
- processing_or_queued = any(item['state'] in ["Processing", "Queued"] for item in queue)
- if task_id_to_check is not None:
- progress_data = progress_tracker.get(task_id_to_check)
- if progress_data:
- is_idle = False
- current_step = progress_data.get('current_step', 0)
- total_steps = progress_data.get('total_steps', 0)
- status = progress_data.get('status', "Starting...")
- repeats = progress_data.get("repeats", "0/0")
+# with tracker_lock:
+# with lock:
+# processing_or_queued = any(item['state'] in ["Processing", "Queued"] for item in queue)
+# if task_id_to_check is not None:
+# progress_data = progress_tracker.get(task_id_to_check)
+# if progress_data:
+# is_idle = False
+# current_step = progress_data.get('current_step', 0)
+# total_steps = progress_data.get('total_steps', 0)
+# status = progress_data.get('status', "Starting...")
+# repeats = progress_data.get("repeats", 1)
- if total_steps > 0:
- progress_float = min(1.0, max(0.0, float(current_step) / float(total_steps)))
- progress_percent = progress_float * 100
- status_string = f"{status} [{repeats}] - {progress_percent:.1f}% complete ({current_step}/{total_steps} steps)"
- else:
- progress_percent = 0.0
- status_string = f"{status} [{repeats}] - Initializing..."
- html_content = create_html_progress_bar(progress_percent, status_string, is_idle)
- return gr.update(value=html_content)
+# if total_steps > 0:
+# progress_float = min(1.0, max(0.0, float(current_step) / float(total_steps)))
+# progress_percent = progress_float * 100
+# status_string = f"{status} [{repeats}] - {progress_percent:.1f}% complete ({current_step}/{total_steps} steps)"
+# else:
+# progress_percent = 0.0
+# status_string = f"{status} [{repeats}] - Initializing..."
+# html_content = create_html_progress_bar(progress_percent, status_string, is_idle)
+# return gr.update(value=html_content)
def update_generation_status(html_content):
if(html_content):
@@ -736,7 +832,6 @@ if args.i2v_1_3B:
only_allow_edit_in_advanced = False
lora_preselected_preset = args.lora_preset
-lora_preselected_preset_for_i2v = use_image2video
# if args.fast : #or args.fastest
# transformer_filename_t2v = transformer_choices_t2v[2]
# attention_mode="sage2" if "sage2" in attention_modes_supported else "sage"
@@ -749,7 +844,6 @@ if args.compile: #args.fastest or
lock_ui_compile = True
model_filename = ""
-lora_model_filename = ""
#attention_mode="sage"
#attention_mode="sage2"
#attention_mode="flash"
@@ -758,15 +852,12 @@ lora_model_filename = ""
# compile = "transformer"
def preprocess_loras(sd):
- if not use_image2video:
- return sd
-
- new_sd = {}
first = next(iter(sd), None)
if first == None:
return sd
- if not first.startswith("lora_unet_"):
+ if not first.startswith("lora_unet_"):
return sd
+ new_sd = {}
print("Converting Lora Safetensors format to Lora Diffusers format")
alphas = {}
repl_list = ["cross_attn", "self_attn", "ffn"]
@@ -845,14 +936,14 @@ download_models(transformer_filename_i2v if use_image2video else transformer_fil
def sanitize_file_name(file_name, rep =""):
return file_name.replace("/",rep).replace("\\",rep).replace(":",rep).replace("|",rep).replace("?",rep).replace("<",rep).replace(">",rep).replace("\"",rep)
-def extract_preset(lset_name, loras):
+def extract_preset(image2video, lset_name, loras):
loras_choices = []
loras_choices_files = []
loras_mult_choices = ""
prompt =""
full_prompt =""
lset_name = sanitize_file_name(lset_name)
- lora_dir = get_lora_dir(use_image2video)
+ lora_dir = get_lora_dir(image2video)
if not lset_name.endswith(".lset"):
lset_name_filename = os.path.join(lora_dir, lset_name + ".lset" )
else:
@@ -923,7 +1014,7 @@ def setup_loras(i2v, transformer, lora_dir, lora_preselected_preset, split_line
if not os.path.isfile(os.path.join(lora_dir, lora_preselected_preset + ".lset")):
raise Exception(f"Unknown preset '{lora_preselected_preset}'")
default_lora_preset = lora_preselected_preset
- default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, _ , error = extract_preset(default_lora_preset, loras)
+ default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, _ , error = extract_preset(i2v, default_lora_preset, loras)
if len(error) > 0:
print(error[:200])
return loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset
@@ -1010,7 +1101,7 @@ def load_models(i2v):
# kwargs["partialPinning"] = True
elif profile == 3:
kwargs["budgets"] = { "*" : "70%" }
- offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", **kwargs)
+ offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", coTenantsMap= {}, **kwargs)
if len(args.gpu) > 0:
torch.set_default_device(args.gpu)
@@ -1087,7 +1178,7 @@ def apply_changes( state,
if gen_in_progress:
yield "Unable to change config when a generation is in progress
"
return
- global offloadobj, wan_model, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset, loras_presets
+ global offloadobj, wan_model, server_config, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset, loras_presets
server_config = {"attention_mode" : attention_choice,
"transformer_filename": transformer_choices_t2v[transformer_t2v_choice],
"transformer_filename_i2v": transformer_choices_i2v[transformer_i2v_choice],
@@ -1152,44 +1243,152 @@ def save_video(final_frames, output_path, fps=24):
final_frames = (final_frames * 255).astype(np.uint8)
ImageSequenceClip(list(final_frames), fps=fps).write_videofile(output_path, verbose= False, logger = None)
-def build_callback(taskid, state, pipe, num_inference_steps, repeats):
- start_time = time.time()
- def update_progress(step_idx, _):
- with tracker_lock:
- step_idx += 1
- if state.get("abort", False):
- # pipe._interrupt = True
- phase = "Aborting"
- elif step_idx == num_inference_steps:
- phase = "VAE Decoding"
- else:
- phase = "Denoising"
- elapsed = time.time() - start_time
- progress_tracker[taskid] = {
- 'current_step': step_idx,
- 'total_steps': num_inference_steps,
- 'start_time': start_time,
- 'last_update': time.time(),
- 'repeats': repeats,
- 'status': phase
- }
- return update_progress
-def refresh_gallery(state):
- return gr.update(value=state.get("file_list", []))
+def get_gen_info(state):
+ cache = state.get("gen", None)
+ if cache == None:
+ cache = dict()
+ state["gen"] = cache
+ return cache
+
+def build_callback(state, pipe, progress, status, num_inference_steps):
+ def callback(step_idx, force_refresh, read_state = False):
+ gen = get_gen_info(state)
+ refresh_id = gen.get("refresh", -1)
+ if force_refresh or step_idx >= 0:
+ pass
+ else:
+ refresh_id = gen.get("refresh", -1)
+ if refresh_id < 0:
+ return
+ UI_refresh = state.get("refresh", 0)
+ if UI_refresh >= refresh_id:
+ return
+
+ status = gen["progress_status"]
+ state["refresh"] = refresh_id
+ if read_state:
+ phase, step_idx = gen["progress_phase"]
+ else:
+ step_idx += 1
+ if gen.get("abort", False):
+ # pipe._interrupt = True
+ phase = " - Aborting"
+ elif step_idx == num_inference_steps:
+ phase = " - VAE Decoding"
+ else:
+ phase = " - Denoising"
+ gen["progress_phase"] = (phase, step_idx)
+ status_msg = status + phase
+ if step_idx >= 0:
+ progress_args = [(step_idx , num_inference_steps) , status_msg , num_inference_steps]
+ else:
+ progress_args = [0, status_msg]
+
+ progress(*progress_args)
+ gen["progress_args"] = progress_args
+
+ return callback
+def abort_generation(state):
+ gen = get_gen_info(state)
+ if "in_progress" in gen:
+
+ gen["abort"] = True
+ gen["extra_orders"] = 0
+ wan_model._interrupt= True
+ msg = "Processing Request to abort Current Generation"
+ gr.Info(msg)
+ return msg, gr.Button(interactive= False)
+ else:
+ return "", gr.Button(interactive= True)
+
+def is_gen_location(state):
+ gen = get_gen_info(state)
+
+ gen_location = gen.get("location",None)
+ if gen_location == None:
+ return None
+ return state["image2video"] == gen_location
+
+
+def refresh_gallery(state, msg):
+ gen = get_gen_info(state)
+
+ if is_gen_location(state):
+ gen["last_msg"] = msg
+ file_list = gen.get("file_list", None)
+ choice = gen.get("selected",0)
+ in_progress = "in_progress" in gen
+ if in_progress:
+ if gen.get("last_selected", True):
+ choice = max(len(file_list) - 1,0)
+
+ queue = gen.get("queue", [])
+ abort_interactive = not gen.get("abort", False)
+ if not in_progress or len(queue) == 0:
+ return gr.Gallery(selected_index=choice, value = file_list), gr.HTML("", visible= False), gr.Button(visible=True), gr.Button(visible=False), gr.Row(visible=False), update_queue_data(queue), gr.Button(interactive= abort_interactive)
+ else:
+ task = queue[0]
+ start_img_md = ""
+ end_img_md = ""
+ prompt = task["prompt"]
+
+ if task.get('image2video'):
+ start_img_uri = task.get('start_image_data_base64')
+ end_img_uri = task.get('end_image_data_base64')
+ thumbnail_size = "100px"
+ if start_img_uri:
+ start_img_md = f'
'
+ if end_img_uri:
+ end_img_md = f'
'
+
+ label = f"Prompt of Video being Generated"
+
+ html = "| " + prompt + " | "
+ if start_img_md != "":
+ html += "" + start_img_md + " | "
+ if end_img_md != "":
+ html += "" + end_img_md + " | "
+
+ html += "
"
+ html_output = gr.HTML(html, visible= True)
+ return gr.Gallery(selected_index=choice, value = file_list), html_output, gr.Button(visible=False), gr.Button(visible=True), gr.Row(visible=True), update_queue_data(queue), gr.Button(interactive= abort_interactive)
+
+
+
+def finalize_generation(state):
+ gen = get_gen_info(state)
+ choice = gen.get("selected",0)
+ if "in_progress" in gen:
+ del gen["in_progress"]
+ if gen.get("last_selected", True):
+ file_list = gen.get("file_list", [])
+ choice = len(file_list) - 1
+
+
+ gen["extra_orders"] = 0
+ time.sleep(0.2)
+ global gen_in_progress
+ gen_in_progress = False
+ return gr.Gallery(selected_index=choice), gr.Button(interactive= True), gr.Button(visible= True), gr.Button(visible= False), gr.Column(visible= False), gr.HTML(visible= False, value="")
+
def refresh_gallery_on_trigger(state):
- if(state.get("update_gallery", False)):
- state['update_gallery'] = False
- return gr.update(value=state.get("file_list", []))
+ gen = get_gen_info(state)
+
+ if(gen.get("update_gallery", False)):
+ gen['update_gallery'] = False
+ return gr.update(value=gen.get("file_list", []))
def select_video(state , event_data: gr.EventData):
data= event_data._data
+ gen = get_gen_info(state)
+
if data!=None:
choice = data.get("index",0)
- file_list = state.get("file_list", [])
- state["last_selected"] = (choice + 1) >= len(file_list)
- state["selected"] = choice
+ file_list = gen.get("file_list", [])
+ gen["last_selected"] = (choice + 1) >= len(file_list)
+ gen["selected"] = choice
return
def expand_slist(slist, num_inference_steps ):
@@ -1221,6 +1420,7 @@ def convert_image(image):
def generate_video(
task_id,
+ progress,
prompt,
negative_prompt,
resolution,
@@ -1254,22 +1454,29 @@ def generate_video(
):
global wan_model, offloadobj, reload_needed, last_model_type
+ gen = get_gen_info(state)
+
+ file_list = gen["file_list"]
+ prompt_no = gen["prompt_no"]
+
file_model_needed = model_needed(image2video)
- with lock:
- queue_not_empty = len(queue) > 0
- if(last_model_type != image2video and (queue_not_empty or server_config.get("reload_model",1) == 2) and (file_model_needed != model_filename or reload_needed)):
+ # queue = gen.get("queue", [])
+ # with lock:
+ # queue_not_empty = len(queue) > 0
+ # if(last_model_type != image2video and (queue_not_empty or server_config.get("reload_model",1) == 2) and (file_model_needed != model_filename or reload_needed)):
+ if file_model_needed != model_filename or reload_needed:
del wan_model
if offloadobj is not None:
offloadobj.release()
del offloadobj
gc.collect()
- print(f"Loading model {get_model_name(file_model_needed)}...")
+ yield f"Loading model {get_model_name(file_model_needed)}..."
wan_model, offloadobj, trans = load_models(image2video)
- print(f"Model loaded")
+ yield f"Model loaded"
reload_needed= False
if wan_model == None:
- raise gr.Error("Unable to generate a Video while a new configuration is being applied.")
+ gr.Info("Unable to generate a Video while a new configuration is being applied.")
if attention_mode == "auto":
attn = get_auto_attention()
elif attention_mode in attention_modes_supported:
@@ -1278,26 +1485,15 @@ def generate_video(
gr.Info(f"You have selected attention mode '{attention_mode}'. However it is not installed or supported on your system. You should either install it or switch to the default 'sdpa' attention.")
return
- raw_resolution = resolution
- width, height = resolution.split("x")
- width, height = int(width), int(height)
+
+
+ if not image2video:
+ width, height = resolution.split("x")
+ width, height = int(width), int(height)
if slg_switch == 0:
slg_layers = None
- if image2video:
- if "480p" in model_filename and not "Fun" in model_filename and width * height > 848*480:
- gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P")
- return
- resolution = str(width) + "*" + str(height)
- if resolution not in ['720*1280', '1280*720', '480*832', '832*480']:
- gr.Info(f"Resolution {resolution} not supported by image 2 video")
- return
-
- if "1.3B" in model_filename and width * height > 848*480:
- gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P")
- return
-
offload.shared_state["_attention"] = attn
# VAE Tiling
@@ -1321,16 +1517,7 @@ def generate_video(
trans = wan_model.model
- global gen_in_progress
- gen_in_progress = True
temp_filename = None
- if image2video:
- if video_to_continue != None and len(video_to_continue) >0 :
- input_image_or_video_path = video_to_continue
- # pipeline.num_input_frames = max_frames
- # pipeline.max_frames = max_frames
- else:
- input_image_or_video_path = None
loras = state["loras"]
if len(loras) > 0:
@@ -1374,10 +1561,6 @@ def generate_video(
raise gr.Error("Error while loading Loras: " + ", ".join(error_files))
seed = None if seed == -1 else seed
# negative_prompt = "" # not applicable in the inference
-
- if "abort" in state:
- del state["abort"]
- state["in_progress"] = True
enable_RIFLEx = RIFLEx_setting == 0 and video_length > (6* 16) or RIFLEx_setting == 1
# VAE Tiling
@@ -1414,49 +1597,53 @@ def generate_video(
if seed == None or seed <0:
seed = random.randint(0, 999999999)
- global file_list
- clear_file_list = server_config.get("clear_file_list", 0)
- file_list = state.get("file_list", [])
- if clear_file_list > 0:
- file_list_current_size = len(file_list)
- keep_file_from = max(file_list_current_size - clear_file_list, 0)
- files_removed = keep_file_from
- choice = state.get("selected",0)
- choice = max(choice- files_removed, 0)
- file_list = file_list[ keep_file_from: ]
- else:
- file_list = []
- choice = 0
- state["selected"] = choice
- state["file_list"] = file_list
-
-
global save_path
os.makedirs(save_path, exist_ok=True)
video_no = 0
abort = False
- repeats = f"{video_no}/{repeat_generation}"
- callback = build_callback(task_id, state, trans, num_inference_steps, repeats)
- offload.shared_state["callback"] = callback
gc.collect()
torch.cuda.empty_cache()
wan_model._interrupt = False
- for i in range(repeat_generation):
+ gen["abort"] = False
+ gen["prompt"] = prompt
+ repeat_no = 0
+ extra_generation = 0
+ while True:
+ extra_generation += gen.get("extra_orders",0)
+ gen["extra_orders"] = 0
+ total_generation = repeat_generation + extra_generation
+ gen["total_generation"] = total_generation
+ if abort or repeat_no >= total_generation:
+ break
+ repeat_no +=1
+ gen["repeat_no"] = repeat_no
+ prompts_max = gen["prompts_max"]
+ status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation)
+
+ yield status
+
+ gen["progress_status"] = status
+ gen["progress_phase"] = (" - Encoding Prompt", -1 )
+ callback = build_callback(state, trans, progress, status, num_inference_steps)
+ progress_args = [0, status + " - Encoding Prompt"]
+ progress(*progress_args )
+ gen["progress_args"] = progress_args
+
try:
- with tracker_lock:
- start_time = time.time()
- progress_tracker[task_id] = {
- 'current_step': 0,
- 'total_steps': num_inference_steps,
- 'start_time': start_time,
- 'last_update': start_time,
- 'repeats': f"{video_no}/{repeat_generation}",
- 'status': "Encoding Prompt"
- }
+ start_time = time.time()
+ # with tracker_lock:
+ # progress_tracker[task_id] = {
+ # 'current_step': 0,
+ # 'total_steps': num_inference_steps,
+ # 'start_time': start_time,
+ # 'last_update': start_time,
+ # 'repeats': repeat_generation, # f"{video_no}/{repeat_generation}",
+ # 'status': "Encoding Prompt"
+ # }
if trans.enable_teacache:
trans.teacache_counter = 0
trans.num_steps = num_inference_steps
- trans.teacache_skipped_steps = 0
+ trans.teacache_skipped_steps = 0
trans.previous_residual_uncond = None
trans.previous_residual_cond = None
@@ -1464,8 +1651,8 @@ def generate_video(
if image2video:
samples = wan_model.generate(
prompt,
- convert_image(image_to_continue),
- convert_image(image_to_end) if image_to_end != None else None,
+ image_to_continue,
+ image_to_end if image_to_end != None else None,
frame_num=(video_length // 4)* 4 + 1,
max_area=MAX_AREA_CONFIGS[resolution],
shift=flow_shift,
@@ -1483,7 +1670,7 @@ def generate_video(
slg_end = slg_end/100,
cfg_star_switch = cfg_star_switch,
cfg_zero_step = cfg_zero_step,
- add_frames_for_end_image = not "Fun" in transformer_filename_i2v
+ add_frames_for_end_image = not "Fun" in transformer_filename_i2v,
)
else:
samples = wan_model.generate(
@@ -1507,7 +1694,6 @@ def generate_video(
cfg_zero_step = cfg_zero_step,
)
except Exception as e:
- gen_in_progress = False
if temp_filename!= None and os.path.isfile(temp_filename):
os.remove(temp_filename)
offload.last_offload_obj.unload_all()
@@ -1530,15 +1716,23 @@ def generate_video(
if any( keyword in frame.name for keyword in keyword_list):
VRAM_crash = True
break
+
+ _ , exc_value, exc_traceback = sys.exc_info()
+
state["prompt"] = ""
if VRAM_crash:
- raise gr.Error("The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames.")
+ new_error = "The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames."
else:
- raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
+ new_error = gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
+ tb = traceback.format_exc().split('\n')[:-2]
+ print('\n'.join(tb))
+ raise gr.Error(new_error, print_exception= False)
+
finally:
- with tracker_lock:
- if task_id in progress_tracker:
- del progress_tracker[task_id]
+ pass
+ # with tracker_lock:
+ # if task_id in progress_tracker:
+ # del progress_tracker[task_id]
if trans.enable_teacache:
print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{num_inference_steps}" )
@@ -1555,7 +1749,7 @@ def generate_video(
end_time = time.time()
abort = True
state["prompt"] = ""
- print(f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s")
+ # yield f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s"
else:
sample = samples.cpu()
# video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c")
@@ -1574,7 +1768,7 @@ def generate_video(
normalize=True,
value_range=(-1, 1))
- configs = get_settings_dict(state, use_image2video, prompt, 0 if image_to_end == None else 1 , video_length, raw_resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
+ configs = get_settings_dict(state, image2video, prompt, 0 if image_to_end == None else 1 , video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
loras_mult_choices, tea_cache , tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start, slg_end, cfg_star_switch, cfg_zero_step)
metadata_choice = server_config.get("metadata_choice","metadata")
@@ -1596,9 +1790,159 @@ def generate_video(
if temp_filename!= None and os.path.isfile(temp_filename):
os.remove(temp_filename)
- gen_in_progress = False
offload.unload_loras_from_model(trans)
+def prepare_generate_video(state):
+ if state.get("validate_success",0) != 1:
+ return gr.Button(visible= True), gr.Button(visible= False), gr.Column(visible= False)
+ else:
+ return gr.Button(visible= False), gr.Button(visible= True), gr.Column(visible= True)
+
+
+def wait_tasks_done(state, progress=gr.Progress()):
+
+ gen = get_gen_info(state)
+ gen_location = is_gen_location(state)
+
+ last_msg = gen.get("last_msg", "")
+ if len(last_msg) > 0:
+ yield last_msg
+
+ if gen_location == None or gen_location:
+ return gr.Text()
+
+
+ while True:
+
+ msg = gen.get("last_msg", "")
+ if len(msg) > 0 and last_msg != msg:
+ yield msg
+ last_msg = msg
+ progress_args = gen.get("progress_args", None)
+ if progress_args != None:
+ progress(*progress_args)
+
+ in_progress= gen.get("in_progress", False)
+ if not in_progress:
+ break
+ time.sleep(0.5)
+
+
+
+def process_tasks(state, progress=gr.Progress()):
+ gen = get_gen_info(state)
+ queue = gen.get("queue", [])
+
+ if len(queue) == 0:
+ return
+ gen = get_gen_info(state)
+ gen["location"] = state["image2video"]
+ clear_file_list = server_config.get("clear_file_list", 0)
+ file_list = gen.get("file_list", [])
+ if clear_file_list > 0:
+ file_list_current_size = len(file_list)
+ keep_file_from = max(file_list_current_size - clear_file_list, 0)
+ files_removed = keep_file_from
+ choice = gen.get("selected",0)
+ choice = max(choice- files_removed, 0)
+ file_list = file_list[ keep_file_from: ]
+ else:
+ file_list = []
+ choice = 0
+ gen["selected"] = choice
+ gen["file_list"] = file_list
+
+ start_time = time.time()
+
+ global gen_in_progress
+ gen_in_progress = True
+ gen["in_progress"] = True
+
+ prompt_no = 0
+ while len(queue) > 0:
+ prompt_no += 1
+ gen["prompt_no"] = prompt_no
+ task = queue[0]
+ task_id = task["id"]
+ params = task['params']
+ iterator = iter(generate_video(task_id, progress, **params))
+ while True:
+ try:
+ ok = False
+ status = next(iterator, "#")
+ if status == "#":
+ break
+ ok = True
+ except Exception as e:
+ _ , exc_value, exc_traceback = sys.exc_info()
+ raise exc_value.with_traceback(exc_traceback)
+ finally:
+ if not ok:
+ queue.clear()
+ yield status
+
+ queue[:] = [item for item in queue if item['id'] != task['id']]
+
+ gen["prompts_max"] = 0
+ gen["prompt"] = ""
+ end_time = time.time()
+ if gen.get("abort"):
+ yield f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s"
+ else:
+ yield f"Total Generation Time: {end_time-start_time:.1f}s"
+
+
+def get_generation_status(prompt_no, prompts_max, repeat_no, repeat_max):
+ if prompts_max == 1:
+ if repeat_max == 1:
+ return "Video"
+ else:
+ return f"Sample {repeat_no}/{repeat_max}"
+ else:
+ if repeat_max == 1:
+ return f"Prompt {prompt_no}/{prompts_max}"
+ else:
+ return f"Prompt {prompt_no}/{prompts_max}, Sample {repeat_no}/{repeat_max}"
+
+
+refresh_id = 0
+
+def get_new_refresh_id():
+ global refresh_id
+ refresh_id += 1
+ return refresh_id
+
+def update_status(state):
+ gen = get_gen_info(state)
+ prompt_no = gen["prompt_no"]
+ prompts_max = gen.get("prompts_max",0)
+ total_generation = gen["total_generation"]
+ repeat_no = gen["repeat_no"]
+ status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation)
+ gen["progress_status"] = status
+ gen["refresh"] = get_new_refresh_id()
+
+
+def one_more_sample(state):
+ gen = get_gen_info(state)
+ extra_orders = gen.get("extra_orders", 0)
+ extra_orders += 1
+ gen["extra_orders"] = extra_orders
+ in_progress = gen.get("in_progress", False)
+ if not in_progress :
+ return state
+ prompt_no = gen["prompt_no"]
+ prompts_max = gen.get("prompts_max",0)
+ total_generation = gen["total_generation"] + extra_orders
+ repeat_no = gen["repeat_no"]
+ status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation)
+
+
+ gen["progress_status"] = status
+ gen["refresh"] = get_new_refresh_id()
+ gr.Info(f"An extra sample generation is planned for a total of {total_generation} videos for this prompt")
+
+ return state
def get_new_preset_msg(advanced = True):
if advanced:
@@ -1650,7 +1994,7 @@ def save_lset(state, lset_name, loras_choices, loras_mult_choices, prompt, save_
lset_name_filename = lset_name + ".lset"
- full_lset_name_filename = os.path.join(get_lora_dir(use_image2video), lset_name_filename)
+ full_lset_name_filename = os.path.join(get_lora_dir(state["image2video"]), lset_name_filename)
with open(full_lset_name_filename, "w", encoding="utf-8") as writer:
writer.write(json.dumps(lset, indent=4))
@@ -1667,7 +2011,7 @@ def save_lset(state, lset_name, loras_choices, loras_mult_choices, prompt, save_
def delete_lset(state, lset_name):
loras_presets = state["loras_presets"]
- lset_name_filename = os.path.join( get_lora_dir(use_image2video), sanitize_file_name(lset_name) + ".lset" )
+ lset_name_filename = os.path.join( get_lora_dir(state["image2video"]), sanitize_file_name(lset_name) + ".lset" )
if len(lset_name) > 0 and lset_name != get_new_preset_msg(True) and lset_name != get_new_preset_msg(False):
if not os.path.isfile(lset_name_filename):
raise gr.Error(f"Preset '{lset_name}' not found ")
@@ -1688,8 +2032,8 @@ def delete_lset(state, lset_name):
def refresh_lora_list(state, lset_name, loras_choices):
loras_names = state["loras_names"]
prev_lora_names_selected = [ loras_names[int(i)] for i in loras_choices]
-
- loras, loras_names, loras_presets, _, _, _, _ = setup_loras(use_image2video, None, get_lora_dir(use_image2video), lora_preselected_preset, None)
+ image2video= state["image2video"]
+ loras, loras_names, loras_presets, _, _, _, _ = setup_loras(image2video, None, get_lora_dir(image2video), lora_preselected_preset, None)
state["loras"] = loras
state["loras_names"] = loras_names
state["loras_presets"] = loras_presets
@@ -1729,7 +2073,7 @@ def apply_lset(state, wizard_prompt_activated, lset_name, loras_choices, loras_m
gr.Info("Please choose a preset in the list or create one")
else:
loras = state["loras"]
- loras_choices, loras_mult_choices, preset_prompt, full_prompt, error = extract_preset(lset_name, loras)
+ loras_choices, loras_mult_choices, preset_prompt, full_prompt, error = extract_preset(state["image2video"], lset_name, loras)
if len(error) > 0:
gr.Info(error)
else:
@@ -1930,10 +2274,11 @@ def save_settings(state, prompt, image_prompt_type, video_length, resolution, nu
if state.get("validate_success",0) != 1:
return
- ui_defaults = get_settings_dict(state, use_image2video, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
+ image2video = state["image2video"]
+ ui_defaults = get_settings_dict(state, image2video, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step)
- defaults_filename = get_settings_file_name(use_image2video)
+ defaults_filename = get_settings_file_name(image2video)
with open(defaults_filename, "w", encoding="utf-8") as f:
json.dump(ui_defaults, f, indent=4)
@@ -1976,7 +2321,12 @@ def generate_video_tab(image2video=False):
state_dict["advanced"] = advanced
state_dict["loras_model"] = filename
- preset_to_load = lora_preselected_preset if lora_preselected_preset_for_i2v == image2video else ""
+ state_dict["image2video"] = image2video
+ gen = dict()
+ gen["queue"] = []
+ state_dict["gen"] = gen
+
+ preset_to_load = lora_preselected_preset if use_image2video == image2video else ""
loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset = setup_loras(image2video, None, get_lora_dir(image2video), preset_to_load, None)
@@ -1989,7 +2339,7 @@ def generate_video_tab(image2video=False):
launch_loras = []
launch_multis_str = ""
- if len(default_lora_preset) > 0 and image2video == lora_preselected_preset_for_i2v:
+ if len(default_lora_preset) > 0 and image2video == use_image2video:
launch_preset = default_lora_preset
launch_prompt = default_lora_preset_prompt
launch_loras = default_loras_choices
@@ -2014,15 +2364,6 @@ def generate_video_tab(image2video=False):
header = gr.Markdown(generate_header(model_filename, compile, attention_mode))
- with gr.Row(visible= image2video):
- with gr.Row(scale =2):
- gr.Markdown("Wan2GP's Lora Festival ! Press the following button to download i2v Remade Loras collection (and bonuses Loras).")
- with gr.Row(scale =1):
- download_loras_btn = gr.Button("---> Let the Lora's Festival Start !", scale =1)
- with gr.Row(scale =1):
- gr.Markdown("")
- with gr.Row(visible= image2video) as download_status_row:
- download_status = gr.Markdown()
with gr.Row():
with gr.Column():
with gr.Column(visible=False, elem_id="image-modal-container") as modal_container:
@@ -2250,89 +2591,112 @@ def generate_video_tab(image2video=False):
cfg_zero_step = gr.Slider(-1, 39, value=ui_defaults.get("cfg_zero_step",-1), step=1, label="CFG Zero below this Layer (Extra Process)")
with gr.Row():
- save_settings_btn = gr.Button("Set Settings as Default")
+ save_settings_btn = gr.Button("Set Settings as Default", visible = not args.lock_config)
show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars])
with gr.Column():
+ gen_status = gr.Text(label="Status", interactive= False)
+ full_sync = gr.Text(label="Status", interactive= False, visible= False)
+ light_sync = gr.Text(label="Status", interactive= False, visible= False)
gen_progress_html = gr.HTML(
label="Status",
value="Idle",
- elem_id="generation_progress_bar_container"
+ elem_id="generation_progress_bar_container", visible= False
)
output = gr.Gallery(
label="Generated videos", show_label=False, elem_id="gallery"
, columns=[3], rows=[1], object_fit="contain", height=450, selected_index=0, interactive= False)
generate_btn = gr.Button("Generate")
- queue_df = gr.DataFrame(
- headers=["Status", "Completed", "Progress", "Steps", "Time", "Prompt", "Start", "End", "", "", ""],
- datatype=["str", "str", "str", "str", "str", "markdown", "markdown", "markdown", "str", "str", "str"],
- interactive=False,
- col_count=(11, "fixed"),
- wrap=True,
- value=update_queue_data,
- every=1,
- elem_id="queue_df"
- )
- def handle_selection(evt: gr.SelectData):
- if evt.index is None:
- return gr.update(), gr.update(), gr.update(visible=False)
- row_index, col_index = evt.index
- cell_value = None
- if col_index in [8, 9, 10]:
- if col_index == 8: cell_value = "↑"
- elif col_index == 9: cell_value = "↓"
- elif col_index == 10: cell_value = "✖"
- if col_index == 8:
- new_df_data = move_up([row_index])
- return new_df_data, gr.update(), gr.update(visible=False)
- elif col_index == 9:
- new_df_data = move_down([row_index])
- return new_df_data, gr.update(), gr.update(visible=False)
- elif col_index == 10:
- new_df_data = remove_task([row_index])
- return new_df_data, gr.update(), gr.update(visible=False)
- start_img_col_idx = 6
- end_img_col_idx = 7
- image_data_to_show = None
- if col_index == start_img_col_idx:
- with lock:
- if row_index < len(queue):
- image_data_to_show = queue[row_index].get('start_image_data')
- elif col_index == end_img_col_idx:
- with lock:
- if row_index < len(queue):
- image_data_to_show = queue[row_index].get('end_image_data')
+ add_to_queue_btn = gr.Button("Add New Prompt To Queue", visible = False)
- if image_data_to_show:
- return gr.update(), gr.update(value=image_data_to_show), gr.update(visible=True)
- else:
- return gr.update(), gr.update(), gr.update(visible=False)
- selected_indices = gr.State([])
- queue_df.select(
- fn=handle_selection,
- inputs=None,
- outputs=[queue_df, modal_image_display, modal_container],
- )
- gallery_update_trigger.change(
- fn=refresh_gallery_on_trigger,
- inputs=[state],
- outputs=[output]
- )
- queue_df.change(
- fn=refresh_gallery,
- inputs=[state],
- outputs=[gallery_update_trigger]
- ).then(
- fn=refresh_progress,
- inputs=None,
- outputs=[progress_update_trigger]
- )
- progress_update_trigger.change(
- fn=update_generation_status,
- inputs=[progress_update_trigger],
- outputs=[gen_progress_html],
- show_progress="hidden"
- )
+ with gr.Column(visible= False) as current_gen_column:
+ with gr.Row():
+ gen_info = gr.HTML(visible=False, min_height=1)
+ with gr.Row():
+ onemore_btn = gr.Button("One More Sample Please !")
+ abort_btn = gr.Button("Abort")
+
+ queue_df = gr.DataFrame(
+ headers=["Qty","Prompt", "Length","Steps","Start", "End", "", "", ""],
+ datatype=[ "str","markdown","str", "markdown", "markdown", "markdown", "str", "str", "str"],
+ interactive=False,
+ col_count=(9, "fixed"),
+ wrap=True,
+ value=[],
+ visible= False,
+ # every=1,
+ elem_id="queue_df"
+ )
+ # queue_df = gr.HTML("",
+ # visible= False,
+ # elem_id="queue_df"
+ # )
+
+ def handle_selection(state, evt: gr.SelectData):
+ gen = get_gen_info(state)
+ queue = gen.get("queue", [])
+
+ if evt.index is None:
+ return gr.update(), gr.update(), gr.update(visible=False)
+ row_index, col_index = evt.index
+ cell_value = None
+ if col_index in [6, 7, 8]:
+ if col_index == 6: cell_value = "↑"
+ elif col_index == 7: cell_value = "↓"
+ elif col_index == 8: cell_value = "✖"
+ if col_index == 6:
+ new_df_data = move_up(queue, [row_index])
+ return new_df_data, gr.update(), gr.update(visible=False)
+ elif col_index == 7:
+ new_df_data = move_down(queue, [row_index])
+ return new_df_data, gr.update(), gr.update(visible=False)
+ elif col_index == 8:
+ new_df_data = remove_task(queue, [row_index])
+ gen["prompts_max"] = gen.get("prompts_max",0) - 1
+ update_status(state)
+ return new_df_data, gr.update(), gr.update(visible=False)
+ start_img_col_idx = 4
+ end_img_col_idx = 5
+ image_data_to_show = None
+ if col_index == start_img_col_idx:
+ with lock:
+ if row_index < len(queue):
+ image_data_to_show = queue[row_index].get('start_image_data')
+ elif col_index == end_img_col_idx:
+ with lock:
+ if row_index < len(queue):
+ image_data_to_show = queue[row_index].get('end_image_data')
+
+ if image_data_to_show:
+ return gr.update(), gr.update(value=image_data_to_show), gr.update(visible=True)
+ else:
+ return gr.update(), gr.update(), gr.update(visible=False)
+ selected_indices = gr.State([])
+ queue_df.select(
+ fn=handle_selection,
+ inputs=state,
+ outputs=[queue_df, modal_image_display, modal_container],
+ )
+ # gallery_update_trigger.change(
+ # fn=refresh_gallery_on_trigger,
+ # inputs=[state],
+ # outputs=[output]
+ # )
+ # queue_df.change(
+ # fn=refresh_gallery,
+ # inputs=[state],
+ # outputs=[gallery_update_trigger]
+ # ).then(
+ # fn=refresh_progress,
+ # inputs=None,
+ # outputs=[progress_update_trigger]
+ # )
+ progress_update_trigger.change(
+ fn=update_generation_status,
+ inputs=[progress_update_trigger],
+ outputs=[gen_progress_html],
+ show_progress="hidden"
+ )
save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
save_settings, inputs = [state, prompt, image_prompt_type_radio, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt,
loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers,
@@ -2348,53 +2712,114 @@ def generate_video_tab(image2video=False):
)
refresh_lora_btn.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
refresh_lora_btn2.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
- download_loras_btn.click(fn=download_loras, inputs=[], outputs=[download_status_row, download_status, presets_column, loras_column]).then(fn=refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
output.select(select_video, state, None )
- generate_btn.click(
- fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]
+ gen_status.change(refresh_gallery,
+ inputs = [state, gen_status],
+ outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn])
+
+ full_sync.change(refresh_gallery,
+ inputs = [state, gen_status],
+ outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn]
+ ).then( fn=wait_tasks_done,
+ inputs= [state],
+ outputs =[gen_status],
+ ).then(finalize_generation,
+ inputs= [state],
+ outputs= [output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info]
+ )
+ light_sync.change(refresh_gallery,
+ inputs = [state, gen_status],
+ outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn]
+ )
+
+ abort_btn.click(abort_generation, [state], [gen_status, abort_btn] ) #.then(refresh_gallery, inputs = [state, gen_info], outputs = [output, gen_info, queue_df] )
+ onemore_btn.click(fn=one_more_sample,inputs=[state], outputs= [state])
+
+
+ gen_inputs=[
+ prompt,
+ negative_prompt,
+ resolution,
+ video_length,
+ seed,
+ num_inference_steps,
+ guidance_scale,
+ flow_shift,
+ embedded_guidance_scale,
+ repeat_generation,
+ multi_images_gen_type,
+ tea_cache_setting,
+ tea_cache_start_step_perc,
+ loras_choices,
+ loras_mult_choices,
+ image_prompt_type_radio,
+ image_to_continue,
+ image_to_end,
+ video_to_continue,
+ max_frames,
+ RIFLEx_setting,
+ slg_switch,
+ slg_layers,
+ slg_start_perc,
+ slg_end_perc,
+ cfg_star_switch,
+ cfg_zero_step,
+ state,
+ gr.State(image2video)
+ ]
+
+ generate_btn.click(fn=validate_wizard_prompt,
+ inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] ,
+ outputs= [prompt]
+ ).then(fn=process_prompt_and_add_tasks,
+ inputs = gen_inputs,
+ outputs= queue_df
+ ).then(fn=prepare_generate_video,
+ inputs= [state],
+ outputs= [generate_btn, add_to_queue_btn, current_gen_column],
+ ).then(fn=process_tasks,
+ inputs= [state],
+ outputs= [gen_status],
+ ).then(finalize_generation,
+ inputs= [state],
+ outputs= [output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info]
+ )
+
+ add_to_queue_btn.click(fn=validate_wizard_prompt,
+ inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] ,
+ outputs= [prompt]
).then(
fn=process_prompt_and_add_tasks,
- inputs=[
- prompt,
- negative_prompt,
- resolution,
- video_length,
- seed,
- num_inference_steps,
- guidance_scale,
- flow_shift,
- embedded_guidance_scale,
- repeat_generation,
- multi_images_gen_type,
- tea_cache_setting,
- tea_cache_start_step_perc,
- loras_choices,
- loras_mult_choices,
- image_prompt_type_radio,
- image_to_continue,
- image_to_end,
- video_to_continue,
- max_frames,
- RIFLEx_setting,
- slg_switch,
- slg_layers,
- slg_start_perc,
- slg_end_perc,
- cfg_star_switch,
- cfg_zero_step,
- state,
- gr.State(image2video)
- ],
+ inputs = gen_inputs,
outputs=queue_df
+ ).then(
+ fn=update_status,
+ inputs = [state],
)
+
+
close_modal_button.click(
lambda: gr.update(visible=False),
inputs=[],
outputs=[modal_container]
)
- return loras_column, loras_choices, presets_column, lset_name, header, state
+ return loras_column, loras_choices, presets_column, lset_name, header, light_sync, full_sync, state
+def generate_doxnload_tab(presets_column, loras_column, lset_name,loras_choices, state):
+ with gr.Row():
+ with gr.Row(scale =2):
+ gr.Markdown("Wan2GP's Lora Festival ! Press the following button to download i2v Remade Loras collection (and bonuses Loras).")
+ with gr.Row(scale =1):
+ download_loras_btn = gr.Button("---> Let the Lora's Festival Start !", scale =1)
+ with gr.Row(scale =1):
+ gr.Markdown("")
+ with gr.Row() as download_status_row:
+ download_status = gr.Markdown()
+
+ download_loras_btn.click(fn=download_loras, inputs=[], outputs=[download_status_row, download_status, presets_column, loras_column]).then(fn=refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
+
+
def generate_configuration_tab():
state_dict = {}
state = gr.State(state_dict)
@@ -2411,7 +2836,7 @@ def generate_configuration_tab():
value= index,
label="Transformer model for Text to Video",
interactive= not lock_ui_transformer,
- visible=True #not use_image2video
+ visible=True
)
index = transformer_choices_i2v.index(transformer_filename_i2v)
index = 0 if index ==0 else index
@@ -2428,7 +2853,7 @@ def generate_configuration_tab():
value= index,
label="Transformer model for Image to Video",
interactive= not lock_ui_transformer,
- visible = True # use_image2video,
+ visible = True,
)
index = text_encoder_choices.index(text_encoder_filename)
index = 0 if index ==0 else index
@@ -2524,7 +2949,7 @@ def generate_configuration_tab():
reload_choice = gr.Dropdown(
choices=[
("When changing tabs", 1),
- ("When pressing generate", 2),
+ ("When pressing Generate", 2),
],
value=server_config.get("reload_model",2),
label="Reload model"
@@ -2577,19 +3002,46 @@ def generate_about_tab():
gr.Markdown("- Remade_AI : for creating their awesome Loras collection")
-def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData):
- global lora_model_filename, use_image2video
-
+def on_tab_select(global_state, t2v_state, i2v_state, evt: gr.SelectData):
t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode)
i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode)
new_t2v = evt.index == 0
new_i2v = evt.index == 1
- use_image2video = new_i2v
+ i2v_light_sync = gr.Text()
+ t2v_light_sync = gr.Text()
+ i2v_full_sync = gr.Text()
+ t2v_full_sync = gr.Text()
+ if new_t2v or new_i2v:
+ last_tab_was_image2video =global_state.get("last_tab_was_image2video", None)
+ if last_tab_was_image2video == None or last_tab_was_image2video:
+ gen = i2v_state["gen"]
+ t2v_state["gen"] = gen
+ else:
+ gen = t2v_state["gen"]
+ i2v_state["gen"] = gen
+
+
+ if last_tab_was_image2video != None and new_t2v != new_i2v:
+ gen_location = gen.get("location", None)
+ if "in_progress" in gen and gen_location !=None and not (gen_location and new_i2v or not gen_location and new_t2v) :
+ if new_i2v:
+ i2v_full_sync = gr.Text(str(time.time()))
+ else:
+ t2v_full_sync = gr.Text(str(time.time()))
+ else:
+ if new_i2v:
+ i2v_light_sync = gr.Text(str(time.time()))
+ else:
+ t2v_light_sync = gr.Text(str(time.time()))
+
+
+ global_state["last_tab_was_image2video"] = new_i2v
if(server_config.get("reload_model",2) == 1):
- with lock:
- queue_empty = len(queue) == 0
+ queue = gen.get("queue", [])
+
+ queue_empty = len(queue) == 0
if queue_empty:
global wan_model, offloadobj
if wan_model is not None:
@@ -2599,7 +3051,7 @@ def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData):
wan_model = None
gc.collect()
torch.cuda.empty_cache()
- wan_model, offloadobj, trans = load_models(use_image2video)
+ wan_model, offloadobj, trans = load_models(new_i2v)
del trans
if new_t2v or new_i2v:
@@ -2625,11 +3077,15 @@ def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData):
gr.Column(visible= visible),
gr.Dropdown(choices=lset_choices, value=get_new_preset_msg(advanced), visible=visible),
t2v_header,
+ t2v_light_sync,
+ t2v_full_sync,
gr.Column(),
gr.Dropdown(),
gr.Column(),
gr.Dropdown(),
- i2v_header,
+ gr.Markdown(),
+ gr.Text(),
+ gr.Text(),
]
else:
return [
@@ -2637,16 +3093,21 @@ def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData):
gr.Dropdown(),
gr.Column(),
gr.Dropdown(),
- t2v_header,
+ gr.Markdown(),
+ gr.Text(),
+ gr.Text(),
+ gr.Text(),
gr.Column(visible= visible),
gr.Dropdown(choices=new_loras_choices, visible=visible, value=[]),
gr.Column(visible= visible),
gr.Dropdown(choices=lset_choices, value=get_new_preset_msg(advanced), visible=visible),
i2v_header,
+ i2v_light_sync,
+ i2v_full_sync,
]
- return [gr.Column(), gr.Dropdown(), gr.Column(), gr.Dropdown(), t2v_header,
- gr.Column(), gr.Dropdown(), gr.Column(), gr.Dropdown(), i2v_header]
+ return [gr.Column(), gr.Dropdown(), gr.Column(), gr.Dropdown(), t2v_header, t2v_light_sync, t2v_full_sync,
+ gr.Column(), gr.Dropdown(), gr.Column(), gr.Dropdown(), i2v_header, i2v_light_sync, i2v_full_sync]
def create_demo():
@@ -2706,112 +3167,112 @@ def create_demo():
overflow: hidden;
text-overflow: ellipsis;
}
- #queue_df td:nth-child(-n+5) {
- cursor: default !important;
- pointer-events: none;
- }
- #queue_df td:nth-child(6) {
- cursor: default !important;
- }
- #queue_df th {
- pointer-events: none;
- text-align: center;
- vertical-align: middle;
- }
- #queue_df table {
- width: 100%;
- overflow: hidden !important;
- }
- #queue_df::-webkit-scrollbar {
- display: none !important;
- }
- #queue_df {
- scrollbar-width: none !important;
- -ms-overflow-style: none !important;
- }
- #queue_df th:nth-child(1),
- #queue_df td:nth-child(1) {
- width: 90px;
- text-align: center;
- vertical-align: middle;
- }
- #queue_df th:nth-child(1) {
- font-size: 0.8em;
- }
- #queue_df th:nth-child(2),
- #queue_df td:nth-child(2) {
- width: 85px;
- text-align: center;
- vertical-align: middle;
- }
- #queue_df th:nth-child(2) {
- font-size: 0.5em;
- }
- #queue_df th:nth-child(3),
- #queue_df td:nth-child(3) {
- width: 75px;
- text-align: center;
- vertical-align: middle;
- }
- #queue_df th:nth-child(3) {
- font-size: 0.6em;
- }
- #queue_df th:nth-child(4),
- #queue_df td:nth-child(4) {
- width: 65px;
- text-align: center;
- white-space: nowrap;
- }
- #queue_df th:nth-child(4) {
- font-size: 0.9em;
- }
- #queue_df th:nth-child(5),
- #queue_df td:nth-child(5) {
- width: 60px;
- text-align: center;
- white-space: nowrap;
- }
- #queue_df th:nth-child(6),
- #queue_df td:nth-child(6) {
- width: auto;
- text-align: center;
- white-space: normal;
- }
- #queue_df th:nth-child(6) {
- font-size: 0.8em;
- }
- #queue_df th:nth-child(7), #queue_df td:nth-child(7),
- #queue_df th:nth-child(8), #queue_df td:nth-child(8) {
- width: 60px;
- text-align: center;
- vertical-align: middle;
- }
- #queue_df td:nth-child(7) img,
- #queue_df td:nth-child(8) img {
- max-width: 50px;
- max-height: 50px;
- object-fit: contain;
- display: block;
- margin: auto;
- cursor: pointer;
- }
- #queue_df th:nth-child(9), #queue_df td:nth-child(9),
- #queue_df th:nth-child(10), #queue_df td:nth-child(10),
- #queue_df th:nth-child(11), #queue_df td:nth-child(11) {
- width: 20px;
- padding: 2px !important;
- cursor: pointer;
- text-align: center;
- font-weight: bold;
- vertical-align: middle;
- }
- #queue_df td:nth-child(7):hover,
- #queue_df td:nth-child(8):hover,
- #queue_df td:nth-child(9):hover,
- #queue_df td:nth-child(10):hover,
- #queue_df td:nth-child(11):hover {
- background-color: #e0e0e0;
- }
+ # #queue_df td:nth-child(-n+5) {
+ # cursor: default !important;
+ # pointer-events: none;
+ # }
+ # #queue_df td:nth-child(6) {
+ # cursor: default !important;
+ # }
+ # #queue_df th {
+ # pointer-events: none;
+ # text-align: center;
+ # vertical-align: middle;
+ # }
+ # #queue_df table {
+ # width: 100%;
+ # overflow: hidden !important;
+ # }
+ # #queue_df::-webkit-scrollbar {
+ # display: none !important;
+ # }
+ # #queue_df {
+ # scrollbar-width: none !important;
+ # -ms-overflow-style: none !important;
+ # }
+ # #queue_df th:nth-child(1),
+ # #queue_df td:nth-child(1) {
+ # width: 90px;
+ # text-align: center;
+ # vertical-align: middle;
+ # }
+ # #queue_df th:nth-child(1) {
+ # font-size: 0.8em;
+ # }
+ # #queue_df th:nth-child(2),
+ # #queue_df td:nth-child(2) {
+ # width: 85px;
+ # text-align: center;
+ # vertical-align: middle;
+ # }
+ # #queue_df th:nth-child(2) {
+ # font-size: 0.5em;
+ # }
+ # #queue_df th:nth-child(3),
+ # #queue_df td:nth-child(3) {
+ # width: 75px;
+ # text-align: center;
+ # vertical-align: middle;
+ # }
+ # #queue_df th:nth-child(3) {
+ # font-size: 0.6em;
+ # }
+ # #queue_df th:nth-child(4),
+ # #queue_df td:nth-child(4) {
+ # width: 65px;
+ # text-align: center;
+ # white-space: nowrap;
+ # }
+ # #queue_df th:nth-child(4) {
+ # font-size: 0.9em;
+ # }
+ # #queue_df th:nth-child(5),
+ # #queue_df td:nth-child(5) {
+ # width: 60px;
+ # text-align: center;
+ # white-space: nowrap;
+ # }
+ # #queue_df th:nth-child(6),
+ # #queue_df td:nth-child(6) {
+ # width: auto;
+ # text-align: center;
+ # white-space: normal;
+ # }
+ # #queue_df th:nth-child(6) {
+ # font-size: 0.8em;
+ # }
+ # #queue_df th:nth-child(7), #queue_df td:nth-child(7),
+ # #queue_df th:nth-child(8), #queue_df td:nth-child(8) {
+ # width: 60px;
+ # text-align: center;
+ # vertical-align: middle;
+ # }
+ # #queue_df td:nth-child(7) img,
+ # #queue_df td:nth-child(8) img {
+ # max-width: 50px;
+ # max-height: 50px;
+ # object-fit: contain;
+ # display: block;
+ # margin: auto;
+ # cursor: pointer;
+ # }
+ # #queue_df th:nth-child(9), #queue_df td:nth-child(9),
+ # #queue_df th:nth-child(10), #queue_df td:nth-child(10),
+ # #queue_df th:nth-child(11), #queue_df td:nth-child(11) {
+ # width: 20px;
+ # padding: 2px !important;
+ # cursor: pointer;
+ # text-align: center;
+ # font-weight: bold;
+ # vertical-align: middle;
+ # }
+ # #queue_df td:nth-child(7):hover,
+ # #queue_df td:nth-child(8):hover,
+ # #queue_df td:nth-child(9):hover,
+ # #queue_df td:nth-child(10):hover,
+ # #queue_df td:nth-child(11):hover {
+ # background-color: #e0e0e0;
+ # }
#image-modal-container {
position: fixed;
top: 0;
@@ -2893,8 +3354,8 @@ def create_demo():
pointer-events: none;
}
"""
- with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md")) as demo:
- gr.Markdown("Wan 2.1GP v3.3 by DeepBeepMeep (Updates)
")
+ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md"), title= "Wan2GP") as demo:
+ gr.Markdown("Wan 2.1GP v3.4 by DeepBeepMeep (Updates)
")
gr.Markdown("Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !")
with gr.Accordion("Click here for some Info on how to use Wan2GP", open = False):
@@ -2904,30 +3365,34 @@ def create_demo():
gr.Markdown("- 1280 x 720 with a 14B model: 80 frames (5s): 11 GB of VRAM")
gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifacts may appear")
gr.Markdown("Please note that if your turn on compilation, the first denoising step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
-
+ global_dict = {}
+ global_dict["last_tab_was_image2video"] = use_image2video
+ global_state = gr.State(global_dict)
with gr.Tabs(selected="i2v" if use_image2video else "t2v") as main_tabs:
with gr.Tab("Text To Video", id="t2v") as t2v_tab:
- t2v_loras_column, t2v_loras_choices, t2v_presets_column, t2v_lset_name, t2v_header, t2v_state = generate_video_tab()
+ t2v_loras_column, t2v_loras_choices, t2v_presets_column, t2v_lset_name, t2v_header, t2v_light_sync, t2v_full_sync, t2v_state = generate_video_tab(False)
with gr.Tab("Image To Video", id="i2v") as i2v_tab:
- i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name, i2v_header, i2v_state = generate_video_tab(True)
+ i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name, i2v_header, i2v_light_sync, i2v_full_sync, i2v_state = generate_video_tab(True)
if not args.lock_config:
+ with gr.Tab("Downloads", id="downloads") as downloads_tab:
+ generate_doxnload_tab(i2v_presets_column, i2v_loras_column, i2v_lset_name, i2v_loras_choices, i2v_state)
with gr.Tab("Configuration"):
generate_configuration_tab()
with gr.Tab("About"):
generate_about_tab()
main_tabs.select(
fn=on_tab_select,
- inputs=[t2v_state, i2v_state],
+ inputs=[global_state, t2v_state, i2v_state],
outputs=[
- t2v_loras_column, t2v_loras_choices, t2v_presets_column, t2v_lset_name, t2v_header,
- i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name, i2v_header
+ t2v_loras_column, t2v_loras_choices, t2v_presets_column, t2v_lset_name, t2v_header, t2v_light_sync, t2v_full_sync,
+ i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name, i2v_header, i2v_light_sync, i2v_full_sync
]
)
return demo
if __name__ == "__main__":
- threading.Thread(target=runner, daemon=True).start()
+ # threading.Thread(target=runner, daemon=True).start()
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
server_port = int(args.server_port)
if os.name == "nt":
diff --git a/wan/image2video.py b/wan/image2video.py
index 2ea4310..a71e9d4 100644
--- a/wan/image2video.py
+++ b/wan/image2video.py
@@ -40,7 +40,12 @@ def optimized_scale(positive_flat, negative_flat):
st_star = dot_product / squared_norm
return st_star
-
+
+def resize_lanczos(img, h, w):
+ img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8))
+ img = img.resize((w,h), resample=Image.Resampling.LANCZOS)
+ return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0)
+
class WanI2V:
@@ -90,7 +95,6 @@ class WanI2V:
self.num_train_timesteps = config.num_train_timesteps
self.param_dtype = config.param_dtype
-
shard_fn = partial(shard_model, device_id=device_id)
self.text_encoder = T5EncoderModel(
text_len=config.text_len,
@@ -208,16 +212,16 @@ class WanI2V:
- H: Frame height (from max_area)
- W: Frame width from max_area)
"""
- img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
+ img = TF.to_tensor(img)
lat_frames = int((frame_num - 1) // self.vae_stride[0] + 1)
any_end_frame = img2 !=None
if any_end_frame:
any_end_frame = True
- img2 = TF.to_tensor(img2).sub_(0.5).div_(0.5).to(self.device)
+ img2 = TF.to_tensor(img2)
if add_frames_for_end_image:
frame_num +=1
lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2)
-
+
h, w = img.shape[1:]
aspect_ratio = h / w
lat_h = round(
@@ -229,6 +233,15 @@ class WanI2V:
h = lat_h * self.vae_stride[1]
w = lat_w * self.vae_stride[2]
+ clip_image_size = self.clip.model.image_size
+ img_interpolated = resize_lanczos(img, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device)
+ img = resize_lanczos(img, clip_image_size, clip_image_size)
+ img = img.sub_(0.5).div_(0.5).to(self.device)
+ if img2!= None:
+ img_interpolated2 = resize_lanczos(img2, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device)
+ img2 = resize_lanczos(img2, clip_image_size, clip_image_size)
+ img2 = img2.sub_(0.5).div_(0.5).to(self.device)
+
max_seq_len = lat_frames * lat_h * lat_w // ( self.patch_size[1] * self.patch_size[2])
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
@@ -273,21 +286,32 @@ class WanI2V:
from mmgp import offload
+
+ # img_interpolated.save('aaa.png')
+
+ # img_interpolated = torch.from_numpy(np.array(img_interpolated).astype(np.float32) / 255.0).movedim(-1, 0)
+
+ # img_interpolated = torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode='lanczos')
+ # img_interpolated = img_interpolated.squeeze(0).transpose(0,2).transpose(1,0)
+ # img_interpolated = img_interpolated.clamp(-1, 1)
+ # img_interpolated = (img_interpolated + 1)/2
+ # img_interpolated = (img_interpolated*255).type(torch.uint8)
+ # img_interpolated = img_interpolated.cpu().numpy()
+ # xxx = Image.fromarray(img_interpolated, 'RGB')
+ # xxx.save('my.png')
+
offload.last_offload_obj.unload_all()
if any_end_frame:
- img_interpolated = torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode='bicubic').transpose(0, 1).to(torch.bfloat16)
- img2_interpolated = torch.nn.functional.interpolate(img2[None].cpu(), size=(h, w), mode='bicubic').transpose(0, 1).to(torch.bfloat16)
mean2 = 0
enc= torch.concat([
img_interpolated,
- torch.full( (3, frame_num-2, h, w), mean2, device="cpu", dtype= torch.bfloat16),
- img2_interpolated,
+ torch.full( (3, frame_num-2, h, w), mean2, device=self.device, dtype= torch.bfloat16),
+ img_interpolated2,
], dim=1).to(self.device)
else:
enc= torch.concat([
- torch.nn.functional.interpolate(
- img[None].cpu(), size=(h, w), mode='bicubic').transpose(0, 1).to(torch.bfloat16),
- torch.zeros(3, frame_num-1, h, w, device="cpu", dtype= torch.bfloat16)
+ img_interpolated,
+ torch.zeros(3, frame_num-1, h, w, device=self.device, dtype= torch.bfloat16)
], dim=1).to(self.device)
lat_y = self.vae.encode([enc], VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
@@ -333,7 +357,8 @@ class WanI2V:
'seq_len': max_seq_len,
'y': [y],
'freqs' : freqs,
- 'pipeline' : self
+ 'pipeline' : self,
+ 'callback' : callback
}
arg_null = {
@@ -342,7 +367,8 @@ class WanI2V:
'seq_len': max_seq_len,
'y': [y],
'freqs' : freqs,
- 'pipeline' : self
+ 'pipeline' : self,
+ 'callback' : callback
}
arg_both= {
@@ -352,7 +378,8 @@ class WanI2V:
'seq_len': max_seq_len,
'y': [y],
'freqs' : freqs,
- 'pipeline' : self
+ 'pipeline' : self,
+ 'callback' : callback
}
if offload_model:
@@ -363,7 +390,7 @@ class WanI2V:
# self.model.to(self.device)
if callback != None:
- callback(-1, None)
+ callback(-1, True)
for i, t in enumerate(tqdm(timesteps)):
offload.set_step_no_for_lora(self.model, i)
@@ -437,7 +464,7 @@ class WanI2V:
del timestep
if callback is not None:
- callback(i, latent)
+ callback(i, False)
x0 = [latent.to(self.device, dtype=torch.bfloat16)]
@@ -451,7 +478,7 @@ class WanI2V:
video = self.vae.decode(x0, VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
if any_end_frame and add_frames_for_end_image:
- # video[:, -1:] = img2_interpolated
+ # video[:, -1:] = img_interpolated2
video = video[:, :-1]
else:
diff --git a/wan/modules/model.py b/wan/modules/model.py
index 3e2ea5c..2daa00c 100644
--- a/wan/modules/model.py
+++ b/wan/modules/model.py
@@ -704,6 +704,7 @@ class WanModel(ModelMixin, ConfigMixin):
is_uncond=False,
max_steps = 0,
slg_layers=None,
+ callback = None,
):
r"""
Forward pass through the diffusion model
@@ -835,12 +836,10 @@ class WanModel(ModelMixin, ConfigMixin):
freqs=freqs,
# context=context,
context_lens=context_lens)
-
for block_idx, block in enumerate(self.blocks):
offload.shared_state["layer"] = block_idx
- if "refresh" in offload.shared_state:
- del offload.shared_state["refresh"]
- offload.shared_state["callback"](-1, -1, True)
+ if callback != None:
+ callback(-1, False, True)
if pipeline._interrupt:
if joint_pass:
return None, None
diff --git a/wan/text2video.py b/wan/text2video.py
index 088a9c7..cdcbd4f 100644
--- a/wan/text2video.py
+++ b/wan/text2video.py
@@ -268,7 +268,7 @@ class WanT2V:
if self.model.enable_teacache:
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
if callback != None:
- callback(-1, None)
+ callback(-1, True)
for i, t in enumerate(tqdm(timesteps)):
latent_model_input = latents
slg_layers_local = None
@@ -322,7 +322,7 @@ class WanT2V:
del temp_x0
if callback is not None:
- callback(i, latents)
+ callback(i, False)
x0 = latents
if offload_model: