mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-05 06:29:14 +00:00
restored original detailed status updates, fixed repeats
This commit is contained in:
parent
bb3f99012f
commit
646d9ace13
308
gradio_server.py
308
gradio_server.py
@ -1,7 +1,6 @@
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
from typing import Dict, Any
|
||||
import argparse
|
||||
from mmgp import offload, safetensors2, profile_type
|
||||
try:
|
||||
@ -23,7 +22,6 @@ import traceback
|
||||
import math
|
||||
import asyncio
|
||||
from wan.utils import prompt_parser
|
||||
from typing import List
|
||||
PROMPT_VARS_MAX = 10
|
||||
|
||||
target_mmgp_version = "3.3.0"
|
||||
@ -49,19 +47,23 @@ def runner():
|
||||
with tracker_lock:
|
||||
progress = progress_tracker.get(task_id, {})
|
||||
|
||||
if item['status'] == "Processing":
|
||||
if item['state'] != "Queued" and item['state'] != "Finished":
|
||||
current_step = progress.get('current_step', 0)
|
||||
total_steps = progress.get('total_steps', 0)
|
||||
elapsed = time.time() - progress.get('start_time', time.time())
|
||||
status = progress.get('status', "")
|
||||
state = progress.get("state")
|
||||
item.update({
|
||||
'progress': f"{((current_step/total_steps)*100 if total_steps > 0 else 0):.1f}%",
|
||||
'steps': f"{current_step}/{total_steps}",
|
||||
'time': f"{elapsed:.1f}s"
|
||||
'time': f"{elapsed:.1f}s",
|
||||
'state': f"{state}",
|
||||
'status': f"{status}"
|
||||
})
|
||||
if not any(item['status'] == "Processing" for item in queue):
|
||||
if not any(item['state'] == "Processing" for item in queue):
|
||||
for item in queue:
|
||||
if item['status'] == "Queued":
|
||||
item['status'] = "Processing"
|
||||
if item['state'] == "Queued":
|
||||
item['state'] = "Processing"
|
||||
current_task_id = item['id']
|
||||
threading.Thread(target=process_task, args=(item,)).start()
|
||||
break
|
||||
@ -157,7 +159,8 @@ def add_video_task(*params):
|
||||
queue.append({
|
||||
"id": current_task_id,
|
||||
"params": (current_task_id,) + params,
|
||||
"status": "Queued",
|
||||
"state": "Queued",
|
||||
"status": "0/0",
|
||||
"progress": "0.0%",
|
||||
"steps": f"0/{params[5]}",
|
||||
"time": "--",
|
||||
@ -209,6 +212,7 @@ def update_queue_data():
|
||||
for item in queue:
|
||||
data.append([
|
||||
str(item['id']),
|
||||
item['state'],
|
||||
item['status'],
|
||||
item.get('progress', "0.0%"),
|
||||
item.get('steps', ''),
|
||||
@ -985,29 +989,38 @@ def save_video(final_frames, output_path, fps=24):
|
||||
final_frames = (final_frames * 255).astype(np.uint8)
|
||||
ImageSequenceClip(list(final_frames), fps=fps).write_videofile(output_path, verbose= False, logger = None)
|
||||
|
||||
def build_callback(task_id, total_steps):
|
||||
def build_callback(state, pipe, num_inference_steps, status):
|
||||
start_time = time.time()
|
||||
def update_progress(step, _):
|
||||
def update_progress(step_idx, latents, read_state = False):
|
||||
with tracker_lock:
|
||||
step_idx += 1
|
||||
if state.get("abort", False):
|
||||
# pipe._interrupt = True
|
||||
phase = "Aborting"
|
||||
elif step_idx == num_inference_steps:
|
||||
phase = "VAE Decoding"
|
||||
else:
|
||||
phase = "Denoising"
|
||||
elapsed = time.time() - start_time
|
||||
progress_tracker[task_id] = {
|
||||
'current_step': step + 1,
|
||||
'total_steps': total_steps,
|
||||
'current_step': step_idx,
|
||||
'total_steps': num_inference_steps,
|
||||
'start_time': start_time,
|
||||
'last_update': time.time()
|
||||
'last_update': time.time(),
|
||||
'status': status,
|
||||
'state': phase
|
||||
}
|
||||
return update_progress
|
||||
|
||||
def refresh_gallery(state):
|
||||
file_list = state.get("file_list", None)
|
||||
return file_list
|
||||
|
||||
def finalize_gallery(state):
|
||||
choice = 0
|
||||
if "in_progress" in state:
|
||||
del state["in_progress"]
|
||||
choice = state.get("selected",0)
|
||||
|
||||
state["extra_orders"] = 0
|
||||
time.sleep(0.2)
|
||||
global gen_in_progress
|
||||
gen_in_progress = False
|
||||
@ -1067,6 +1080,7 @@ def generate_video(
|
||||
offloadobj.release()
|
||||
offloadobj = None
|
||||
wan_model = None
|
||||
gc.collect()
|
||||
print(f"Loading model {get_model_name(file_model_needed)}...")
|
||||
wan_model, offloadobj, trans = load_models(image2video)
|
||||
print(f"Model loaded")
|
||||
@ -1243,143 +1257,153 @@ def generate_video(
|
||||
trans.teacache_skipped_steps = 0
|
||||
trans.previous_residual_uncond = None
|
||||
trans.previous_residual_cond = None
|
||||
callback = build_callback(task_id, num_inference_steps)
|
||||
video_no = 0
|
||||
status = f"{video_no}/{repeat_generation}"
|
||||
with tracker_lock:
|
||||
if task_id in progress_tracker:
|
||||
progress_tracker[task_id]['state'] = "Encoding Prompt"
|
||||
progress_tracker[task_id]['status'] = status
|
||||
callback = build_callback(state, trans, num_inference_steps, status)
|
||||
offload.shared_state["callback"] = callback
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
wan_model._interrupt = False
|
||||
state["progress_status"] = "Starting"
|
||||
try:
|
||||
if image2video:
|
||||
samples = wan_model.generate(
|
||||
prompt,
|
||||
image_to_continue.convert('RGB'),
|
||||
image_to_end.convert('RGB') if image_to_end != None else None,
|
||||
frame_num=(video_length // 4)* 4 + 1,
|
||||
max_area=MAX_AREA_CONFIGS[resolution],
|
||||
shift=flow_shift,
|
||||
sampling_steps=num_inference_steps,
|
||||
guide_scale=guidance_scale,
|
||||
n_prompt=negative_prompt,
|
||||
seed=seed,
|
||||
offload_model=False,
|
||||
callback=callback,
|
||||
enable_RIFLEx = enable_RIFLEx,
|
||||
VAE_tile_size = VAE_tile_size,
|
||||
joint_pass = joint_pass,
|
||||
slg_layers = slg_layers,
|
||||
slg_start = slg_start/100,
|
||||
slg_end = slg_end/100,
|
||||
)
|
||||
for i in range(repeat_generation):
|
||||
try:
|
||||
with tracker_lock:
|
||||
if task_id in progress_tracker:
|
||||
progress_tracker[task_id]['status'] = video_no
|
||||
video_no += 1
|
||||
if image2video:
|
||||
samples = wan_model.generate(
|
||||
prompt,
|
||||
image_to_continue.convert('RGB'),
|
||||
image_to_end.convert('RGB') if image_to_end != None else None,
|
||||
frame_num=(video_length // 4)* 4 + 1,
|
||||
max_area=MAX_AREA_CONFIGS[resolution],
|
||||
shift=flow_shift,
|
||||
sampling_steps=num_inference_steps,
|
||||
guide_scale=guidance_scale,
|
||||
n_prompt=negative_prompt,
|
||||
seed=seed,
|
||||
offload_model=False,
|
||||
callback=callback,
|
||||
enable_RIFLEx = enable_RIFLEx,
|
||||
VAE_tile_size = VAE_tile_size,
|
||||
joint_pass = joint_pass,
|
||||
slg_layers = slg_layers,
|
||||
slg_start = slg_start/100,
|
||||
slg_end = slg_end/100,
|
||||
)
|
||||
|
||||
else:
|
||||
samples = wan_model.generate(
|
||||
prompt,
|
||||
frame_num=(video_length // 4)* 4 + 1,
|
||||
size=(width, height),
|
||||
shift=flow_shift,
|
||||
sampling_steps=num_inference_steps,
|
||||
guide_scale=guidance_scale,
|
||||
n_prompt=negative_prompt,
|
||||
seed=seed,
|
||||
offload_model=False,
|
||||
callback=callback,
|
||||
enable_RIFLEx = enable_RIFLEx,
|
||||
VAE_tile_size = VAE_tile_size,
|
||||
joint_pass = joint_pass,
|
||||
slg_layers = slg_layers,
|
||||
slg_start = slg_start/100,
|
||||
slg_end = slg_end/100,
|
||||
)
|
||||
except Exception as e:
|
||||
gen_in_progress = False
|
||||
if temp_filename!= None and os.path.isfile(temp_filename):
|
||||
os.remove(temp_filename)
|
||||
else:
|
||||
samples = wan_model.generate(
|
||||
prompt,
|
||||
frame_num=(video_length // 4)* 4 + 1,
|
||||
size=(width, height),
|
||||
shift=flow_shift,
|
||||
sampling_steps=num_inference_steps,
|
||||
guide_scale=guidance_scale,
|
||||
n_prompt=negative_prompt,
|
||||
seed=seed,
|
||||
offload_model=False,
|
||||
callback=callback,
|
||||
enable_RIFLEx = enable_RIFLEx,
|
||||
VAE_tile_size = VAE_tile_size,
|
||||
joint_pass = joint_pass,
|
||||
slg_layers = slg_layers,
|
||||
slg_start = slg_start/100,
|
||||
slg_end = slg_end/100,
|
||||
)
|
||||
except Exception as e:
|
||||
gen_in_progress = False
|
||||
if temp_filename!= None and os.path.isfile(temp_filename):
|
||||
os.remove(temp_filename)
|
||||
offload.last_offload_obj.unload_all()
|
||||
offload.unload_loras_from_model(trans)
|
||||
# if compile:
|
||||
# cache_size = torch._dynamo.config.cache_size_limit
|
||||
# torch.compiler.reset()
|
||||
# torch._dynamo.config.cache_size_limit = cache_size
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
s = str(e)
|
||||
keyword_list = ["vram", "VRAM", "memory","allocat"]
|
||||
VRAM_crash= False
|
||||
if any( keyword in s for keyword in keyword_list):
|
||||
VRAM_crash = True
|
||||
else:
|
||||
stack = traceback.extract_stack(f=None, limit=5)
|
||||
for frame in stack:
|
||||
if any( keyword in frame.name for keyword in keyword_list):
|
||||
VRAM_crash = True
|
||||
break
|
||||
state["prompt"] = ""
|
||||
if VRAM_crash:
|
||||
raise gr.Error("The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames.")
|
||||
else:
|
||||
raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
|
||||
finally:
|
||||
with tracker_lock:
|
||||
if task_id in progress_tracker:
|
||||
del progress_tracker[task_id]
|
||||
|
||||
if trans.enable_teacache:
|
||||
print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{num_inference_steps}" )
|
||||
trans.previous_residual_uncond = None
|
||||
trans.previous_residual_cond = None
|
||||
|
||||
if samples != None:
|
||||
samples = samples.to("cpu")
|
||||
offload.last_offload_obj.unload_all()
|
||||
offload.unload_loras_from_model(trans)
|
||||
# if compile:
|
||||
# cache_size = torch._dynamo.config.cache_size_limit
|
||||
# torch.compiler.reset()
|
||||
# torch._dynamo.config.cache_size_limit = cache_size
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
s = str(e)
|
||||
keyword_list = ["vram", "VRAM", "memory","allocat"]
|
||||
VRAM_crash= False
|
||||
if any( keyword in s for keyword in keyword_list):
|
||||
VRAM_crash = True
|
||||
|
||||
if samples == None:
|
||||
end_time = time.time()
|
||||
abort = True
|
||||
state["prompt"] = ""
|
||||
else:
|
||||
stack = traceback.extract_stack(f=None, limit=5)
|
||||
for frame in stack:
|
||||
if any( keyword in frame.name for keyword in keyword_list):
|
||||
VRAM_crash = True
|
||||
break
|
||||
state["prompt"] = ""
|
||||
if VRAM_crash:
|
||||
raise gr.Error("The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames.")
|
||||
else:
|
||||
raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
|
||||
finally:
|
||||
with tracker_lock:
|
||||
if task_id in progress_tracker:
|
||||
del progress_tracker[task_id]
|
||||
sample = samples.cpu()
|
||||
# video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c")
|
||||
|
||||
if trans.enable_teacache:
|
||||
print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{num_inference_steps}" )
|
||||
trans.previous_residual_uncond = None
|
||||
trans.previous_residual_cond = None
|
||||
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
|
||||
if os.name == 'nt':
|
||||
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50]).strip()}.mp4"
|
||||
else:
|
||||
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:100]).strip()}.mp4"
|
||||
video_path = os.path.join(save_path, file_name)
|
||||
cache_video(
|
||||
tensor=sample[None],
|
||||
save_file=video_path,
|
||||
fps=16,
|
||||
nrow=1,
|
||||
normalize=True,
|
||||
value_range=(-1, 1))
|
||||
|
||||
if samples != None:
|
||||
samples = samples.to("cpu")
|
||||
offload.last_offload_obj.unload_all()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
configs = {
|
||||
'prompt': prompt,
|
||||
'negative_prompt': negative_prompt,
|
||||
'resolution': resolution,
|
||||
'video_length': video_length,
|
||||
'seed': seed,
|
||||
'num_inference_steps': num_inference_steps,
|
||||
}
|
||||
|
||||
if samples == None:
|
||||
end_time = time.time()
|
||||
abort = True
|
||||
state["prompt"] = ""
|
||||
else:
|
||||
sample = samples.cpu()
|
||||
# video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c")
|
||||
metadata_choice = server_config.get("metadata_choice","metadata")
|
||||
if metadata_choice == "json":
|
||||
with open(video_path.replace('.mp4', '.json'), 'w') as f:
|
||||
json.dump(configs, f, indent=4)
|
||||
elif metadata_choice == "metadata":
|
||||
from mutagen.mp4 import MP4
|
||||
file = MP4(video_path)
|
||||
file.tags['©cmt'] = [json.dumps(configs)]
|
||||
file.save()
|
||||
|
||||
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
|
||||
if os.name == 'nt':
|
||||
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50]).strip()}.mp4"
|
||||
else:
|
||||
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:100]).strip()}.mp4"
|
||||
video_path = os.path.join(save_path, file_name)
|
||||
cache_video(
|
||||
tensor=sample[None],
|
||||
save_file=video_path,
|
||||
fps=16,
|
||||
nrow=1,
|
||||
normalize=True,
|
||||
value_range=(-1, 1))
|
||||
|
||||
configs = {
|
||||
'prompt': prompt,
|
||||
'negative_prompt': negative_prompt,
|
||||
'resolution': resolution,
|
||||
'video_length': video_length,
|
||||
'seed': seed,
|
||||
'num_inference_steps': num_inference_steps,
|
||||
}
|
||||
|
||||
metadata_choice = server_config.get("metadata_choice","metadata")
|
||||
if metadata_choice == "json":
|
||||
with open(video_path.replace('.mp4', '.json'), 'w') as f:
|
||||
json.dump(configs, f, indent=4)
|
||||
elif metadata_choice == "metadata":
|
||||
from mutagen.mp4 import MP4
|
||||
file = MP4(video_path)
|
||||
file.tags['©cmt'] = [json.dumps(configs)]
|
||||
file.save()
|
||||
|
||||
print(f"New video saved to Path: "+video_path)
|
||||
file_list.append(video_path)
|
||||
seed += 1
|
||||
print(f"New video saved to Path: "+video_path)
|
||||
file_list.append(video_path)
|
||||
seed += 1
|
||||
|
||||
if temp_filename!= None and os.path.isfile(temp_filename):
|
||||
os.remove(temp_filename)
|
||||
@ -2006,10 +2030,10 @@ def generate_video_tab(image2video=False):
|
||||
, columns=[3], rows=[1], object_fit="contain", height=450, selected_index=0, interactive= False)
|
||||
generate_btn = gr.Button("Generate")
|
||||
queue_df = gr.DataFrame(
|
||||
headers=["ID", "Status", "Progress", "Steps", "Time", "Prompt", "", "", ""],
|
||||
datatype=["str", "str", "str", "str", "str", "str", "str", "str", "str"],
|
||||
headers=["ID", "Status", "Repeats", "Progress", "Steps", "Time", "Prompt", "", "", ""],
|
||||
datatype=["str", "str", "str", "str", "str", "str", "str", "str", "str", "str"],
|
||||
interactive=False,
|
||||
col_count=(9, "fixed"),
|
||||
col_count=(10, "fixed"),
|
||||
wrap=True,
|
||||
value=update_queue_data,
|
||||
every=1,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user