From f7c5193eaad99ec85467c49e21fa2dbf4511bcfa Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Mon, 24 Mar 2025 17:20:53 +1100 Subject: [PATCH] add queues --- gradio_server.py | 836 +++++++++++++++++++++++++++-------------------- 1 file changed, 476 insertions(+), 360 deletions(-) diff --git a/gradio_server.py b/gradio_server.py index 65b14ce..7848b39 100644 --- a/gradio_server.py +++ b/gradio_server.py @@ -1,5 +1,7 @@ import os import time +import threading +from typing import Dict, Any import argparse from mmgp import offload, safetensors2, profile_type try: @@ -21,6 +23,7 @@ import traceback import math import asyncio from wan.utils import prompt_parser +from typing import List PROMPT_VARS_MAX = 10 target_mmgp_version = "3.3.0" @@ -29,6 +32,194 @@ mmgp_version = version("mmgp") if mmgp_version != target_mmgp_version: print(f"Incorrect version of mmgp ({mmgp_version}), version {target_mmgp_version} is needed. Please upgrade with the command 'pip install -r requirements.txt'") exit() +queue = [] +lock = threading.Lock() +current_task_id = None +task_id = 0 +progress_tracker = {} +tracker_lock = threading.Lock() +file_list = [] + +def runner(): + global current_task_id + while True: + with lock: + for item in queue: + task_id = item['id'] + with tracker_lock: + progress = progress_tracker.get(task_id, {}) + + if item['status'] == "Processing": + current_step = progress.get('current_step', 0) + total_steps = progress.get('total_steps', 0) + elapsed = time.time() - progress.get('start_time', time.time()) + item.update({ + 'progress': f"{((current_step/total_steps)*100 if total_steps > 0 else 0):.1f}%", + 'steps': f"{current_step}/{total_steps}", + 'time': f"{elapsed:.1f}s" + }) + if not any(item['status'] == "Processing" for item in queue): + for item in queue: + if item['status'] == "Queued": + item['status'] = "Processing" + current_task_id = item['id'] + threading.Thread(target=process_task, args=(item,)).start() + break + time.sleep(1) + +def process_prompt_and_add_tasks( + prompt, + negative_prompt, + resolution, + video_length, + seed, + num_inference_steps, + guidance_scale, + flow_shift, + embedded_guidance_scale, + repeat_generation, + multi_images_gen_type, + tea_cache, + tea_cache_start_step_perc, + loras_choices, + loras_mult_choices, + image_to_continue, + image_to_end, + video_to_continue, + max_frames, + RIFLEx_setting, + slg_switch, + slg_layers, + slg_start, + slg_end, + state_arg, + image2video +): + if len(prompt) ==0: + return + prompt, errors = prompt_parser.process_template(prompt) + if len(errors) > 0: + gr.Info("Error processing prompt template: " + errors) + return + prompts = prompt.replace("\r", "").split("\n") + prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")] + if len(prompts) ==0: + return + + for single_prompt in prompts: + task_params = ( + single_prompt, + negative_prompt, + resolution, + video_length, + seed, + num_inference_steps, + guidance_scale, + flow_shift, + embedded_guidance_scale, + repeat_generation, + multi_images_gen_type, + tea_cache, + tea_cache_start_step_perc, + loras_choices, + loras_mult_choices, + image_to_continue, + image_to_end, + video_to_continue, + max_frames, + RIFLEx_setting, + slg_switch, + slg_layers, + slg_start, + slg_end, + state_arg, + image2video + ) + add_video_task(*task_params) + return update_queue_data() + +def process_task(task): + try: + task_id, *params = task['params'] + generate_video(task_id, *params) + finally: + with lock: + queue[:] = [item for item in queue if item['id'] != task['id']] + with tracker_lock: + if task['id'] in progress_tracker: + del progress_tracker[task['id']] + +def add_video_task(*params): + global task_id + with lock: + task_id += 1 + current_task_id = task_id + queue.append({ + "id": current_task_id, + "params": (current_task_id,) + params, + "status": "Queued", + "progress": "0.0%", + "steps": f"0/{params[5]}", + "time": "--", + "prompt": params[0] + }) + return + +def move_up(selected_indices): + if not selected_indices or len(selected_indices) == 0: + return update_queue_data() + idx = selected_indices[0] + if isinstance(idx, list): + idx = idx[0] + idx = int(idx) + with lock: + if idx > 0: + queue[idx], queue[idx-1] = queue[idx-1], queue[idx] + return update_queue_data() + +def move_down(selected_indices): + if not selected_indices or len(selected_indices) == 0: + return update_queue_data() + idx = selected_indices[0] + if isinstance(idx, list): + idx = idx[0] + idx = int(idx) + with lock: + if idx < len(queue)-1: + queue[idx], queue[idx+1] = queue[idx+1], queue[idx] + return update_queue_data() + +def remove_task(selected_indices): + if not selected_indices or len(selected_indices) == 0: + return update_queue_data() + idx = selected_indices[0] + if isinstance(idx, list): + idx = idx[0] + idx = int(idx) + with lock: + if idx < len(queue): + if idx == 0: + wan_model._interrupt = True + del queue[idx] + return update_queue_data() + +def update_queue_data(): + with lock: + data = [] + for item in queue: + data.append([ + str(item['id']), + item['status'], + item.get('progress', "0.0%"), + item.get('steps', ''), + item.get('time', '--'), + (item['prompt'][:47] + '...') if len(item['prompt']) > 50 else item['prompt'], + "↑", + "↓", + "✖" + ]) + return data + def _parse_args(): parser = argparse.ArgumentParser( description="Generate a video from a text prompt or image using Gradio") @@ -788,53 +979,22 @@ def save_video(final_frames, output_path, fps=24): final_frames = (final_frames * 255).astype(np.uint8) ImageSequenceClip(list(final_frames), fps=fps).write_videofile(output_path, verbose= False, logger = None) -def build_callback(state, pipe, progress, status, num_inference_steps): - def callback(step_idx, latents, read_state = False): - status = state["progress_status"] - if read_state: - phase, step_idx = state["progress_phase"] - else: - step_idx += 1 - if state.get("abort", False): - # pipe._interrupt = True - phase = " - Aborting" - elif step_idx == num_inference_steps: - phase = " - VAE Decoding" - else: - phase = " - Denoising" - state["progress_phase"] = (phase, step_idx) - status_msg = status + phase - if step_idx >= 0: - progress( (step_idx , num_inference_steps) , status_msg , num_inference_steps) - else: - progress(0, status_msg) - - return callback +def build_callback(task_id, total_steps): + start_time = time.time() + def update_progress(step, _): + with tracker_lock: + elapsed = time.time() - start_time + progress_tracker[task_id] = { + 'current_step': step + 1, + 'total_steps': total_steps, + 'start_time': start_time, + 'last_update': time.time() + } + return update_progress -def abort_generation(state): - if "in_progress" in state: - state["abort"] = True - state["extra_orders"] = 0 - wan_model._interrupt= True - return gr.Button(interactive= False) - else: - return gr.Button(interactive= True) - -def refresh_gallery(state, txt): - file_list = state.get("file_list", None) - prompt = state.get("prompt", "") - if len(prompt) == 0: - return file_list, gr.Text(visible= False, value="") - else: - prompts_max = state.get("prompts_max",0) - prompt_no = state.get("prompt_no",0) - if prompts_max >1 : - label = f"Current Prompt ({prompt_no+1}/{prompts_max})" - else: - label = f"Current Prompt" - return file_list, gr.Text(visible= True, value=prompt, label=label) - - +def refresh_gallery(state): + file_list = state.get("file_list", None) + return file_list def finalize_gallery(state): choice = 0 if "in_progress" in state: @@ -845,7 +1005,7 @@ def finalize_gallery(state): time.sleep(0.2) global gen_in_progress gen_in_progress = False - return gr.Gallery(selected_index=choice), gr.Button(interactive= True), gr.Button(visible= True), gr.Checkbox(visible= False), gr.Text(visible= False, value="") + return gr.Gallery(selected_index=choice), gr.Button(interactive=True), gr.Button(visible=False), gr.Checkbox(visible=False), gr.Text(visible=False, value="") def select_video(state , event_data: gr.EventData): data= event_data._data @@ -862,36 +1022,8 @@ def expand_slist(slist, num_inference_steps ): pos += inc return new_slist - -def one_more_video(state): - extra_orders = state.get("extra_orders", 0) - extra_orders += 1 - state["extra_orders"] = extra_orders - prompts_max = state.get("prompts_max",0) - if prompts_max == 0: - return state - prompt_no = state["prompt_no"] - video_no = state["video_no"] - total_video = state["total_video"] - # total_video += (prompts_max- prompt_no) - total_video += 1 - total_generation = state["total_generation"] + extra_orders - state["total_video"] = total_video - - state["progress_status"] = f"Video {video_no}/{total_video}" - offload.shared_state["refresh"] = 1 - # if (prompts_max - prompt_no) > 1: - # gr.Info(f"An extra video generation is planned for a total of {total_generation} videos for the next {prompts_max - prompt_no} prompts") - # else: - gr.Info(f"An extra video generation is planned for a total of {total_generation} videos for this prompt") - - return state - -def prepare_generate_video(): - - return gr.Button(visible= False), gr.Checkbox(visible= True) - def generate_video( + task_id, prompt, negative_prompt, resolution, @@ -921,7 +1053,6 @@ def generate_video( progress=gr.Progress() #track_tqdm= True ): - global wan_model, offloadobj reload_needed = state.get("_reload_needed", False) file_model_needed = model_needed(image2video) @@ -930,9 +1061,9 @@ def generate_video( offloadobj.release() offloadobj = None wan_model = None - yield f"Loading model {get_model_name(file_model_needed)}..." + print(f"Loading model {get_model_name(file_model_needed)}...") wan_model, offloadobj, trans = load_models(image2video) - yield f"Model loaded" + print(f"Model loaded") state["_reload_needed"] = False from PIL import Image @@ -949,8 +1080,8 @@ def generate_video( gr.Info(f"You have selected attention mode '{attention_mode}'. However it is not installed on your system. You should either install it or switch to the default 'sdpa' attention.") return - if state.get("validate_success",0) != 1: - return + #if state.get("validate_success",0) != 1: + # return width, height = resolution.split("x") width, height = int(width), int(height) @@ -997,71 +1128,11 @@ def generate_video( global gen_in_progress gen_in_progress = True temp_filename = None - if len(prompt) ==0: - return - prompt, errors = prompt_parser.process_template(prompt) - if len(errors) > 0: - gr.Info(f"Error processing prompt template: " + errors) - prompts = prompt.replace("\r", "").split("\n") - prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")] - if len(prompts) ==0: - return if image2video: - if image_to_continue is not None: - if isinstance(image_to_continue, list): - image_to_continue = [ tup[0] for tup in image_to_continue ] - else: - image_to_continue = [image_to_continue] - if image_to_end != None: - if isinstance(image_to_end , list): - image_to_end = [ tup[0] for tup in image_to_end ] - else: - image_to_end = [image_to_end ] - if len(image_to_continue) != len(image_to_end): - gr.Info("The number of start and end images should be the same ") - return - if multi_images_gen_type == 0: - new_prompts = [] - new_image_to_continue = [] - new_image_to_end = [] - for i in range(len(prompts) * len(image_to_continue) ): - new_prompts.append( prompts[ i % len(prompts)] ) - new_image_to_continue.append(image_to_continue[i // len(prompts)] ) - if image_to_end != None: - new_image_to_end.append(image_to_end[i // len(prompts)] ) - prompts = new_prompts - image_to_continue = new_image_to_continue - if image_to_end != None: - image_to_end = new_image_to_end - else: - if len(prompts) >= len(image_to_continue): - if len(prompts) % len(image_to_continue) !=0: - raise gr.Error("If there are more text prompts than input images the number of text prompts should be dividable by the number of images") - rep = len(prompts) // len(image_to_continue) - new_image_to_continue = [] - new_image_to_end = [] - for i, _ in enumerate(prompts): - new_image_to_continue.append(image_to_continue[i//rep] ) - if image_to_end != None: - new_image_to_end.append(image_to_end[i//rep] ) - image_to_continue = new_image_to_continue - if image_to_end != None: - image_to_end = new_image_to_end - else: - if len(image_to_continue) % len(prompts) !=0: - raise gr.Error("If there are more input images than text prompts the number of images should be dividable by the number of text prompts") - rep = len(image_to_continue) // len(prompts) - new_prompts = [] - for i, _ in enumerate(image_to_continue): - new_prompts.append( prompts[ i//rep] ) - prompts = new_prompts - - elif video_to_continue != None and len(video_to_continue) >0 : + if video_to_continue != None and len(video_to_continue) >0 : input_image_or_video_path = video_to_continue # pipeline.num_input_frames = max_frames # pipeline.max_frames = max_frames - else: - return else: input_image_or_video_path = None @@ -1148,195 +1219,167 @@ def generate_video( if seed == None or seed <0: seed = random.randint(0, 999999999) - file_list = [] - state["file_list"] = file_list + global file_list + state["file_list"] = file_list global save_path os.makedirs(save_path, exist_ok=True) - video_no = 0 - total_video = repeat_generation * len(prompts) - state["total_video"] = total_video - extra_generation = 0 abort = False - start_time = time.time() - state["prompts_max"] = len(prompts) - for no, prompt in enumerate(prompts): - state["prompt"] = prompt - repeat_no = 0 - state["prompt_no"] = no - extra_generation = 0 - yield f"Prompt No{no}" - while True: - extra_orders = state.get("extra_orders",0) - state["extra_orders"] = 0 - extra_generation += extra_orders - state["total_generation"] = repeat_generation + extra_generation - # total_video += (len(prompts)- no) * extra_orders - total_video += extra_orders - if abort or repeat_no >= (repeat_generation + extra_generation): - break + with tracker_lock: + progress_tracker[task_id] = { + 'current_step': 0, + 'total_steps': num_inference_steps, + 'start_time': time.time(), + 'last_update': time.time() + } + if trans.enable_teacache: + trans.teacache_counter = 0 + trans.num_steps = num_inference_steps + trans.teacache_skipped_steps = 0 + trans.previous_residual_uncond = None + trans.previous_residual_cond = None + callback = build_callback(task_id, num_inference_steps) + offload.shared_state["callback"] = callback + gc.collect() + torch.cuda.empty_cache() + wan_model._interrupt = False + state["progress_status"] = "Starting" + try: + if image2video: + samples = wan_model.generate( + prompt, + image_to_continue.convert('RGB'), + image_to_end.convert('RGB') if image_to_end != None else None, + frame_num=(video_length // 4)* 4 + 1, + max_area=MAX_AREA_CONFIGS[resolution], + shift=flow_shift, + sampling_steps=num_inference_steps, + guide_scale=guidance_scale, + n_prompt=negative_prompt, + seed=seed, + offload_model=False, + callback=callback, + enable_RIFLEx = enable_RIFLEx, + VAE_tile_size = VAE_tile_size, + joint_pass = joint_pass, + slg_layers = slg_layers, + slg_start = slg_start/100, + slg_end = slg_end/100, + ) - if trans.enable_teacache: - trans.teacache_counter = 0 - trans.num_steps = num_inference_steps - trans.teacache_skipped_steps = 0 - trans.previous_residual_uncond = None - trans.previous_residual_cond = None + else: + samples = wan_model.generate( + prompt, + frame_num=(video_length // 4)* 4 + 1, + size=(width, height), + shift=flow_shift, + sampling_steps=num_inference_steps, + guide_scale=guidance_scale, + n_prompt=negative_prompt, + seed=seed, + offload_model=False, + callback=callback, + enable_RIFLEx = enable_RIFLEx, + VAE_tile_size = VAE_tile_size, + joint_pass = joint_pass, + slg_layers = slg_layers, + slg_start = slg_start/100, + slg_end = slg_end/100, + ) + except Exception as e: + gen_in_progress = False + if temp_filename!= None and os.path.isfile(temp_filename): + os.remove(temp_filename) + offload.last_offload_obj.unload_all() + offload.unload_loras_from_model(trans) + # if compile: + # cache_size = torch._dynamo.config.cache_size_limit + # torch.compiler.reset() + # torch._dynamo.config.cache_size_limit = cache_size - video_no += 1 - status = f"Video {video_no}/{total_video}" - state["video_no"] = video_no - state["progress_status"] = status - state["progress_phase"] = (" - Encoding Prompt", -1 ) - progress(0, desc=status + " - Encoding Prompt" ) - callback = build_callback(state, trans, progress, status, num_inference_steps) - offload.shared_state["callback"] = callback - - - gc.collect() - torch.cuda.empty_cache() - wan_model._interrupt = False - try: - if image2video: - samples = wan_model.generate( - prompt, - image_to_continue[no].convert('RGB'), - image_to_end[no].convert('RGB') if image_to_end != None else None, - frame_num=(video_length // 4)* 4 + 1, - max_area=MAX_AREA_CONFIGS[resolution], - shift=flow_shift, - sampling_steps=num_inference_steps, - guide_scale=guidance_scale, - n_prompt=negative_prompt, - seed=seed, - offload_model=False, - callback=callback, - enable_RIFLEx = enable_RIFLEx, - VAE_tile_size = VAE_tile_size, - joint_pass = joint_pass, - slg_layers = slg_layers, - slg_start = slg_start/100, - slg_end = slg_end/100, - ) - - else: - samples = wan_model.generate( - prompt, - frame_num=(video_length // 4)* 4 + 1, - size=(width, height), - shift=flow_shift, - sampling_steps=num_inference_steps, - guide_scale=guidance_scale, - n_prompt=negative_prompt, - seed=seed, - offload_model=False, - callback=callback, - enable_RIFLEx = enable_RIFLEx, - VAE_tile_size = VAE_tile_size, - joint_pass = joint_pass, - slg_layers = slg_layers, - slg_start = slg_start/100, - slg_end = slg_end/100, - ) - except Exception as e: - gen_in_progress = False - if temp_filename!= None and os.path.isfile(temp_filename): - os.remove(temp_filename) - offload.last_offload_obj.unload_all() - offload.unload_loras_from_model(trans) - # if compile: - # cache_size = torch._dynamo.config.cache_size_limit - # torch.compiler.reset() - # torch._dynamo.config.cache_size_limit = cache_size - - gc.collect() - torch.cuda.empty_cache() - s = str(e) - keyword_list = ["vram", "VRAM", "memory","allocat"] - VRAM_crash= False - if any( keyword in s for keyword in keyword_list): + gc.collect() + torch.cuda.empty_cache() + s = str(e) + keyword_list = ["vram", "VRAM", "memory","allocat"] + VRAM_crash= False + if any( keyword in s for keyword in keyword_list): + VRAM_crash = True + else: + stack = traceback.extract_stack(f=None, limit=5) + for frame in stack: + if any( keyword in frame.name for keyword in keyword_list): VRAM_crash = True - else: - stack = traceback.extract_stack(f=None, limit=5) - for frame in stack: - if any( keyword in frame.name for keyword in keyword_list): - VRAM_crash = True - break - state["prompt"] = "" - if VRAM_crash: - raise gr.Error("The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames.") - else: - raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'") + break + state["prompt"] = "" + if VRAM_crash: + raise gr.Error("The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames.") + else: + raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'") + finally: + with tracker_lock: + if task_id in progress_tracker: + del progress_tracker[task_id] - if trans.enable_teacache: - print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{num_inference_steps}" ) - trans.previous_residual_uncond = None - trans.previous_residual_cond = None + if trans.enable_teacache: + print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{num_inference_steps}" ) + trans.previous_residual_uncond = None + trans.previous_residual_cond = None - if samples != None: - samples = samples.to("cpu") - offload.last_offload_obj.unload_all() - gc.collect() - torch.cuda.empty_cache() + if samples != None: + samples = samples.to("cpu") + offload.last_offload_obj.unload_all() + gc.collect() + torch.cuda.empty_cache() - if samples == None: - end_time = time.time() - abort = True - state["prompt"] = "" - yield f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s" - else: - sample = samples.cpu() - # video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c") + if samples == None: + end_time = time.time() + abort = True + state["prompt"] = "" + else: + sample = samples.cpu() + # video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c") - time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss") - if os.name == 'nt': - file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50]).strip()}.mp4" - else: - file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:100]).strip()}.mp4" - video_path = os.path.join(save_path, file_name) - cache_video( - tensor=sample[None], - save_file=video_path, - fps=16, - nrow=1, - normalize=True, - value_range=(-1, 1)) - - configs = { - 'prompt': prompt, - 'negative_prompt': negative_prompt, - 'resolution': resolution, - 'video_length': video_length, - 'seed': seed, - 'num_inference_steps': num_inference_steps, - } + time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss") + if os.name == 'nt': + file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50]).strip()}.mp4" + else: + file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:100]).strip()}.mp4" + video_path = os.path.join(save_path, file_name) + cache_video( + tensor=sample[None], + save_file=video_path, + fps=16, + nrow=1, + normalize=True, + value_range=(-1, 1)) + + configs = { + 'prompt': prompt, + 'negative_prompt': negative_prompt, + 'resolution': resolution, + 'video_length': video_length, + 'seed': seed, + 'num_inference_steps': num_inference_steps, + } - metadata_choice = server_config["metadata_choice"] - if metadata_choice == "json": - with open(video_path.replace('.mp4', '.json'), 'w') as f: - json.dump(configs, f, indent=4) - elif metadata_choice == "metadata": - from mutagen.mp4 import MP4 - file = MP4(video_path) - file.tags['©cmt'] = [json.dumps(configs)] - file.save() + metadata_choice = server_config["metadata_choice"] + if metadata_choice == "json": + with open(video_path.replace('.mp4', '.json'), 'w') as f: + json.dump(configs, f, indent=4) + elif metadata_choice == "metadata": + from mutagen.mp4 import MP4 + file = MP4(video_path) + file.tags['©cmt'] = [json.dumps(configs)] + file.save() - print(f"New video saved to Path: "+video_path) - file_list.append(video_path) - if video_no < total_video: - yield status - else: - end_time = time.time() - state["prompt"] = "" - yield f"Total Generation Time: {end_time-start_time:.1f}s" - seed += 1 - repeat_no += 1 + print(f"New video saved to Path: "+video_path) + file_list.append(video_path) + seed += 1 if temp_filename!= None and os.path.isfile(temp_filename): os.remove(temp_filename) gen_in_progress = False offload.unload_loras_from_model(trans) - def get_new_preset_msg(advanced = True): if advanced: return "Enter here a Name for a Lora Preset or Choose one in the List" @@ -1952,16 +1995,40 @@ def generate_video_tab(image2video=False): show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then( fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars]) with gr.Column(): - gen_status = gr.Text(label="Status", interactive= False) output = gr.Gallery( label="Generated videos", show_label=False, elem_id="gallery" , columns=[3], rows=[1], object_fit="contain", height=450, selected_index=0, interactive= False) generate_btn = gr.Button("Generate") - onemore_btn = gr.Button("One More Please !", visible= False) - abort_btn = gr.Button("Abort") - gen_info = gr.Text(label="Current prompt", visible= False , interactive= False) - - + queue_df = gr.DataFrame( + headers=["ID", "Status", "Progress", "Steps", "Time", "Prompt", "", "", ""], + datatype=["str", "str", "str", "str", "str", "str", "str", "str", "str"], + interactive=False, + col_count=(9, "fixed"), + wrap=True, + value=update_queue_data, + every=1, + elem_id="queue_df" + ) + def handle_selection(evt: gr.SelectData): + cell_value = evt.value + selected_index = evt.index + if cell_value == "↑": + return move_up([selected_index]) + elif cell_value == "↓": + return move_down([selected_index]) + elif cell_value == "✖": + return remove_task([selected_index]) + return queue_df + selected_indices = gr.State([]) + queue_df.select( + fn=handle_selection, + outputs=selected_indices + ) + queue_df.change( + fn=refresh_gallery, + inputs=[state], + outputs=[output] + ) save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then( save_settings, inputs = [state, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, @@ -1978,48 +2045,43 @@ def generate_video_tab(image2video=False): refresh_lora_btn.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices]) refresh_lora_btn2.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices]) download_loras_btn.click(fn=download_loras, inputs=[], outputs=[download_status_row, download_status, presets_column, loras_column]).then(fn=refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices]) - gen_status.change(refresh_gallery, inputs = [state, gen_info], outputs = [output, gen_info] ) - abort_btn.click(abort_generation,state,abort_btn ) output.select(select_video, state, None ) - onemore_btn.click(fn=one_more_video,inputs=[state], outputs= [state]) - generate_btn.click(fn=prepare_generate_video,inputs=[], outputs= [generate_btn, onemore_btn] - ).then( - fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt] - ).then( - fn=generate_video, - inputs=[ - prompt, - negative_prompt, - resolution, - video_length, - seed, - num_inference_steps, - guidance_scale, - flow_shift, - embedded_guidance_scale, - repeat_generation, - multi_images_gen_type, - tea_cache_setting, - tea_cache_start_step_perc, - loras_choices, - loras_mult_choices, - image_to_continue, - image_to_end, - video_to_continue, - max_frames, - RIFLEx_setting, - slg_switch, - slg_layers, - slg_start_perc, - slg_end_perc, - state, - gr.State(image2video) - ], - outputs= [gen_status] - ).then( - finalize_gallery, - [state], - [output , abort_btn, generate_btn, onemore_btn, gen_info] + original_inputs = [ + prompt, + negative_prompt, + resolution, + video_length, + seed, + num_inference_steps, + guidance_scale, + flow_shift, + embedded_guidance_scale, + repeat_generation, + multi_images_gen_type, + tea_cache_setting, + tea_cache_start_step_perc, + loras_choices, + loras_mult_choices, + image_to_continue, + image_to_end, + video_to_continue, + max_frames, + RIFLEx_setting, + slg_switch, + slg_layers, + slg_start_perc, + slg_end_perc, + state, + gr.State(image2video) + ] + + #generate_btn.click( + # fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt] + #).then( + generate_btn.click( + fn=process_prompt_and_add_tasks, + inputs=original_inputs, + outputs=queue_df ) return loras_choices, lset_name, header, state @@ -2226,6 +2288,59 @@ def create_demo(): margin: 0 20px; white-space: nowrap; } + .queue-item { + border: 1px solid #ccc; + padding: 10px; + margin: 5px 0; + border-radius: 5px; + } + .current { + background: #f8f9fa; + border-left: 4px solid #007bff; + } + .task-header { + display: flex; + justify-content: space-between; + margin-bottom: 5px; + } + .progress-container { + height: 10px; + background: #e9ecef; + border-radius: 5px; + overflow: hidden; + } + .progress-bar { + height: 100%; + background: #007bff; + transition: width 0.3s ease; + } + .task-details { + display: flex; + justify-content: space-between; + font-size: 0.9em; + color: #6c757d; + margin-top: 5px; + } + .task-prompt { + font-size: 0.8em; + color: #868e96; + margin-top: 5px; + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; + } + #queue_df td:nth-child(7), + #queue_df td:nth-child(8), + #queue_df td:nth-child(9) { + cursor: pointer; + text-align: center; + font-weight: bold; + } + #queue_df td:nth-child(7):hover, + #queue_df td:nth-child(8):hover, + #queue_df td:nth-child(9):hover { + background-color: #e0e0e0; + } """ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md")) as demo: gr.Markdown("

Wan 2.1GP v3.0 by DeepBeepMeep (Updates)

") @@ -2260,6 +2375,7 @@ def create_demo(): return demo if __name__ == "__main__": + threading.Thread(target=runner, daemon=True).start() os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" server_port = int(args.server_port) if os.name == "nt":