mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-05 14:33:15 +00:00
fix broken queue states, avoid unnecessary reloading
This commit is contained in:
parent
b5eca59a71
commit
3cd0dbf4dd
@ -37,6 +37,7 @@ task_id = 0
|
|||||||
progress_tracker = {}
|
progress_tracker = {}
|
||||||
tracker_lock = threading.Lock()
|
tracker_lock = threading.Lock()
|
||||||
file_list = []
|
file_list = []
|
||||||
|
last_model_type = None
|
||||||
|
|
||||||
def runner():
|
def runner():
|
||||||
global current_task_id
|
global current_task_id
|
||||||
@ -47,22 +48,23 @@ def runner():
|
|||||||
with tracker_lock:
|
with tracker_lock:
|
||||||
progress = progress_tracker.get(task_id, {})
|
progress = progress_tracker.get(task_id, {})
|
||||||
|
|
||||||
if item['state'] != "Queued" and item['state'] != "Finished":
|
if item['state'] == "Processing":
|
||||||
current_step = progress.get('current_step', 0)
|
current_step = progress.get('current_step', 0)
|
||||||
total_steps = progress.get('total_steps', 0)
|
total_steps = progress.get('total_steps', 0)
|
||||||
elapsed = time.time() - progress.get('start_time', time.time())
|
elapsed = time.time() - progress.get('start_time', time.time())
|
||||||
status = progress.get('status', "")
|
status = progress.get('status', "")
|
||||||
state = progress.get("state")
|
repeats = progress.get("repeats")
|
||||||
item.update({
|
item.update({
|
||||||
'progress': f"{((current_step/total_steps)*100 if total_steps > 0 else 0):.1f}%",
|
'progress': f"{((current_step/total_steps)*100 if total_steps > 0 else 0):.1f}%",
|
||||||
'steps': f"{current_step}/{total_steps}",
|
'steps': f"{current_step}/{total_steps}",
|
||||||
'time': f"{elapsed:.1f}s",
|
'time': f"{elapsed:.1f}s",
|
||||||
'state': f"{state}",
|
'repeats': f"{repeats}",
|
||||||
'status': f"{status}"
|
'status': f"{status}"
|
||||||
})
|
})
|
||||||
if not any(item['state'] == "Processing" for item in queue):
|
if not any(item['state'] == "Processing" for item in queue):
|
||||||
for item in queue:
|
for item in queue:
|
||||||
if item['state'] == "Queued":
|
if item['state'] == "Queued":
|
||||||
|
item['status'] = "Processing"
|
||||||
item['state'] = "Processing"
|
item['state'] = "Processing"
|
||||||
current_task_id = item['id']
|
current_task_id = item['id']
|
||||||
threading.Thread(target=process_task, args=(item,)).start()
|
threading.Thread(target=process_task, args=(item,)).start()
|
||||||
@ -160,7 +162,8 @@ def add_video_task(*params):
|
|||||||
"id": current_task_id,
|
"id": current_task_id,
|
||||||
"params": (current_task_id,) + params,
|
"params": (current_task_id,) + params,
|
||||||
"state": "Queued",
|
"state": "Queued",
|
||||||
"status": "0/0",
|
"status": "Queued",
|
||||||
|
"repeats": "0/0",
|
||||||
"progress": "0.0%",
|
"progress": "0.0%",
|
||||||
"steps": f"0/{params[5]}",
|
"steps": f"0/{params[5]}",
|
||||||
"time": "--",
|
"time": "--",
|
||||||
@ -212,8 +215,8 @@ def update_queue_data():
|
|||||||
for item in queue:
|
for item in queue:
|
||||||
data.append([
|
data.append([
|
||||||
str(item['id']),
|
str(item['id']),
|
||||||
item['state'],
|
|
||||||
item['status'],
|
item['status'],
|
||||||
|
item['repeats'],
|
||||||
item.get('progress', "0.0%"),
|
item.get('progress', "0.0%"),
|
||||||
item.get('steps', ''),
|
item.get('steps', ''),
|
||||||
item.get('time', '--'),
|
item.get('time', '--'),
|
||||||
@ -1013,8 +1016,8 @@ def build_callback(state, pipe, num_inference_steps, status):
|
|||||||
'total_steps': num_inference_steps,
|
'total_steps': num_inference_steps,
|
||||||
'start_time': start_time,
|
'start_time': start_time,
|
||||||
'last_update': time.time(),
|
'last_update': time.time(),
|
||||||
'status': status,
|
'repeats': status,
|
||||||
'state': phase
|
'status': phase
|
||||||
}
|
}
|
||||||
return update_progress
|
return update_progress
|
||||||
|
|
||||||
@ -1078,11 +1081,12 @@ def generate_video(
|
|||||||
progress=gr.Progress() #track_tqdm= True
|
progress=gr.Progress() #track_tqdm= True
|
||||||
|
|
||||||
):
|
):
|
||||||
global wan_model, offloadobj
|
global wan_model, offloadobj, last_model_type
|
||||||
reload_needed = state.get("_reload_needed", False)
|
reload_needed = state.get("_reload_needed", False)
|
||||||
file_model_needed = model_needed(image2video)
|
file_model_needed = model_needed(image2video)
|
||||||
if(server_config.get("reload_model",1) == 2):
|
with lock:
|
||||||
if file_model_needed != model_filename or reload_needed:
|
queue_not_empty = len(queue) > 0
|
||||||
|
if(last_model_type != image2video and (queue_not_empty or server_config.get("reload_model",1) == 2) and (file_model_needed != model_filename or reload_needed)):
|
||||||
del wan_model
|
del wan_model
|
||||||
if offloadobj is not None:
|
if offloadobj is not None:
|
||||||
offloadobj.release()
|
offloadobj.release()
|
||||||
@ -1251,13 +1255,6 @@ def generate_video(
|
|||||||
global save_path
|
global save_path
|
||||||
os.makedirs(save_path, exist_ok=True)
|
os.makedirs(save_path, exist_ok=True)
|
||||||
abort = False
|
abort = False
|
||||||
with tracker_lock:
|
|
||||||
progress_tracker[task_id] = {
|
|
||||||
'current_step': 0,
|
|
||||||
'total_steps': num_inference_steps,
|
|
||||||
'start_time': time.time(),
|
|
||||||
'last_update': time.time()
|
|
||||||
}
|
|
||||||
if trans.enable_teacache:
|
if trans.enable_teacache:
|
||||||
trans.teacache_counter = 0
|
trans.teacache_counter = 0
|
||||||
trans.num_steps = num_inference_steps
|
trans.num_steps = num_inference_steps
|
||||||
@ -1268,8 +1265,12 @@ def generate_video(
|
|||||||
status = f"{video_no}/{repeat_generation}"
|
status = f"{video_no}/{repeat_generation}"
|
||||||
with tracker_lock:
|
with tracker_lock:
|
||||||
if task_id in progress_tracker:
|
if task_id in progress_tracker:
|
||||||
progress_tracker[task_id]['state'] = "Encoding Prompt"
|
progress_tracker[task_id]['status'] = "Encoding Prompt"
|
||||||
progress_tracker[task_id]['status'] = status
|
progress_tracker[task_id]['repeats'] = status
|
||||||
|
progress_tracker[task_id]['current_step'] = 0
|
||||||
|
progress_tracker[task_id]['total_steps'] = num_inference_steps
|
||||||
|
progress_tracker[task_id]['start_time'] = time.time()
|
||||||
|
progress_tracker[task_id]['last_update'] = time.time()
|
||||||
callback = build_callback(state, trans, num_inference_steps, status)
|
callback = build_callback(state, trans, num_inference_steps, status)
|
||||||
offload.shared_state["callback"] = callback
|
offload.shared_state["callback"] = callback
|
||||||
gc.collect()
|
gc.collect()
|
||||||
@ -1279,7 +1280,7 @@ def generate_video(
|
|||||||
try:
|
try:
|
||||||
with tracker_lock:
|
with tracker_lock:
|
||||||
if task_id in progress_tracker:
|
if task_id in progress_tracker:
|
||||||
progress_tracker[task_id]['status'] = video_no
|
progress_tracker[task_id]['repeats'] = video_no
|
||||||
video_no += 1
|
video_no += 1
|
||||||
if image2video:
|
if image2video:
|
||||||
samples = wan_model.generate(
|
samples = wan_model.generate(
|
||||||
@ -1326,8 +1327,8 @@ def generate_video(
|
|||||||
gen_in_progress = False
|
gen_in_progress = False
|
||||||
if temp_filename!= None and os.path.isfile(temp_filename):
|
if temp_filename!= None and os.path.isfile(temp_filename):
|
||||||
os.remove(temp_filename)
|
os.remove(temp_filename)
|
||||||
offload.last_offload_obj.unload_all()
|
if(offload.last_offload_obj): offload.last_offload_obj.unload_all()
|
||||||
offload.unload_loras_from_model(trans)
|
if(trans): offload.unload_loras_from_model(trans)
|
||||||
# if compile:
|
# if compile:
|
||||||
# cache_size = torch._dynamo.config.cache_size_limit
|
# cache_size = torch._dynamo.config.cache_size_limit
|
||||||
# torch.compiler.reset()
|
# torch.compiler.reset()
|
||||||
@ -1411,6 +1412,7 @@ def generate_video(
|
|||||||
print(f"New video saved to Path: "+video_path)
|
print(f"New video saved to Path: "+video_path)
|
||||||
file_list.append(video_path)
|
file_list.append(video_path)
|
||||||
seed += 1
|
seed += 1
|
||||||
|
last_model_type = image2video
|
||||||
|
|
||||||
if temp_filename!= None and os.path.isfile(temp_filename):
|
if temp_filename!= None and os.path.isfile(temp_filename):
|
||||||
os.remove(temp_filename)
|
os.remove(temp_filename)
|
||||||
@ -2291,6 +2293,9 @@ def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData):
|
|||||||
use_image2video = new_i2v
|
use_image2video = new_i2v
|
||||||
|
|
||||||
if(server_config.get("reload_model",2) == 1):
|
if(server_config.get("reload_model",2) == 1):
|
||||||
|
with lock:
|
||||||
|
queue_empty = len(queue) == 0
|
||||||
|
if queue_empty:
|
||||||
global wan_model, offloadobj
|
global wan_model, offloadobj
|
||||||
if wan_model is not None:
|
if wan_model is not None:
|
||||||
if offloadobj is not None:
|
if offloadobj is not None:
|
||||||
@ -2300,6 +2305,7 @@ def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
wan_model, offloadobj, trans = load_models(use_image2video)
|
wan_model, offloadobj, trans = load_models(use_image2video)
|
||||||
|
del trans
|
||||||
|
|
||||||
t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode)
|
t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode)
|
||||||
i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode)
|
i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user