mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2026-01-12 01:03:44 +00:00
restored original detailed status updates, fixed repeats
This commit is contained in:
parent
bb3f99012f
commit
646d9ace13
@ -1,7 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import threading
|
import threading
|
||||||
from typing import Dict, Any
|
|
||||||
import argparse
|
import argparse
|
||||||
from mmgp import offload, safetensors2, profile_type
|
from mmgp import offload, safetensors2, profile_type
|
||||||
try:
|
try:
|
||||||
@ -23,7 +22,6 @@ import traceback
|
|||||||
import math
|
import math
|
||||||
import asyncio
|
import asyncio
|
||||||
from wan.utils import prompt_parser
|
from wan.utils import prompt_parser
|
||||||
from typing import List
|
|
||||||
PROMPT_VARS_MAX = 10
|
PROMPT_VARS_MAX = 10
|
||||||
|
|
||||||
target_mmgp_version = "3.3.0"
|
target_mmgp_version = "3.3.0"
|
||||||
@ -49,19 +47,23 @@ def runner():
|
|||||||
with tracker_lock:
|
with tracker_lock:
|
||||||
progress = progress_tracker.get(task_id, {})
|
progress = progress_tracker.get(task_id, {})
|
||||||
|
|
||||||
if item['status'] == "Processing":
|
if item['state'] != "Queued" and item['state'] != "Finished":
|
||||||
current_step = progress.get('current_step', 0)
|
current_step = progress.get('current_step', 0)
|
||||||
total_steps = progress.get('total_steps', 0)
|
total_steps = progress.get('total_steps', 0)
|
||||||
elapsed = time.time() - progress.get('start_time', time.time())
|
elapsed = time.time() - progress.get('start_time', time.time())
|
||||||
|
status = progress.get('status', "")
|
||||||
|
state = progress.get("state")
|
||||||
item.update({
|
item.update({
|
||||||
'progress': f"{((current_step/total_steps)*100 if total_steps > 0 else 0):.1f}%",
|
'progress': f"{((current_step/total_steps)*100 if total_steps > 0 else 0):.1f}%",
|
||||||
'steps': f"{current_step}/{total_steps}",
|
'steps': f"{current_step}/{total_steps}",
|
||||||
'time': f"{elapsed:.1f}s"
|
'time': f"{elapsed:.1f}s",
|
||||||
|
'state': f"{state}",
|
||||||
|
'status': f"{status}"
|
||||||
})
|
})
|
||||||
if not any(item['status'] == "Processing" for item in queue):
|
if not any(item['state'] == "Processing" for item in queue):
|
||||||
for item in queue:
|
for item in queue:
|
||||||
if item['status'] == "Queued":
|
if item['state'] == "Queued":
|
||||||
item['status'] = "Processing"
|
item['state'] = "Processing"
|
||||||
current_task_id = item['id']
|
current_task_id = item['id']
|
||||||
threading.Thread(target=process_task, args=(item,)).start()
|
threading.Thread(target=process_task, args=(item,)).start()
|
||||||
break
|
break
|
||||||
@ -157,7 +159,8 @@ def add_video_task(*params):
|
|||||||
queue.append({
|
queue.append({
|
||||||
"id": current_task_id,
|
"id": current_task_id,
|
||||||
"params": (current_task_id,) + params,
|
"params": (current_task_id,) + params,
|
||||||
"status": "Queued",
|
"state": "Queued",
|
||||||
|
"status": "0/0",
|
||||||
"progress": "0.0%",
|
"progress": "0.0%",
|
||||||
"steps": f"0/{params[5]}",
|
"steps": f"0/{params[5]}",
|
||||||
"time": "--",
|
"time": "--",
|
||||||
@ -209,6 +212,7 @@ def update_queue_data():
|
|||||||
for item in queue:
|
for item in queue:
|
||||||
data.append([
|
data.append([
|
||||||
str(item['id']),
|
str(item['id']),
|
||||||
|
item['state'],
|
||||||
item['status'],
|
item['status'],
|
||||||
item.get('progress', "0.0%"),
|
item.get('progress', "0.0%"),
|
||||||
item.get('steps', ''),
|
item.get('steps', ''),
|
||||||
@ -985,29 +989,38 @@ def save_video(final_frames, output_path, fps=24):
|
|||||||
final_frames = (final_frames * 255).astype(np.uint8)
|
final_frames = (final_frames * 255).astype(np.uint8)
|
||||||
ImageSequenceClip(list(final_frames), fps=fps).write_videofile(output_path, verbose= False, logger = None)
|
ImageSequenceClip(list(final_frames), fps=fps).write_videofile(output_path, verbose= False, logger = None)
|
||||||
|
|
||||||
def build_callback(task_id, total_steps):
|
def build_callback(state, pipe, num_inference_steps, status):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
def update_progress(step, _):
|
def update_progress(step_idx, latents, read_state = False):
|
||||||
with tracker_lock:
|
with tracker_lock:
|
||||||
|
step_idx += 1
|
||||||
|
if state.get("abort", False):
|
||||||
|
# pipe._interrupt = True
|
||||||
|
phase = "Aborting"
|
||||||
|
elif step_idx == num_inference_steps:
|
||||||
|
phase = "VAE Decoding"
|
||||||
|
else:
|
||||||
|
phase = "Denoising"
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
progress_tracker[task_id] = {
|
progress_tracker[task_id] = {
|
||||||
'current_step': step + 1,
|
'current_step': step_idx,
|
||||||
'total_steps': total_steps,
|
'total_steps': num_inference_steps,
|
||||||
'start_time': start_time,
|
'start_time': start_time,
|
||||||
'last_update': time.time()
|
'last_update': time.time(),
|
||||||
|
'status': status,
|
||||||
|
'state': phase
|
||||||
}
|
}
|
||||||
return update_progress
|
return update_progress
|
||||||
|
|
||||||
def refresh_gallery(state):
|
def refresh_gallery(state):
|
||||||
file_list = state.get("file_list", None)
|
file_list = state.get("file_list", None)
|
||||||
return file_list
|
return file_list
|
||||||
|
|
||||||
def finalize_gallery(state):
|
def finalize_gallery(state):
|
||||||
choice = 0
|
choice = 0
|
||||||
if "in_progress" in state:
|
if "in_progress" in state:
|
||||||
del state["in_progress"]
|
del state["in_progress"]
|
||||||
choice = state.get("selected",0)
|
choice = state.get("selected",0)
|
||||||
|
|
||||||
state["extra_orders"] = 0
|
|
||||||
time.sleep(0.2)
|
time.sleep(0.2)
|
||||||
global gen_in_progress
|
global gen_in_progress
|
||||||
gen_in_progress = False
|
gen_in_progress = False
|
||||||
@ -1067,6 +1080,7 @@ def generate_video(
|
|||||||
offloadobj.release()
|
offloadobj.release()
|
||||||
offloadobj = None
|
offloadobj = None
|
||||||
wan_model = None
|
wan_model = None
|
||||||
|
gc.collect()
|
||||||
print(f"Loading model {get_model_name(file_model_needed)}...")
|
print(f"Loading model {get_model_name(file_model_needed)}...")
|
||||||
wan_model, offloadobj, trans = load_models(image2video)
|
wan_model, offloadobj, trans = load_models(image2video)
|
||||||
print(f"Model loaded")
|
print(f"Model loaded")
|
||||||
@ -1243,13 +1257,23 @@ def generate_video(
|
|||||||
trans.teacache_skipped_steps = 0
|
trans.teacache_skipped_steps = 0
|
||||||
trans.previous_residual_uncond = None
|
trans.previous_residual_uncond = None
|
||||||
trans.previous_residual_cond = None
|
trans.previous_residual_cond = None
|
||||||
callback = build_callback(task_id, num_inference_steps)
|
video_no = 0
|
||||||
|
status = f"{video_no}/{repeat_generation}"
|
||||||
|
with tracker_lock:
|
||||||
|
if task_id in progress_tracker:
|
||||||
|
progress_tracker[task_id]['state'] = "Encoding Prompt"
|
||||||
|
progress_tracker[task_id]['status'] = status
|
||||||
|
callback = build_callback(state, trans, num_inference_steps, status)
|
||||||
offload.shared_state["callback"] = callback
|
offload.shared_state["callback"] = callback
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
wan_model._interrupt = False
|
wan_model._interrupt = False
|
||||||
state["progress_status"] = "Starting"
|
for i in range(repeat_generation):
|
||||||
try:
|
try:
|
||||||
|
with tracker_lock:
|
||||||
|
if task_id in progress_tracker:
|
||||||
|
progress_tracker[task_id]['status'] = video_no
|
||||||
|
video_no += 1
|
||||||
if image2video:
|
if image2video:
|
||||||
samples = wan_model.generate(
|
samples = wan_model.generate(
|
||||||
prompt,
|
prompt,
|
||||||
@ -2006,10 +2030,10 @@ def generate_video_tab(image2video=False):
|
|||||||
, columns=[3], rows=[1], object_fit="contain", height=450, selected_index=0, interactive= False)
|
, columns=[3], rows=[1], object_fit="contain", height=450, selected_index=0, interactive= False)
|
||||||
generate_btn = gr.Button("Generate")
|
generate_btn = gr.Button("Generate")
|
||||||
queue_df = gr.DataFrame(
|
queue_df = gr.DataFrame(
|
||||||
headers=["ID", "Status", "Progress", "Steps", "Time", "Prompt", "", "", ""],
|
headers=["ID", "Status", "Repeats", "Progress", "Steps", "Time", "Prompt", "", "", ""],
|
||||||
datatype=["str", "str", "str", "str", "str", "str", "str", "str", "str"],
|
datatype=["str", "str", "str", "str", "str", "str", "str", "str", "str", "str"],
|
||||||
interactive=False,
|
interactive=False,
|
||||||
col_count=(9, "fixed"),
|
col_count=(10, "fixed"),
|
||||||
wrap=True,
|
wrap=True,
|
||||||
value=update_queue_data,
|
value=update_queue_data,
|
||||||
every=1,
|
every=1,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user