fixed bugs

This commit is contained in:
DeepBeepMeep 2025-04-03 01:31:32 +02:00
parent a8fea50e78
commit 8f1e6149a8
2 changed files with 13 additions and 12 deletions

View File

@ -1773,7 +1773,6 @@ def generate_video(
# yield f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s" # yield f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s"
else: else:
sample = samples.cpu() sample = samples.cpu()
# video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c")
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss") time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
if os.name == 'nt': if os.name == 'nt':
@ -1782,14 +1781,14 @@ def generate_video(
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:100]).strip()}.mp4" file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:100]).strip()}.mp4"
video_path = os.path.join(save_path, file_name) video_path = os.path.join(save_path, file_name)
# if False: # for testing # if False: # for testing
# torch.save(sample, "ouput.pt") # torch.save(sample, "output.pt")
# else: # else:
# sample =torch.load("ouput.pt") # sample =torch.load("output.pt")
exp = 0 exp = 0
fps = 16 fps = 16
if len(temporal_upsampling) > 0 or len(spatial_upsampling) > 0: if len(temporal_upsampling) > 0 or len(spatial_upsampling) > 0:
progress_args = [0, status + " - Upsampling"] progress_args = [(num_inference_steps , num_inference_steps) , status + " - Upsampling" , num_inference_steps]
progress(*progress_args ) progress(*progress_args )
gen["progress_args"] = progress_args gen["progress_args"] = progress_args
@ -1804,7 +1803,7 @@ def generate_video(
fps = fps * 2**exp fps = fps * 2**exp
if len(spatial_upsampling) > 0: if len(spatial_upsampling) > 0:
from wan.utils.utils import resize_lanczos from wan.utils.utils import resize_lanczos # need multithreading or to do lanczos with cuda
if spatial_upsampling == "lanczos1.5": if spatial_upsampling == "lanczos1.5":
scale = 1.5 scale = 1.5
else: else:
@ -2712,10 +2711,12 @@ def generate_video_tab(image2video=False):
queue_df = gr.DataFrame( queue_df = gr.DataFrame(
headers=["Qty","Prompt", "Length","Steps","Start", "End", "", "", ""], headers=["Qty","Prompt", "Length","Steps","Start", "End", "", "", ""],
datatype=[ "str","markdown","str", "markdown", "markdown", "markdown", "str", "str", "str"], datatype=[ "str","markdown","str", "markdown", "markdown", "markdown", "str", "str", "str"],
column_widths= ["50","", "65","55", "60", "60", "30", "30", "35"],
interactive=False, interactive=False,
col_count=(9, "fixed"), col_count=(9, "fixed"),
wrap=True, wrap=True,
value=[], value=[],
line_breaks= True,
visible= False, visible= False,
# every=1, # every=1,
elem_id="queue_df" elem_id="queue_df"
@ -3266,15 +3267,16 @@ def create_demo():
pointer-events: none; pointer-events: none;
text-align: center; text-align: center;
vertical-align: middle; vertical-align: middle;
font-size:11px;
} }
#queue_df table { #xqueue_df table {
width: 100%; width: 100%;
overflow: hidden !important; overflow: hidden !important;
} }
#queue_df::-webkit-scrollbar { #xqueue_df::-webkit-scrollbar {
display: none !important; display: none !important;
} }
#queue_df { #xqueue_df {
scrollbar-width: none !important; scrollbar-width: none !important;
-ms-overflow-style: none !important; -ms-overflow-style: none !important;
} }
@ -3292,9 +3294,8 @@ def create_demo():
cursor: default !important; cursor: default !important;
pointer-events: none; pointer-events: none;
} }
#queue_df th:nth-child(2), #xqueue_df th:nth-child(2),
#queue_df td:nth-child(2) { #queue_df td:nth-child(2) {
width: auto;
text-align: center; text-align: center;
vertical-align: middle; vertical-align: middle;
white-space: normal; white-space: normal;

View File

@ -73,7 +73,7 @@ def process_frames(model, device, frames, exp):
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3]) ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
break_flag = False break_flag = False
if ssim > 0.996: if ssim > 0.996 or pos > 100:
pos += 1 pos += 1
frame = get_frame(frames, pos) frame = get_frame(frames, pos)
if frame is None: if frame is None:
@ -86,7 +86,7 @@ def process_frames(model, device, frames, exp):
I1 = model.inference(I0, I1, scale) I1 = model.inference(I0, I1, scale)
I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False) I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3]) ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
frame = I1[0] frame = I1[0][:, :h, :w]
if ssim < 0.2: if ssim < 0.2:
output = [] output = []