fixed vace bugs

This commit is contained in:
DeepBeepMeep 2025-05-05 23:58:21 +02:00
parent cf02cc4004
commit 94d9b4aa4d
7 changed files with 258 additions and 252 deletions

View File

@ -15,6 +15,7 @@ from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE
from wan.modules.posemb_layers import get_rotary_pos_embed
from wan.utils.utils import calculate_new_dimensions
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
@ -190,6 +191,7 @@ class DTT2V:
input_video = None,
height: int = 480,
width: int = 832,
fit_into_canvas = True,
num_frames: int = 97,
num_inference_steps: int = 50,
shift: float = 1.0,
@ -221,15 +223,16 @@ class DTT2V:
i2v_extra_kwrags = {}
prefix_video = None
predix_video_latent_length = 0
if input_video != None:
_ , _ , height, width = input_video.shape
elif image != None:
image = image[0]
frame_width, frame_height = image.size
scale = min(height / frame_height, width / frame_width)
height = (int(frame_height * scale) // 16) * 16
width = (int(frame_width * scale) // 16) * 16
height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas)
image = np.array(image.resize((width, height))).transpose(2, 0, 1)
latent_length = (num_frames - 1) // 4 + 1
latent_height = height // 8
latent_width = width // 8

View File

@ -25,7 +25,7 @@ from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from wan.modules.posemb_layers import get_rotary_pos_embed
from wan.utils.utils import resize_lanczos
from wan.utils.utils import resize_lanczos, calculate_new_dimensions
def optimized_scale(positive_flat, negative_flat):
@ -120,7 +120,7 @@ class WanI2V:
img2 = None,
height =720,
width = 1280,
max_area=720 * 1280,
fit_into_canvas = True,
frame_num=81,
shift=5.0,
sample_solver='unipc',
@ -188,22 +188,16 @@ class WanI2V:
if add_frames_for_end_image:
frame_num +=1
lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2)
h, w = img.shape[1:]
# aspect_ratio = h / w
scale1 = min(height / h, width / w)
scale2 = min(height / h, width / w)
scale = max(scale1, scale2)
new_height = int(h * scale)
new_width = int(w * scale)
h, w = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
lat_h = round(
new_height // self.vae_stride[1] //
h // self.vae_stride[1] //
self.patch_size[1] * self.patch_size[1])
lat_w = round(
new_width // self.vae_stride[2] //
w // self.vae_stride[2] //
self.patch_size[2] * self.patch_size[2])
h = lat_h * self.vae_stride[1]
w = lat_w * self.vae_stride[2]

View File

@ -963,7 +963,7 @@ class WanModel(ModelMixin, ConfigMixin):
hints_list = [None ] *len(x_list)
else:
# Vace embeddings
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
c = [self.vace_patch_embedding(u.to(self.vace_patch_embedding.weight.dtype).unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c]
c = c[0]

View File

@ -177,15 +177,16 @@ class WanT2V:
def vace_latent(self, z, m):
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, original_video = False, keep_frames= [], start_frame = 0, pre_src_video = None):
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, original_video = False, keep_frames= [], start_frame = 0, fit_into_canvas = True, pre_src_video = None):
image_sizes = []
trim_video = len(keep_frames)
canvas_height, canvas_width = image_size
for i, (sub_src_video, sub_src_mask, sub_pre_src_video) in enumerate(zip(src_video, src_mask,pre_src_video)):
prepend_count = 0 if sub_pre_src_video == None else sub_pre_src_video.shape[1]
num_frames = total_frames - prepend_count
if sub_src_mask is not None and sub_src_video is not None:
src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame)
src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame, canvas_height = canvas_height, canvas_width = canvas_width, fit_into_canvas = fit_into_canvas)
# src_video is [-1, 1], 0 = inpainting area (in fact 127 in [0, 255])
# src_mask is [-1, 1], 0 = preserve original video (in fact 127 in [0, 255]) and 1 = Inpainting (in fact 255 in [0, 255])
src_video[i] = src_video[i].to(device)
@ -208,7 +209,7 @@ class WanT2V:
src_mask[i] = torch.ones_like(src_video[i], device=device)
image_sizes.append(image_size)
else:
src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame)
src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame, canvas_height = canvas_height, canvas_width = canvas_width, fit_into_canvas = fit_into_canvas)
src_video[i] = src_video[i].to(device)
src_mask[i] = torch.zeros_like(src_video[i], device=device) if original_video else torch.ones_like(src_video[i], device=device)
if prepend_count > 0:
@ -277,6 +278,7 @@ class WanT2V:
target_camera=None,
context_scale=1.0,
size=(1280, 720),
fit_into_canvas = True,
frame_num=81,
shift=5.0,
sample_solver='unipc',
@ -430,7 +432,7 @@ class WanT2V:
kwargs.update({'cam_emb': cam_emb})
if vace:
ref_images_count = len(input_ref_images[0]) if input_ref_images != None else 0
ref_images_count = len(input_ref_images[0]) if input_ref_images != None and input_ref_images[0] != None else 0
kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale})
if overlapped_latents > 0:
z_reactive = [ zz[0:16, ref_images_count:overlapped_latents + ref_images_count].clone() for zz in z]

View File

@ -67,7 +67,17 @@ def remove_background(img, session=None):
return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0)
def calculate_new_dimensions(canvas_height, canvas_width, height, width, fit_into_canvas):
if fit_into_canvas:
scale1 = min(canvas_height / height, canvas_width / width)
scale2 = min(canvas_width / height, canvas_height / width)
scale = max(scale1, scale2)
else:
scale = (canvas_height * canvas_width / (height * width))**(1/2)
new_height = round( height * scale / 16) * 16
new_width = round( width * scale / 16) * 16
return new_height, new_width
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, fit_into_canvas = False ):
if rm_background:

View File

@ -5,6 +5,7 @@ from PIL import Image
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from .utils import calculate_new_dimensions
class VaceImageProcessor(object):
@ -182,53 +183,22 @@ class VaceVideoProcessor(object):
def _get_frameid_bbox_adjust_last(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0, start_frame =0):
def _get_frameid_bbox_adjust_last(self, fps, video_frames_count, canvas_height, canvas_width, h, w, fit_into_canvas, crop_box, rng, max_frames= 0, start_frame =0):
from wan.utils.utils import resample
target_fps = self.max_fps
# video_frames_count = len(frame_timestamps)
frame_ids= resample(fps, video_frames_count, max_frames, target_fps, start_frame )
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
h, w = y2 - y1, x2 - x1
ratio = h / w
df, dh, dw = self.downsample
seq_len = self.seq_len
# min/max area of the [latent video]
min_area_z = self.min_area / (dh * dw)
# max_area_z = min(seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw))
max_area_z = min_area_z # workaround bug
# sample a frame number of the [latent video]
rand_area_z = np.square(np.power(2, rng.uniform(
np.log2(np.sqrt(min_area_z)),
np.log2(np.sqrt(max_area_z))
)))
seq_len = max_area_z * ((max_frames- start_frame - 1) // df +1)
# of = min(
# (len(frame_ids) - 1) // df + 1,
# int(seq_len / rand_area_z)
# )
of = (len(frame_ids) - 1) // df + 1
# deduce target shape of the [latent video]
# target_area_z = min(max_area_z, int(seq_len / of))
target_area_z = max_area_z
oh = round(np.sqrt(target_area_z * ratio))
ow = int(target_area_z / oh)
of = (of - 1) * df + 1
oh *= dh
ow *= dw
oh, ow = calculate_new_dimensions(canvas_height, canvas_width, h, w, fit_into_canvas)
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
def _get_frameid_bbox(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0, start_frame= 0):
def _get_frameid_bbox(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0, start_frame= 0, canvas_height = 0, canvas_width = 0, fit_into_canvas= True):
if self.keep_last:
return self._get_frameid_bbox_adjust_last(fps, video_frames_count, h, w, crop_box, rng, max_frames= max_frames, start_frame= start_frame)
return self._get_frameid_bbox_adjust_last(fps, video_frames_count, canvas_height, canvas_width, h, w, fit_into_canvas, crop_box, rng, max_frames= max_frames, start_frame= start_frame)
else:
return self._get_frameid_bbox_default(fps, video_frames_count, h, w, crop_box, rng, max_frames= max_frames)
@ -238,7 +208,7 @@ class VaceVideoProcessor(object):
def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs):
return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs)
def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, max_frames= 0, trim_video =0, start_frame = 0, **kwargs):
def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, max_frames= 0, trim_video =0, start_frame = 0, canvas_height = 0, canvas_width = 0, fit_into_canvas = False, **kwargs):
rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000)
# read video
import decord
@ -269,7 +239,7 @@ class VaceVideoProcessor(object):
h, w = src_video.shape[1:3]
else:
h, w = readers[0].next().shape[:2]
frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, length, h, w, crop_box, rng, max_frames=max_frames, start_frame = start_frame )
frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, length, h, w, crop_box, rng, canvas_height = canvas_height, canvas_width = canvas_width, fit_into_canvas = fit_into_canvas, max_frames=max_frames, start_frame = start_frame )
# preprocess video
videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers]

415
wgp.py
View File

@ -84,7 +84,6 @@ def format_time(seconds):
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
return f"{hours}h {minutes}m"
def pil_to_base64_uri(pil_image, format="png", quality=75):
if pil_image is None:
return None
@ -275,12 +274,12 @@ def process_prompt_and_add_tasks(state, model_choice):
video_guide = inputs["video_guide"]
video_mask = inputs["video_mask"]
if "1.3B" in model_filename :
resolution_reformated = str(height) + "*" + str(width)
if not resolution_reformated in VACE_SIZE_CONFIGS:
res = (" and ").join(VACE_SIZE_CONFIGS.keys())
gr.Info(f"Video Resolution for Vace model is not supported. Only {res} resolutions are allowed.")
return
# if "1.3B" in model_filename :
# resolution_reformated = str(height) + "*" + str(width)
# if not resolution_reformated in VACE_SIZE_CONFIGS:
# res = (" and ").join(VACE_SIZE_CONFIGS.keys())
# gr.Info(f"Video Resolution for Vace model is not supported. Only {res} resolutions are allowed.")
# return
if "I" in video_prompt_type:
if image_refs == None:
gr.Info("You must provide at least one Refererence Image")
@ -1995,7 +1994,8 @@ def apply_changes( state,
boost_choice = 1,
clear_file_list = 0,
preload_model_policy_choice = 1,
UI_theme_choice = "default"
UI_theme_choice = "default",
fit_canvas_choice = 0
):
if args.lock_config:
return
@ -2016,7 +2016,8 @@ def apply_changes( state,
"boost" : boost_choice,
"clear_file_list" : clear_file_list,
"preload_model_policy" : preload_model_policy_choice,
"UI_theme" : UI_theme_choice
"UI_theme" : UI_theme_choice,
"fit_canvas": fit_canvas_choice,
}
if Path(server_config_filename).is_file():
@ -2050,7 +2051,7 @@ def apply_changes( state,
transformer_quantization = server_config["transformer_quantization"]
transformer_types = server_config["transformer_types"]
if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list"] for change in changes ):
if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list", "fit_canvas"] for change in changes ):
model_choice = gr.Dropdown()
else:
reload_needed = True
@ -2413,7 +2414,7 @@ def generate_video(
file_list = gen["file_list"]
prompt_no = gen["prompt_no"]
fit_canvas = server_config.get("fit_canvas", 0)
# if wan_model == None:
# gr.Info("Unable to generate a Video while a new configuration is being applied.")
# return
@ -2555,7 +2556,7 @@ def generate_video(
source_video = None
target_camera = None
if "recam" in model_filename:
source_video = preprocess_video("", width=width, height=height,video_in=video_source, max_frames= video_length, start_frame = 0, fit_canvas= True)
source_video = preprocess_video("", width=width, height=height,video_in=video_source, max_frames= video_length, start_frame = 0, fit_canvas= fit_canvas)
target_camera = model_mode
audio_proj_split = None
@ -2646,7 +2647,7 @@ def generate_video(
elif diffusion_forcing:
if video_source != None and len(video_source) > 0 and window_no == 1:
keep_frames_video_source= 1000 if len(keep_frames_video_source) ==0 else int(keep_frames_video_source)
prefix_video = preprocess_video(None, width=width, height=height,video_in=video_source, max_frames= keep_frames_video_source , start_frame = 0, fit_canvas= True, target_fps = fps)
prefix_video = preprocess_video(None, width=width, height=height,video_in=video_source, max_frames= keep_frames_video_source , start_frame = 0, fit_canvas= fit_canvas, target_fps = fps)
prefix_video = prefix_video .permute(3, 0, 1, 2)
prefix_video = prefix_video .float().div_(127.5).sub_(1.) # c, f, h, w
prefix_video_frames_count = prefix_video.shape[1]
@ -2675,13 +2676,13 @@ def generate_video(
if preprocess_type != None :
send_cmd("progress", progress_args)
video_guide_copy = preprocess_video(preprocess_type, width=width, height=height,video_in=video_guide, max_frames= video_length if window_no == 1 else video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = True, target_fps = fps)
video_guide_copy = preprocess_video(preprocess_type, width=width, height=height,video_in=video_guide, max_frames= video_length if window_no == 1 else video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = fit_canvas, target_fps = fps)
keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, max_frames_to_generate)
if len(error) > 0:
raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
keep_frames_parsed = keep_frames_parsed[guide_start_frame: guide_start_frame + video_length]
if window_no == 1:
image_size = VACE_SIZE_CONFIGS[resolution_reformated] # default frame dimensions until it is set by video_src (if there is any)
image_size = (height, width) # VACE_SIZE_CONFIGS[resolution_reformated] # default frame dimensions until it is set by video_src (if there is any)
src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_copy],
[video_mask_copy ],
[image_refs_copy],
@ -2689,10 +2690,11 @@ def generate_video(
original_video= "O" in video_prompt_type,
keep_frames=keep_frames_parsed,
start_frame = guide_start_frame,
pre_src_video = [pre_video_guide]
pre_src_video = [pre_video_guide],
fit_into_canvas = fit_canvas
)
if window_no == 1 and src_video != None and len(src_video) > 0:
image_size = src_video[0].shape[-2:]
# if window_no == 1 and src_video != None and len(src_video) > 0:
# image_size = src_video[0].shape[-2:]
prompts_max = gen["prompts_max"]
status = get_latest_status(state)
@ -2722,6 +2724,7 @@ def generate_video(
# max_area=MAX_AREA_CONFIGS[resolution_reformated],
height = height,
width = width,
fit_into_canvas = fit_canvas,
shift=flow_shift,
sampling_steps=num_inference_steps,
guide_scale=guidance_scale,
@ -2750,6 +2753,7 @@ def generate_video(
input_video= pre_video_guide,
height = height,
width = width,
fit_into_canvas = fit_canvas,
seed = seed,
num_frames = (video_length // 4)* 4 + 1, #377
num_inference_steps = num_inference_steps,
@ -2777,6 +2781,7 @@ def generate_video(
target_camera= target_camera,
frame_num=(video_length // 4)* 4 + 1,
size=(width, height),
fit_into_canvas = fit_canvas,
shift=flow_shift,
sampling_steps=num_inference_steps,
guide_scale=guidance_scale,
@ -4042,39 +4047,35 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False)
wizard_variables_var = gr.Text(wizard_variables, visible = False)
with gr.Row():
if test_class_i2v(model_filename) and False:
resolution = gr.Dropdown(
choices=[
# 720p
("720p (same amount of pixels)", "1280x720"),
("480p (same amount of pixels)", "832x480"),
],
value=ui_defaults.get("resolution","480p"),
label="Resolution (video will have the same height / width ratio than the original image)"
)
if test_class_i2v(model_filename):
if server_config.get("fit_canvas", 0) == 1:
label = "Max Resolution (as it maybe less depending on video width / height ratio)"
else:
label = "Max Resolution (as it maybe less depending on video width / height ratio)"
else:
resolution = gr.Dropdown(
choices=[
# 720p
("1280x720 (16:9, 720p)", "1280x720"),
("720x1280 (9:16, 720p)", "720x1280"),
("1024x1024 (4:3, 720p)", "1024x024"),
("832x1104 (3:4, 720p)", "832x1104"),
("1104x832 (3:4, 720p)", "1104x832"),
("960x960 (1:1, 720p)", "960x960"),
# 480p
("960x544 (16:9, 540p)", "960x544"),
("544x960 (16:9, 540p)", "544x960"),
("832x480 (16:9, 480p)", "832x480"),
("480x832 (9:16, 480p)", "480x832"),
("832x624 (4:3, 480p)", "832x624"),
("624x832 (3:4, 480p)", "624x832"),
("720x720 (1:1, 480p)", "720x720"),
("512x512 (1:1, 480p)", "512x512"),
],
value=ui_defaults.get("resolution","832x480"),
label="Max Resolution (as it maybe less depending on video width / height ratio)" if test_class_i2v(model_filename) else "Resolution"
)
label = "Max Resolution (as it maybe less depending on video width / height ratio)"
resolution = gr.Dropdown(
choices=[
# 720p
("1280x720 (16:9, 720p)", "1280x720"),
("720x1280 (9:16, 720p)", "720x1280"),
("1024x1024 (4:3, 720p)", "1024x024"),
("832x1104 (3:4, 720p)", "832x1104"),
("1104x832 (3:4, 720p)", "1104x832"),
("960x960 (1:1, 720p)", "960x960"),
# 480p
("960x544 (16:9, 540p)", "960x544"),
("544x960 (16:9, 540p)", "544x960"),
("832x480 (16:9, 480p)", "832x480"),
("480x832 (9:16, 480p)", "480x832"),
("832x624 (4:3, 480p)", "832x624"),
("624x832 (3:4, 480p)", "624x832"),
("720x720 (1:1, 480p)", "720x720"),
("512x512 (1:1, 480p)", "512x512"),
],
value=ui_defaults.get("resolution","832x480"),
label= label
)
with gr.Row():
if recammaster:
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s), locked", interactive= False)
@ -4556,156 +4557,181 @@ def generate_configuration_tab(state, blocks, header, model_choice):
with gr.Column():
model_list = []
for model_type in model_types:
choice = get_model_filename(model_type, transformer_quantization)
model_list.append(choice)
dropdown_choices = [ ( get_model_name(choice), get_model_type(choice) ) for choice in model_list]
transformer_types_choices = gr.Dropdown(
choices= dropdown_choices,
value= transformer_types,
label= "Selectable Wan Transformer Models (keep empty to get All of them)",
scale= 2,
multiselect= True
)
quantization_choice = gr.Dropdown(
choices=[
("Scaled Int8 Quantization (recommended)", "int8"),
("16 bits (no quantization)", "bf16"),
],
value= transformer_quantization,
label="Wan Transformer Model Quantization Type (if available)",
)
with gr.Tabs():
# with gr.Row(visible=advanced_ui) as advanced_row:
with gr.Tab("General"):
for model_type in model_types:
choice = get_model_filename(model_type, transformer_quantization)
model_list.append(choice)
dropdown_choices = [ ( get_model_name(choice), get_model_type(choice) ) for choice in model_list]
transformer_types_choices = gr.Dropdown(
choices= dropdown_choices,
value= transformer_types,
label= "Selectable Wan Transformer Models (keep empty to get All of them)",
scale= 2,
multiselect= True
)
mixed_precision_choice = gr.Dropdown(
choices=[
("16 bits only, requires less VRAM", "0"),
("Mixed 16 / 32 bits, slightly more VRAM needed but better Quality", "1"),
],
value= server_config.get("mixed_precision", "0"),
label="Transformer Engine Calculation"
)
fit_canvas_choice = gr.Dropdown(
choices=[
("Dimensions correspond to the Pixels Budget (as the Prompt Image/Video will be resized to match this pixels budget, output video height or width may exceed the requested dimensions )", 0),
("Dimensions correspond to the Maximum Width and Height (as the Prompt Image/Video will be resized to fit into these dimensions, the output video may be smaller)", 1),
],
value= server_config.get("fit_canvas", 0),
label="Generated Video Dimensions when Prompt contains an Image or a Video",
interactive= not lock_ui_attention
)
index = text_encoder_choices.index(text_encoder_filename)
index = 0 if index ==0 else index
text_encoder_choice = gr.Dropdown(
choices=[
("UMT5 XXL 16 bits - unquantized text encoder, better quality uses more RAM", 0),
("UMT5 XXL quantized to 8 bits - quantized text encoder, slightly worse quality but uses less RAM", 1),
],
value= index,
label="Text Encoder model"
)
VAE_precision_choice = gr.Dropdown(
choices=[
("16 bits, requires less VRAM and faster", "16"),
("32 bits, requires twice more VRAM and slower but recommended with Window Sliding", "32"),
],
value= server_config.get("vae_precision", "16"),
label="VAE Encoding / Decoding precision"
)
def check(mode):
if not mode in attention_modes_installed:
return " (NOT INSTALLED)"
elif not mode in attention_modes_supported:
return " (NOT SUPPORTED)"
else:
return ""
attention_choice = gr.Dropdown(
choices=[
("Auto : pick sage2 > sage > sdpa depending on what is installed", "auto"),
("Scale Dot Product Attention: default, always available", "sdpa"),
("Flash" + check("flash")+ ": good quality - requires additional install (usually complex to set up on Windows without WSL)", "flash"),
("Xformers" + check("xformers")+ ": good quality - requires additional install (usually complex, may consume less VRAM to set up on Windows without WSL)", "xformers"),
("Sage" + check("sage")+ ": 30% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage"),
("Sage2" + check("sage2")+ ": 40% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage2"),
],
value= attention_mode,
label="Attention Type",
interactive= not lock_ui_attention
)
save_path_choice = gr.Textbox(
label="Output Folder for Generated Videos",
value=server_config.get("save_path", save_path)
)
def check(mode):
if not mode in attention_modes_installed:
return " (NOT INSTALLED)"
elif not mode in attention_modes_supported:
return " (NOT SUPPORTED)"
else:
return ""
attention_choice = gr.Dropdown(
choices=[
("Auto : pick sage2 > sage > sdpa depending on what is installed", "auto"),
("Scale Dot Product Attention: default, always available", "sdpa"),
("Flash" + check("flash")+ ": good quality - requires additional install (usually complex to set up on Windows without WSL)", "flash"),
("Xformers" + check("xformers")+ ": good quality - requires additional install (usually complex, may consume less VRAM to set up on Windows without WSL)", "xformers"),
("Sage" + check("sage")+ ": 30% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage"),
("Sage2" + check("sage2")+ ": 40% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage2"),
],
value= attention_mode,
label="Attention Type",
interactive= not lock_ui_attention
)
gr.Markdown("Beware: when restarting the server or changing a resolution or video duration, the first step of generation for a duration / resolution may last a few minutes due to recompilation")
compile_choice = gr.Dropdown(
choices=[
("ON: works only on Linux / WSL", "transformer"),
("OFF: no other choice if you have Windows without using WSL", "" ),
],
value= compile,
label="Compile Transformer (up to 50% faster and 30% more frames but requires Linux / WSL and Flash or Sage attention)",
interactive= not lock_ui_compile
)
vae_config_choice = gr.Dropdown(
choices=[
("Auto", 0),
("Disabled (faster but may require up to 22 GB of VRAM)", 1),
("256 x 256 : If at least 8 GB of VRAM", 2),
("128 x 128 : If at least 6 GB of VRAM", 3),
],
value= vae_config,
label="VAE Tiling - reduce the high VRAM requirements for VAE decoding and VAE encoding (if enabled it will be slower)"
)
boost_choice = gr.Dropdown(
choices=[
# ("Auto (ON if Video longer than 5s)", 0),
("ON", 1),
("OFF", 2),
],
value=boost,
label="Boost: Give a 10% speed speedup without losing quality at the cost of a litle VRAM (up to 1GB for max frames and resolution)"
)
profile_choice = gr.Dropdown(
choices=[
("HighRAM_HighVRAM, profile 1: at least 48 GB of RAM and 24 GB of VRAM, the fastest for short videos a RTX 3090 / RTX 4090", 1),
("HighRAM_LowVRAM, profile 2 (Recommended): at least 48 GB of RAM and 12 GB of VRAM, the most versatile profile with high RAM, better suited for RTX 3070/3080/4070/4080 or for RTX 3090 / RTX 4090 with large pictures batches or long videos", 2),
("LowRAM_HighVRAM, profile 3: at least 32 GB of RAM and 24 GB of VRAM, adapted for RTX 3090 / RTX 4090 with limited RAM for good speed short video",3),
("LowRAM_LowVRAM, profile 4 (Default): at least 32 GB of RAM and 12 GB of VRAM, if you have little VRAM or want to generate longer videos",4),
("VerylowRAM_LowVRAM, profile 5: (Fail safe): at least 16 GB of RAM and 10 GB of VRAM, if you don't have much it won't be fast but maybe it will work",5)
],
value= profile,
label="Profile (for power users only, not needed to change it)"
)
metadata_choice = gr.Dropdown(
choices=[
("Export JSON files", "json"),
("Add metadata to video", "metadata"),
("Neither", "none")
],
value=server_config.get("metadata_type", "metadata"),
label="Metadata Handling"
)
preload_model_policy_choice = gr.CheckboxGroup([("Preload Model while Launching the App","P"), ("Preload Model while Switching Model", "S"), ("Unload Model when Queue is Done", "U")],
value=server_config.get("preload_model_policy",[]),
label="RAM Loading / Unloading Model Policy (in any case VRAM will be freed once the queue has been processed)"
)
clear_file_list_choice = gr.Dropdown(
choices=[
("None", 0),
("Keep the last video", 1),
("Keep the last 5 videos", 5),
("Keep the last 10 videos", 10),
("Keep the last 20 videos", 20),
("Keep the last 30 videos", 30),
],
value=server_config.get("clear_file_list", 5),
label="Keep Previously Generated Videos when starting a Generation Batch"
)
metadata_choice = gr.Dropdown(
choices=[
("Export JSON files", "json"),
("Add metadata to video", "metadata"),
("Neither", "none")
],
value=server_config.get("metadata_type", "metadata"),
label="Metadata Handling"
)
preload_model_policy_choice = gr.CheckboxGroup([("Preload Model while Launching the App","P"), ("Preload Model while Switching Model", "S"), ("Unload Model when Queue is Done", "U")],
value=server_config.get("preload_model_policy",[]),
label="RAM Loading / Unloading Model Policy (in any case VRAM will be freed once the queue has been processed)"
)
clear_file_list_choice = gr.Dropdown(
choices=[
("None", 0),
("Keep the last video", 1),
("Keep the last 5 videos", 5),
("Keep the last 10 videos", 10),
("Keep the last 20 videos", 20),
("Keep the last 30 videos", 30),
],
value=server_config.get("clear_file_list", 5),
label="Keep Previously Generated Videos when starting a new Generation Batch"
)
UI_theme_choice = gr.Dropdown(
choices=[
("Blue Sky", "default"),
("Classic Gradio", "gradio"),
],
value=server_config.get("UI_theme_choice", "default"),
label="User Interface Theme. You will need to restart the App the see new Theme."
)
save_path_choice = gr.Textbox(
label="Output Folder for Generated Videos",
value=server_config.get("save_path", save_path)
)
with gr.Tab("Performance"):
quantization_choice = gr.Dropdown(
choices=[
("Scaled Int8 Quantization (recommended)", "int8"),
("16 bits (no quantization)", "bf16"),
],
value= transformer_quantization,
label="Wan Transformer Model Quantization Type (if available)",
)
mixed_precision_choice = gr.Dropdown(
choices=[
("16 bits only, requires less VRAM", "0"),
("Mixed 16 / 32 bits, slightly more VRAM needed but better Quality", "1"),
],
value= server_config.get("mixed_precision", "0"),
label="Transformer Engine Calculation"
)
index = text_encoder_choices.index(text_encoder_filename)
index = 0 if index ==0 else index
text_encoder_choice = gr.Dropdown(
choices=[
("UMT5 XXL 16 bits - unquantized text encoder, better quality uses more RAM", 0),
("UMT5 XXL quantized to 8 bits - quantized text encoder, slightly worse quality but uses less RAM", 1),
],
value= index,
label="Text Encoder model"
)
VAE_precision_choice = gr.Dropdown(
choices=[
("16 bits, requires less VRAM and faster", "16"),
("32 bits, requires twice more VRAM and slower but recommended with Window Sliding", "32"),
],
value= server_config.get("vae_precision", "16"),
label="VAE Encoding / Decoding precision"
)
gr.Text("Beware: when restarting the server or changing a resolution or video duration, the first step of generation for a duration / resolution may last a few minutes due to recompilation", interactive= False, show_label= False )
compile_choice = gr.Dropdown(
choices=[
("ON: works only on Linux / WSL", "transformer"),
("OFF: no other choice if you have Windows without using WSL", "" ),
],
value= compile,
label="Compile Transformer (up to 50% faster and 30% more frames but requires Linux / WSL and Flash or Sage attention)",
interactive= not lock_ui_compile
)
vae_config_choice = gr.Dropdown(
choices=[
("Auto", 0),
("Disabled (faster but may require up to 22 GB of VRAM)", 1),
("256 x 256 : If at least 8 GB of VRAM", 2),
("128 x 128 : If at least 6 GB of VRAM", 3),
],
value= vae_config,
label="VAE Tiling - reduce the high VRAM requirements for VAE decoding and VAE encoding (if enabled it will be slower)"
)
boost_choice = gr.Dropdown(
choices=[
# ("Auto (ON if Video longer than 5s)", 0),
("ON", 1),
("OFF", 2),
],
value=boost,
label="Boost: Give a 10% speedup without losing quality at the cost of a litle VRAM (up to 1GB at max frames and resolution)"
)
profile_choice = gr.Dropdown(
choices=[
("HighRAM_HighVRAM, profile 1: at least 48 GB of RAM and 24 GB of VRAM, the fastest for short videos a RTX 3090 / RTX 4090", 1),
("HighRAM_LowVRAM, profile 2 (Recommended): at least 48 GB of RAM and 12 GB of VRAM, the most versatile profile with high RAM, better suited for RTX 3070/3080/4070/4080 or for RTX 3090 / RTX 4090 with large pictures batches or long videos", 2),
("LowRAM_HighVRAM, profile 3: at least 32 GB of RAM and 24 GB of VRAM, adapted for RTX 3090 / RTX 4090 with limited RAM for good speed short video",3),
("LowRAM_LowVRAM, profile 4 (Default): at least 32 GB of RAM and 12 GB of VRAM, if you have little VRAM or want to generate longer videos",4),
("VerylowRAM_LowVRAM, profile 5: (Fail safe): at least 16 GB of RAM and 10 GB of VRAM, if you don't have much it won't be fast but maybe it will work",5)
],
value= profile,
label="Profile (for power users only, not needed to change it)"
)
UI_theme_choice = gr.Dropdown(
choices=[
("Blue Sky", "default"),
("Classic Gradio", "gradio"),
],
value=server_config.get("UI_theme_choice", "default"),
label="User Interface Theme. You will need to restart the App the see new Theme."
)
msg = gr.Markdown()
@ -4728,7 +4754,8 @@ def generate_configuration_tab(state, blocks, header, model_choice):
boost_choice,
clear_file_list_choice,
preload_model_policy_choice,
UI_theme_choice
UI_theme_choice,
fit_canvas_choice
],
outputs= [msg , header, model_choice]
)