restored original detailed status updates, fixed repeats

This commit is contained in:
Chris Malone 2025-03-24 21:00:30 +11:00
parent bb3f99012f
commit 646d9ace13

View File

@ -1,7 +1,6 @@
import os import os
import time import time
import threading import threading
from typing import Dict, Any
import argparse import argparse
from mmgp import offload, safetensors2, profile_type from mmgp import offload, safetensors2, profile_type
try: try:
@ -23,7 +22,6 @@ import traceback
import math import math
import asyncio import asyncio
from wan.utils import prompt_parser from wan.utils import prompt_parser
from typing import List
PROMPT_VARS_MAX = 10 PROMPT_VARS_MAX = 10
target_mmgp_version = "3.3.0" target_mmgp_version = "3.3.0"
@ -49,19 +47,23 @@ def runner():
with tracker_lock: with tracker_lock:
progress = progress_tracker.get(task_id, {}) progress = progress_tracker.get(task_id, {})
if item['status'] == "Processing": if item['state'] != "Queued" and item['state'] != "Finished":
current_step = progress.get('current_step', 0) current_step = progress.get('current_step', 0)
total_steps = progress.get('total_steps', 0) total_steps = progress.get('total_steps', 0)
elapsed = time.time() - progress.get('start_time', time.time()) elapsed = time.time() - progress.get('start_time', time.time())
status = progress.get('status', "")
state = progress.get("state")
item.update({ item.update({
'progress': f"{((current_step/total_steps)*100 if total_steps > 0 else 0):.1f}%", 'progress': f"{((current_step/total_steps)*100 if total_steps > 0 else 0):.1f}%",
'steps': f"{current_step}/{total_steps}", 'steps': f"{current_step}/{total_steps}",
'time': f"{elapsed:.1f}s" 'time': f"{elapsed:.1f}s",
'state': f"{state}",
'status': f"{status}"
}) })
if not any(item['status'] == "Processing" for item in queue): if not any(item['state'] == "Processing" for item in queue):
for item in queue: for item in queue:
if item['status'] == "Queued": if item['state'] == "Queued":
item['status'] = "Processing" item['state'] = "Processing"
current_task_id = item['id'] current_task_id = item['id']
threading.Thread(target=process_task, args=(item,)).start() threading.Thread(target=process_task, args=(item,)).start()
break break
@ -157,7 +159,8 @@ def add_video_task(*params):
queue.append({ queue.append({
"id": current_task_id, "id": current_task_id,
"params": (current_task_id,) + params, "params": (current_task_id,) + params,
"status": "Queued", "state": "Queued",
"status": "0/0",
"progress": "0.0%", "progress": "0.0%",
"steps": f"0/{params[5]}", "steps": f"0/{params[5]}",
"time": "--", "time": "--",
@ -209,6 +212,7 @@ def update_queue_data():
for item in queue: for item in queue:
data.append([ data.append([
str(item['id']), str(item['id']),
item['state'],
item['status'], item['status'],
item.get('progress', "0.0%"), item.get('progress', "0.0%"),
item.get('steps', ''), item.get('steps', ''),
@ -985,29 +989,38 @@ def save_video(final_frames, output_path, fps=24):
final_frames = (final_frames * 255).astype(np.uint8) final_frames = (final_frames * 255).astype(np.uint8)
ImageSequenceClip(list(final_frames), fps=fps).write_videofile(output_path, verbose= False, logger = None) ImageSequenceClip(list(final_frames), fps=fps).write_videofile(output_path, verbose= False, logger = None)
def build_callback(task_id, total_steps): def build_callback(state, pipe, num_inference_steps, status):
start_time = time.time() start_time = time.time()
def update_progress(step, _): def update_progress(step_idx, latents, read_state = False):
with tracker_lock: with tracker_lock:
step_idx += 1
if state.get("abort", False):
# pipe._interrupt = True
phase = "Aborting"
elif step_idx == num_inference_steps:
phase = "VAE Decoding"
else:
phase = "Denoising"
elapsed = time.time() - start_time elapsed = time.time() - start_time
progress_tracker[task_id] = { progress_tracker[task_id] = {
'current_step': step + 1, 'current_step': step_idx,
'total_steps': total_steps, 'total_steps': num_inference_steps,
'start_time': start_time, 'start_time': start_time,
'last_update': time.time() 'last_update': time.time(),
'status': status,
'state': phase
} }
return update_progress return update_progress
def refresh_gallery(state): def refresh_gallery(state):
file_list = state.get("file_list", None) file_list = state.get("file_list", None)
return file_list return file_list
def finalize_gallery(state): def finalize_gallery(state):
choice = 0 choice = 0
if "in_progress" in state: if "in_progress" in state:
del state["in_progress"] del state["in_progress"]
choice = state.get("selected",0) choice = state.get("selected",0)
state["extra_orders"] = 0
time.sleep(0.2) time.sleep(0.2)
global gen_in_progress global gen_in_progress
gen_in_progress = False gen_in_progress = False
@ -1067,6 +1080,7 @@ def generate_video(
offloadobj.release() offloadobj.release()
offloadobj = None offloadobj = None
wan_model = None wan_model = None
gc.collect()
print(f"Loading model {get_model_name(file_model_needed)}...") print(f"Loading model {get_model_name(file_model_needed)}...")
wan_model, offloadobj, trans = load_models(image2video) wan_model, offloadobj, trans = load_models(image2video)
print(f"Model loaded") print(f"Model loaded")
@ -1243,143 +1257,153 @@ def generate_video(
trans.teacache_skipped_steps = 0 trans.teacache_skipped_steps = 0
trans.previous_residual_uncond = None trans.previous_residual_uncond = None
trans.previous_residual_cond = None trans.previous_residual_cond = None
callback = build_callback(task_id, num_inference_steps) video_no = 0
offload.shared_state["callback"] = callback status = f"{video_no}/{repeat_generation}"
with tracker_lock:
if task_id in progress_tracker:
progress_tracker[task_id]['state'] = "Encoding Prompt"
progress_tracker[task_id]['status'] = status
callback = build_callback(state, trans, num_inference_steps, status)
offload.shared_state["callback"] = callback
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
wan_model._interrupt = False wan_model._interrupt = False
state["progress_status"] = "Starting" for i in range(repeat_generation):
try: try:
if image2video: with tracker_lock:
samples = wan_model.generate( if task_id in progress_tracker:
prompt, progress_tracker[task_id]['status'] = video_no
image_to_continue.convert('RGB'), video_no += 1
image_to_end.convert('RGB') if image_to_end != None else None, if image2video:
frame_num=(video_length // 4)* 4 + 1, samples = wan_model.generate(
max_area=MAX_AREA_CONFIGS[resolution], prompt,
shift=flow_shift, image_to_continue.convert('RGB'),
sampling_steps=num_inference_steps, image_to_end.convert('RGB') if image_to_end != None else None,
guide_scale=guidance_scale, frame_num=(video_length // 4)* 4 + 1,
n_prompt=negative_prompt, max_area=MAX_AREA_CONFIGS[resolution],
seed=seed, shift=flow_shift,
offload_model=False, sampling_steps=num_inference_steps,
callback=callback, guide_scale=guidance_scale,
enable_RIFLEx = enable_RIFLEx, n_prompt=negative_prompt,
VAE_tile_size = VAE_tile_size, seed=seed,
joint_pass = joint_pass, offload_model=False,
slg_layers = slg_layers, callback=callback,
slg_start = slg_start/100, enable_RIFLEx = enable_RIFLEx,
slg_end = slg_end/100, VAE_tile_size = VAE_tile_size,
) joint_pass = joint_pass,
slg_layers = slg_layers,
slg_start = slg_start/100,
slg_end = slg_end/100,
)
else: else:
samples = wan_model.generate( samples = wan_model.generate(
prompt, prompt,
frame_num=(video_length // 4)* 4 + 1, frame_num=(video_length // 4)* 4 + 1,
size=(width, height), size=(width, height),
shift=flow_shift, shift=flow_shift,
sampling_steps=num_inference_steps, sampling_steps=num_inference_steps,
guide_scale=guidance_scale, guide_scale=guidance_scale,
n_prompt=negative_prompt, n_prompt=negative_prompt,
seed=seed, seed=seed,
offload_model=False, offload_model=False,
callback=callback, callback=callback,
enable_RIFLEx = enable_RIFLEx, enable_RIFLEx = enable_RIFLEx,
VAE_tile_size = VAE_tile_size, VAE_tile_size = VAE_tile_size,
joint_pass = joint_pass, joint_pass = joint_pass,
slg_layers = slg_layers, slg_layers = slg_layers,
slg_start = slg_start/100, slg_start = slg_start/100,
slg_end = slg_end/100, slg_end = slg_end/100,
) )
except Exception as e: except Exception as e:
gen_in_progress = False gen_in_progress = False
if temp_filename!= None and os.path.isfile(temp_filename): if temp_filename!= None and os.path.isfile(temp_filename):
os.remove(temp_filename) os.remove(temp_filename)
offload.last_offload_obj.unload_all()
offload.unload_loras_from_model(trans)
# if compile:
# cache_size = torch._dynamo.config.cache_size_limit
# torch.compiler.reset()
# torch._dynamo.config.cache_size_limit = cache_size
gc.collect()
torch.cuda.empty_cache()
s = str(e)
keyword_list = ["vram", "VRAM", "memory","allocat"]
VRAM_crash= False
if any( keyword in s for keyword in keyword_list):
VRAM_crash = True
else:
stack = traceback.extract_stack(f=None, limit=5)
for frame in stack:
if any( keyword in frame.name for keyword in keyword_list):
VRAM_crash = True
break
state["prompt"] = ""
if VRAM_crash:
raise gr.Error("The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames.")
else:
raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
finally:
with tracker_lock:
if task_id in progress_tracker:
del progress_tracker[task_id]
if trans.enable_teacache:
print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{num_inference_steps}" )
trans.previous_residual_uncond = None
trans.previous_residual_cond = None
if samples != None:
samples = samples.to("cpu")
offload.last_offload_obj.unload_all() offload.last_offload_obj.unload_all()
offload.unload_loras_from_model(trans)
# if compile:
# cache_size = torch._dynamo.config.cache_size_limit
# torch.compiler.reset()
# torch._dynamo.config.cache_size_limit = cache_size
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
s = str(e)
keyword_list = ["vram", "VRAM", "memory","allocat"] if samples == None:
VRAM_crash= False end_time = time.time()
if any( keyword in s for keyword in keyword_list): abort = True
VRAM_crash = True state["prompt"] = ""
else: else:
stack = traceback.extract_stack(f=None, limit=5) sample = samples.cpu()
for frame in stack: # video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c")
if any( keyword in frame.name for keyword in keyword_list):
VRAM_crash = True
break
state["prompt"] = ""
if VRAM_crash:
raise gr.Error("The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames.")
else:
raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
finally:
with tracker_lock:
if task_id in progress_tracker:
del progress_tracker[task_id]
if trans.enable_teacache: time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{num_inference_steps}" ) if os.name == 'nt':
trans.previous_residual_uncond = None file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50]).strip()}.mp4"
trans.previous_residual_cond = None else:
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:100]).strip()}.mp4"
video_path = os.path.join(save_path, file_name)
cache_video(
tensor=sample[None],
save_file=video_path,
fps=16,
nrow=1,
normalize=True,
value_range=(-1, 1))
configs = {
'prompt': prompt,
'negative_prompt': negative_prompt,
'resolution': resolution,
'video_length': video_length,
'seed': seed,
'num_inference_steps': num_inference_steps,
}
if samples != None: metadata_choice = server_config.get("metadata_choice","metadata")
samples = samples.to("cpu") if metadata_choice == "json":
offload.last_offload_obj.unload_all() with open(video_path.replace('.mp4', '.json'), 'w') as f:
gc.collect() json.dump(configs, f, indent=4)
torch.cuda.empty_cache() elif metadata_choice == "metadata":
from mutagen.mp4 import MP4
file = MP4(video_path)
file.tags['©cmt'] = [json.dumps(configs)]
file.save()
if samples == None: print(f"New video saved to Path: "+video_path)
end_time = time.time() file_list.append(video_path)
abort = True seed += 1
state["prompt"] = ""
else:
sample = samples.cpu()
# video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c")
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
if os.name == 'nt':
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50]).strip()}.mp4"
else:
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:100]).strip()}.mp4"
video_path = os.path.join(save_path, file_name)
cache_video(
tensor=sample[None],
save_file=video_path,
fps=16,
nrow=1,
normalize=True,
value_range=(-1, 1))
configs = {
'prompt': prompt,
'negative_prompt': negative_prompt,
'resolution': resolution,
'video_length': video_length,
'seed': seed,
'num_inference_steps': num_inference_steps,
}
metadata_choice = server_config.get("metadata_choice","metadata")
if metadata_choice == "json":
with open(video_path.replace('.mp4', '.json'), 'w') as f:
json.dump(configs, f, indent=4)
elif metadata_choice == "metadata":
from mutagen.mp4 import MP4
file = MP4(video_path)
file.tags['©cmt'] = [json.dumps(configs)]
file.save()
print(f"New video saved to Path: "+video_path)
file_list.append(video_path)
seed += 1
if temp_filename!= None and os.path.isfile(temp_filename): if temp_filename!= None and os.path.isfile(temp_filename):
os.remove(temp_filename) os.remove(temp_filename)
@ -2006,10 +2030,10 @@ def generate_video_tab(image2video=False):
, columns=[3], rows=[1], object_fit="contain", height=450, selected_index=0, interactive= False) , columns=[3], rows=[1], object_fit="contain", height=450, selected_index=0, interactive= False)
generate_btn = gr.Button("Generate") generate_btn = gr.Button("Generate")
queue_df = gr.DataFrame( queue_df = gr.DataFrame(
headers=["ID", "Status", "Progress", "Steps", "Time", "Prompt", "", "", ""], headers=["ID", "Status", "Repeats", "Progress", "Steps", "Time", "Prompt", "", "", ""],
datatype=["str", "str", "str", "str", "str", "str", "str", "str", "str"], datatype=["str", "str", "str", "str", "str", "str", "str", "str", "str", "str"],
interactive=False, interactive=False,
col_count=(9, "fixed"), col_count=(10, "fixed"),
wrap=True, wrap=True,
value=update_queue_data, value=update_queue_data,
every=1, every=1,