From 94d9b4aa4d539ea054adcc78971dd807453c159e Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Mon, 5 May 2025 23:58:21 +0200 Subject: [PATCH] fixed vace bugs --- wan/diffusion_forcing.py | 9 +- wan/image2video.py | 20 +- wan/modules/model.py | 2 +- wan/text2video.py | 10 +- wan/utils/utils.py | 10 + wan/utils/vace_preprocessor.py | 44 +--- wgp.py | 415 ++++++++++++++++++--------------- 7 files changed, 258 insertions(+), 252 deletions(-) diff --git a/wan/diffusion_forcing.py b/wan/diffusion_forcing.py index af9627b..87352eb 100644 --- a/wan/diffusion_forcing.py +++ b/wan/diffusion_forcing.py @@ -15,6 +15,7 @@ from .modules.model import WanModel from .modules.t5 import T5EncoderModel from .modules.vae import WanVAE from wan.modules.posemb_layers import get_rotary_pos_embed +from wan.utils.utils import calculate_new_dimensions from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps) from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler @@ -190,6 +191,7 @@ class DTT2V: input_video = None, height: int = 480, width: int = 832, + fit_into_canvas = True, num_frames: int = 97, num_inference_steps: int = 50, shift: float = 1.0, @@ -221,15 +223,16 @@ class DTT2V: i2v_extra_kwrags = {} prefix_video = None predix_video_latent_length = 0 + if input_video != None: _ , _ , height, width = input_video.shape elif image != None: image = image[0] frame_width, frame_height = image.size - scale = min(height / frame_height, width / frame_width) - height = (int(frame_height * scale) // 16) * 16 - width = (int(frame_width * scale) // 16) * 16 + height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas) image = np.array(image.resize((width, height))).transpose(2, 0, 1) + + latent_length = (num_frames - 1) // 4 + 1 latent_height = height // 8 latent_width = width // 8 diff --git a/wan/image2video.py b/wan/image2video.py index 996d5b5..7b0fcb7 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -25,7 +25,7 @@ from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps) from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from wan.modules.posemb_layers import get_rotary_pos_embed -from wan.utils.utils import resize_lanczos +from wan.utils.utils import resize_lanczos, calculate_new_dimensions def optimized_scale(positive_flat, negative_flat): @@ -120,7 +120,7 @@ class WanI2V: img2 = None, height =720, width = 1280, - max_area=720 * 1280, + fit_into_canvas = True, frame_num=81, shift=5.0, sample_solver='unipc', @@ -188,22 +188,16 @@ class WanI2V: if add_frames_for_end_image: frame_num +=1 lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2) - - + h, w = img.shape[1:] - # aspect_ratio = h / w - - scale1 = min(height / h, width / w) - scale2 = min(height / h, width / w) - scale = max(scale1, scale2) - new_height = int(h * scale) - new_width = int(w * scale) + h, w = calculate_new_dimensions(height, width, h, w, fit_into_canvas) + lat_h = round( - new_height // self.vae_stride[1] // + h // self.vae_stride[1] // self.patch_size[1] * self.patch_size[1]) lat_w = round( - new_width // self.vae_stride[2] // + w // self.vae_stride[2] // self.patch_size[2] * self.patch_size[2]) h = lat_h * self.vae_stride[1] w = lat_w * self.vae_stride[2] diff --git a/wan/modules/model.py b/wan/modules/model.py index 77292e2..50f1d8c 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -963,7 +963,7 @@ class WanModel(ModelMixin, ConfigMixin): hints_list = [None ] *len(x_list) else: # Vace embeddings - c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] + c = [self.vace_patch_embedding(u.to(self.vace_patch_embedding.weight.dtype).unsqueeze(0)) for u in vace_context] c = [u.flatten(2).transpose(1, 2) for u in c] c = c[0] diff --git a/wan/text2video.py b/wan/text2video.py index e85aaca..7fd7935 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -177,15 +177,16 @@ class WanT2V: def vace_latent(self, z, m): return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] - def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, original_video = False, keep_frames= [], start_frame = 0, pre_src_video = None): + def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, original_video = False, keep_frames= [], start_frame = 0, fit_into_canvas = True, pre_src_video = None): image_sizes = [] trim_video = len(keep_frames) + canvas_height, canvas_width = image_size for i, (sub_src_video, sub_src_mask, sub_pre_src_video) in enumerate(zip(src_video, src_mask,pre_src_video)): prepend_count = 0 if sub_pre_src_video == None else sub_pre_src_video.shape[1] num_frames = total_frames - prepend_count if sub_src_mask is not None and sub_src_video is not None: - src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame) + src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame, canvas_height = canvas_height, canvas_width = canvas_width, fit_into_canvas = fit_into_canvas) # src_video is [-1, 1], 0 = inpainting area (in fact 127 in [0, 255]) # src_mask is [-1, 1], 0 = preserve original video (in fact 127 in [0, 255]) and 1 = Inpainting (in fact 255 in [0, 255]) src_video[i] = src_video[i].to(device) @@ -208,7 +209,7 @@ class WanT2V: src_mask[i] = torch.ones_like(src_video[i], device=device) image_sizes.append(image_size) else: - src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame) + src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame, canvas_height = canvas_height, canvas_width = canvas_width, fit_into_canvas = fit_into_canvas) src_video[i] = src_video[i].to(device) src_mask[i] = torch.zeros_like(src_video[i], device=device) if original_video else torch.ones_like(src_video[i], device=device) if prepend_count > 0: @@ -277,6 +278,7 @@ class WanT2V: target_camera=None, context_scale=1.0, size=(1280, 720), + fit_into_canvas = True, frame_num=81, shift=5.0, sample_solver='unipc', @@ -430,7 +432,7 @@ class WanT2V: kwargs.update({'cam_emb': cam_emb}) if vace: - ref_images_count = len(input_ref_images[0]) if input_ref_images != None else 0 + ref_images_count = len(input_ref_images[0]) if input_ref_images != None and input_ref_images[0] != None else 0 kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale}) if overlapped_latents > 0: z_reactive = [ zz[0:16, ref_images_count:overlapped_latents + ref_images_count].clone() for zz in z] diff --git a/wan/utils/utils.py b/wan/utils/utils.py index 0a64dd2..8cedff5 100644 --- a/wan/utils/utils.py +++ b/wan/utils/utils.py @@ -67,7 +67,17 @@ def remove_background(img, session=None): return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0) +def calculate_new_dimensions(canvas_height, canvas_width, height, width, fit_into_canvas): + if fit_into_canvas: + scale1 = min(canvas_height / height, canvas_width / width) + scale2 = min(canvas_width / height, canvas_height / width) + scale = max(scale1, scale2) + else: + scale = (canvas_height * canvas_width / (height * width))**(1/2) + new_height = round( height * scale / 16) * 16 + new_width = round( width * scale / 16) * 16 + return new_height, new_width def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, fit_into_canvas = False ): if rm_background: diff --git a/wan/utils/vace_preprocessor.py b/wan/utils/vace_preprocessor.py index c591cca..0722a17 100644 --- a/wan/utils/vace_preprocessor.py +++ b/wan/utils/vace_preprocessor.py @@ -5,6 +5,7 @@ from PIL import Image import torch import torch.nn.functional as F import torchvision.transforms.functional as TF +from .utils import calculate_new_dimensions class VaceImageProcessor(object): @@ -182,53 +183,22 @@ class VaceVideoProcessor(object): - def _get_frameid_bbox_adjust_last(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0, start_frame =0): + def _get_frameid_bbox_adjust_last(self, fps, video_frames_count, canvas_height, canvas_width, h, w, fit_into_canvas, crop_box, rng, max_frames= 0, start_frame =0): from wan.utils.utils import resample target_fps = self.max_fps - # video_frames_count = len(frame_timestamps) - frame_ids= resample(fps, video_frames_count, max_frames, target_fps, start_frame ) x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box h, w = y2 - y1, x2 - x1 - ratio = h / w - df, dh, dw = self.downsample - seq_len = self.seq_len - # min/max area of the [latent video] - min_area_z = self.min_area / (dh * dw) - # max_area_z = min(seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw)) - max_area_z = min_area_z # workaround bug - # sample a frame number of the [latent video] - rand_area_z = np.square(np.power(2, rng.uniform( - np.log2(np.sqrt(min_area_z)), - np.log2(np.sqrt(max_area_z)) - ))) - - seq_len = max_area_z * ((max_frames- start_frame - 1) // df +1) - - # of = min( - # (len(frame_ids) - 1) // df + 1, - # int(seq_len / rand_area_z) - # ) - of = (len(frame_ids) - 1) // df + 1 - - - # deduce target shape of the [latent video] - # target_area_z = min(max_area_z, int(seq_len / of)) - target_area_z = max_area_z - oh = round(np.sqrt(target_area_z * ratio)) - ow = int(target_area_z / oh) - of = (of - 1) * df + 1 - oh *= dh - ow *= dw + oh, ow = calculate_new_dimensions(canvas_height, canvas_width, h, w, fit_into_canvas) return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps - def _get_frameid_bbox(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0, start_frame= 0): + def _get_frameid_bbox(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0, start_frame= 0, canvas_height = 0, canvas_width = 0, fit_into_canvas= True): if self.keep_last: - return self._get_frameid_bbox_adjust_last(fps, video_frames_count, h, w, crop_box, rng, max_frames= max_frames, start_frame= start_frame) + return self._get_frameid_bbox_adjust_last(fps, video_frames_count, canvas_height, canvas_width, h, w, fit_into_canvas, crop_box, rng, max_frames= max_frames, start_frame= start_frame) else: return self._get_frameid_bbox_default(fps, video_frames_count, h, w, crop_box, rng, max_frames= max_frames) @@ -238,7 +208,7 @@ class VaceVideoProcessor(object): def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs): return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs) - def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, max_frames= 0, trim_video =0, start_frame = 0, **kwargs): + def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, max_frames= 0, trim_video =0, start_frame = 0, canvas_height = 0, canvas_width = 0, fit_into_canvas = False, **kwargs): rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000) # read video import decord @@ -269,7 +239,7 @@ class VaceVideoProcessor(object): h, w = src_video.shape[1:3] else: h, w = readers[0].next().shape[:2] - frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, length, h, w, crop_box, rng, max_frames=max_frames, start_frame = start_frame ) + frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, length, h, w, crop_box, rng, canvas_height = canvas_height, canvas_width = canvas_width, fit_into_canvas = fit_into_canvas, max_frames=max_frames, start_frame = start_frame ) # preprocess video videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers] diff --git a/wgp.py b/wgp.py index 00ebf9e..935966e 100644 --- a/wgp.py +++ b/wgp.py @@ -84,7 +84,6 @@ def format_time(seconds): hours = int(seconds // 3600) minutes = int((seconds % 3600) // 60) return f"{hours}h {minutes}m" - def pil_to_base64_uri(pil_image, format="png", quality=75): if pil_image is None: return None @@ -275,12 +274,12 @@ def process_prompt_and_add_tasks(state, model_choice): video_guide = inputs["video_guide"] video_mask = inputs["video_mask"] - if "1.3B" in model_filename : - resolution_reformated = str(height) + "*" + str(width) - if not resolution_reformated in VACE_SIZE_CONFIGS: - res = (" and ").join(VACE_SIZE_CONFIGS.keys()) - gr.Info(f"Video Resolution for Vace model is not supported. Only {res} resolutions are allowed.") - return + # if "1.3B" in model_filename : + # resolution_reformated = str(height) + "*" + str(width) + # if not resolution_reformated in VACE_SIZE_CONFIGS: + # res = (" and ").join(VACE_SIZE_CONFIGS.keys()) + # gr.Info(f"Video Resolution for Vace model is not supported. Only {res} resolutions are allowed.") + # return if "I" in video_prompt_type: if image_refs == None: gr.Info("You must provide at least one Refererence Image") @@ -1995,7 +1994,8 @@ def apply_changes( state, boost_choice = 1, clear_file_list = 0, preload_model_policy_choice = 1, - UI_theme_choice = "default" + UI_theme_choice = "default", + fit_canvas_choice = 0 ): if args.lock_config: return @@ -2016,7 +2016,8 @@ def apply_changes( state, "boost" : boost_choice, "clear_file_list" : clear_file_list, "preload_model_policy" : preload_model_policy_choice, - "UI_theme" : UI_theme_choice + "UI_theme" : UI_theme_choice, + "fit_canvas": fit_canvas_choice, } if Path(server_config_filename).is_file(): @@ -2050,7 +2051,7 @@ def apply_changes( state, transformer_quantization = server_config["transformer_quantization"] transformer_types = server_config["transformer_types"] - if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list"] for change in changes ): + if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list", "fit_canvas"] for change in changes ): model_choice = gr.Dropdown() else: reload_needed = True @@ -2413,7 +2414,7 @@ def generate_video( file_list = gen["file_list"] prompt_no = gen["prompt_no"] - + fit_canvas = server_config.get("fit_canvas", 0) # if wan_model == None: # gr.Info("Unable to generate a Video while a new configuration is being applied.") # return @@ -2555,7 +2556,7 @@ def generate_video( source_video = None target_camera = None if "recam" in model_filename: - source_video = preprocess_video("", width=width, height=height,video_in=video_source, max_frames= video_length, start_frame = 0, fit_canvas= True) + source_video = preprocess_video("", width=width, height=height,video_in=video_source, max_frames= video_length, start_frame = 0, fit_canvas= fit_canvas) target_camera = model_mode audio_proj_split = None @@ -2646,7 +2647,7 @@ def generate_video( elif diffusion_forcing: if video_source != None and len(video_source) > 0 and window_no == 1: keep_frames_video_source= 1000 if len(keep_frames_video_source) ==0 else int(keep_frames_video_source) - prefix_video = preprocess_video(None, width=width, height=height,video_in=video_source, max_frames= keep_frames_video_source , start_frame = 0, fit_canvas= True, target_fps = fps) + prefix_video = preprocess_video(None, width=width, height=height,video_in=video_source, max_frames= keep_frames_video_source , start_frame = 0, fit_canvas= fit_canvas, target_fps = fps) prefix_video = prefix_video .permute(3, 0, 1, 2) prefix_video = prefix_video .float().div_(127.5).sub_(1.) # c, f, h, w prefix_video_frames_count = prefix_video.shape[1] @@ -2675,13 +2676,13 @@ def generate_video( if preprocess_type != None : send_cmd("progress", progress_args) - video_guide_copy = preprocess_video(preprocess_type, width=width, height=height,video_in=video_guide, max_frames= video_length if window_no == 1 else video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = True, target_fps = fps) + video_guide_copy = preprocess_video(preprocess_type, width=width, height=height,video_in=video_guide, max_frames= video_length if window_no == 1 else video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = fit_canvas, target_fps = fps) keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, max_frames_to_generate) if len(error) > 0: raise gr.Error(f"invalid keep frames {keep_frames_video_guide}") keep_frames_parsed = keep_frames_parsed[guide_start_frame: guide_start_frame + video_length] if window_no == 1: - image_size = VACE_SIZE_CONFIGS[resolution_reformated] # default frame dimensions until it is set by video_src (if there is any) + image_size = (height, width) # VACE_SIZE_CONFIGS[resolution_reformated] # default frame dimensions until it is set by video_src (if there is any) src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_copy], [video_mask_copy ], [image_refs_copy], @@ -2689,10 +2690,11 @@ def generate_video( original_video= "O" in video_prompt_type, keep_frames=keep_frames_parsed, start_frame = guide_start_frame, - pre_src_video = [pre_video_guide] + pre_src_video = [pre_video_guide], + fit_into_canvas = fit_canvas ) - if window_no == 1 and src_video != None and len(src_video) > 0: - image_size = src_video[0].shape[-2:] + # if window_no == 1 and src_video != None and len(src_video) > 0: + # image_size = src_video[0].shape[-2:] prompts_max = gen["prompts_max"] status = get_latest_status(state) @@ -2722,6 +2724,7 @@ def generate_video( # max_area=MAX_AREA_CONFIGS[resolution_reformated], height = height, width = width, + fit_into_canvas = fit_canvas, shift=flow_shift, sampling_steps=num_inference_steps, guide_scale=guidance_scale, @@ -2750,6 +2753,7 @@ def generate_video( input_video= pre_video_guide, height = height, width = width, + fit_into_canvas = fit_canvas, seed = seed, num_frames = (video_length // 4)* 4 + 1, #377 num_inference_steps = num_inference_steps, @@ -2777,6 +2781,7 @@ def generate_video( target_camera= target_camera, frame_num=(video_length // 4)* 4 + 1, size=(width, height), + fit_into_canvas = fit_canvas, shift=flow_shift, sampling_steps=num_inference_steps, guide_scale=guidance_scale, @@ -4042,39 +4047,35 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False) wizard_variables_var = gr.Text(wizard_variables, visible = False) with gr.Row(): - if test_class_i2v(model_filename) and False: - resolution = gr.Dropdown( - choices=[ - # 720p - ("720p (same amount of pixels)", "1280x720"), - ("480p (same amount of pixels)", "832x480"), - ], - value=ui_defaults.get("resolution","480p"), - label="Resolution (video will have the same height / width ratio than the original image)" - ) + if test_class_i2v(model_filename): + if server_config.get("fit_canvas", 0) == 1: + label = "Max Resolution (as it maybe less depending on video width / height ratio)" + else: + label = "Max Resolution (as it maybe less depending on video width / height ratio)" else: - resolution = gr.Dropdown( - choices=[ - # 720p - ("1280x720 (16:9, 720p)", "1280x720"), - ("720x1280 (9:16, 720p)", "720x1280"), - ("1024x1024 (4:3, 720p)", "1024x024"), - ("832x1104 (3:4, 720p)", "832x1104"), - ("1104x832 (3:4, 720p)", "1104x832"), - ("960x960 (1:1, 720p)", "960x960"), - # 480p - ("960x544 (16:9, 540p)", "960x544"), - ("544x960 (16:9, 540p)", "544x960"), - ("832x480 (16:9, 480p)", "832x480"), - ("480x832 (9:16, 480p)", "480x832"), - ("832x624 (4:3, 480p)", "832x624"), - ("624x832 (3:4, 480p)", "624x832"), - ("720x720 (1:1, 480p)", "720x720"), - ("512x512 (1:1, 480p)", "512x512"), - ], - value=ui_defaults.get("resolution","832x480"), - label="Max Resolution (as it maybe less depending on video width / height ratio)" if test_class_i2v(model_filename) else "Resolution" - ) + label = "Max Resolution (as it maybe less depending on video width / height ratio)" + resolution = gr.Dropdown( + choices=[ + # 720p + ("1280x720 (16:9, 720p)", "1280x720"), + ("720x1280 (9:16, 720p)", "720x1280"), + ("1024x1024 (4:3, 720p)", "1024x024"), + ("832x1104 (3:4, 720p)", "832x1104"), + ("1104x832 (3:4, 720p)", "1104x832"), + ("960x960 (1:1, 720p)", "960x960"), + # 480p + ("960x544 (16:9, 540p)", "960x544"), + ("544x960 (16:9, 540p)", "544x960"), + ("832x480 (16:9, 480p)", "832x480"), + ("480x832 (9:16, 480p)", "480x832"), + ("832x624 (4:3, 480p)", "832x624"), + ("624x832 (3:4, 480p)", "624x832"), + ("720x720 (1:1, 480p)", "720x720"), + ("512x512 (1:1, 480p)", "512x512"), + ], + value=ui_defaults.get("resolution","832x480"), + label= label + ) with gr.Row(): if recammaster: video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s), locked", interactive= False) @@ -4556,156 +4557,181 @@ def generate_configuration_tab(state, blocks, header, model_choice): with gr.Column(): model_list = [] - for model_type in model_types: - choice = get_model_filename(model_type, transformer_quantization) - model_list.append(choice) - dropdown_choices = [ ( get_model_name(choice), get_model_type(choice) ) for choice in model_list] - transformer_types_choices = gr.Dropdown( - choices= dropdown_choices, - value= transformer_types, - label= "Selectable Wan Transformer Models (keep empty to get All of them)", - scale= 2, - multiselect= True - ) - quantization_choice = gr.Dropdown( - choices=[ - ("Scaled Int8 Quantization (recommended)", "int8"), - ("16 bits (no quantization)", "bf16"), - ], - value= transformer_quantization, - label="Wan Transformer Model Quantization Type (if available)", - ) + with gr.Tabs(): + # with gr.Row(visible=advanced_ui) as advanced_row: + with gr.Tab("General"): + for model_type in model_types: + choice = get_model_filename(model_type, transformer_quantization) + model_list.append(choice) + dropdown_choices = [ ( get_model_name(choice), get_model_type(choice) ) for choice in model_list] + transformer_types_choices = gr.Dropdown( + choices= dropdown_choices, + value= transformer_types, + label= "Selectable Wan Transformer Models (keep empty to get All of them)", + scale= 2, + multiselect= True + ) - mixed_precision_choice = gr.Dropdown( - choices=[ - ("16 bits only, requires less VRAM", "0"), - ("Mixed 16 / 32 bits, slightly more VRAM needed but better Quality", "1"), - ], - value= server_config.get("mixed_precision", "0"), - label="Transformer Engine Calculation" - ) + fit_canvas_choice = gr.Dropdown( + choices=[ + ("Dimensions correspond to the Pixels Budget (as the Prompt Image/Video will be resized to match this pixels budget, output video height or width may exceed the requested dimensions )", 0), + ("Dimensions correspond to the Maximum Width and Height (as the Prompt Image/Video will be resized to fit into these dimensions, the output video may be smaller)", 1), + ], + value= server_config.get("fit_canvas", 0), + label="Generated Video Dimensions when Prompt contains an Image or a Video", + interactive= not lock_ui_attention + ) - index = text_encoder_choices.index(text_encoder_filename) - index = 0 if index ==0 else index - text_encoder_choice = gr.Dropdown( - choices=[ - ("UMT5 XXL 16 bits - unquantized text encoder, better quality uses more RAM", 0), - ("UMT5 XXL quantized to 8 bits - quantized text encoder, slightly worse quality but uses less RAM", 1), - ], - value= index, - label="Text Encoder model" - ) - VAE_precision_choice = gr.Dropdown( - choices=[ - ("16 bits, requires less VRAM and faster", "16"), - ("32 bits, requires twice more VRAM and slower but recommended with Window Sliding", "32"), - ], - value= server_config.get("vae_precision", "16"), - label="VAE Encoding / Decoding precision" - ) + def check(mode): + if not mode in attention_modes_installed: + return " (NOT INSTALLED)" + elif not mode in attention_modes_supported: + return " (NOT SUPPORTED)" + else: + return "" + attention_choice = gr.Dropdown( + choices=[ + ("Auto : pick sage2 > sage > sdpa depending on what is installed", "auto"), + ("Scale Dot Product Attention: default, always available", "sdpa"), + ("Flash" + check("flash")+ ": good quality - requires additional install (usually complex to set up on Windows without WSL)", "flash"), + ("Xformers" + check("xformers")+ ": good quality - requires additional install (usually complex, may consume less VRAM to set up on Windows without WSL)", "xformers"), + ("Sage" + check("sage")+ ": 30% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage"), + ("Sage2" + check("sage2")+ ": 40% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage2"), + ], + value= attention_mode, + label="Attention Type", + interactive= not lock_ui_attention + ) - save_path_choice = gr.Textbox( - label="Output Folder for Generated Videos", - value=server_config.get("save_path", save_path) - ) - def check(mode): - if not mode in attention_modes_installed: - return " (NOT INSTALLED)" - elif not mode in attention_modes_supported: - return " (NOT SUPPORTED)" - else: - return "" - attention_choice = gr.Dropdown( - choices=[ - ("Auto : pick sage2 > sage > sdpa depending on what is installed", "auto"), - ("Scale Dot Product Attention: default, always available", "sdpa"), - ("Flash" + check("flash")+ ": good quality - requires additional install (usually complex to set up on Windows without WSL)", "flash"), - ("Xformers" + check("xformers")+ ": good quality - requires additional install (usually complex, may consume less VRAM to set up on Windows without WSL)", "xformers"), - ("Sage" + check("sage")+ ": 30% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage"), - ("Sage2" + check("sage2")+ ": 40% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage2"), - ], - value= attention_mode, - label="Attention Type", - interactive= not lock_ui_attention - ) - gr.Markdown("Beware: when restarting the server or changing a resolution or video duration, the first step of generation for a duration / resolution may last a few minutes due to recompilation") - compile_choice = gr.Dropdown( - choices=[ - ("ON: works only on Linux / WSL", "transformer"), - ("OFF: no other choice if you have Windows without using WSL", "" ), - ], - value= compile, - label="Compile Transformer (up to 50% faster and 30% more frames but requires Linux / WSL and Flash or Sage attention)", - interactive= not lock_ui_compile - ) - vae_config_choice = gr.Dropdown( - choices=[ - ("Auto", 0), - ("Disabled (faster but may require up to 22 GB of VRAM)", 1), - ("256 x 256 : If at least 8 GB of VRAM", 2), - ("128 x 128 : If at least 6 GB of VRAM", 3), - ], - value= vae_config, - label="VAE Tiling - reduce the high VRAM requirements for VAE decoding and VAE encoding (if enabled it will be slower)" - ) - boost_choice = gr.Dropdown( - choices=[ - # ("Auto (ON if Video longer than 5s)", 0), - ("ON", 1), - ("OFF", 2), - ], - value=boost, - label="Boost: Give a 10% speed speedup without losing quality at the cost of a litle VRAM (up to 1GB for max frames and resolution)" - ) - profile_choice = gr.Dropdown( - choices=[ - ("HighRAM_HighVRAM, profile 1: at least 48 GB of RAM and 24 GB of VRAM, the fastest for short videos a RTX 3090 / RTX 4090", 1), - ("HighRAM_LowVRAM, profile 2 (Recommended): at least 48 GB of RAM and 12 GB of VRAM, the most versatile profile with high RAM, better suited for RTX 3070/3080/4070/4080 or for RTX 3090 / RTX 4090 with large pictures batches or long videos", 2), - ("LowRAM_HighVRAM, profile 3: at least 32 GB of RAM and 24 GB of VRAM, adapted for RTX 3090 / RTX 4090 with limited RAM for good speed short video",3), - ("LowRAM_LowVRAM, profile 4 (Default): at least 32 GB of RAM and 12 GB of VRAM, if you have little VRAM or want to generate longer videos",4), - ("VerylowRAM_LowVRAM, profile 5: (Fail safe): at least 16 GB of RAM and 10 GB of VRAM, if you don't have much it won't be fast but maybe it will work",5) - ], - value= profile, - label="Profile (for power users only, not needed to change it)" - ) - - metadata_choice = gr.Dropdown( - choices=[ - ("Export JSON files", "json"), - ("Add metadata to video", "metadata"), - ("Neither", "none") - ], - value=server_config.get("metadata_type", "metadata"), - label="Metadata Handling" - ) - preload_model_policy_choice = gr.CheckboxGroup([("Preload Model while Launching the App","P"), ("Preload Model while Switching Model", "S"), ("Unload Model when Queue is Done", "U")], - value=server_config.get("preload_model_policy",[]), - label="RAM Loading / Unloading Model Policy (in any case VRAM will be freed once the queue has been processed)" - ) - clear_file_list_choice = gr.Dropdown( - choices=[ - ("None", 0), - ("Keep the last video", 1), - ("Keep the last 5 videos", 5), - ("Keep the last 10 videos", 10), - ("Keep the last 20 videos", 20), - ("Keep the last 30 videos", 30), - ], - value=server_config.get("clear_file_list", 5), - label="Keep Previously Generated Videos when starting a Generation Batch" - ) + metadata_choice = gr.Dropdown( + choices=[ + ("Export JSON files", "json"), + ("Add metadata to video", "metadata"), + ("Neither", "none") + ], + value=server_config.get("metadata_type", "metadata"), + label="Metadata Handling" + ) + preload_model_policy_choice = gr.CheckboxGroup([("Preload Model while Launching the App","P"), ("Preload Model while Switching Model", "S"), ("Unload Model when Queue is Done", "U")], + value=server_config.get("preload_model_policy",[]), + label="RAM Loading / Unloading Model Policy (in any case VRAM will be freed once the queue has been processed)" + ) + + clear_file_list_choice = gr.Dropdown( + choices=[ + ("None", 0), + ("Keep the last video", 1), + ("Keep the last 5 videos", 5), + ("Keep the last 10 videos", 10), + ("Keep the last 20 videos", 20), + ("Keep the last 30 videos", 30), + ], + value=server_config.get("clear_file_list", 5), + label="Keep Previously Generated Videos when starting a new Generation Batch" + ) + + UI_theme_choice = gr.Dropdown( + choices=[ + ("Blue Sky", "default"), + ("Classic Gradio", "gradio"), + ], + value=server_config.get("UI_theme_choice", "default"), + label="User Interface Theme. You will need to restart the App the see new Theme." + ) + + save_path_choice = gr.Textbox( + label="Output Folder for Generated Videos", + value=server_config.get("save_path", save_path) + ) + + with gr.Tab("Performance"): + + quantization_choice = gr.Dropdown( + choices=[ + ("Scaled Int8 Quantization (recommended)", "int8"), + ("16 bits (no quantization)", "bf16"), + ], + value= transformer_quantization, + label="Wan Transformer Model Quantization Type (if available)", + ) + + mixed_precision_choice = gr.Dropdown( + choices=[ + ("16 bits only, requires less VRAM", "0"), + ("Mixed 16 / 32 bits, slightly more VRAM needed but better Quality", "1"), + ], + value= server_config.get("mixed_precision", "0"), + label="Transformer Engine Calculation" + ) + + index = text_encoder_choices.index(text_encoder_filename) + index = 0 if index ==0 else index + text_encoder_choice = gr.Dropdown( + choices=[ + ("UMT5 XXL 16 bits - unquantized text encoder, better quality uses more RAM", 0), + ("UMT5 XXL quantized to 8 bits - quantized text encoder, slightly worse quality but uses less RAM", 1), + ], + value= index, + label="Text Encoder model" + ) + + VAE_precision_choice = gr.Dropdown( + choices=[ + ("16 bits, requires less VRAM and faster", "16"), + ("32 bits, requires twice more VRAM and slower but recommended with Window Sliding", "32"), + ], + value= server_config.get("vae_precision", "16"), + label="VAE Encoding / Decoding precision" + ) + + gr.Text("Beware: when restarting the server or changing a resolution or video duration, the first step of generation for a duration / resolution may last a few minutes due to recompilation", interactive= False, show_label= False ) + compile_choice = gr.Dropdown( + choices=[ + ("ON: works only on Linux / WSL", "transformer"), + ("OFF: no other choice if you have Windows without using WSL", "" ), + ], + value= compile, + label="Compile Transformer (up to 50% faster and 30% more frames but requires Linux / WSL and Flash or Sage attention)", + interactive= not lock_ui_compile + ) + + vae_config_choice = gr.Dropdown( + choices=[ + ("Auto", 0), + ("Disabled (faster but may require up to 22 GB of VRAM)", 1), + ("256 x 256 : If at least 8 GB of VRAM", 2), + ("128 x 128 : If at least 6 GB of VRAM", 3), + ], + value= vae_config, + label="VAE Tiling - reduce the high VRAM requirements for VAE decoding and VAE encoding (if enabled it will be slower)" + ) + + boost_choice = gr.Dropdown( + choices=[ + # ("Auto (ON if Video longer than 5s)", 0), + ("ON", 1), + ("OFF", 2), + ], + value=boost, + label="Boost: Give a 10% speedup without losing quality at the cost of a litle VRAM (up to 1GB at max frames and resolution)" + ) + + profile_choice = gr.Dropdown( + choices=[ + ("HighRAM_HighVRAM, profile 1: at least 48 GB of RAM and 24 GB of VRAM, the fastest for short videos a RTX 3090 / RTX 4090", 1), + ("HighRAM_LowVRAM, profile 2 (Recommended): at least 48 GB of RAM and 12 GB of VRAM, the most versatile profile with high RAM, better suited for RTX 3070/3080/4070/4080 or for RTX 3090 / RTX 4090 with large pictures batches or long videos", 2), + ("LowRAM_HighVRAM, profile 3: at least 32 GB of RAM and 24 GB of VRAM, adapted for RTX 3090 / RTX 4090 with limited RAM for good speed short video",3), + ("LowRAM_LowVRAM, profile 4 (Default): at least 32 GB of RAM and 12 GB of VRAM, if you have little VRAM or want to generate longer videos",4), + ("VerylowRAM_LowVRAM, profile 5: (Fail safe): at least 16 GB of RAM and 10 GB of VRAM, if you don't have much it won't be fast but maybe it will work",5) + ], + value= profile, + label="Profile (for power users only, not needed to change it)" + ) + - UI_theme_choice = gr.Dropdown( - choices=[ - ("Blue Sky", "default"), - ("Classic Gradio", "gradio"), - ], - value=server_config.get("UI_theme_choice", "default"), - label="User Interface Theme. You will need to restart the App the see new Theme." - ) msg = gr.Markdown() @@ -4728,7 +4754,8 @@ def generate_configuration_tab(state, blocks, header, model_choice): boost_choice, clear_file_list_choice, preload_model_policy_choice, - UI_theme_choice + UI_theme_choice, + fit_canvas_choice ], outputs= [msg , header, model_choice] )