diff --git a/gradio_server.py b/gradio_server.py index b3ccd94..0a0fafb 100644 --- a/gradio_server.py +++ b/gradio_server.py @@ -37,6 +37,7 @@ task_id = 0 progress_tracker = {} tracker_lock = threading.Lock() file_list = [] +last_model_type = None def runner(): global current_task_id @@ -47,22 +48,23 @@ def runner(): with tracker_lock: progress = progress_tracker.get(task_id, {}) - if item['state'] != "Queued" and item['state'] != "Finished": + if item['state'] == "Processing": current_step = progress.get('current_step', 0) total_steps = progress.get('total_steps', 0) elapsed = time.time() - progress.get('start_time', time.time()) status = progress.get('status', "") - state = progress.get("state") + repeats = progress.get("repeats") item.update({ 'progress': f"{((current_step/total_steps)*100 if total_steps > 0 else 0):.1f}%", 'steps': f"{current_step}/{total_steps}", 'time': f"{elapsed:.1f}s", - 'state': f"{state}", + 'repeats': f"{repeats}", 'status': f"{status}" }) if not any(item['state'] == "Processing" for item in queue): for item in queue: if item['state'] == "Queued": + item['status'] = "Processing" item['state'] = "Processing" current_task_id = item['id'] threading.Thread(target=process_task, args=(item,)).start() @@ -160,7 +162,8 @@ def add_video_task(*params): "id": current_task_id, "params": (current_task_id,) + params, "state": "Queued", - "status": "0/0", + "status": "Queued", + "repeats": "0/0", "progress": "0.0%", "steps": f"0/{params[5]}", "time": "--", @@ -212,8 +215,8 @@ def update_queue_data(): for item in queue: data.append([ str(item['id']), - item['state'], item['status'], + item['repeats'], item.get('progress', "0.0%"), item.get('steps', ''), item.get('time', '--'), @@ -1013,8 +1016,8 @@ def build_callback(state, pipe, num_inference_steps, status): 'total_steps': num_inference_steps, 'start_time': start_time, 'last_update': time.time(), - 'status': status, - 'state': phase + 'repeats': status, + 'status': phase } return update_progress @@ -1078,20 +1081,21 @@ def generate_video( progress=gr.Progress() #track_tqdm= True ): - global wan_model, offloadobj + global wan_model, offloadobj, last_model_type reload_needed = state.get("_reload_needed", False) file_model_needed = model_needed(image2video) - if(server_config.get("reload_model",1) == 2): - if file_model_needed != model_filename or reload_needed: - del wan_model - if offloadobj is not None: - offloadobj.release() - del offloadobj - gc.collect() - print(f"Loading model {get_model_name(file_model_needed)}...") - wan_model, offloadobj, trans = load_models(image2video) - print(f"Model loaded") - state["_reload_needed"] = False + with lock: + queue_not_empty = len(queue) > 0 + if(last_model_type != image2video and (queue_not_empty or server_config.get("reload_model",1) == 2) and (file_model_needed != model_filename or reload_needed)): + del wan_model + if offloadobj is not None: + offloadobj.release() + del offloadobj + gc.collect() + print(f"Loading model {get_model_name(file_model_needed)}...") + wan_model, offloadobj, trans = load_models(image2video) + print(f"Model loaded") + state["_reload_needed"] = False from PIL import Image import numpy as np @@ -1251,13 +1255,6 @@ def generate_video( global save_path os.makedirs(save_path, exist_ok=True) abort = False - with tracker_lock: - progress_tracker[task_id] = { - 'current_step': 0, - 'total_steps': num_inference_steps, - 'start_time': time.time(), - 'last_update': time.time() - } if trans.enable_teacache: trans.teacache_counter = 0 trans.num_steps = num_inference_steps @@ -1268,8 +1265,12 @@ def generate_video( status = f"{video_no}/{repeat_generation}" with tracker_lock: if task_id in progress_tracker: - progress_tracker[task_id]['state'] = "Encoding Prompt" - progress_tracker[task_id]['status'] = status + progress_tracker[task_id]['status'] = "Encoding Prompt" + progress_tracker[task_id]['repeats'] = status + progress_tracker[task_id]['current_step'] = 0 + progress_tracker[task_id]['total_steps'] = num_inference_steps + progress_tracker[task_id]['start_time'] = time.time() + progress_tracker[task_id]['last_update'] = time.time() callback = build_callback(state, trans, num_inference_steps, status) offload.shared_state["callback"] = callback gc.collect() @@ -1279,7 +1280,7 @@ def generate_video( try: with tracker_lock: if task_id in progress_tracker: - progress_tracker[task_id]['status'] = video_no + progress_tracker[task_id]['repeats'] = video_no video_no += 1 if image2video: samples = wan_model.generate( @@ -1326,8 +1327,8 @@ def generate_video( gen_in_progress = False if temp_filename!= None and os.path.isfile(temp_filename): os.remove(temp_filename) - offload.last_offload_obj.unload_all() - offload.unload_loras_from_model(trans) + if(offload.last_offload_obj): offload.last_offload_obj.unload_all() + if(trans): offload.unload_loras_from_model(trans) # if compile: # cache_size = torch._dynamo.config.cache_size_limit # torch.compiler.reset() @@ -1411,6 +1412,7 @@ def generate_video( print(f"New video saved to Path: "+video_path) file_list.append(video_path) seed += 1 + last_model_type = image2video if temp_filename!= None and os.path.isfile(temp_filename): os.remove(temp_filename) @@ -2291,15 +2293,19 @@ def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData): use_image2video = new_i2v if(server_config.get("reload_model",2) == 1): - global wan_model, offloadobj - if wan_model is not None: - if offloadobj is not None: - offloadobj.release() - offloadobj = None - wan_model = None - gc.collect() - torch.cuda.empty_cache() - wan_model, offloadobj, trans = load_models(use_image2video) + with lock: + queue_empty = len(queue) == 0 + if queue_empty: + global wan_model, offloadobj + if wan_model is not None: + if offloadobj is not None: + offloadobj.release() + offloadobj = None + wan_model = None + gc.collect() + torch.cuda.empty_cache() + wan_model, offloadobj, trans = load_models(use_image2video) + del trans t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode) i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode)