fixed vace bugs

This commit is contained in:
DeepBeepMeep 2025-05-05 23:58:21 +02:00
parent cf02cc4004
commit 94d9b4aa4d
7 changed files with 258 additions and 252 deletions

View File

@ -15,6 +15,7 @@ from .modules.model import WanModel
from .modules.t5 import T5EncoderModel from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE from .modules.vae import WanVAE
from wan.modules.posemb_layers import get_rotary_pos_embed from wan.modules.posemb_layers import get_rotary_pos_embed
from wan.utils.utils import calculate_new_dimensions
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps) get_sampling_sigmas, retrieve_timesteps)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
@ -190,6 +191,7 @@ class DTT2V:
input_video = None, input_video = None,
height: int = 480, height: int = 480,
width: int = 832, width: int = 832,
fit_into_canvas = True,
num_frames: int = 97, num_frames: int = 97,
num_inference_steps: int = 50, num_inference_steps: int = 50,
shift: float = 1.0, shift: float = 1.0,
@ -221,15 +223,16 @@ class DTT2V:
i2v_extra_kwrags = {} i2v_extra_kwrags = {}
prefix_video = None prefix_video = None
predix_video_latent_length = 0 predix_video_latent_length = 0
if input_video != None: if input_video != None:
_ , _ , height, width = input_video.shape _ , _ , height, width = input_video.shape
elif image != None: elif image != None:
image = image[0] image = image[0]
frame_width, frame_height = image.size frame_width, frame_height = image.size
scale = min(height / frame_height, width / frame_width) height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas)
height = (int(frame_height * scale) // 16) * 16
width = (int(frame_width * scale) // 16) * 16
image = np.array(image.resize((width, height))).transpose(2, 0, 1) image = np.array(image.resize((width, height))).transpose(2, 0, 1)
latent_length = (num_frames - 1) // 4 + 1 latent_length = (num_frames - 1) // 4 + 1
latent_height = height // 8 latent_height = height // 8
latent_width = width // 8 latent_width = width // 8

View File

@ -25,7 +25,7 @@ from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps) get_sampling_sigmas, retrieve_timesteps)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from wan.modules.posemb_layers import get_rotary_pos_embed from wan.modules.posemb_layers import get_rotary_pos_embed
from wan.utils.utils import resize_lanczos from wan.utils.utils import resize_lanczos, calculate_new_dimensions
def optimized_scale(positive_flat, negative_flat): def optimized_scale(positive_flat, negative_flat):
@ -120,7 +120,7 @@ class WanI2V:
img2 = None, img2 = None,
height =720, height =720,
width = 1280, width = 1280,
max_area=720 * 1280, fit_into_canvas = True,
frame_num=81, frame_num=81,
shift=5.0, shift=5.0,
sample_solver='unipc', sample_solver='unipc',
@ -189,21 +189,15 @@ class WanI2V:
frame_num +=1 frame_num +=1
lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2) lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2)
h, w = img.shape[1:] h, w = img.shape[1:]
# aspect_ratio = h / w
scale1 = min(height / h, width / w) h, w = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
scale2 = min(height / h, width / w)
scale = max(scale1, scale2)
new_height = int(h * scale)
new_width = int(w * scale)
lat_h = round( lat_h = round(
new_height // self.vae_stride[1] // h // self.vae_stride[1] //
self.patch_size[1] * self.patch_size[1]) self.patch_size[1] * self.patch_size[1])
lat_w = round( lat_w = round(
new_width // self.vae_stride[2] // w // self.vae_stride[2] //
self.patch_size[2] * self.patch_size[2]) self.patch_size[2] * self.patch_size[2])
h = lat_h * self.vae_stride[1] h = lat_h * self.vae_stride[1]
w = lat_w * self.vae_stride[2] w = lat_w * self.vae_stride[2]

View File

@ -963,7 +963,7 @@ class WanModel(ModelMixin, ConfigMixin):
hints_list = [None ] *len(x_list) hints_list = [None ] *len(x_list)
else: else:
# Vace embeddings # Vace embeddings
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] c = [self.vace_patch_embedding(u.to(self.vace_patch_embedding.weight.dtype).unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c] c = [u.flatten(2).transpose(1, 2) for u in c]
c = c[0] c = c[0]

View File

@ -177,15 +177,16 @@ class WanT2V:
def vace_latent(self, z, m): def vace_latent(self, z, m):
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, original_video = False, keep_frames= [], start_frame = 0, pre_src_video = None): def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, original_video = False, keep_frames= [], start_frame = 0, fit_into_canvas = True, pre_src_video = None):
image_sizes = [] image_sizes = []
trim_video = len(keep_frames) trim_video = len(keep_frames)
canvas_height, canvas_width = image_size
for i, (sub_src_video, sub_src_mask, sub_pre_src_video) in enumerate(zip(src_video, src_mask,pre_src_video)): for i, (sub_src_video, sub_src_mask, sub_pre_src_video) in enumerate(zip(src_video, src_mask,pre_src_video)):
prepend_count = 0 if sub_pre_src_video == None else sub_pre_src_video.shape[1] prepend_count = 0 if sub_pre_src_video == None else sub_pre_src_video.shape[1]
num_frames = total_frames - prepend_count num_frames = total_frames - prepend_count
if sub_src_mask is not None and sub_src_video is not None: if sub_src_mask is not None and sub_src_video is not None:
src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame) src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame, canvas_height = canvas_height, canvas_width = canvas_width, fit_into_canvas = fit_into_canvas)
# src_video is [-1, 1], 0 = inpainting area (in fact 127 in [0, 255]) # src_video is [-1, 1], 0 = inpainting area (in fact 127 in [0, 255])
# src_mask is [-1, 1], 0 = preserve original video (in fact 127 in [0, 255]) and 1 = Inpainting (in fact 255 in [0, 255]) # src_mask is [-1, 1], 0 = preserve original video (in fact 127 in [0, 255]) and 1 = Inpainting (in fact 255 in [0, 255])
src_video[i] = src_video[i].to(device) src_video[i] = src_video[i].to(device)
@ -208,7 +209,7 @@ class WanT2V:
src_mask[i] = torch.ones_like(src_video[i], device=device) src_mask[i] = torch.ones_like(src_video[i], device=device)
image_sizes.append(image_size) image_sizes.append(image_size)
else: else:
src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame) src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame, canvas_height = canvas_height, canvas_width = canvas_width, fit_into_canvas = fit_into_canvas)
src_video[i] = src_video[i].to(device) src_video[i] = src_video[i].to(device)
src_mask[i] = torch.zeros_like(src_video[i], device=device) if original_video else torch.ones_like(src_video[i], device=device) src_mask[i] = torch.zeros_like(src_video[i], device=device) if original_video else torch.ones_like(src_video[i], device=device)
if prepend_count > 0: if prepend_count > 0:
@ -277,6 +278,7 @@ class WanT2V:
target_camera=None, target_camera=None,
context_scale=1.0, context_scale=1.0,
size=(1280, 720), size=(1280, 720),
fit_into_canvas = True,
frame_num=81, frame_num=81,
shift=5.0, shift=5.0,
sample_solver='unipc', sample_solver='unipc',
@ -430,7 +432,7 @@ class WanT2V:
kwargs.update({'cam_emb': cam_emb}) kwargs.update({'cam_emb': cam_emb})
if vace: if vace:
ref_images_count = len(input_ref_images[0]) if input_ref_images != None else 0 ref_images_count = len(input_ref_images[0]) if input_ref_images != None and input_ref_images[0] != None else 0
kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale}) kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale})
if overlapped_latents > 0: if overlapped_latents > 0:
z_reactive = [ zz[0:16, ref_images_count:overlapped_latents + ref_images_count].clone() for zz in z] z_reactive = [ zz[0:16, ref_images_count:overlapped_latents + ref_images_count].clone() for zz in z]

View File

@ -67,7 +67,17 @@ def remove_background(img, session=None):
return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0) return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0)
def calculate_new_dimensions(canvas_height, canvas_width, height, width, fit_into_canvas):
if fit_into_canvas:
scale1 = min(canvas_height / height, canvas_width / width)
scale2 = min(canvas_width / height, canvas_height / width)
scale = max(scale1, scale2)
else:
scale = (canvas_height * canvas_width / (height * width))**(1/2)
new_height = round( height * scale / 16) * 16
new_width = round( width * scale / 16) * 16
return new_height, new_width
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, fit_into_canvas = False ): def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, fit_into_canvas = False ):
if rm_background: if rm_background:

View File

@ -5,6 +5,7 @@ from PIL import Image
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torchvision.transforms.functional as TF import torchvision.transforms.functional as TF
from .utils import calculate_new_dimensions
class VaceImageProcessor(object): class VaceImageProcessor(object):
@ -182,53 +183,22 @@ class VaceVideoProcessor(object):
def _get_frameid_bbox_adjust_last(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0, start_frame =0): def _get_frameid_bbox_adjust_last(self, fps, video_frames_count, canvas_height, canvas_width, h, w, fit_into_canvas, crop_box, rng, max_frames= 0, start_frame =0):
from wan.utils.utils import resample from wan.utils.utils import resample
target_fps = self.max_fps target_fps = self.max_fps
# video_frames_count = len(frame_timestamps)
frame_ids= resample(fps, video_frames_count, max_frames, target_fps, start_frame ) frame_ids= resample(fps, video_frames_count, max_frames, target_fps, start_frame )
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
h, w = y2 - y1, x2 - x1 h, w = y2 - y1, x2 - x1
ratio = h / w oh, ow = calculate_new_dimensions(canvas_height, canvas_width, h, w, fit_into_canvas)
df, dh, dw = self.downsample
seq_len = self.seq_len
# min/max area of the [latent video]
min_area_z = self.min_area / (dh * dw)
# max_area_z = min(seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw))
max_area_z = min_area_z # workaround bug
# sample a frame number of the [latent video]
rand_area_z = np.square(np.power(2, rng.uniform(
np.log2(np.sqrt(min_area_z)),
np.log2(np.sqrt(max_area_z))
)))
seq_len = max_area_z * ((max_frames- start_frame - 1) // df +1)
# of = min(
# (len(frame_ids) - 1) // df + 1,
# int(seq_len / rand_area_z)
# )
of = (len(frame_ids) - 1) // df + 1
# deduce target shape of the [latent video]
# target_area_z = min(max_area_z, int(seq_len / of))
target_area_z = max_area_z
oh = round(np.sqrt(target_area_z * ratio))
ow = int(target_area_z / oh)
of = (of - 1) * df + 1
oh *= dh
ow *= dw
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
def _get_frameid_bbox(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0, start_frame= 0): def _get_frameid_bbox(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0, start_frame= 0, canvas_height = 0, canvas_width = 0, fit_into_canvas= True):
if self.keep_last: if self.keep_last:
return self._get_frameid_bbox_adjust_last(fps, video_frames_count, h, w, crop_box, rng, max_frames= max_frames, start_frame= start_frame) return self._get_frameid_bbox_adjust_last(fps, video_frames_count, canvas_height, canvas_width, h, w, fit_into_canvas, crop_box, rng, max_frames= max_frames, start_frame= start_frame)
else: else:
return self._get_frameid_bbox_default(fps, video_frames_count, h, w, crop_box, rng, max_frames= max_frames) return self._get_frameid_bbox_default(fps, video_frames_count, h, w, crop_box, rng, max_frames= max_frames)
@ -238,7 +208,7 @@ class VaceVideoProcessor(object):
def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs): def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs):
return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs) return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs)
def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, max_frames= 0, trim_video =0, start_frame = 0, **kwargs): def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, max_frames= 0, trim_video =0, start_frame = 0, canvas_height = 0, canvas_width = 0, fit_into_canvas = False, **kwargs):
rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000) rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000)
# read video # read video
import decord import decord
@ -269,7 +239,7 @@ class VaceVideoProcessor(object):
h, w = src_video.shape[1:3] h, w = src_video.shape[1:3]
else: else:
h, w = readers[0].next().shape[:2] h, w = readers[0].next().shape[:2]
frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, length, h, w, crop_box, rng, max_frames=max_frames, start_frame = start_frame ) frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, length, h, w, crop_box, rng, canvas_height = canvas_height, canvas_width = canvas_width, fit_into_canvas = fit_into_canvas, max_frames=max_frames, start_frame = start_frame )
# preprocess video # preprocess video
videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers] videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers]

413
wgp.py
View File

@ -84,7 +84,6 @@ def format_time(seconds):
hours = int(seconds // 3600) hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60) minutes = int((seconds % 3600) // 60)
return f"{hours}h {minutes}m" return f"{hours}h {minutes}m"
def pil_to_base64_uri(pil_image, format="png", quality=75): def pil_to_base64_uri(pil_image, format="png", quality=75):
if pil_image is None: if pil_image is None:
return None return None
@ -275,12 +274,12 @@ def process_prompt_and_add_tasks(state, model_choice):
video_guide = inputs["video_guide"] video_guide = inputs["video_guide"]
video_mask = inputs["video_mask"] video_mask = inputs["video_mask"]
if "1.3B" in model_filename : # if "1.3B" in model_filename :
resolution_reformated = str(height) + "*" + str(width) # resolution_reformated = str(height) + "*" + str(width)
if not resolution_reformated in VACE_SIZE_CONFIGS: # if not resolution_reformated in VACE_SIZE_CONFIGS:
res = (" and ").join(VACE_SIZE_CONFIGS.keys()) # res = (" and ").join(VACE_SIZE_CONFIGS.keys())
gr.Info(f"Video Resolution for Vace model is not supported. Only {res} resolutions are allowed.") # gr.Info(f"Video Resolution for Vace model is not supported. Only {res} resolutions are allowed.")
return # return
if "I" in video_prompt_type: if "I" in video_prompt_type:
if image_refs == None: if image_refs == None:
gr.Info("You must provide at least one Refererence Image") gr.Info("You must provide at least one Refererence Image")
@ -1995,7 +1994,8 @@ def apply_changes( state,
boost_choice = 1, boost_choice = 1,
clear_file_list = 0, clear_file_list = 0,
preload_model_policy_choice = 1, preload_model_policy_choice = 1,
UI_theme_choice = "default" UI_theme_choice = "default",
fit_canvas_choice = 0
): ):
if args.lock_config: if args.lock_config:
return return
@ -2016,7 +2016,8 @@ def apply_changes( state,
"boost" : boost_choice, "boost" : boost_choice,
"clear_file_list" : clear_file_list, "clear_file_list" : clear_file_list,
"preload_model_policy" : preload_model_policy_choice, "preload_model_policy" : preload_model_policy_choice,
"UI_theme" : UI_theme_choice "UI_theme" : UI_theme_choice,
"fit_canvas": fit_canvas_choice,
} }
if Path(server_config_filename).is_file(): if Path(server_config_filename).is_file():
@ -2050,7 +2051,7 @@ def apply_changes( state,
transformer_quantization = server_config["transformer_quantization"] transformer_quantization = server_config["transformer_quantization"]
transformer_types = server_config["transformer_types"] transformer_types = server_config["transformer_types"]
if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list"] for change in changes ): if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list", "fit_canvas"] for change in changes ):
model_choice = gr.Dropdown() model_choice = gr.Dropdown()
else: else:
reload_needed = True reload_needed = True
@ -2413,7 +2414,7 @@ def generate_video(
file_list = gen["file_list"] file_list = gen["file_list"]
prompt_no = gen["prompt_no"] prompt_no = gen["prompt_no"]
fit_canvas = server_config.get("fit_canvas", 0)
# if wan_model == None: # if wan_model == None:
# gr.Info("Unable to generate a Video while a new configuration is being applied.") # gr.Info("Unable to generate a Video while a new configuration is being applied.")
# return # return
@ -2555,7 +2556,7 @@ def generate_video(
source_video = None source_video = None
target_camera = None target_camera = None
if "recam" in model_filename: if "recam" in model_filename:
source_video = preprocess_video("", width=width, height=height,video_in=video_source, max_frames= video_length, start_frame = 0, fit_canvas= True) source_video = preprocess_video("", width=width, height=height,video_in=video_source, max_frames= video_length, start_frame = 0, fit_canvas= fit_canvas)
target_camera = model_mode target_camera = model_mode
audio_proj_split = None audio_proj_split = None
@ -2646,7 +2647,7 @@ def generate_video(
elif diffusion_forcing: elif diffusion_forcing:
if video_source != None and len(video_source) > 0 and window_no == 1: if video_source != None and len(video_source) > 0 and window_no == 1:
keep_frames_video_source= 1000 if len(keep_frames_video_source) ==0 else int(keep_frames_video_source) keep_frames_video_source= 1000 if len(keep_frames_video_source) ==0 else int(keep_frames_video_source)
prefix_video = preprocess_video(None, width=width, height=height,video_in=video_source, max_frames= keep_frames_video_source , start_frame = 0, fit_canvas= True, target_fps = fps) prefix_video = preprocess_video(None, width=width, height=height,video_in=video_source, max_frames= keep_frames_video_source , start_frame = 0, fit_canvas= fit_canvas, target_fps = fps)
prefix_video = prefix_video .permute(3, 0, 1, 2) prefix_video = prefix_video .permute(3, 0, 1, 2)
prefix_video = prefix_video .float().div_(127.5).sub_(1.) # c, f, h, w prefix_video = prefix_video .float().div_(127.5).sub_(1.) # c, f, h, w
prefix_video_frames_count = prefix_video.shape[1] prefix_video_frames_count = prefix_video.shape[1]
@ -2675,13 +2676,13 @@ def generate_video(
if preprocess_type != None : if preprocess_type != None :
send_cmd("progress", progress_args) send_cmd("progress", progress_args)
video_guide_copy = preprocess_video(preprocess_type, width=width, height=height,video_in=video_guide, max_frames= video_length if window_no == 1 else video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = True, target_fps = fps) video_guide_copy = preprocess_video(preprocess_type, width=width, height=height,video_in=video_guide, max_frames= video_length if window_no == 1 else video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = fit_canvas, target_fps = fps)
keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, max_frames_to_generate) keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, max_frames_to_generate)
if len(error) > 0: if len(error) > 0:
raise gr.Error(f"invalid keep frames {keep_frames_video_guide}") raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
keep_frames_parsed = keep_frames_parsed[guide_start_frame: guide_start_frame + video_length] keep_frames_parsed = keep_frames_parsed[guide_start_frame: guide_start_frame + video_length]
if window_no == 1: if window_no == 1:
image_size = VACE_SIZE_CONFIGS[resolution_reformated] # default frame dimensions until it is set by video_src (if there is any) image_size = (height, width) # VACE_SIZE_CONFIGS[resolution_reformated] # default frame dimensions until it is set by video_src (if there is any)
src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_copy], src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_copy],
[video_mask_copy ], [video_mask_copy ],
[image_refs_copy], [image_refs_copy],
@ -2689,10 +2690,11 @@ def generate_video(
original_video= "O" in video_prompt_type, original_video= "O" in video_prompt_type,
keep_frames=keep_frames_parsed, keep_frames=keep_frames_parsed,
start_frame = guide_start_frame, start_frame = guide_start_frame,
pre_src_video = [pre_video_guide] pre_src_video = [pre_video_guide],
fit_into_canvas = fit_canvas
) )
if window_no == 1 and src_video != None and len(src_video) > 0: # if window_no == 1 and src_video != None and len(src_video) > 0:
image_size = src_video[0].shape[-2:] # image_size = src_video[0].shape[-2:]
prompts_max = gen["prompts_max"] prompts_max = gen["prompts_max"]
status = get_latest_status(state) status = get_latest_status(state)
@ -2722,6 +2724,7 @@ def generate_video(
# max_area=MAX_AREA_CONFIGS[resolution_reformated], # max_area=MAX_AREA_CONFIGS[resolution_reformated],
height = height, height = height,
width = width, width = width,
fit_into_canvas = fit_canvas,
shift=flow_shift, shift=flow_shift,
sampling_steps=num_inference_steps, sampling_steps=num_inference_steps,
guide_scale=guidance_scale, guide_scale=guidance_scale,
@ -2750,6 +2753,7 @@ def generate_video(
input_video= pre_video_guide, input_video= pre_video_guide,
height = height, height = height,
width = width, width = width,
fit_into_canvas = fit_canvas,
seed = seed, seed = seed,
num_frames = (video_length // 4)* 4 + 1, #377 num_frames = (video_length // 4)* 4 + 1, #377
num_inference_steps = num_inference_steps, num_inference_steps = num_inference_steps,
@ -2777,6 +2781,7 @@ def generate_video(
target_camera= target_camera, target_camera= target_camera,
frame_num=(video_length // 4)* 4 + 1, frame_num=(video_length // 4)* 4 + 1,
size=(width, height), size=(width, height),
fit_into_canvas = fit_canvas,
shift=flow_shift, shift=flow_shift,
sampling_steps=num_inference_steps, sampling_steps=num_inference_steps,
guide_scale=guidance_scale, guide_scale=guidance_scale,
@ -4042,39 +4047,35 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False) wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False)
wizard_variables_var = gr.Text(wizard_variables, visible = False) wizard_variables_var = gr.Text(wizard_variables, visible = False)
with gr.Row(): with gr.Row():
if test_class_i2v(model_filename) and False: if test_class_i2v(model_filename):
resolution = gr.Dropdown( if server_config.get("fit_canvas", 0) == 1:
choices=[ label = "Max Resolution (as it maybe less depending on video width / height ratio)"
# 720p else:
("720p (same amount of pixels)", "1280x720"), label = "Max Resolution (as it maybe less depending on video width / height ratio)"
("480p (same amount of pixels)", "832x480"),
],
value=ui_defaults.get("resolution","480p"),
label="Resolution (video will have the same height / width ratio than the original image)"
)
else: else:
resolution = gr.Dropdown( label = "Max Resolution (as it maybe less depending on video width / height ratio)"
choices=[ resolution = gr.Dropdown(
# 720p choices=[
("1280x720 (16:9, 720p)", "1280x720"), # 720p
("720x1280 (9:16, 720p)", "720x1280"), ("1280x720 (16:9, 720p)", "1280x720"),
("1024x1024 (4:3, 720p)", "1024x024"), ("720x1280 (9:16, 720p)", "720x1280"),
("832x1104 (3:4, 720p)", "832x1104"), ("1024x1024 (4:3, 720p)", "1024x024"),
("1104x832 (3:4, 720p)", "1104x832"), ("832x1104 (3:4, 720p)", "832x1104"),
("960x960 (1:1, 720p)", "960x960"), ("1104x832 (3:4, 720p)", "1104x832"),
# 480p ("960x960 (1:1, 720p)", "960x960"),
("960x544 (16:9, 540p)", "960x544"), # 480p
("544x960 (16:9, 540p)", "544x960"), ("960x544 (16:9, 540p)", "960x544"),
("832x480 (16:9, 480p)", "832x480"), ("544x960 (16:9, 540p)", "544x960"),
("480x832 (9:16, 480p)", "480x832"), ("832x480 (16:9, 480p)", "832x480"),
("832x624 (4:3, 480p)", "832x624"), ("480x832 (9:16, 480p)", "480x832"),
("624x832 (3:4, 480p)", "624x832"), ("832x624 (4:3, 480p)", "832x624"),
("720x720 (1:1, 480p)", "720x720"), ("624x832 (3:4, 480p)", "624x832"),
("512x512 (1:1, 480p)", "512x512"), ("720x720 (1:1, 480p)", "720x720"),
], ("512x512 (1:1, 480p)", "512x512"),
value=ui_defaults.get("resolution","832x480"), ],
label="Max Resolution (as it maybe less depending on video width / height ratio)" if test_class_i2v(model_filename) else "Resolution" value=ui_defaults.get("resolution","832x480"),
) label= label
)
with gr.Row(): with gr.Row():
if recammaster: if recammaster:
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s), locked", interactive= False) video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s), locked", interactive= False)
@ -4556,156 +4557,181 @@ def generate_configuration_tab(state, blocks, header, model_choice):
with gr.Column(): with gr.Column():
model_list = [] model_list = []
for model_type in model_types:
choice = get_model_filename(model_type, transformer_quantization)
model_list.append(choice)
dropdown_choices = [ ( get_model_name(choice), get_model_type(choice) ) for choice in model_list]
transformer_types_choices = gr.Dropdown(
choices= dropdown_choices,
value= transformer_types,
label= "Selectable Wan Transformer Models (keep empty to get All of them)",
scale= 2,
multiselect= True
)
quantization_choice = gr.Dropdown( with gr.Tabs():
choices=[ # with gr.Row(visible=advanced_ui) as advanced_row:
("Scaled Int8 Quantization (recommended)", "int8"), with gr.Tab("General"):
("16 bits (no quantization)", "bf16"), for model_type in model_types:
], choice = get_model_filename(model_type, transformer_quantization)
value= transformer_quantization, model_list.append(choice)
label="Wan Transformer Model Quantization Type (if available)", dropdown_choices = [ ( get_model_name(choice), get_model_type(choice) ) for choice in model_list]
) transformer_types_choices = gr.Dropdown(
choices= dropdown_choices,
value= transformer_types,
label= "Selectable Wan Transformer Models (keep empty to get All of them)",
scale= 2,
multiselect= True
)
mixed_precision_choice = gr.Dropdown( fit_canvas_choice = gr.Dropdown(
choices=[ choices=[
("16 bits only, requires less VRAM", "0"), ("Dimensions correspond to the Pixels Budget (as the Prompt Image/Video will be resized to match this pixels budget, output video height or width may exceed the requested dimensions )", 0),
("Mixed 16 / 32 bits, slightly more VRAM needed but better Quality", "1"), ("Dimensions correspond to the Maximum Width and Height (as the Prompt Image/Video will be resized to fit into these dimensions, the output video may be smaller)", 1),
], ],
value= server_config.get("mixed_precision", "0"), value= server_config.get("fit_canvas", 0),
label="Transformer Engine Calculation" label="Generated Video Dimensions when Prompt contains an Image or a Video",
) interactive= not lock_ui_attention
)
index = text_encoder_choices.index(text_encoder_filename)
index = 0 if index ==0 else index
text_encoder_choice = gr.Dropdown(
choices=[
("UMT5 XXL 16 bits - unquantized text encoder, better quality uses more RAM", 0),
("UMT5 XXL quantized to 8 bits - quantized text encoder, slightly worse quality but uses less RAM", 1),
],
value= index,
label="Text Encoder model"
)
VAE_precision_choice = gr.Dropdown( def check(mode):
choices=[ if not mode in attention_modes_installed:
("16 bits, requires less VRAM and faster", "16"), return " (NOT INSTALLED)"
("32 bits, requires twice more VRAM and slower but recommended with Window Sliding", "32"), elif not mode in attention_modes_supported:
], return " (NOT SUPPORTED)"
value= server_config.get("vae_precision", "16"), else:
label="VAE Encoding / Decoding precision" return ""
) attention_choice = gr.Dropdown(
choices=[
("Auto : pick sage2 > sage > sdpa depending on what is installed", "auto"),
("Scale Dot Product Attention: default, always available", "sdpa"),
("Flash" + check("flash")+ ": good quality - requires additional install (usually complex to set up on Windows without WSL)", "flash"),
("Xformers" + check("xformers")+ ": good quality - requires additional install (usually complex, may consume less VRAM to set up on Windows without WSL)", "xformers"),
("Sage" + check("sage")+ ": 30% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage"),
("Sage2" + check("sage2")+ ": 40% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage2"),
],
value= attention_mode,
label="Attention Type",
interactive= not lock_ui_attention
)
save_path_choice = gr.Textbox(
label="Output Folder for Generated Videos",
value=server_config.get("save_path", save_path)
)
def check(mode):
if not mode in attention_modes_installed:
return " (NOT INSTALLED)"
elif not mode in attention_modes_supported:
return " (NOT SUPPORTED)"
else:
return ""
attention_choice = gr.Dropdown(
choices=[
("Auto : pick sage2 > sage > sdpa depending on what is installed", "auto"),
("Scale Dot Product Attention: default, always available", "sdpa"),
("Flash" + check("flash")+ ": good quality - requires additional install (usually complex to set up on Windows without WSL)", "flash"),
("Xformers" + check("xformers")+ ": good quality - requires additional install (usually complex, may consume less VRAM to set up on Windows without WSL)", "xformers"),
("Sage" + check("sage")+ ": 30% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage"),
("Sage2" + check("sage2")+ ": 40% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage2"),
],
value= attention_mode,
label="Attention Type",
interactive= not lock_ui_attention
)
gr.Markdown("Beware: when restarting the server or changing a resolution or video duration, the first step of generation for a duration / resolution may last a few minutes due to recompilation")
compile_choice = gr.Dropdown(
choices=[
("ON: works only on Linux / WSL", "transformer"),
("OFF: no other choice if you have Windows without using WSL", "" ),
],
value= compile,
label="Compile Transformer (up to 50% faster and 30% more frames but requires Linux / WSL and Flash or Sage attention)",
interactive= not lock_ui_compile
)
vae_config_choice = gr.Dropdown(
choices=[
("Auto", 0),
("Disabled (faster but may require up to 22 GB of VRAM)", 1),
("256 x 256 : If at least 8 GB of VRAM", 2),
("128 x 128 : If at least 6 GB of VRAM", 3),
],
value= vae_config,
label="VAE Tiling - reduce the high VRAM requirements for VAE decoding and VAE encoding (if enabled it will be slower)"
)
boost_choice = gr.Dropdown(
choices=[
# ("Auto (ON if Video longer than 5s)", 0),
("ON", 1),
("OFF", 2),
],
value=boost,
label="Boost: Give a 10% speed speedup without losing quality at the cost of a litle VRAM (up to 1GB for max frames and resolution)"
)
profile_choice = gr.Dropdown(
choices=[
("HighRAM_HighVRAM, profile 1: at least 48 GB of RAM and 24 GB of VRAM, the fastest for short videos a RTX 3090 / RTX 4090", 1),
("HighRAM_LowVRAM, profile 2 (Recommended): at least 48 GB of RAM and 12 GB of VRAM, the most versatile profile with high RAM, better suited for RTX 3070/3080/4070/4080 or for RTX 3090 / RTX 4090 with large pictures batches or long videos", 2),
("LowRAM_HighVRAM, profile 3: at least 32 GB of RAM and 24 GB of VRAM, adapted for RTX 3090 / RTX 4090 with limited RAM for good speed short video",3),
("LowRAM_LowVRAM, profile 4 (Default): at least 32 GB of RAM and 12 GB of VRAM, if you have little VRAM or want to generate longer videos",4),
("VerylowRAM_LowVRAM, profile 5: (Fail safe): at least 16 GB of RAM and 10 GB of VRAM, if you don't have much it won't be fast but maybe it will work",5)
],
value= profile,
label="Profile (for power users only, not needed to change it)"
)
metadata_choice = gr.Dropdown( metadata_choice = gr.Dropdown(
choices=[ choices=[
("Export JSON files", "json"), ("Export JSON files", "json"),
("Add metadata to video", "metadata"), ("Add metadata to video", "metadata"),
("Neither", "none") ("Neither", "none")
], ],
value=server_config.get("metadata_type", "metadata"), value=server_config.get("metadata_type", "metadata"),
label="Metadata Handling" label="Metadata Handling"
) )
preload_model_policy_choice = gr.CheckboxGroup([("Preload Model while Launching the App","P"), ("Preload Model while Switching Model", "S"), ("Unload Model when Queue is Done", "U")], preload_model_policy_choice = gr.CheckboxGroup([("Preload Model while Launching the App","P"), ("Preload Model while Switching Model", "S"), ("Unload Model when Queue is Done", "U")],
value=server_config.get("preload_model_policy",[]), value=server_config.get("preload_model_policy",[]),
label="RAM Loading / Unloading Model Policy (in any case VRAM will be freed once the queue has been processed)" label="RAM Loading / Unloading Model Policy (in any case VRAM will be freed once the queue has been processed)"
) )
clear_file_list_choice = gr.Dropdown(
choices=[
("None", 0),
("Keep the last video", 1),
("Keep the last 5 videos", 5),
("Keep the last 10 videos", 10),
("Keep the last 20 videos", 20),
("Keep the last 30 videos", 30),
],
value=server_config.get("clear_file_list", 5),
label="Keep Previously Generated Videos when starting a new Generation Batch"
)
UI_theme_choice = gr.Dropdown(
choices=[
("Blue Sky", "default"),
("Classic Gradio", "gradio"),
],
value=server_config.get("UI_theme_choice", "default"),
label="User Interface Theme. You will need to restart the App the see new Theme."
)
save_path_choice = gr.Textbox(
label="Output Folder for Generated Videos",
value=server_config.get("save_path", save_path)
)
with gr.Tab("Performance"):
quantization_choice = gr.Dropdown(
choices=[
("Scaled Int8 Quantization (recommended)", "int8"),
("16 bits (no quantization)", "bf16"),
],
value= transformer_quantization,
label="Wan Transformer Model Quantization Type (if available)",
)
mixed_precision_choice = gr.Dropdown(
choices=[
("16 bits only, requires less VRAM", "0"),
("Mixed 16 / 32 bits, slightly more VRAM needed but better Quality", "1"),
],
value= server_config.get("mixed_precision", "0"),
label="Transformer Engine Calculation"
)
index = text_encoder_choices.index(text_encoder_filename)
index = 0 if index ==0 else index
text_encoder_choice = gr.Dropdown(
choices=[
("UMT5 XXL 16 bits - unquantized text encoder, better quality uses more RAM", 0),
("UMT5 XXL quantized to 8 bits - quantized text encoder, slightly worse quality but uses less RAM", 1),
],
value= index,
label="Text Encoder model"
)
VAE_precision_choice = gr.Dropdown(
choices=[
("16 bits, requires less VRAM and faster", "16"),
("32 bits, requires twice more VRAM and slower but recommended with Window Sliding", "32"),
],
value= server_config.get("vae_precision", "16"),
label="VAE Encoding / Decoding precision"
)
gr.Text("Beware: when restarting the server or changing a resolution or video duration, the first step of generation for a duration / resolution may last a few minutes due to recompilation", interactive= False, show_label= False )
compile_choice = gr.Dropdown(
choices=[
("ON: works only on Linux / WSL", "transformer"),
("OFF: no other choice if you have Windows without using WSL", "" ),
],
value= compile,
label="Compile Transformer (up to 50% faster and 30% more frames but requires Linux / WSL and Flash or Sage attention)",
interactive= not lock_ui_compile
)
vae_config_choice = gr.Dropdown(
choices=[
("Auto", 0),
("Disabled (faster but may require up to 22 GB of VRAM)", 1),
("256 x 256 : If at least 8 GB of VRAM", 2),
("128 x 128 : If at least 6 GB of VRAM", 3),
],
value= vae_config,
label="VAE Tiling - reduce the high VRAM requirements for VAE decoding and VAE encoding (if enabled it will be slower)"
)
boost_choice = gr.Dropdown(
choices=[
# ("Auto (ON if Video longer than 5s)", 0),
("ON", 1),
("OFF", 2),
],
value=boost,
label="Boost: Give a 10% speedup without losing quality at the cost of a litle VRAM (up to 1GB at max frames and resolution)"
)
profile_choice = gr.Dropdown(
choices=[
("HighRAM_HighVRAM, profile 1: at least 48 GB of RAM and 24 GB of VRAM, the fastest for short videos a RTX 3090 / RTX 4090", 1),
("HighRAM_LowVRAM, profile 2 (Recommended): at least 48 GB of RAM and 12 GB of VRAM, the most versatile profile with high RAM, better suited for RTX 3070/3080/4070/4080 or for RTX 3090 / RTX 4090 with large pictures batches or long videos", 2),
("LowRAM_HighVRAM, profile 3: at least 32 GB of RAM and 24 GB of VRAM, adapted for RTX 3090 / RTX 4090 with limited RAM for good speed short video",3),
("LowRAM_LowVRAM, profile 4 (Default): at least 32 GB of RAM and 12 GB of VRAM, if you have little VRAM or want to generate longer videos",4),
("VerylowRAM_LowVRAM, profile 5: (Fail safe): at least 16 GB of RAM and 10 GB of VRAM, if you don't have much it won't be fast but maybe it will work",5)
],
value= profile,
label="Profile (for power users only, not needed to change it)"
)
clear_file_list_choice = gr.Dropdown(
choices=[
("None", 0),
("Keep the last video", 1),
("Keep the last 5 videos", 5),
("Keep the last 10 videos", 10),
("Keep the last 20 videos", 20),
("Keep the last 30 videos", 30),
],
value=server_config.get("clear_file_list", 5),
label="Keep Previously Generated Videos when starting a Generation Batch"
)
UI_theme_choice = gr.Dropdown(
choices=[
("Blue Sky", "default"),
("Classic Gradio", "gradio"),
],
value=server_config.get("UI_theme_choice", "default"),
label="User Interface Theme. You will need to restart the App the see new Theme."
)
msg = gr.Markdown() msg = gr.Markdown()
@ -4728,7 +4754,8 @@ def generate_configuration_tab(state, blocks, header, model_choice):
boost_choice, boost_choice,
clear_file_list_choice, clear_file_list_choice,
preload_model_policy_choice, preload_model_policy_choice,
UI_theme_choice UI_theme_choice,
fit_canvas_choice
], ],
outputs= [msg , header, model_choice] outputs= [msg , header, model_choice]
) )