mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-12-15 11:43:21 +00:00
fixed vace bugs
This commit is contained in:
parent
cf02cc4004
commit
94d9b4aa4d
@ -15,6 +15,7 @@ from .modules.model import WanModel
|
||||
from .modules.t5 import T5EncoderModel
|
||||
from .modules.vae import WanVAE
|
||||
from wan.modules.posemb_layers import get_rotary_pos_embed
|
||||
from wan.utils.utils import calculate_new_dimensions
|
||||
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
||||
get_sampling_sigmas, retrieve_timesteps)
|
||||
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||
@ -190,6 +191,7 @@ class DTT2V:
|
||||
input_video = None,
|
||||
height: int = 480,
|
||||
width: int = 832,
|
||||
fit_into_canvas = True,
|
||||
num_frames: int = 97,
|
||||
num_inference_steps: int = 50,
|
||||
shift: float = 1.0,
|
||||
@ -221,15 +223,16 @@ class DTT2V:
|
||||
i2v_extra_kwrags = {}
|
||||
prefix_video = None
|
||||
predix_video_latent_length = 0
|
||||
|
||||
if input_video != None:
|
||||
_ , _ , height, width = input_video.shape
|
||||
elif image != None:
|
||||
image = image[0]
|
||||
frame_width, frame_height = image.size
|
||||
scale = min(height / frame_height, width / frame_width)
|
||||
height = (int(frame_height * scale) // 16) * 16
|
||||
width = (int(frame_width * scale) // 16) * 16
|
||||
height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas)
|
||||
image = np.array(image.resize((width, height))).transpose(2, 0, 1)
|
||||
|
||||
|
||||
latent_length = (num_frames - 1) // 4 + 1
|
||||
latent_height = height // 8
|
||||
latent_width = width // 8
|
||||
|
||||
@ -25,7 +25,7 @@ from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
||||
get_sampling_sigmas, retrieve_timesteps)
|
||||
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||
from wan.modules.posemb_layers import get_rotary_pos_embed
|
||||
from wan.utils.utils import resize_lanczos
|
||||
from wan.utils.utils import resize_lanczos, calculate_new_dimensions
|
||||
|
||||
def optimized_scale(positive_flat, negative_flat):
|
||||
|
||||
@ -120,7 +120,7 @@ class WanI2V:
|
||||
img2 = None,
|
||||
height =720,
|
||||
width = 1280,
|
||||
max_area=720 * 1280,
|
||||
fit_into_canvas = True,
|
||||
frame_num=81,
|
||||
shift=5.0,
|
||||
sample_solver='unipc',
|
||||
@ -188,22 +188,16 @@ class WanI2V:
|
||||
if add_frames_for_end_image:
|
||||
frame_num +=1
|
||||
lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2)
|
||||
|
||||
|
||||
|
||||
h, w = img.shape[1:]
|
||||
# aspect_ratio = h / w
|
||||
|
||||
scale1 = min(height / h, width / w)
|
||||
scale2 = min(height / h, width / w)
|
||||
scale = max(scale1, scale2)
|
||||
new_height = int(h * scale)
|
||||
new_width = int(w * scale)
|
||||
|
||||
h, w = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
|
||||
|
||||
lat_h = round(
|
||||
new_height // self.vae_stride[1] //
|
||||
h // self.vae_stride[1] //
|
||||
self.patch_size[1] * self.patch_size[1])
|
||||
lat_w = round(
|
||||
new_width // self.vae_stride[2] //
|
||||
w // self.vae_stride[2] //
|
||||
self.patch_size[2] * self.patch_size[2])
|
||||
h = lat_h * self.vae_stride[1]
|
||||
w = lat_w * self.vae_stride[2]
|
||||
|
||||
@ -963,7 +963,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
hints_list = [None ] *len(x_list)
|
||||
else:
|
||||
# Vace embeddings
|
||||
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
||||
c = [self.vace_patch_embedding(u.to(self.vace_patch_embedding.weight.dtype).unsqueeze(0)) for u in vace_context]
|
||||
c = [u.flatten(2).transpose(1, 2) for u in c]
|
||||
c = c[0]
|
||||
|
||||
|
||||
@ -177,15 +177,16 @@ class WanT2V:
|
||||
def vace_latent(self, z, m):
|
||||
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
|
||||
|
||||
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, original_video = False, keep_frames= [], start_frame = 0, pre_src_video = None):
|
||||
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, original_video = False, keep_frames= [], start_frame = 0, fit_into_canvas = True, pre_src_video = None):
|
||||
image_sizes = []
|
||||
trim_video = len(keep_frames)
|
||||
canvas_height, canvas_width = image_size
|
||||
|
||||
for i, (sub_src_video, sub_src_mask, sub_pre_src_video) in enumerate(zip(src_video, src_mask,pre_src_video)):
|
||||
prepend_count = 0 if sub_pre_src_video == None else sub_pre_src_video.shape[1]
|
||||
num_frames = total_frames - prepend_count
|
||||
if sub_src_mask is not None and sub_src_video is not None:
|
||||
src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame)
|
||||
src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame, canvas_height = canvas_height, canvas_width = canvas_width, fit_into_canvas = fit_into_canvas)
|
||||
# src_video is [-1, 1], 0 = inpainting area (in fact 127 in [0, 255])
|
||||
# src_mask is [-1, 1], 0 = preserve original video (in fact 127 in [0, 255]) and 1 = Inpainting (in fact 255 in [0, 255])
|
||||
src_video[i] = src_video[i].to(device)
|
||||
@ -208,7 +209,7 @@ class WanT2V:
|
||||
src_mask[i] = torch.ones_like(src_video[i], device=device)
|
||||
image_sizes.append(image_size)
|
||||
else:
|
||||
src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame)
|
||||
src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame, canvas_height = canvas_height, canvas_width = canvas_width, fit_into_canvas = fit_into_canvas)
|
||||
src_video[i] = src_video[i].to(device)
|
||||
src_mask[i] = torch.zeros_like(src_video[i], device=device) if original_video else torch.ones_like(src_video[i], device=device)
|
||||
if prepend_count > 0:
|
||||
@ -277,6 +278,7 @@ class WanT2V:
|
||||
target_camera=None,
|
||||
context_scale=1.0,
|
||||
size=(1280, 720),
|
||||
fit_into_canvas = True,
|
||||
frame_num=81,
|
||||
shift=5.0,
|
||||
sample_solver='unipc',
|
||||
@ -430,7 +432,7 @@ class WanT2V:
|
||||
kwargs.update({'cam_emb': cam_emb})
|
||||
|
||||
if vace:
|
||||
ref_images_count = len(input_ref_images[0]) if input_ref_images != None else 0
|
||||
ref_images_count = len(input_ref_images[0]) if input_ref_images != None and input_ref_images[0] != None else 0
|
||||
kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale})
|
||||
if overlapped_latents > 0:
|
||||
z_reactive = [ zz[0:16, ref_images_count:overlapped_latents + ref_images_count].clone() for zz in z]
|
||||
|
||||
@ -67,7 +67,17 @@ def remove_background(img, session=None):
|
||||
return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0)
|
||||
|
||||
|
||||
def calculate_new_dimensions(canvas_height, canvas_width, height, width, fit_into_canvas):
|
||||
if fit_into_canvas:
|
||||
scale1 = min(canvas_height / height, canvas_width / width)
|
||||
scale2 = min(canvas_width / height, canvas_height / width)
|
||||
scale = max(scale1, scale2)
|
||||
else:
|
||||
scale = (canvas_height * canvas_width / (height * width))**(1/2)
|
||||
|
||||
new_height = round( height * scale / 16) * 16
|
||||
new_width = round( width * scale / 16) * 16
|
||||
return new_height, new_width
|
||||
|
||||
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, fit_into_canvas = False ):
|
||||
if rm_background:
|
||||
|
||||
@ -5,6 +5,7 @@ from PIL import Image
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms.functional as TF
|
||||
from .utils import calculate_new_dimensions
|
||||
|
||||
|
||||
class VaceImageProcessor(object):
|
||||
@ -182,53 +183,22 @@ class VaceVideoProcessor(object):
|
||||
|
||||
|
||||
|
||||
def _get_frameid_bbox_adjust_last(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0, start_frame =0):
|
||||
def _get_frameid_bbox_adjust_last(self, fps, video_frames_count, canvas_height, canvas_width, h, w, fit_into_canvas, crop_box, rng, max_frames= 0, start_frame =0):
|
||||
from wan.utils.utils import resample
|
||||
|
||||
target_fps = self.max_fps
|
||||
|
||||
# video_frames_count = len(frame_timestamps)
|
||||
|
||||
frame_ids= resample(fps, video_frames_count, max_frames, target_fps, start_frame )
|
||||
|
||||
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
|
||||
h, w = y2 - y1, x2 - x1
|
||||
ratio = h / w
|
||||
df, dh, dw = self.downsample
|
||||
seq_len = self.seq_len
|
||||
# min/max area of the [latent video]
|
||||
min_area_z = self.min_area / (dh * dw)
|
||||
# max_area_z = min(seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw))
|
||||
max_area_z = min_area_z # workaround bug
|
||||
# sample a frame number of the [latent video]
|
||||
rand_area_z = np.square(np.power(2, rng.uniform(
|
||||
np.log2(np.sqrt(min_area_z)),
|
||||
np.log2(np.sqrt(max_area_z))
|
||||
)))
|
||||
|
||||
seq_len = max_area_z * ((max_frames- start_frame - 1) // df +1)
|
||||
|
||||
# of = min(
|
||||
# (len(frame_ids) - 1) // df + 1,
|
||||
# int(seq_len / rand_area_z)
|
||||
# )
|
||||
of = (len(frame_ids) - 1) // df + 1
|
||||
|
||||
|
||||
# deduce target shape of the [latent video]
|
||||
# target_area_z = min(max_area_z, int(seq_len / of))
|
||||
target_area_z = max_area_z
|
||||
oh = round(np.sqrt(target_area_z * ratio))
|
||||
ow = int(target_area_z / oh)
|
||||
of = (of - 1) * df + 1
|
||||
oh *= dh
|
||||
ow *= dw
|
||||
oh, ow = calculate_new_dimensions(canvas_height, canvas_width, h, w, fit_into_canvas)
|
||||
|
||||
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
|
||||
|
||||
def _get_frameid_bbox(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0, start_frame= 0):
|
||||
def _get_frameid_bbox(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0, start_frame= 0, canvas_height = 0, canvas_width = 0, fit_into_canvas= True):
|
||||
if self.keep_last:
|
||||
return self._get_frameid_bbox_adjust_last(fps, video_frames_count, h, w, crop_box, rng, max_frames= max_frames, start_frame= start_frame)
|
||||
return self._get_frameid_bbox_adjust_last(fps, video_frames_count, canvas_height, canvas_width, h, w, fit_into_canvas, crop_box, rng, max_frames= max_frames, start_frame= start_frame)
|
||||
else:
|
||||
return self._get_frameid_bbox_default(fps, video_frames_count, h, w, crop_box, rng, max_frames= max_frames)
|
||||
|
||||
@ -238,7 +208,7 @@ class VaceVideoProcessor(object):
|
||||
def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs):
|
||||
return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs)
|
||||
|
||||
def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, max_frames= 0, trim_video =0, start_frame = 0, **kwargs):
|
||||
def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, max_frames= 0, trim_video =0, start_frame = 0, canvas_height = 0, canvas_width = 0, fit_into_canvas = False, **kwargs):
|
||||
rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000)
|
||||
# read video
|
||||
import decord
|
||||
@ -269,7 +239,7 @@ class VaceVideoProcessor(object):
|
||||
h, w = src_video.shape[1:3]
|
||||
else:
|
||||
h, w = readers[0].next().shape[:2]
|
||||
frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, length, h, w, crop_box, rng, max_frames=max_frames, start_frame = start_frame )
|
||||
frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, length, h, w, crop_box, rng, canvas_height = canvas_height, canvas_width = canvas_width, fit_into_canvas = fit_into_canvas, max_frames=max_frames, start_frame = start_frame )
|
||||
|
||||
# preprocess video
|
||||
videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers]
|
||||
|
||||
415
wgp.py
415
wgp.py
@ -84,7 +84,6 @@ def format_time(seconds):
|
||||
hours = int(seconds // 3600)
|
||||
minutes = int((seconds % 3600) // 60)
|
||||
return f"{hours}h {minutes}m"
|
||||
|
||||
def pil_to_base64_uri(pil_image, format="png", quality=75):
|
||||
if pil_image is None:
|
||||
return None
|
||||
@ -275,12 +274,12 @@ def process_prompt_and_add_tasks(state, model_choice):
|
||||
video_guide = inputs["video_guide"]
|
||||
video_mask = inputs["video_mask"]
|
||||
|
||||
if "1.3B" in model_filename :
|
||||
resolution_reformated = str(height) + "*" + str(width)
|
||||
if not resolution_reformated in VACE_SIZE_CONFIGS:
|
||||
res = (" and ").join(VACE_SIZE_CONFIGS.keys())
|
||||
gr.Info(f"Video Resolution for Vace model is not supported. Only {res} resolutions are allowed.")
|
||||
return
|
||||
# if "1.3B" in model_filename :
|
||||
# resolution_reformated = str(height) + "*" + str(width)
|
||||
# if not resolution_reformated in VACE_SIZE_CONFIGS:
|
||||
# res = (" and ").join(VACE_SIZE_CONFIGS.keys())
|
||||
# gr.Info(f"Video Resolution for Vace model is not supported. Only {res} resolutions are allowed.")
|
||||
# return
|
||||
if "I" in video_prompt_type:
|
||||
if image_refs == None:
|
||||
gr.Info("You must provide at least one Refererence Image")
|
||||
@ -1995,7 +1994,8 @@ def apply_changes( state,
|
||||
boost_choice = 1,
|
||||
clear_file_list = 0,
|
||||
preload_model_policy_choice = 1,
|
||||
UI_theme_choice = "default"
|
||||
UI_theme_choice = "default",
|
||||
fit_canvas_choice = 0
|
||||
):
|
||||
if args.lock_config:
|
||||
return
|
||||
@ -2016,7 +2016,8 @@ def apply_changes( state,
|
||||
"boost" : boost_choice,
|
||||
"clear_file_list" : clear_file_list,
|
||||
"preload_model_policy" : preload_model_policy_choice,
|
||||
"UI_theme" : UI_theme_choice
|
||||
"UI_theme" : UI_theme_choice,
|
||||
"fit_canvas": fit_canvas_choice,
|
||||
}
|
||||
|
||||
if Path(server_config_filename).is_file():
|
||||
@ -2050,7 +2051,7 @@ def apply_changes( state,
|
||||
transformer_quantization = server_config["transformer_quantization"]
|
||||
transformer_types = server_config["transformer_types"]
|
||||
|
||||
if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list"] for change in changes ):
|
||||
if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list", "fit_canvas"] for change in changes ):
|
||||
model_choice = gr.Dropdown()
|
||||
else:
|
||||
reload_needed = True
|
||||
@ -2413,7 +2414,7 @@ def generate_video(
|
||||
file_list = gen["file_list"]
|
||||
prompt_no = gen["prompt_no"]
|
||||
|
||||
|
||||
fit_canvas = server_config.get("fit_canvas", 0)
|
||||
# if wan_model == None:
|
||||
# gr.Info("Unable to generate a Video while a new configuration is being applied.")
|
||||
# return
|
||||
@ -2555,7 +2556,7 @@ def generate_video(
|
||||
source_video = None
|
||||
target_camera = None
|
||||
if "recam" in model_filename:
|
||||
source_video = preprocess_video("", width=width, height=height,video_in=video_source, max_frames= video_length, start_frame = 0, fit_canvas= True)
|
||||
source_video = preprocess_video("", width=width, height=height,video_in=video_source, max_frames= video_length, start_frame = 0, fit_canvas= fit_canvas)
|
||||
target_camera = model_mode
|
||||
|
||||
audio_proj_split = None
|
||||
@ -2646,7 +2647,7 @@ def generate_video(
|
||||
elif diffusion_forcing:
|
||||
if video_source != None and len(video_source) > 0 and window_no == 1:
|
||||
keep_frames_video_source= 1000 if len(keep_frames_video_source) ==0 else int(keep_frames_video_source)
|
||||
prefix_video = preprocess_video(None, width=width, height=height,video_in=video_source, max_frames= keep_frames_video_source , start_frame = 0, fit_canvas= True, target_fps = fps)
|
||||
prefix_video = preprocess_video(None, width=width, height=height,video_in=video_source, max_frames= keep_frames_video_source , start_frame = 0, fit_canvas= fit_canvas, target_fps = fps)
|
||||
prefix_video = prefix_video .permute(3, 0, 1, 2)
|
||||
prefix_video = prefix_video .float().div_(127.5).sub_(1.) # c, f, h, w
|
||||
prefix_video_frames_count = prefix_video.shape[1]
|
||||
@ -2675,13 +2676,13 @@ def generate_video(
|
||||
|
||||
if preprocess_type != None :
|
||||
send_cmd("progress", progress_args)
|
||||
video_guide_copy = preprocess_video(preprocess_type, width=width, height=height,video_in=video_guide, max_frames= video_length if window_no == 1 else video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = True, target_fps = fps)
|
||||
video_guide_copy = preprocess_video(preprocess_type, width=width, height=height,video_in=video_guide, max_frames= video_length if window_no == 1 else video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = fit_canvas, target_fps = fps)
|
||||
keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, max_frames_to_generate)
|
||||
if len(error) > 0:
|
||||
raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
|
||||
keep_frames_parsed = keep_frames_parsed[guide_start_frame: guide_start_frame + video_length]
|
||||
if window_no == 1:
|
||||
image_size = VACE_SIZE_CONFIGS[resolution_reformated] # default frame dimensions until it is set by video_src (if there is any)
|
||||
image_size = (height, width) # VACE_SIZE_CONFIGS[resolution_reformated] # default frame dimensions until it is set by video_src (if there is any)
|
||||
src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_copy],
|
||||
[video_mask_copy ],
|
||||
[image_refs_copy],
|
||||
@ -2689,10 +2690,11 @@ def generate_video(
|
||||
original_video= "O" in video_prompt_type,
|
||||
keep_frames=keep_frames_parsed,
|
||||
start_frame = guide_start_frame,
|
||||
pre_src_video = [pre_video_guide]
|
||||
pre_src_video = [pre_video_guide],
|
||||
fit_into_canvas = fit_canvas
|
||||
)
|
||||
if window_no == 1 and src_video != None and len(src_video) > 0:
|
||||
image_size = src_video[0].shape[-2:]
|
||||
# if window_no == 1 and src_video != None and len(src_video) > 0:
|
||||
# image_size = src_video[0].shape[-2:]
|
||||
prompts_max = gen["prompts_max"]
|
||||
status = get_latest_status(state)
|
||||
|
||||
@ -2722,6 +2724,7 @@ def generate_video(
|
||||
# max_area=MAX_AREA_CONFIGS[resolution_reformated],
|
||||
height = height,
|
||||
width = width,
|
||||
fit_into_canvas = fit_canvas,
|
||||
shift=flow_shift,
|
||||
sampling_steps=num_inference_steps,
|
||||
guide_scale=guidance_scale,
|
||||
@ -2750,6 +2753,7 @@ def generate_video(
|
||||
input_video= pre_video_guide,
|
||||
height = height,
|
||||
width = width,
|
||||
fit_into_canvas = fit_canvas,
|
||||
seed = seed,
|
||||
num_frames = (video_length // 4)* 4 + 1, #377
|
||||
num_inference_steps = num_inference_steps,
|
||||
@ -2777,6 +2781,7 @@ def generate_video(
|
||||
target_camera= target_camera,
|
||||
frame_num=(video_length // 4)* 4 + 1,
|
||||
size=(width, height),
|
||||
fit_into_canvas = fit_canvas,
|
||||
shift=flow_shift,
|
||||
sampling_steps=num_inference_steps,
|
||||
guide_scale=guidance_scale,
|
||||
@ -4042,39 +4047,35 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False)
|
||||
wizard_variables_var = gr.Text(wizard_variables, visible = False)
|
||||
with gr.Row():
|
||||
if test_class_i2v(model_filename) and False:
|
||||
resolution = gr.Dropdown(
|
||||
choices=[
|
||||
# 720p
|
||||
("720p (same amount of pixels)", "1280x720"),
|
||||
("480p (same amount of pixels)", "832x480"),
|
||||
],
|
||||
value=ui_defaults.get("resolution","480p"),
|
||||
label="Resolution (video will have the same height / width ratio than the original image)"
|
||||
)
|
||||
if test_class_i2v(model_filename):
|
||||
if server_config.get("fit_canvas", 0) == 1:
|
||||
label = "Max Resolution (as it maybe less depending on video width / height ratio)"
|
||||
else:
|
||||
label = "Max Resolution (as it maybe less depending on video width / height ratio)"
|
||||
else:
|
||||
resolution = gr.Dropdown(
|
||||
choices=[
|
||||
# 720p
|
||||
("1280x720 (16:9, 720p)", "1280x720"),
|
||||
("720x1280 (9:16, 720p)", "720x1280"),
|
||||
("1024x1024 (4:3, 720p)", "1024x024"),
|
||||
("832x1104 (3:4, 720p)", "832x1104"),
|
||||
("1104x832 (3:4, 720p)", "1104x832"),
|
||||
("960x960 (1:1, 720p)", "960x960"),
|
||||
# 480p
|
||||
("960x544 (16:9, 540p)", "960x544"),
|
||||
("544x960 (16:9, 540p)", "544x960"),
|
||||
("832x480 (16:9, 480p)", "832x480"),
|
||||
("480x832 (9:16, 480p)", "480x832"),
|
||||
("832x624 (4:3, 480p)", "832x624"),
|
||||
("624x832 (3:4, 480p)", "624x832"),
|
||||
("720x720 (1:1, 480p)", "720x720"),
|
||||
("512x512 (1:1, 480p)", "512x512"),
|
||||
],
|
||||
value=ui_defaults.get("resolution","832x480"),
|
||||
label="Max Resolution (as it maybe less depending on video width / height ratio)" if test_class_i2v(model_filename) else "Resolution"
|
||||
)
|
||||
label = "Max Resolution (as it maybe less depending on video width / height ratio)"
|
||||
resolution = gr.Dropdown(
|
||||
choices=[
|
||||
# 720p
|
||||
("1280x720 (16:9, 720p)", "1280x720"),
|
||||
("720x1280 (9:16, 720p)", "720x1280"),
|
||||
("1024x1024 (4:3, 720p)", "1024x024"),
|
||||
("832x1104 (3:4, 720p)", "832x1104"),
|
||||
("1104x832 (3:4, 720p)", "1104x832"),
|
||||
("960x960 (1:1, 720p)", "960x960"),
|
||||
# 480p
|
||||
("960x544 (16:9, 540p)", "960x544"),
|
||||
("544x960 (16:9, 540p)", "544x960"),
|
||||
("832x480 (16:9, 480p)", "832x480"),
|
||||
("480x832 (9:16, 480p)", "480x832"),
|
||||
("832x624 (4:3, 480p)", "832x624"),
|
||||
("624x832 (3:4, 480p)", "624x832"),
|
||||
("720x720 (1:1, 480p)", "720x720"),
|
||||
("512x512 (1:1, 480p)", "512x512"),
|
||||
],
|
||||
value=ui_defaults.get("resolution","832x480"),
|
||||
label= label
|
||||
)
|
||||
with gr.Row():
|
||||
if recammaster:
|
||||
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s), locked", interactive= False)
|
||||
@ -4556,156 +4557,181 @@ def generate_configuration_tab(state, blocks, header, model_choice):
|
||||
with gr.Column():
|
||||
model_list = []
|
||||
|
||||
for model_type in model_types:
|
||||
choice = get_model_filename(model_type, transformer_quantization)
|
||||
model_list.append(choice)
|
||||
dropdown_choices = [ ( get_model_name(choice), get_model_type(choice) ) for choice in model_list]
|
||||
transformer_types_choices = gr.Dropdown(
|
||||
choices= dropdown_choices,
|
||||
value= transformer_types,
|
||||
label= "Selectable Wan Transformer Models (keep empty to get All of them)",
|
||||
scale= 2,
|
||||
multiselect= True
|
||||
)
|
||||
|
||||
quantization_choice = gr.Dropdown(
|
||||
choices=[
|
||||
("Scaled Int8 Quantization (recommended)", "int8"),
|
||||
("16 bits (no quantization)", "bf16"),
|
||||
],
|
||||
value= transformer_quantization,
|
||||
label="Wan Transformer Model Quantization Type (if available)",
|
||||
)
|
||||
with gr.Tabs():
|
||||
# with gr.Row(visible=advanced_ui) as advanced_row:
|
||||
with gr.Tab("General"):
|
||||
for model_type in model_types:
|
||||
choice = get_model_filename(model_type, transformer_quantization)
|
||||
model_list.append(choice)
|
||||
dropdown_choices = [ ( get_model_name(choice), get_model_type(choice) ) for choice in model_list]
|
||||
transformer_types_choices = gr.Dropdown(
|
||||
choices= dropdown_choices,
|
||||
value= transformer_types,
|
||||
label= "Selectable Wan Transformer Models (keep empty to get All of them)",
|
||||
scale= 2,
|
||||
multiselect= True
|
||||
)
|
||||
|
||||
mixed_precision_choice = gr.Dropdown(
|
||||
choices=[
|
||||
("16 bits only, requires less VRAM", "0"),
|
||||
("Mixed 16 / 32 bits, slightly more VRAM needed but better Quality", "1"),
|
||||
],
|
||||
value= server_config.get("mixed_precision", "0"),
|
||||
label="Transformer Engine Calculation"
|
||||
)
|
||||
fit_canvas_choice = gr.Dropdown(
|
||||
choices=[
|
||||
("Dimensions correspond to the Pixels Budget (as the Prompt Image/Video will be resized to match this pixels budget, output video height or width may exceed the requested dimensions )", 0),
|
||||
("Dimensions correspond to the Maximum Width and Height (as the Prompt Image/Video will be resized to fit into these dimensions, the output video may be smaller)", 1),
|
||||
],
|
||||
value= server_config.get("fit_canvas", 0),
|
||||
label="Generated Video Dimensions when Prompt contains an Image or a Video",
|
||||
interactive= not lock_ui_attention
|
||||
)
|
||||
|
||||
index = text_encoder_choices.index(text_encoder_filename)
|
||||
index = 0 if index ==0 else index
|
||||
text_encoder_choice = gr.Dropdown(
|
||||
choices=[
|
||||
("UMT5 XXL 16 bits - unquantized text encoder, better quality uses more RAM", 0),
|
||||
("UMT5 XXL quantized to 8 bits - quantized text encoder, slightly worse quality but uses less RAM", 1),
|
||||
],
|
||||
value= index,
|
||||
label="Text Encoder model"
|
||||
)
|
||||
|
||||
VAE_precision_choice = gr.Dropdown(
|
||||
choices=[
|
||||
("16 bits, requires less VRAM and faster", "16"),
|
||||
("32 bits, requires twice more VRAM and slower but recommended with Window Sliding", "32"),
|
||||
],
|
||||
value= server_config.get("vae_precision", "16"),
|
||||
label="VAE Encoding / Decoding precision"
|
||||
)
|
||||
def check(mode):
|
||||
if not mode in attention_modes_installed:
|
||||
return " (NOT INSTALLED)"
|
||||
elif not mode in attention_modes_supported:
|
||||
return " (NOT SUPPORTED)"
|
||||
else:
|
||||
return ""
|
||||
attention_choice = gr.Dropdown(
|
||||
choices=[
|
||||
("Auto : pick sage2 > sage > sdpa depending on what is installed", "auto"),
|
||||
("Scale Dot Product Attention: default, always available", "sdpa"),
|
||||
("Flash" + check("flash")+ ": good quality - requires additional install (usually complex to set up on Windows without WSL)", "flash"),
|
||||
("Xformers" + check("xformers")+ ": good quality - requires additional install (usually complex, may consume less VRAM to set up on Windows without WSL)", "xformers"),
|
||||
("Sage" + check("sage")+ ": 30% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage"),
|
||||
("Sage2" + check("sage2")+ ": 40% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage2"),
|
||||
],
|
||||
value= attention_mode,
|
||||
label="Attention Type",
|
||||
interactive= not lock_ui_attention
|
||||
)
|
||||
|
||||
save_path_choice = gr.Textbox(
|
||||
label="Output Folder for Generated Videos",
|
||||
value=server_config.get("save_path", save_path)
|
||||
)
|
||||
def check(mode):
|
||||
if not mode in attention_modes_installed:
|
||||
return " (NOT INSTALLED)"
|
||||
elif not mode in attention_modes_supported:
|
||||
return " (NOT SUPPORTED)"
|
||||
else:
|
||||
return ""
|
||||
attention_choice = gr.Dropdown(
|
||||
choices=[
|
||||
("Auto : pick sage2 > sage > sdpa depending on what is installed", "auto"),
|
||||
("Scale Dot Product Attention: default, always available", "sdpa"),
|
||||
("Flash" + check("flash")+ ": good quality - requires additional install (usually complex to set up on Windows without WSL)", "flash"),
|
||||
("Xformers" + check("xformers")+ ": good quality - requires additional install (usually complex, may consume less VRAM to set up on Windows without WSL)", "xformers"),
|
||||
("Sage" + check("sage")+ ": 30% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage"),
|
||||
("Sage2" + check("sage2")+ ": 40% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage2"),
|
||||
],
|
||||
value= attention_mode,
|
||||
label="Attention Type",
|
||||
interactive= not lock_ui_attention
|
||||
)
|
||||
gr.Markdown("Beware: when restarting the server or changing a resolution or video duration, the first step of generation for a duration / resolution may last a few minutes due to recompilation")
|
||||
compile_choice = gr.Dropdown(
|
||||
choices=[
|
||||
("ON: works only on Linux / WSL", "transformer"),
|
||||
("OFF: no other choice if you have Windows without using WSL", "" ),
|
||||
],
|
||||
value= compile,
|
||||
label="Compile Transformer (up to 50% faster and 30% more frames but requires Linux / WSL and Flash or Sage attention)",
|
||||
interactive= not lock_ui_compile
|
||||
)
|
||||
vae_config_choice = gr.Dropdown(
|
||||
choices=[
|
||||
("Auto", 0),
|
||||
("Disabled (faster but may require up to 22 GB of VRAM)", 1),
|
||||
("256 x 256 : If at least 8 GB of VRAM", 2),
|
||||
("128 x 128 : If at least 6 GB of VRAM", 3),
|
||||
],
|
||||
value= vae_config,
|
||||
label="VAE Tiling - reduce the high VRAM requirements for VAE decoding and VAE encoding (if enabled it will be slower)"
|
||||
)
|
||||
boost_choice = gr.Dropdown(
|
||||
choices=[
|
||||
# ("Auto (ON if Video longer than 5s)", 0),
|
||||
("ON", 1),
|
||||
("OFF", 2),
|
||||
],
|
||||
value=boost,
|
||||
label="Boost: Give a 10% speed speedup without losing quality at the cost of a litle VRAM (up to 1GB for max frames and resolution)"
|
||||
)
|
||||
profile_choice = gr.Dropdown(
|
||||
choices=[
|
||||
("HighRAM_HighVRAM, profile 1: at least 48 GB of RAM and 24 GB of VRAM, the fastest for short videos a RTX 3090 / RTX 4090", 1),
|
||||
("HighRAM_LowVRAM, profile 2 (Recommended): at least 48 GB of RAM and 12 GB of VRAM, the most versatile profile with high RAM, better suited for RTX 3070/3080/4070/4080 or for RTX 3090 / RTX 4090 with large pictures batches or long videos", 2),
|
||||
("LowRAM_HighVRAM, profile 3: at least 32 GB of RAM and 24 GB of VRAM, adapted for RTX 3090 / RTX 4090 with limited RAM for good speed short video",3),
|
||||
("LowRAM_LowVRAM, profile 4 (Default): at least 32 GB of RAM and 12 GB of VRAM, if you have little VRAM or want to generate longer videos",4),
|
||||
("VerylowRAM_LowVRAM, profile 5: (Fail safe): at least 16 GB of RAM and 10 GB of VRAM, if you don't have much it won't be fast but maybe it will work",5)
|
||||
],
|
||||
value= profile,
|
||||
label="Profile (for power users only, not needed to change it)"
|
||||
)
|
||||
|
||||
metadata_choice = gr.Dropdown(
|
||||
choices=[
|
||||
("Export JSON files", "json"),
|
||||
("Add metadata to video", "metadata"),
|
||||
("Neither", "none")
|
||||
],
|
||||
value=server_config.get("metadata_type", "metadata"),
|
||||
label="Metadata Handling"
|
||||
)
|
||||
preload_model_policy_choice = gr.CheckboxGroup([("Preload Model while Launching the App","P"), ("Preload Model while Switching Model", "S"), ("Unload Model when Queue is Done", "U")],
|
||||
value=server_config.get("preload_model_policy",[]),
|
||||
label="RAM Loading / Unloading Model Policy (in any case VRAM will be freed once the queue has been processed)"
|
||||
)
|
||||
|
||||
clear_file_list_choice = gr.Dropdown(
|
||||
choices=[
|
||||
("None", 0),
|
||||
("Keep the last video", 1),
|
||||
("Keep the last 5 videos", 5),
|
||||
("Keep the last 10 videos", 10),
|
||||
("Keep the last 20 videos", 20),
|
||||
("Keep the last 30 videos", 30),
|
||||
],
|
||||
value=server_config.get("clear_file_list", 5),
|
||||
label="Keep Previously Generated Videos when starting a Generation Batch"
|
||||
)
|
||||
metadata_choice = gr.Dropdown(
|
||||
choices=[
|
||||
("Export JSON files", "json"),
|
||||
("Add metadata to video", "metadata"),
|
||||
("Neither", "none")
|
||||
],
|
||||
value=server_config.get("metadata_type", "metadata"),
|
||||
label="Metadata Handling"
|
||||
)
|
||||
preload_model_policy_choice = gr.CheckboxGroup([("Preload Model while Launching the App","P"), ("Preload Model while Switching Model", "S"), ("Unload Model when Queue is Done", "U")],
|
||||
value=server_config.get("preload_model_policy",[]),
|
||||
label="RAM Loading / Unloading Model Policy (in any case VRAM will be freed once the queue has been processed)"
|
||||
)
|
||||
|
||||
clear_file_list_choice = gr.Dropdown(
|
||||
choices=[
|
||||
("None", 0),
|
||||
("Keep the last video", 1),
|
||||
("Keep the last 5 videos", 5),
|
||||
("Keep the last 10 videos", 10),
|
||||
("Keep the last 20 videos", 20),
|
||||
("Keep the last 30 videos", 30),
|
||||
],
|
||||
value=server_config.get("clear_file_list", 5),
|
||||
label="Keep Previously Generated Videos when starting a new Generation Batch"
|
||||
)
|
||||
|
||||
UI_theme_choice = gr.Dropdown(
|
||||
choices=[
|
||||
("Blue Sky", "default"),
|
||||
("Classic Gradio", "gradio"),
|
||||
],
|
||||
value=server_config.get("UI_theme_choice", "default"),
|
||||
label="User Interface Theme. You will need to restart the App the see new Theme."
|
||||
)
|
||||
|
||||
save_path_choice = gr.Textbox(
|
||||
label="Output Folder for Generated Videos",
|
||||
value=server_config.get("save_path", save_path)
|
||||
)
|
||||
|
||||
with gr.Tab("Performance"):
|
||||
|
||||
quantization_choice = gr.Dropdown(
|
||||
choices=[
|
||||
("Scaled Int8 Quantization (recommended)", "int8"),
|
||||
("16 bits (no quantization)", "bf16"),
|
||||
],
|
||||
value= transformer_quantization,
|
||||
label="Wan Transformer Model Quantization Type (if available)",
|
||||
)
|
||||
|
||||
mixed_precision_choice = gr.Dropdown(
|
||||
choices=[
|
||||
("16 bits only, requires less VRAM", "0"),
|
||||
("Mixed 16 / 32 bits, slightly more VRAM needed but better Quality", "1"),
|
||||
],
|
||||
value= server_config.get("mixed_precision", "0"),
|
||||
label="Transformer Engine Calculation"
|
||||
)
|
||||
|
||||
index = text_encoder_choices.index(text_encoder_filename)
|
||||
index = 0 if index ==0 else index
|
||||
text_encoder_choice = gr.Dropdown(
|
||||
choices=[
|
||||
("UMT5 XXL 16 bits - unquantized text encoder, better quality uses more RAM", 0),
|
||||
("UMT5 XXL quantized to 8 bits - quantized text encoder, slightly worse quality but uses less RAM", 1),
|
||||
],
|
||||
value= index,
|
||||
label="Text Encoder model"
|
||||
)
|
||||
|
||||
VAE_precision_choice = gr.Dropdown(
|
||||
choices=[
|
||||
("16 bits, requires less VRAM and faster", "16"),
|
||||
("32 bits, requires twice more VRAM and slower but recommended with Window Sliding", "32"),
|
||||
],
|
||||
value= server_config.get("vae_precision", "16"),
|
||||
label="VAE Encoding / Decoding precision"
|
||||
)
|
||||
|
||||
gr.Text("Beware: when restarting the server or changing a resolution or video duration, the first step of generation for a duration / resolution may last a few minutes due to recompilation", interactive= False, show_label= False )
|
||||
compile_choice = gr.Dropdown(
|
||||
choices=[
|
||||
("ON: works only on Linux / WSL", "transformer"),
|
||||
("OFF: no other choice if you have Windows without using WSL", "" ),
|
||||
],
|
||||
value= compile,
|
||||
label="Compile Transformer (up to 50% faster and 30% more frames but requires Linux / WSL and Flash or Sage attention)",
|
||||
interactive= not lock_ui_compile
|
||||
)
|
||||
|
||||
vae_config_choice = gr.Dropdown(
|
||||
choices=[
|
||||
("Auto", 0),
|
||||
("Disabled (faster but may require up to 22 GB of VRAM)", 1),
|
||||
("256 x 256 : If at least 8 GB of VRAM", 2),
|
||||
("128 x 128 : If at least 6 GB of VRAM", 3),
|
||||
],
|
||||
value= vae_config,
|
||||
label="VAE Tiling - reduce the high VRAM requirements for VAE decoding and VAE encoding (if enabled it will be slower)"
|
||||
)
|
||||
|
||||
boost_choice = gr.Dropdown(
|
||||
choices=[
|
||||
# ("Auto (ON if Video longer than 5s)", 0),
|
||||
("ON", 1),
|
||||
("OFF", 2),
|
||||
],
|
||||
value=boost,
|
||||
label="Boost: Give a 10% speedup without losing quality at the cost of a litle VRAM (up to 1GB at max frames and resolution)"
|
||||
)
|
||||
|
||||
profile_choice = gr.Dropdown(
|
||||
choices=[
|
||||
("HighRAM_HighVRAM, profile 1: at least 48 GB of RAM and 24 GB of VRAM, the fastest for short videos a RTX 3090 / RTX 4090", 1),
|
||||
("HighRAM_LowVRAM, profile 2 (Recommended): at least 48 GB of RAM and 12 GB of VRAM, the most versatile profile with high RAM, better suited for RTX 3070/3080/4070/4080 or for RTX 3090 / RTX 4090 with large pictures batches or long videos", 2),
|
||||
("LowRAM_HighVRAM, profile 3: at least 32 GB of RAM and 24 GB of VRAM, adapted for RTX 3090 / RTX 4090 with limited RAM for good speed short video",3),
|
||||
("LowRAM_LowVRAM, profile 4 (Default): at least 32 GB of RAM and 12 GB of VRAM, if you have little VRAM or want to generate longer videos",4),
|
||||
("VerylowRAM_LowVRAM, profile 5: (Fail safe): at least 16 GB of RAM and 10 GB of VRAM, if you don't have much it won't be fast but maybe it will work",5)
|
||||
],
|
||||
value= profile,
|
||||
label="Profile (for power users only, not needed to change it)"
|
||||
)
|
||||
|
||||
|
||||
UI_theme_choice = gr.Dropdown(
|
||||
choices=[
|
||||
("Blue Sky", "default"),
|
||||
("Classic Gradio", "gradio"),
|
||||
],
|
||||
value=server_config.get("UI_theme_choice", "default"),
|
||||
label="User Interface Theme. You will need to restart the App the see new Theme."
|
||||
)
|
||||
|
||||
|
||||
msg = gr.Markdown()
|
||||
@ -4728,7 +4754,8 @@ def generate_configuration_tab(state, blocks, header, model_choice):
|
||||
boost_choice,
|
||||
clear_file_list_choice,
|
||||
preload_model_policy_choice,
|
||||
UI_theme_choice
|
||||
UI_theme_choice,
|
||||
fit_canvas_choice
|
||||
],
|
||||
outputs= [msg , header, model_choice]
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user