From 646d9ace13db54def4e43a2dbf08b81083388c52 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Mon, 24 Mar 2025 21:00:30 +1100 Subject: [PATCH] restored original detailed status updates, fixed repeats --- gradio_server.py | 312 +++++++++++++++++++++++++---------------------- 1 file changed, 168 insertions(+), 144 deletions(-) diff --git a/gradio_server.py b/gradio_server.py index d7037e4..2d551cb 100644 --- a/gradio_server.py +++ b/gradio_server.py @@ -1,7 +1,6 @@ import os import time import threading -from typing import Dict, Any import argparse from mmgp import offload, safetensors2, profile_type try: @@ -23,7 +22,6 @@ import traceback import math import asyncio from wan.utils import prompt_parser -from typing import List PROMPT_VARS_MAX = 10 target_mmgp_version = "3.3.0" @@ -49,19 +47,23 @@ def runner(): with tracker_lock: progress = progress_tracker.get(task_id, {}) - if item['status'] == "Processing": + if item['state'] != "Queued" and item['state'] != "Finished": current_step = progress.get('current_step', 0) total_steps = progress.get('total_steps', 0) elapsed = time.time() - progress.get('start_time', time.time()) + status = progress.get('status', "") + state = progress.get("state") item.update({ 'progress': f"{((current_step/total_steps)*100 if total_steps > 0 else 0):.1f}%", 'steps': f"{current_step}/{total_steps}", - 'time': f"{elapsed:.1f}s" + 'time': f"{elapsed:.1f}s", + 'state': f"{state}", + 'status': f"{status}" }) - if not any(item['status'] == "Processing" for item in queue): + if not any(item['state'] == "Processing" for item in queue): for item in queue: - if item['status'] == "Queued": - item['status'] = "Processing" + if item['state'] == "Queued": + item['state'] = "Processing" current_task_id = item['id'] threading.Thread(target=process_task, args=(item,)).start() break @@ -157,7 +159,8 @@ def add_video_task(*params): queue.append({ "id": current_task_id, "params": (current_task_id,) + params, - "status": "Queued", + "state": "Queued", + "status": "0/0", "progress": "0.0%", "steps": f"0/{params[5]}", "time": "--", @@ -209,6 +212,7 @@ def update_queue_data(): for item in queue: data.append([ str(item['id']), + item['state'], item['status'], item.get('progress', "0.0%"), item.get('steps', ''), @@ -985,29 +989,38 @@ def save_video(final_frames, output_path, fps=24): final_frames = (final_frames * 255).astype(np.uint8) ImageSequenceClip(list(final_frames), fps=fps).write_videofile(output_path, verbose= False, logger = None) -def build_callback(task_id, total_steps): +def build_callback(state, pipe, num_inference_steps, status): start_time = time.time() - def update_progress(step, _): + def update_progress(step_idx, latents, read_state = False): with tracker_lock: + step_idx += 1 + if state.get("abort", False): + # pipe._interrupt = True + phase = "Aborting" + elif step_idx == num_inference_steps: + phase = "VAE Decoding" + else: + phase = "Denoising" elapsed = time.time() - start_time progress_tracker[task_id] = { - 'current_step': step + 1, - 'total_steps': total_steps, + 'current_step': step_idx, + 'total_steps': num_inference_steps, 'start_time': start_time, - 'last_update': time.time() + 'last_update': time.time(), + 'status': status, + 'state': phase } return update_progress def refresh_gallery(state): file_list = state.get("file_list", None) return file_list + def finalize_gallery(state): choice = 0 if "in_progress" in state: del state["in_progress"] choice = state.get("selected",0) - - state["extra_orders"] = 0 time.sleep(0.2) global gen_in_progress gen_in_progress = False @@ -1067,6 +1080,7 @@ def generate_video( offloadobj.release() offloadobj = None wan_model = None + gc.collect() print(f"Loading model {get_model_name(file_model_needed)}...") wan_model, offloadobj, trans = load_models(image2video) print(f"Model loaded") @@ -1243,143 +1257,153 @@ def generate_video( trans.teacache_skipped_steps = 0 trans.previous_residual_uncond = None trans.previous_residual_cond = None - callback = build_callback(task_id, num_inference_steps) - offload.shared_state["callback"] = callback + video_no = 0 + status = f"{video_no}/{repeat_generation}" + with tracker_lock: + if task_id in progress_tracker: + progress_tracker[task_id]['state'] = "Encoding Prompt" + progress_tracker[task_id]['status'] = status + callback = build_callback(state, trans, num_inference_steps, status) + offload.shared_state["callback"] = callback gc.collect() torch.cuda.empty_cache() wan_model._interrupt = False - state["progress_status"] = "Starting" - try: - if image2video: - samples = wan_model.generate( - prompt, - image_to_continue.convert('RGB'), - image_to_end.convert('RGB') if image_to_end != None else None, - frame_num=(video_length // 4)* 4 + 1, - max_area=MAX_AREA_CONFIGS[resolution], - shift=flow_shift, - sampling_steps=num_inference_steps, - guide_scale=guidance_scale, - n_prompt=negative_prompt, - seed=seed, - offload_model=False, - callback=callback, - enable_RIFLEx = enable_RIFLEx, - VAE_tile_size = VAE_tile_size, - joint_pass = joint_pass, - slg_layers = slg_layers, - slg_start = slg_start/100, - slg_end = slg_end/100, - ) + for i in range(repeat_generation): + try: + with tracker_lock: + if task_id in progress_tracker: + progress_tracker[task_id]['status'] = video_no + video_no += 1 + if image2video: + samples = wan_model.generate( + prompt, + image_to_continue.convert('RGB'), + image_to_end.convert('RGB') if image_to_end != None else None, + frame_num=(video_length // 4)* 4 + 1, + max_area=MAX_AREA_CONFIGS[resolution], + shift=flow_shift, + sampling_steps=num_inference_steps, + guide_scale=guidance_scale, + n_prompt=negative_prompt, + seed=seed, + offload_model=False, + callback=callback, + enable_RIFLEx = enable_RIFLEx, + VAE_tile_size = VAE_tile_size, + joint_pass = joint_pass, + slg_layers = slg_layers, + slg_start = slg_start/100, + slg_end = slg_end/100, + ) - else: - samples = wan_model.generate( - prompt, - frame_num=(video_length // 4)* 4 + 1, - size=(width, height), - shift=flow_shift, - sampling_steps=num_inference_steps, - guide_scale=guidance_scale, - n_prompt=negative_prompt, - seed=seed, - offload_model=False, - callback=callback, - enable_RIFLEx = enable_RIFLEx, - VAE_tile_size = VAE_tile_size, - joint_pass = joint_pass, - slg_layers = slg_layers, - slg_start = slg_start/100, - slg_end = slg_end/100, - ) - except Exception as e: - gen_in_progress = False - if temp_filename!= None and os.path.isfile(temp_filename): - os.remove(temp_filename) + else: + samples = wan_model.generate( + prompt, + frame_num=(video_length // 4)* 4 + 1, + size=(width, height), + shift=flow_shift, + sampling_steps=num_inference_steps, + guide_scale=guidance_scale, + n_prompt=negative_prompt, + seed=seed, + offload_model=False, + callback=callback, + enable_RIFLEx = enable_RIFLEx, + VAE_tile_size = VAE_tile_size, + joint_pass = joint_pass, + slg_layers = slg_layers, + slg_start = slg_start/100, + slg_end = slg_end/100, + ) + except Exception as e: + gen_in_progress = False + if temp_filename!= None and os.path.isfile(temp_filename): + os.remove(temp_filename) + offload.last_offload_obj.unload_all() + offload.unload_loras_from_model(trans) + # if compile: + # cache_size = torch._dynamo.config.cache_size_limit + # torch.compiler.reset() + # torch._dynamo.config.cache_size_limit = cache_size + + gc.collect() + torch.cuda.empty_cache() + s = str(e) + keyword_list = ["vram", "VRAM", "memory","allocat"] + VRAM_crash= False + if any( keyword in s for keyword in keyword_list): + VRAM_crash = True + else: + stack = traceback.extract_stack(f=None, limit=5) + for frame in stack: + if any( keyword in frame.name for keyword in keyword_list): + VRAM_crash = True + break + state["prompt"] = "" + if VRAM_crash: + raise gr.Error("The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames.") + else: + raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'") + finally: + with tracker_lock: + if task_id in progress_tracker: + del progress_tracker[task_id] + + if trans.enable_teacache: + print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{num_inference_steps}" ) + trans.previous_residual_uncond = None + trans.previous_residual_cond = None + + if samples != None: + samples = samples.to("cpu") offload.last_offload_obj.unload_all() - offload.unload_loras_from_model(trans) - # if compile: - # cache_size = torch._dynamo.config.cache_size_limit - # torch.compiler.reset() - # torch._dynamo.config.cache_size_limit = cache_size - gc.collect() torch.cuda.empty_cache() - s = str(e) - keyword_list = ["vram", "VRAM", "memory","allocat"] - VRAM_crash= False - if any( keyword in s for keyword in keyword_list): - VRAM_crash = True + + if samples == None: + end_time = time.time() + abort = True + state["prompt"] = "" else: - stack = traceback.extract_stack(f=None, limit=5) - for frame in stack: - if any( keyword in frame.name for keyword in keyword_list): - VRAM_crash = True - break - state["prompt"] = "" - if VRAM_crash: - raise gr.Error("The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames.") - else: - raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'") - finally: - with tracker_lock: - if task_id in progress_tracker: - del progress_tracker[task_id] + sample = samples.cpu() + # video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c") - if trans.enable_teacache: - print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{num_inference_steps}" ) - trans.previous_residual_uncond = None - trans.previous_residual_cond = None + time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss") + if os.name == 'nt': + file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50]).strip()}.mp4" + else: + file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:100]).strip()}.mp4" + video_path = os.path.join(save_path, file_name) + cache_video( + tensor=sample[None], + save_file=video_path, + fps=16, + nrow=1, + normalize=True, + value_range=(-1, 1)) + + configs = { + 'prompt': prompt, + 'negative_prompt': negative_prompt, + 'resolution': resolution, + 'video_length': video_length, + 'seed': seed, + 'num_inference_steps': num_inference_steps, + } - if samples != None: - samples = samples.to("cpu") - offload.last_offload_obj.unload_all() - gc.collect() - torch.cuda.empty_cache() + metadata_choice = server_config.get("metadata_choice","metadata") + if metadata_choice == "json": + with open(video_path.replace('.mp4', '.json'), 'w') as f: + json.dump(configs, f, indent=4) + elif metadata_choice == "metadata": + from mutagen.mp4 import MP4 + file = MP4(video_path) + file.tags['©cmt'] = [json.dumps(configs)] + file.save() - if samples == None: - end_time = time.time() - abort = True - state["prompt"] = "" - else: - sample = samples.cpu() - # video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c") - - time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss") - if os.name == 'nt': - file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50]).strip()}.mp4" - else: - file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:100]).strip()}.mp4" - video_path = os.path.join(save_path, file_name) - cache_video( - tensor=sample[None], - save_file=video_path, - fps=16, - nrow=1, - normalize=True, - value_range=(-1, 1)) - - configs = { - 'prompt': prompt, - 'negative_prompt': negative_prompt, - 'resolution': resolution, - 'video_length': video_length, - 'seed': seed, - 'num_inference_steps': num_inference_steps, - } - - metadata_choice = server_config.get("metadata_choice","metadata") - if metadata_choice == "json": - with open(video_path.replace('.mp4', '.json'), 'w') as f: - json.dump(configs, f, indent=4) - elif metadata_choice == "metadata": - from mutagen.mp4 import MP4 - file = MP4(video_path) - file.tags['©cmt'] = [json.dumps(configs)] - file.save() - - print(f"New video saved to Path: "+video_path) - file_list.append(video_path) - seed += 1 + print(f"New video saved to Path: "+video_path) + file_list.append(video_path) + seed += 1 if temp_filename!= None and os.path.isfile(temp_filename): os.remove(temp_filename) @@ -2006,10 +2030,10 @@ def generate_video_tab(image2video=False): , columns=[3], rows=[1], object_fit="contain", height=450, selected_index=0, interactive= False) generate_btn = gr.Button("Generate") queue_df = gr.DataFrame( - headers=["ID", "Status", "Progress", "Steps", "Time", "Prompt", "", "", ""], - datatype=["str", "str", "str", "str", "str", "str", "str", "str", "str"], + headers=["ID", "Status", "Repeats", "Progress", "Steps", "Time", "Prompt", "", "", ""], + datatype=["str", "str", "str", "str", "str", "str", "str", "str", "str", "str"], interactive=False, - col_count=(9, "fixed"), + col_count=(10, "fixed"), wrap=True, value=update_queue_data, every=1,