From 8f1e6149a834cff2f25d4495af52738218de5b92 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Thu, 3 Apr 2025 01:31:32 +0200 Subject: [PATCH] fixed bugs --- gradio_server.py | 21 +++++++++++---------- rife/inference.py | 4 ++-- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/gradio_server.py b/gradio_server.py index d883726..e85f67b 100644 --- a/gradio_server.py +++ b/gradio_server.py @@ -1773,7 +1773,6 @@ def generate_video( # yield f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s" else: sample = samples.cpu() - # video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c") time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss") if os.name == 'nt': @@ -1782,14 +1781,14 @@ def generate_video( file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:100]).strip()}.mp4" video_path = os.path.join(save_path, file_name) # if False: # for testing - # torch.save(sample, "ouput.pt") + # torch.save(sample, "output.pt") # else: - # sample =torch.load("ouput.pt") + # sample =torch.load("output.pt") exp = 0 fps = 16 if len(temporal_upsampling) > 0 or len(spatial_upsampling) > 0: - progress_args = [0, status + " - Upsampling"] + progress_args = [(num_inference_steps , num_inference_steps) , status + " - Upsampling" , num_inference_steps] progress(*progress_args ) gen["progress_args"] = progress_args @@ -1804,7 +1803,7 @@ def generate_video( fps = fps * 2**exp if len(spatial_upsampling) > 0: - from wan.utils.utils import resize_lanczos + from wan.utils.utils import resize_lanczos # need multithreading or to do lanczos with cuda if spatial_upsampling == "lanczos1.5": scale = 1.5 else: @@ -2712,10 +2711,12 @@ def generate_video_tab(image2video=False): queue_df = gr.DataFrame( headers=["Qty","Prompt", "Length","Steps","Start", "End", "", "", ""], datatype=[ "str","markdown","str", "markdown", "markdown", "markdown", "str", "str", "str"], + column_widths= ["50","", "65","55", "60", "60", "30", "30", "35"], interactive=False, col_count=(9, "fixed"), wrap=True, value=[], + line_breaks= True, visible= False, # every=1, elem_id="queue_df" @@ -3266,15 +3267,16 @@ def create_demo(): pointer-events: none; text-align: center; vertical-align: middle; + font-size:11px; } - #queue_df table { + #xqueue_df table { width: 100%; overflow: hidden !important; } - #queue_df::-webkit-scrollbar { + #xqueue_df::-webkit-scrollbar { display: none !important; } - #queue_df { + #xqueue_df { scrollbar-width: none !important; -ms-overflow-style: none !important; } @@ -3292,9 +3294,8 @@ def create_demo(): cursor: default !important; pointer-events: none; } - #queue_df th:nth-child(2), + #xqueue_df th:nth-child(2), #queue_df td:nth-child(2) { - width: auto; text-align: center; vertical-align: middle; white-space: normal; diff --git a/rife/inference.py b/rife/inference.py index 24a2bdd..8f775d3 100644 --- a/rife/inference.py +++ b/rife/inference.py @@ -73,7 +73,7 @@ def process_frames(model, device, frames, exp): ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3]) break_flag = False - if ssim > 0.996: + if ssim > 0.996 or pos > 100: pos += 1 frame = get_frame(frames, pos) if frame is None: @@ -86,7 +86,7 @@ def process_frames(model, device, frames, exp): I1 = model.inference(I0, I1, scale) I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False) ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3]) - frame = I1[0] + frame = I1[0][:, :h, :w] if ssim < 0.2: output = []