Added Vace Sliding Window

This commit is contained in:
DeepBeepMeep 2025-04-13 01:36:57 +02:00
parent 4c79c62419
commit 033546ca42
6 changed files with 304 additions and 168 deletions

View File

@ -14,10 +14,14 @@
## 🔥 Latest News!!
* April 9 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you !
* April 13 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you !
- A new queuing system that lets you stack in a queue as many text2video and imag2video tasks as you want. Each task can rely on complete different generation parameters (different number of frames, steps, loras, ...).
- Temporal upsampling (Rife) and spatial upsampling (Lanczos) for a smoother video (32 fps or 64 fps) and to enlarge you video by x2 or x4. Check these new advanced options.
- Wan Vace Control Net support : with Vace you can inject in the scene people or objects, animate a person, perform inpainting or outpainting, continue a video, ... I have provided an introduction guide below.
- Integrated *Matanyone* tool directly inside WanGP so that you can create easily inpainting masks
- Sliding Window generation for Vace, create windows that can last dozen of seconds
- A new UI, tabs were replaced by a Dropdown box to easily switch models
* Mar 27 2025: 👋 Added support for the new Wan Fun InP models (image2video). The 14B Fun InP has probably better end image support but unfortunately existing loras do not work so well with it. The great novelty is the Fun InP image2 1.3B model : Image 2 Video is now accessible to even lower hardware configuration. It is not as good as the 14B models but very impressive for its size. You can choose any of those models in the Configuration tab. Many thanks to the VideoX-Fun team (https://github.com/aigc-apps/VideoX-Fun)
* Mar 26 2025: 👋 Good news ! Official support for RTX 50xx please check the installation instructions below.
* Mar 24 2025: 👋 Wan2.1GP v3.2:

View File

@ -163,10 +163,10 @@ def get_frames_from_video(video_input, video_state):
model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
return video_state, video_info, video_state["origin_images"][0], \
gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), gr.update(visible=False, maximum=len(frames), value=len(frames)), \
gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True), gr.update(visible=True),\
gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True), gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=True), \
gr.update(visible=True)
@ -273,7 +273,7 @@ def save_video(frames, output_path, fps):
return output_path
# video matting
def video_matting(video_state, end_slider, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size):
def video_matting(video_state, end_slider, matting_type, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size):
matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg)
# if interactive_state["track_end_number"]:
# following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
@ -301,9 +301,16 @@ def video_matting(video_state, end_slider, interactive_state, mask_dropdown, ero
template_mask[0][0]=1
foreground, alpha = matanyone(matanyone_processor, following_frames, template_mask*255, r_erode=erode_kernel_size, r_dilate=dilate_kernel_size)
output_frames = []
foreground_mat = matting_type == "Foreground"
for frame_origin, frame_alpha in zip(following_frames, alpha):
frame_alpha[frame_alpha > 127] = 255
frame_alpha[frame_alpha <= 127] = 0
if foreground_mat:
frame_alpha[frame_alpha > 127] = 255
frame_alpha[frame_alpha <= 127] = 0
else:
frame_temp = frame_alpha.copy()
frame_alpha[frame_temp > 127] = 0
frame_alpha[frame_temp <= 127] = 255
output_frame = np.bitwise_and(frame_origin, 255-frame_alpha)
frame_grey = frame_alpha.copy()
frame_grey[frame_alpha == 255] = 127
@ -314,15 +321,19 @@ def video_matting(video_state, end_slider, interactive_state, mask_dropdown, ero
if not os.path.exists("mask_outputs"):
os.makedirs("mask_outputs")
foreground_output = save_video(foreground, output_path="./mask_outputs/{}_fg.mp4".format(video_state["video_name"]), fps=fps)
file_name= video_state["video_name"]
file_name = ".".join(file_name.split(".")[:-1])
foreground_output = save_video(foreground, output_path="./mask_outputs/{}_fg.mp4".format(file_name), fps=fps)
# foreground_output = generate_video_from_frames(foreground, output_path="./results/{}_fg.mp4".format(video_state["video_name"]), fps=fps, audio_path=audio_path) # import video_input to name the output video
alpha_output = save_video(alpha, output_path="./mask_outputs/{}_alpha.mp4".format(video_state["video_name"]), fps=fps)
alpha_output = save_video(alpha, output_path="./mask_outputs/{}_alpha.mp4".format(file_name), fps=fps)
# alpha_output = generate_video_from_frames(alpha, output_path="./results/{}_alpha.mp4".format(video_state["video_name"]), fps=fps, gray2rgb=True, audio_path=audio_path) # import video_input to name the output video
return foreground_output, alpha_output
return foreground_output, alpha_output, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
def show_outputs():
return gr.update(visible=True), gr.update(visible=True)
def add_audio_to_video(video_path, audio_path, output_path):
try:
video_input = ffmpeg.input(video_path)
@ -392,8 +403,8 @@ def restart():
},
"track_end_number": None,
}, [[],[]], None, None, \
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),\
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),\
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=False, choices=[], value=[]), "", gr.update(visible=False)
@ -529,7 +540,16 @@ def display(vace_video_input, vace_video_mask, video_prompt_video_guide_trigger)
visible=False,
min_width=100,
scale=1)
mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask Selection", info="Choose 1~all mask(s) added in Step 2", visible=False)
matting_type = gr.Radio(
choices=["Foreground", "Background"],
value="Foreground",
label="Matting Type",
info="Type of Video Matting to Generate",
interactive=True,
visible=False,
min_width=100,
scale=1)
mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask Selection", info="Choose 1~all mask(s) added in Step 2", visible=False, scale=2)
gr.Markdown("---")
@ -549,9 +569,9 @@ def display(vace_video_input, vace_video_mask, video_prompt_video_guide_trigger)
template_frame = gr.Image(label="Start Frame", type="pil",interactive=True, elem_id="template_frame", visible=False, elem_classes="image")
with gr.Row():
clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False, min_width=100)
add_mask_button = gr.Button(value="Add Mask", interactive=True, visible=False, min_width=100)
add_mask_button = gr.Button(value="Set Mask", interactive=True, visible=False, min_width=100)
remove_mask_button = gr.Button(value="Remove Mask", interactive=True, visible=False, min_width=100) # no use
matting_button = gr.Button(value="Video Matting", interactive=True, visible=False, min_width=100)
matting_button = gr.Button(value="Generate Video Matting", interactive=True, visible=False, min_width=100)
with gr.Row():
gr.Markdown("")
@ -560,11 +580,11 @@ def display(vace_video_input, vace_video_mask, video_prompt_video_guide_trigger)
with gr.Column(scale=2):
foreground_video_output = gr.Video(label="Masked Video Output", visible=False, elem_classes="video")
foreground_output_button = gr.Button(value="Black & White Video Output", visible=False, elem_classes="new_button")
export_to_vace_video_input_btn = gr.Button("Export to Vace Video Input Video For Inpainting")
export_to_vace_video_input_btn = gr.Button("Export to Vace Video Input Video For Inpainting", visible= False)
with gr.Column(scale=2):
alpha_video_output = gr.Video(label="B & W Mask Video Output", visible=False, elem_classes="video")
alpha_output_button = gr.Button(value="Alpha Mask Output", visible=False, elem_classes="new_button")
export_to_vace_video_mask_btn = gr.Button("Export to Vace Video Input and Video Mask for stronger Inpainting")
export_to_vace_video_mask_btn = gr.Button("Export to Vace Video Input and Video Mask for stronger Inpainting", visible= False)
export_to_vace_video_input_btn.click(fn=export_to_vace_video_input, inputs= [foreground_video_output], outputs= [video_prompt_video_guide_trigger, vace_video_input])
export_to_vace_video_mask_btn.click(fn=export_to_vace_video_mask, inputs= [foreground_video_output, alpha_video_output], outputs= [video_prompt_video_guide_trigger, vace_video_input, vace_video_mask])
@ -575,7 +595,7 @@ def display(vace_video_input, vace_video_mask, video_prompt_video_guide_trigger)
video_input, video_state
],
outputs=[video_state, video_info, template_frame,
image_selection_slider, end_selection_slider, track_pause_number_slider, point_prompt, clear_button_click, add_mask_button, matting_button, template_frame,
image_selection_slider, end_selection_slider, track_pause_number_slider, point_prompt, matting_type, clear_button_click, add_mask_button, matting_button, template_frame,
foreground_video_output, alpha_video_output, foreground_output_button, alpha_output_button, mask_dropdown, step2_title]
)
@ -609,9 +629,12 @@ def display(vace_video_input, vace_video_mask, video_prompt_video_guide_trigger)
# video matting
matting_button.click(
fn=show_outputs,
inputs=[],
outputs=[foreground_video_output, alpha_video_output]).then(
fn=video_matting,
inputs=[video_state, end_selection_slider, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size],
outputs=[foreground_video_output, alpha_video_output]
inputs=[video_state, end_selection_slider, matting_type, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size],
outputs=[foreground_video_output, alpha_video_output,foreground_video_output, alpha_video_output, export_to_vace_video_input_btn, export_to_vace_video_mask_btn]
)
# click to get mask
@ -631,7 +654,7 @@ def display(vace_video_input, vace_video_mask, video_prompt_video_guide_trigger)
click_state,
foreground_video_output, alpha_video_output,
template_frame,
image_selection_slider , track_pause_number_slider,point_prompt, clear_button_click,
image_selection_slider, end_selection_slider, track_pause_number_slider,point_prompt, export_to_vace_video_input_btn, export_to_vace_video_mask_btn, matting_type, clear_button_click,
add_mask_button, matting_button, template_frame, foreground_video_output, alpha_video_output, remove_mask_button, foreground_output_button, alpha_output_button, mask_dropdown, video_info, step2_title
],
queue=False,
@ -646,7 +669,7 @@ def display(vace_video_input, vace_video_mask, video_prompt_video_guide_trigger)
click_state,
foreground_video_output, alpha_video_output,
template_frame,
image_selection_slider , track_pause_number_slider,point_prompt, clear_button_click,
image_selection_slider , end_selection_slider, track_pause_number_slider,point_prompt, export_to_vace_video_input_btn, export_to_vace_video_mask_btn, matting_type, clear_button_click,
add_mask_button, matting_button, template_frame, foreground_video_output, alpha_video_output, remove_mask_button, foreground_output_button, alpha_output_button, mask_dropdown, video_info, step2_title
],
queue=False,

View File

@ -209,34 +209,47 @@ class WanT2V:
def vace_latent(self, z, m):
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device, original_video = False, keep_frames= []):
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, original_video = False, keep_frames= [], start_frame = 0, pre_src_video = None):
image_sizes = []
trim_video = len(keep_frames)
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
for i, (sub_src_video, sub_src_mask, sub_pre_src_video) in enumerate(zip(src_video, src_mask,pre_src_video)):
prepend_count = 0 if sub_pre_src_video == None else sub_pre_src_video.shape[1]
num_frames = total_frames - prepend_count
if sub_src_mask is not None and sub_src_video is not None:
src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video)
src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame)
# src_video is [-1, 1], 0 = inpainting area (in fact 127 in [0, 255])
# src_mask is [-1, 1], 0 = preserve original video (in fact 127 in [0, 255]) and 1 = Inpainting (in fact 255 in [0, 255])
src_video[i] = src_video[i].to(device)
src_mask[i] = src_mask[i].to(device)
if prepend_count > 0:
src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1)
src_video_shape = src_video[i].shape
if src_video_shape[1] != num_frames:
src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
if src_video_shape[1] != total_frames:
src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
image_sizes.append(src_video[i].shape[2:])
elif sub_src_video is None:
src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
src_mask[i] = torch.ones_like(src_video[i], device=device)
if prepend_count > 0:
src_video[i] = torch.cat( [sub_pre_src_video, torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)], dim=1)
src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), torch.ones((3, num_frames, image_size[0], image_size[1]), device=device)] ,1)
else:
src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
src_mask[i] = torch.ones_like(src_video[i], device=device)
image_sizes.append(image_size)
else:
src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video)
src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame)
src_video[i] = src_video[i].to(device)
src_mask[i] = torch.zeros_like(src_video[i], device=device) if original_video else torch.ones_like(src_video[i], device=device)
if prepend_count > 0:
src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1)
src_video_shape = src_video[i].shape
if src_video_shape[1] != num_frames:
src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
if src_video_shape[1] != total_frames:
src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
image_sizes.append(src_video[i].shape[2:])
for k, keep in enumerate(keep_frames):
if not keep:

View File

@ -22,18 +22,18 @@ __all__ = ['cache_video', 'cache_image', 'str2bool']
from PIL import Image
def resample(video_fps, video_frames_count, max_frames, target_fps):
def resample(video_fps, video_frames_count, max_target_frames_count, target_fps, start_target_frame ):
import math
video_frame_duration = 1 /video_fps
target_frame_duration = 1 / target_fps
cur_time = 0
target_time = 0
frame_no = 0
target_time = start_target_frame * target_frame_duration
frame_no = math.ceil(target_time / video_frame_duration)
cur_time = frame_no * video_frame_duration
frame_ids =[]
while True:
if max_frames != 0 and len(frame_ids) >= max_frames:
if max_target_frames_count != 0 and len(frame_ids) >= max_target_frames_count :
break
add_frames_count = math.ceil( (target_time -cur_time) / video_frame_duration )
frame_no += add_frames_count
@ -42,6 +42,7 @@ def resample(video_fps, video_frames_count, max_frames, target_fps):
frame_ids.append(frame_no)
cur_time += add_frames_count * video_frame_duration
target_time += target_frame_duration
frame_ids = frame_ids[:max_target_frames_count]
return frame_ids
def get_video_frame(file_name, frame_no):

View File

@ -182,14 +182,14 @@ class VaceVideoProcessor(object):
def _get_frameid_bbox_adjust_last(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0):
def _get_frameid_bbox_adjust_last(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0, start_frame =0):
from wan.utils.utils import resample
target_fps = self.max_fps
# video_frames_count = len(frame_timestamps)
frame_ids= resample(fps, video_frames_count, max_frames, target_fps)
frame_ids= resample(fps, video_frames_count, max_frames, target_fps, start_frame )
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
h, w = y2 - y1, x2 - x1
@ -206,7 +206,7 @@ class VaceVideoProcessor(object):
np.log2(np.sqrt(max_area_z))
)))
seq_len = max_area_z * ((max_frames- 1) // df +1)
seq_len = max_area_z * ((max_frames- start_frame - 1) // df +1)
# of = min(
# (len(frame_ids) - 1) // df + 1,
@ -226,9 +226,9 @@ class VaceVideoProcessor(object):
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
def _get_frameid_bbox(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0):
def _get_frameid_bbox(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0, start_frame= 0):
if self.keep_last:
return self._get_frameid_bbox_adjust_last(fps, video_frames_count, h, w, crop_box, rng, max_frames= max_frames)
return self._get_frameid_bbox_adjust_last(fps, video_frames_count, h, w, crop_box, rng, max_frames= max_frames, start_frame= start_frame)
else:
return self._get_frameid_bbox_default(fps, video_frames_count, h, w, crop_box, rng, max_frames= max_frames)
@ -238,7 +238,7 @@ class VaceVideoProcessor(object):
def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs):
return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs)
def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, max_frames= 0, trim_video =0, **kwargs):
def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, max_frames= 0, trim_video =0, start_frame = 0, **kwargs):
rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000)
# read video
import decord
@ -254,7 +254,7 @@ class VaceVideoProcessor(object):
if src_video != None:
fps = 16
length = src_video.shape[0]
length = src_video.shape[0] + start_frame
if len(readers) > 0:
min_readers = min([len(r) for r in readers])
length = min(length, min_readers )
@ -269,7 +269,7 @@ class VaceVideoProcessor(object):
h, w = src_video.shape[1:3]
else:
h, w = readers[0].next().shape[:2]
frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, length, h, w, crop_box, rng, max_frames=max_frames)
frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, length, h, w, crop_box, rng, max_frames=max_frames, start_frame = start_frame )
# preprocess video
videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers]

337
wgp.py
View File

@ -144,13 +144,23 @@ def process_prompt_and_add_tasks(state, model_choice):
gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P")
return
sliding_window_repeat = inputs["sliding_window_repeat"]
sliding_window = sliding_window_repeat > 0
if "Vace" in model_filename:
video_prompt_type = inputs["video_prompt_type"]
image_refs = inputs["image_refs"]
video_guide = inputs["video_guide"]
video_mask = inputs["video_mask"]
if "Vace" in model_filename and "1.3B" in model_filename :
if sliding_window:
if inputs["repeat_generation"]!=1:
gr.Info("Only one Video generated per Prompt is supported when Sliding windows is used")
return
if inputs["sliding_window_overlap"]>=inputs["video_length"] :
gr.Info("The number of frames of the Sliding Window Overlap must be less than the Number of Frames to Generate")
return
if "1.3B" in model_filename :
resolution_reformated = str(height) + "*" + str(width)
if not resolution_reformated in VACE_SIZE_CONFIGS:
res = (" and ").join(VACE_SIZE_CONFIGS.keys())
@ -197,6 +207,9 @@ def process_prompt_and_add_tasks(state, model_choice):
image_refs = resize_and_remove_background(image_refs, width, height, inputs["remove_background_image_ref"] ==1)
if sliding_window and len(prompts) > 0:
prompts = ["\n".join(prompts)]
for single_prompt in prompts:
extra_inputs = {
"prompt" : single_prompt,
@ -2053,7 +2066,7 @@ def convert_image(image):
return cast(Image, ImageOps.exif_transpose(image))
def preprocess_video(process_type, height, width, video_in, max_frames):
def preprocess_video(process_type, height, width, video_in, max_frames, start_frame=0):
from wan.utils.utils import resample
@ -2063,8 +2076,10 @@ def preprocess_video(process_type, height, width, video_in, max_frames):
fps = reader.get_avg_fps()
frame_nos = resample(fps, len(reader), max_frames= max_frames, target_fps=16)
frame_nos = resample(fps, len(reader), max_target_frames_count= max_frames, target_fps=16, start_target_frame= start_frame)
frames_list = reader.get_batch(frame_nos)
if len(frames_list) == 0:
return None
frame_height, frame_width, _ = frames_list[0].shape
scale = ((height * width ) / (frame_height * frame_width))**(1/2)
@ -2187,6 +2202,9 @@ def generate_video(
video_guide,
video_mask,
keep_frames,
sliding_window_repeat,
sliding_window_overlap,
sliding_window_discard_last_frames,
remove_background_image_ref,
temporal_upsampling,
spatial_upsampling,
@ -2342,41 +2360,6 @@ def generate_video(
else:
raise gr.Error("Teacache not supported for this model")
if "Vace" in model_filename:
# video_prompt_type = video_prompt_type +"G"
if any(process in video_prompt_type for process in ("P", "D", "G")) :
prompts_max = gen["prompts_max"]
status = get_generation_status(prompt_no, prompts_max, 1, 1)
preprocess_type = None
if "P" in video_prompt_type :
progress_args = [0, status + " - Extracting Open Pose Information"]
preprocess_type = "pose"
elif "D" in video_prompt_type :
progress_args = [0, status + " - Extracting Depth Information"]
preprocess_type = "depth"
elif "G" in video_prompt_type :
progress_args = [0, status + " - Extracting Gray Level Information"]
preprocess_type = "gray"
if preprocess_type != None :
progress(*progress_args )
gen["progress_args"] = progress_args
video_guide = preprocess_video(preprocess_type, width=width, height=height,video_in=video_guide, max_frames= video_length)
image_refs = image_refs.copy() if image_refs != None else None # required since prepare_source do inplace modifications
keep_frames_parsed, error = parse_keep_frames(keep_frames, video_length)
if len(error) > 0:
raise gr.Error(f"invalid keep frames {keep_frames}")
src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide],
[video_mask],
[image_refs],
video_length, VACE_SIZE_CONFIGS[resolution_reformated], "cpu",
original_video= "O" in video_prompt_type,
keep_frames=keep_frames_parsed)
else:
src_video, src_mask, src_ref_images = None, None, None
import random
if seed == None or seed <0:
@ -2393,6 +2376,21 @@ def generate_video(
gen["prompt"] = prompt
repeat_no = 0
extra_generation = 0
sliding_window = sliding_window_repeat > 0
if sliding_window:
start_frame = 0
reuse_frames = sliding_window_overlap
discard_last_frames = sliding_window_discard_last_frames #4
repeat_generation = sliding_window_repeat
prompts = prompt.split("\n")
prompts = [part for part in prompts if len(prompt)>0]
gen["sliding_window"] = sliding_window
frames_already_processed = None
pre_video_guide = None
while True:
extra_generation += gen.get("extra_orders",0)
gen["extra_orders"] = 0
@ -2400,10 +2398,59 @@ def generate_video(
gen["total_generation"] = total_generation
if abort or repeat_no >= total_generation:
break
if "Vace" in model_filename and (repeat_no == 0 or sliding_window):
if sliding_window:
prompt = prompts[repeat_no] if repeat_no < len(prompts) else prompts[-1]
# video_prompt_type = video_prompt_type +"G"
image_refs_copy = image_refs.copy() if image_refs != None else None # required since prepare_source do inplace modifications
video_guide_copy = video_guide
video_mask_copy = video_mask
if any(process in video_prompt_type for process in ("P", "D", "G")) :
prompts_max = gen["prompts_max"]
status = get_generation_status(prompt_no, prompts_max, 1, 1, sliding_window)
preprocess_type = None
if "P" in video_prompt_type :
progress_args = [0, status + " - Extracting Open Pose Information"]
preprocess_type = "pose"
elif "D" in video_prompt_type :
progress_args = [0, status + " - Extracting Depth Information"]
preprocess_type = "depth"
elif "G" in video_prompt_type :
progress_args = [0, status + " - Extracting Gray Level Information"]
preprocess_type = "gray"
if preprocess_type != None :
progress(*progress_args )
gen["progress_args"] = progress_args
video_guide_copy = preprocess_video(preprocess_type, width=width, height=height,video_in=video_guide, max_frames= video_length if repeat_no ==0 else video_length - reuse_frames, start_frame = start_frame)
keep_frames_parsed, error = parse_keep_frames(keep_frames, video_length)
if len(error) > 0:
raise gr.Error(f"invalid keep frames {keep_frames}")
if repeat_no == 0:
image_size = VACE_SIZE_CONFIGS[resolution_reformated] # default frame dimensions until it is set by video_src (if there is any)
src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_copy],
[video_mask_copy ],
[image_refs_copy],
video_length, image_size = image_size, device ="cpu",
original_video= "O" in video_prompt_type,
keep_frames=keep_frames_parsed,
start_frame = start_frame,
pre_src_video = [pre_video_guide]
)
if repeat_no == 0 and src_video != None and len(src_video) > 0:
image_size = src_video[0].shape[-2:]
else:
src_video, src_mask, src_ref_images = None, None, None
repeat_no +=1
gen["repeat_no"] = repeat_no
prompts_max = gen["prompts_max"]
status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation)
status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation, sliding_window)
yield status
@ -2539,6 +2586,15 @@ def generate_video(
# yield f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s"
else:
sample = samples.cpu()
if sliding_window :
start_frame += video_length
if discard_last_frames > 0:
sample = sample[: , :-discard_last_frames]
start_frame -= discard_last_frames
pre_video_guide = sample[:, -reuse_frames:]
if repeat_no > 1:
sample = sample[: , reuse_frames:]
start_frame -= reuse_frames
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
if os.name == 'nt':
@ -2565,7 +2621,13 @@ def generate_video(
if exp > 0:
from rife.inference import temporal_interpolation
sample = temporal_interpolation( os.path.join("ckpts", "flownet.pkl"), sample, exp, device=processing_device)
if sliding_window and repeat_no > 1:
sample = torch.cat([frames_already_processed[:, -2:-1], sample], dim=1)
sample = temporal_interpolation( os.path.join("ckpts", "flownet.pkl"), sample, exp, device=processing_device)
sample = sample[:, 1:]
else:
sample = temporal_interpolation( os.path.join("ckpts", "flownet.pkl"), sample, exp, device=processing_device)
fps = fps * 2**exp
if len(spatial_upsampling) > 0:
@ -2590,6 +2652,12 @@ def generate_video(
new_frames = None
sample = sample * 2 - 1
if sliding_window :
if repeat_no == 1:
frames_already_processed = sample
else:
sample = torch.cat([frames_already_processed, sample], dim=1)
frames_already_processed = sample
cache_video(
tensor=sample[None],
@ -2616,7 +2684,8 @@ def generate_video(
print(f"New video saved to Path: "+video_path)
file_list.append(video_path)
state['update_gallery'] = True
seed += 1
if not sliding_window:
seed += 1
if temp_filename!= None and os.path.isfile(temp_filename):
os.remove(temp_filename)
@ -2694,17 +2763,19 @@ def process_tasks(state, progress=gr.Progress()):
yield f"Total Generation Time: {end_time-start_time:.1f}s"
def get_generation_status(prompt_no, prompts_max, repeat_no, repeat_max):
if prompts_max == 1:
def get_generation_status(prompt_no, prompts_max, repeat_no, repeat_max, sliding_window):
item = "Sliding Window" if sliding_window else "Sample"
if prompts_max == 1:
if repeat_max == 1:
return "Video"
else:
return f"Sample {repeat_no}/{repeat_max}"
return f"{item} {repeat_no}/{repeat_max}"
else:
if repeat_max == 1:
return f"Prompt {prompt_no}/{prompts_max}"
else:
return f"Prompt {prompt_no}/{prompts_max}, Sample {repeat_no}/{repeat_max}"
return f"Prompt {prompt_no}/{prompts_max}, {item} {repeat_no}/{repeat_max}"
refresh_id = 0
@ -2720,7 +2791,8 @@ def update_status(state):
prompts_max = gen.get("prompts_max",0)
total_generation = gen["total_generation"]
repeat_no = gen["repeat_no"]
status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation)
sliding_window = gen["sliding_window"]
status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation, sliding_window)
gen["progress_status"] = status
gen["refresh"] = get_new_refresh_id()
@ -2737,7 +2809,7 @@ def one_more_sample(state):
prompts_max = gen.get("prompts_max",0)
total_generation = gen["total_generation"] + extra_orders
repeat_no = gen["repeat_no"]
status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation)
status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation, gen.get("sliding_window",False))
gen["progress_status"] = status
@ -3059,7 +3131,7 @@ def prepare_inputs_dict(target, inputs ):
if not "Vace" in model_filename:
unsaved_params = ["video_prompt_type", "keep_frames", "remove_background_image_ref"]
unsaved_params = ["video_prompt_type", "keep_frames", "remove_background_image_ref", "sliding_window_repeat", "sliding_window_overlap", "sliding_window_discard_last_frames"]
for k in unsaved_params:
inputs.pop(k)
@ -3102,6 +3174,9 @@ def save_inputs(
video_guide,
video_mask,
keep_frames,
sliding_window_repeat,
sliding_window_overlap,
sliding_window_discard_last_frames,
remove_background_image_ref,
temporal_upsampling,
spatial_upsampling,
@ -3437,7 +3512,6 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
# video_prompt_type_image_refs = gr.Checkbox(value="I" in video_prompt_type_value , label= "Use References Images (Faces, Objects) to customize New Video", scale =1 )
video_guide = gr.Video(label= "Control Video", visible= "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None),)
# keep_frames = gr.Slider(0, 100, value=ui_defaults.get("keep_frames",0), step=1, label="Nb of frames in Control Video to use (0 = max)", visible= "V" in video_prompt_type_value, scale = 2 )
keep_frames = gr.Text(value=ui_defaults.get("keep_frames","") , visible= "V" in video_prompt_type_value, scale = 2, label= "Frames to keep in Control Video (empty=All, 1=first, a:b for a range, space to separate values)" ) #, -1=last
image_refs = gr.Gallery( label ="Reference Images",
type ="pil", show_label= True,
@ -3513,28 +3587,32 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
label="Resolution"
)
with gr.Row():
with gr.Column():
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)")
with gr.Column():
num_inference_steps = gr.Slider(1, 100, value=ui_defaults.get("num_inference_steps",30), step=1, label="Number of Inference Steps")
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)")
num_inference_steps = gr.Slider(1, 100, value=ui_defaults.get("num_inference_steps",30), step=1, label="Number of Inference Steps")
show_advanced = gr.Checkbox(label="Advanced Mode", value=advanced_ui)
with gr.Row(visible=advanced_ui) as advanced_row:
with gr.Column():
seed = gr.Slider(-1, 999999999, value=ui_defaults["seed"], step=1, label="Seed (-1 for random)")
with gr.Row():
repeat_generation = gr.Slider(1, 25.0, value=ui_defaults.get("repeat_generation",1), step=1, label="Default Number of Generated Videos per Prompt")
multi_images_gen_type = gr.Dropdown( value=ui_defaults.get("multi_images_gen_type",0),
choices=[
("Generate every combination of images and texts", 0),
("Match images and text prompts", 1),
], visible= args.multiple_images, label= "Multiple Images as Texts Prompts"
)
with gr.Row():
guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance_scale",5), step=0.5, label="Guidance Scale", visible=True)
embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=False)
flow_shift = gr.Slider(0.0, 25.0, value=ui_defaults.get("flow_shift",3), step=0.1, label="Shift Scale")
with gr.Row():
negative_prompt = gr.Textbox(label="Negative Prompt", value=ui_defaults.get("negative_prompt", "") )
with gr.Tabs(visible=advanced_ui) as advanced_row:
# with gr.Row(visible=advanced_ui) as advanced_row:
with gr.Tab("Generation"):
with gr.Column():
seed = gr.Slider(-1, 999999999, value=ui_defaults["seed"], step=1, label="Seed (-1 for random)")
with gr.Row():
repeat_generation = gr.Slider(1, 25.0, value=ui_defaults.get("repeat_generation",1), step=1, label="Default Number of Generated Videos per Prompt")
multi_images_gen_type = gr.Dropdown( value=ui_defaults.get("multi_images_gen_type",0),
choices=[
("Generate every combination of images and texts", 0),
("Match images and text prompts", 1),
], visible= args.multiple_images, label= "Multiple Images as Texts Prompts"
)
with gr.Row():
guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance_scale",5), step=0.5, label="Guidance Scale", visible=True)
embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=False)
flow_shift = gr.Slider(0.0, 25.0, value=ui_defaults.get("flow_shift",3), step=0.1, label="Shift Scale")
with gr.Row():
negative_prompt = gr.Textbox(label="Negative Prompt", value=ui_defaults.get("negative_prompt", "") )
with gr.Tab("Loras"):
with gr.Column(visible = True): #as loras_column:
gr.Markdown("<B>Loras can be used to create special effects on the video by mentioning a trigger word in the Prompt. You can save Loras combinations in presets.</B>")
loras_choices = gr.Dropdown(
@ -3548,7 +3626,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
loras_multipliers = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by space characters or carriage returns, line that starts with # are ignored", value=launch_multis_str)
with gr.Row():
gr.Markdown("<B>Tea Cache accelerates by skipping intelligently some steps, the more steps are skipped the lower the quality of the video (Tea Cache consumes also VRAM)</B>")
with gr.Row():
with gr.Tab("Speed"):
with gr.Column():
gr.Markdown("<B>Tea Cache accelerates the Video generation by skipping denoising steps. This may impact the quality</B>")
tea_cache_setting = gr.Dropdown(
choices=[
("Tea Cache Disabled", 0),
@ -3564,9 +3645,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
)
tea_cache_start_step_perc = gr.Slider(0, 100, value=ui_defaults.get("tea_cache_start_step_perc",0), step=1, label="Tea Cache starting moment in % of generation")
with gr.Row():
with gr.Tab("Upsampling"):
with gr.Column():
gr.Markdown("<B>Upsampling - postprocessing that may improve fluidity and the size of the video</B>")
with gr.Row():
temporal_upsampling = gr.Dropdown(
choices=[
("Disabled", ""),
@ -3590,6 +3672,59 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
label="Spatial Upsampling"
)
with gr.Tab("Quality"):
with gr.Row():
gr.Markdown("<B>Experimental: Skip Layer Guidance, should improve video quality</B>")
with gr.Row():
slg_switch = gr.Dropdown(
choices=[
("OFF", 0),
("ON", 1),
],
value=ui_defaults.get("slg_switch",0),
visible=True,
scale = 1,
label="Skip Layer guidance"
)
slg_layers = gr.Dropdown(
choices=[
(str(i), i ) for i in range(40)
],
value=ui_defaults.get("slg_layers", ["9"]),
multiselect= True,
label="Skip Layers",
scale= 3
)
with gr.Row():
slg_start_perc = gr.Slider(0, 100, value=ui_defaults.get("slg_start_perc",10), step=1, label="Denoising Steps % start")
slg_end_perc = gr.Slider(0, 100, value=ui_defaults.get("slg_end_perc",90), step=1, label="Denoising Steps % end")
with gr.Row():
gr.Markdown("<B>Experimental: Classifier-Free Guidance Zero Star, better adherence to Text Prompt")
with gr.Row():
cfg_star_switch = gr.Dropdown(
choices=[
("OFF", 0),
("ON", 1),
],
value=ui_defaults.get("cfg_star_switch",0),
visible=True,
scale = 1,
label="CFG Star"
)
with gr.Row():
cfg_zero_step = gr.Slider(-1, 39, value=ui_defaults.get("cfg_zero_step",-1), step=1, label="CFG Zero below this Layer (Extra Process)")
with gr.Tab("Sliding Window", visible= "Vace" in model_filename ) as sliding_window_tab:
with gr.Column(visible= "Vace" in model_filename ) as sliding_window_row:
gr.Markdown("<B>A Sliding Window allows you to generate video longer than those of the model limits</B>")
sliding_window_repeat = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_repeat", 0), step=1, label="Sliding Window Iterations (O=Disabled)")
sliding_window_overlap = gr.Slider(1, 32, value=ui_defaults.get("sliding_window_overlap",16), step=1, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
sliding_window_discard_last_frames = gr.Slider(1, 10, value=ui_defaults.get("sliding_window_discard_last_frames", 4), step=1, label="Discard Last Frames of a Window (that may have bad quality)")
with gr.Tab("Miscellaneous"):
gr.Markdown("<B>With Riflex you can generate videos longer than 5s which is the default duration of videos used to train the model</B>")
RIFLEx_setting = gr.Dropdown(
choices=[
@ -3600,50 +3735,9 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
value=ui_defaults.get("RIFLEx_setting",0),
label="RIFLEx positional embedding to generate long video"
)
with gr.Row():
gr.Markdown("<B>Experimental: Skip Layer Guidance, should improve video quality</B>")
with gr.Row():
slg_switch = gr.Dropdown(
choices=[
("OFF", 0),
("ON", 1),
],
value=ui_defaults.get("slg_switch",0),
visible=True,
scale = 1,
label="Skip Layer guidance"
)
slg_layers = gr.Dropdown(
choices=[
(str(i), i ) for i in range(40)
],
value=ui_defaults.get("slg_layers", ["9"]),
multiselect= True,
label="Skip Layers",
scale= 3
)
with gr.Row():
slg_start_perc = gr.Slider(0, 100, value=ui_defaults.get("slg_start_perc",10), step=1, label="Denoising Steps % start")
slg_end_perc = gr.Slider(0, 100, value=ui_defaults.get("slg_end_perc",90), step=1, label="Denoising Steps % end")
with gr.Row():
gr.Markdown("<B>Experimental: Classifier-Free Guidance Zero Star, better adherence to Text Prompt")
with gr.Row():
cfg_star_switch = gr.Dropdown(
choices=[
("OFF", 0),
("ON", 1),
],
value=ui_defaults.get("cfg_star_switch",0),
visible=True,
scale = 1,
label="CFG Star"
)
with gr.Row():
cfg_zero_step = gr.Slider(-1, 39, value=ui_defaults.get("cfg_zero_step",-1), step=1, label="CFG Zero below this Layer (Extra Process)")
with gr.Row():
save_settings_btn = gr.Button("Set Settings as Default", visible = not args.lock_config)
with gr.Row():
save_settings_btn = gr.Button("Set Settings as Default", visible = not args.lock_config)
if not update_form:
with gr.Column():
@ -3697,11 +3791,11 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
let countdown = 5;
const label = document.getElementById('quit_timer_label');
if (label) {
label.innerText = `Quitting in ${countdown}...`;
label.innerText = `${countdown}...`;
window.quitCountdownInterval = setInterval(() => {
countdown--;
if (countdown > 0) {
label.innerText = `Quitting in ${countdown}...`;
label.innerText = `${countdown}`;
} else {
clearInterval(window.quitCountdownInterval);
findAndClickGradioButton('comfirm_quit_btn_hidden');
@ -3841,7 +3935,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
)
extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column,
prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, advanced_row] # show_advanced presets_column,
prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, advanced_row, sliding_window_tab] # show_advanced presets_column,
if update_form:
locals_dict = locals()
gen_inputs = [state_dict if k=="state" else locals_dict[k] for k in inputs_names] + [state_dict] + extra_inputs
@ -4141,13 +4235,14 @@ def generate_about_tab():
gr.Markdown("- <B>Alibaba Wan team for the best open source video generator")
gr.Markdown("- <B>Alibaba Vace and Fun Teams for their incredible control net models")
gr.Markdown("- <B>Cocktail Peanuts</B> : QA and simple installation via Pinokio.computer")
gr.Markdown("- <B>Tophness</B> : created multi tabs and queuing frameworks")
gr.Markdown("- <B>Tophness</B> : created (former) multi tabs and queuing frameworks")
gr.Markdown("- <B>AmericanPresidentJimmyCarter</B> : added original support for Skip Layer Guidance")
gr.Markdown("- <B>Remade_AI</B> : for their awesome Loras collection")
gr.Markdown("<BR>Huge acknowlegments to these great open source projects used in WanGP:")
gr.Markdown("- <B>Rife</B>: temporal upsampler (https://github.com/hzwer/ECCV2022-RIFE)")
gr.Markdown("- <B>DwPose</B>: Open Pose extractor (https://github.com/IDEA-Research/DWPose)")
gr.Markdown("- <B>Midas</B>: Depth extractor (https://github.com/isl-org/MiDaS")
gr.Markdown("- <B>Matanyone</B> and <B>SAM2</B>: Mask Generation (https://github.com/pq-yang/MatAnyone) and (https://github.com/facebookresearch/sam2)")
def generate_info_tab():