fix broken queue states, avoid unnecessary reloading

This commit is contained in:
Chris Malone 2025-03-25 03:07:35 +11:00
parent b5eca59a71
commit 3cd0dbf4dd

View File

@ -37,6 +37,7 @@ task_id = 0
progress_tracker = {} progress_tracker = {}
tracker_lock = threading.Lock() tracker_lock = threading.Lock()
file_list = [] file_list = []
last_model_type = None
def runner(): def runner():
global current_task_id global current_task_id
@ -47,22 +48,23 @@ def runner():
with tracker_lock: with tracker_lock:
progress = progress_tracker.get(task_id, {}) progress = progress_tracker.get(task_id, {})
if item['state'] != "Queued" and item['state'] != "Finished": if item['state'] == "Processing":
current_step = progress.get('current_step', 0) current_step = progress.get('current_step', 0)
total_steps = progress.get('total_steps', 0) total_steps = progress.get('total_steps', 0)
elapsed = time.time() - progress.get('start_time', time.time()) elapsed = time.time() - progress.get('start_time', time.time())
status = progress.get('status', "") status = progress.get('status', "")
state = progress.get("state") repeats = progress.get("repeats")
item.update({ item.update({
'progress': f"{((current_step/total_steps)*100 if total_steps > 0 else 0):.1f}%", 'progress': f"{((current_step/total_steps)*100 if total_steps > 0 else 0):.1f}%",
'steps': f"{current_step}/{total_steps}", 'steps': f"{current_step}/{total_steps}",
'time': f"{elapsed:.1f}s", 'time': f"{elapsed:.1f}s",
'state': f"{state}", 'repeats': f"{repeats}",
'status': f"{status}" 'status': f"{status}"
}) })
if not any(item['state'] == "Processing" for item in queue): if not any(item['state'] == "Processing" for item in queue):
for item in queue: for item in queue:
if item['state'] == "Queued": if item['state'] == "Queued":
item['status'] = "Processing"
item['state'] = "Processing" item['state'] = "Processing"
current_task_id = item['id'] current_task_id = item['id']
threading.Thread(target=process_task, args=(item,)).start() threading.Thread(target=process_task, args=(item,)).start()
@ -160,7 +162,8 @@ def add_video_task(*params):
"id": current_task_id, "id": current_task_id,
"params": (current_task_id,) + params, "params": (current_task_id,) + params,
"state": "Queued", "state": "Queued",
"status": "0/0", "status": "Queued",
"repeats": "0/0",
"progress": "0.0%", "progress": "0.0%",
"steps": f"0/{params[5]}", "steps": f"0/{params[5]}",
"time": "--", "time": "--",
@ -212,8 +215,8 @@ def update_queue_data():
for item in queue: for item in queue:
data.append([ data.append([
str(item['id']), str(item['id']),
item['state'],
item['status'], item['status'],
item['repeats'],
item.get('progress', "0.0%"), item.get('progress', "0.0%"),
item.get('steps', ''), item.get('steps', ''),
item.get('time', '--'), item.get('time', '--'),
@ -1013,8 +1016,8 @@ def build_callback(state, pipe, num_inference_steps, status):
'total_steps': num_inference_steps, 'total_steps': num_inference_steps,
'start_time': start_time, 'start_time': start_time,
'last_update': time.time(), 'last_update': time.time(),
'status': status, 'repeats': status,
'state': phase 'status': phase
} }
return update_progress return update_progress
@ -1078,20 +1081,21 @@ def generate_video(
progress=gr.Progress() #track_tqdm= True progress=gr.Progress() #track_tqdm= True
): ):
global wan_model, offloadobj global wan_model, offloadobj, last_model_type
reload_needed = state.get("_reload_needed", False) reload_needed = state.get("_reload_needed", False)
file_model_needed = model_needed(image2video) file_model_needed = model_needed(image2video)
if(server_config.get("reload_model",1) == 2): with lock:
if file_model_needed != model_filename or reload_needed: queue_not_empty = len(queue) > 0
del wan_model if(last_model_type != image2video and (queue_not_empty or server_config.get("reload_model",1) == 2) and (file_model_needed != model_filename or reload_needed)):
if offloadobj is not None: del wan_model
offloadobj.release() if offloadobj is not None:
del offloadobj offloadobj.release()
gc.collect() del offloadobj
print(f"Loading model {get_model_name(file_model_needed)}...") gc.collect()
wan_model, offloadobj, trans = load_models(image2video) print(f"Loading model {get_model_name(file_model_needed)}...")
print(f"Model loaded") wan_model, offloadobj, trans = load_models(image2video)
state["_reload_needed"] = False print(f"Model loaded")
state["_reload_needed"] = False
from PIL import Image from PIL import Image
import numpy as np import numpy as np
@ -1251,13 +1255,6 @@ def generate_video(
global save_path global save_path
os.makedirs(save_path, exist_ok=True) os.makedirs(save_path, exist_ok=True)
abort = False abort = False
with tracker_lock:
progress_tracker[task_id] = {
'current_step': 0,
'total_steps': num_inference_steps,
'start_time': time.time(),
'last_update': time.time()
}
if trans.enable_teacache: if trans.enable_teacache:
trans.teacache_counter = 0 trans.teacache_counter = 0
trans.num_steps = num_inference_steps trans.num_steps = num_inference_steps
@ -1268,8 +1265,12 @@ def generate_video(
status = f"{video_no}/{repeat_generation}" status = f"{video_no}/{repeat_generation}"
with tracker_lock: with tracker_lock:
if task_id in progress_tracker: if task_id in progress_tracker:
progress_tracker[task_id]['state'] = "Encoding Prompt" progress_tracker[task_id]['status'] = "Encoding Prompt"
progress_tracker[task_id]['status'] = status progress_tracker[task_id]['repeats'] = status
progress_tracker[task_id]['current_step'] = 0
progress_tracker[task_id]['total_steps'] = num_inference_steps
progress_tracker[task_id]['start_time'] = time.time()
progress_tracker[task_id]['last_update'] = time.time()
callback = build_callback(state, trans, num_inference_steps, status) callback = build_callback(state, trans, num_inference_steps, status)
offload.shared_state["callback"] = callback offload.shared_state["callback"] = callback
gc.collect() gc.collect()
@ -1279,7 +1280,7 @@ def generate_video(
try: try:
with tracker_lock: with tracker_lock:
if task_id in progress_tracker: if task_id in progress_tracker:
progress_tracker[task_id]['status'] = video_no progress_tracker[task_id]['repeats'] = video_no
video_no += 1 video_no += 1
if image2video: if image2video:
samples = wan_model.generate( samples = wan_model.generate(
@ -1326,8 +1327,8 @@ def generate_video(
gen_in_progress = False gen_in_progress = False
if temp_filename!= None and os.path.isfile(temp_filename): if temp_filename!= None and os.path.isfile(temp_filename):
os.remove(temp_filename) os.remove(temp_filename)
offload.last_offload_obj.unload_all() if(offload.last_offload_obj): offload.last_offload_obj.unload_all()
offload.unload_loras_from_model(trans) if(trans): offload.unload_loras_from_model(trans)
# if compile: # if compile:
# cache_size = torch._dynamo.config.cache_size_limit # cache_size = torch._dynamo.config.cache_size_limit
# torch.compiler.reset() # torch.compiler.reset()
@ -1411,6 +1412,7 @@ def generate_video(
print(f"New video saved to Path: "+video_path) print(f"New video saved to Path: "+video_path)
file_list.append(video_path) file_list.append(video_path)
seed += 1 seed += 1
last_model_type = image2video
if temp_filename!= None and os.path.isfile(temp_filename): if temp_filename!= None and os.path.isfile(temp_filename):
os.remove(temp_filename) os.remove(temp_filename)
@ -2291,15 +2293,19 @@ def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData):
use_image2video = new_i2v use_image2video = new_i2v
if(server_config.get("reload_model",2) == 1): if(server_config.get("reload_model",2) == 1):
global wan_model, offloadobj with lock:
if wan_model is not None: queue_empty = len(queue) == 0
if offloadobj is not None: if queue_empty:
offloadobj.release() global wan_model, offloadobj
offloadobj = None if wan_model is not None:
wan_model = None if offloadobj is not None:
gc.collect() offloadobj.release()
torch.cuda.empty_cache() offloadobj = None
wan_model, offloadobj, trans = load_models(use_image2video) wan_model = None
gc.collect()
torch.cuda.empty_cache()
wan_model, offloadobj, trans = load_models(use_image2video)
del trans
t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode) t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode)
i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode) i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode)