diff --git a/README.md b/README.md index 9513ead..4532d97 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,7 @@ ## π₯ Latest News!! +* May 5 2025: π Wan 2.1GP v4.5: FantasySpeaking model, you can animate a talking head using a voice track. This works not only on people but also on objects. Also better seamless transitions between Vace sliding windows for very long videos (see recommended settings). New high quality processing features (mixed 16/32 bits calculation and 32 bitsVAE) * April 27 2025: π Wan 2.1GP v4.4: Phantom model support, very good model to transfer people or objects into video, works quite well at 720p and with the number of steps > 30 * April 25 2025: π Wan 2.1GP v4.3: Added preview mode and support for Sky Reels v2 Diffusion Forcing for high quality "infinite length videos" (see Window Sliding section below).Note that Skyreel uses causal attention that is only supported by Sdpa attention so even if chose an other type of attention, some of the processes will use Sdpa attention. @@ -303,7 +304,13 @@ Vace provides on its github (https://github.com/ali-vilab/VACE/tree/main/vace/gr There is also a guide that describes the various combination of hints (https://github.com/ali-vilab/VACE/blob/main/UserGuide.md).Good luck ! -It seems you will get better results if you turn on "Skip Layer Guidance" with its default configuration +It seems you will get better results with Vace if you turn on "Skip Layer Guidance" with its default configuration. + +Other recommended setttings for Vace: +- Use a long prompt description especially for the people / objects that are in the background and not in reference images. This will ensure consistency between the windows. +- Set a medium size overlap window: long enough to give the model a sense of the motion but short enough so any overlapped blurred frames do no turn the rest of the video into a blurred video +- Truncate at least the last 4 frames of the each generated window as Vace last frames tends to be blurry + ### VACE and Sky Reels v2 Diffusion Forcing Slidig Window With this mode (that works for the moment only with Vace and Sky Reels v2) you can merge mutiple Videos to form a very long video (up to 1 min). diff --git a/fantasytalking/infer.py b/fantasytalking/infer.py new file mode 100644 index 0000000..f2d4964 --- /dev/null +++ b/fantasytalking/infer.py @@ -0,0 +1,27 @@ +# Copyright Alibaba Inc. All Rights Reserved. + +from transformers import Wav2Vec2Model, Wav2Vec2Processor + +from .model import FantasyTalkingAudioConditionModel +from .utils import get_audio_features + + +def parse_audio(audio_path, num_frames, fps = 23, device = "cuda"): + fantasytalking = FantasyTalkingAudioConditionModel(None, 768, 2048).to(device) + from mmgp import offload + from accelerate import init_empty_weights + from fantasytalking.model import AudioProjModel + with init_empty_weights(): + proj_model = AudioProjModel( 768, 2048) + offload.load_model_data(proj_model, "ckpts/fantasy_proj_model.safetensors") + proj_model.to(device).eval().requires_grad_(False) + + wav2vec_model_dir = "ckpts/wav2vec" + wav2vec_processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_dir) + wav2vec = Wav2Vec2Model.from_pretrained(wav2vec_model_dir).to(device).eval().requires_grad_(False) + audio_wav2vec_fea = get_audio_features( wav2vec, wav2vec_processor, audio_path, fps, num_frames ) + + audio_proj_fea = proj_model(audio_wav2vec_fea) + pos_idx_ranges = fantasytalking.split_audio_sequence( audio_proj_fea.size(1), num_frames=num_frames ) + audio_proj_split, audio_context_lens = fantasytalking.split_tensor_with_padding( audio_proj_fea, pos_idx_ranges, expand_length=4 ) # [b,21,9+8,768] + return audio_proj_split, audio_context_lens \ No newline at end of file diff --git a/fantasytalking/model.py b/fantasytalking/model.py new file mode 100644 index 0000000..5ec3655 --- /dev/null +++ b/fantasytalking/model.py @@ -0,0 +1,162 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from wan.modules.attention import pay_attention + + +class AudioProjModel(nn.Module): + def __init__(self, audio_in_dim=1024, cross_attention_dim=1024): + super().__init__() + self.cross_attention_dim = cross_attention_dim + self.proj = torch.nn.Linear(audio_in_dim, cross_attention_dim, bias=False) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, audio_embeds): + context_tokens = self.proj(audio_embeds) + context_tokens = self.norm(context_tokens) + return context_tokens # [B,L,C] + +class WanCrossAttentionProcessor(nn.Module): + def __init__(self, context_dim, hidden_dim): + super().__init__() + + self.context_dim = context_dim + self.hidden_dim = hidden_dim + + self.k_proj = nn.Linear(context_dim, hidden_dim, bias=False) + self.v_proj = nn.Linear(context_dim, hidden_dim, bias=False) + + nn.init.zeros_(self.k_proj.weight) + nn.init.zeros_(self.v_proj.weight) + + def __call__( + self, + q: torch.Tensor, + audio_proj: torch.Tensor, + latents_num_frames: int = 21, + audio_context_lens = None + ) -> torch.Tensor: + """ + audio_proj: [B, 21, L3, C] + audio_context_lens: [B*21]. + """ + b, l, n, d = q.shape + + if len(audio_proj.shape) == 4: + audio_q = q.view(b * latents_num_frames, -1, n, d) # [b, 21, l1, n, d] + ip_key = self.k_proj(audio_proj).view(b * latents_num_frames, -1, n, d) + ip_value = self.v_proj(audio_proj).view(b * latents_num_frames, -1, n, d) + qkv_list = [audio_q, ip_key, ip_value] + del q, audio_q, ip_key, ip_value + audio_x = pay_attention(qkv_list, k_lens =audio_context_lens) #audio_context_lens + audio_x = audio_x.view(b, l, n, d) + audio_x = audio_x.flatten(2) + elif len(audio_proj.shape) == 3: + ip_key = self.k_proj(audio_proj).view(b, -1, n, d) + ip_value = self.v_proj(audio_proj).view(b, -1, n, d) + qkv_list = [q, ip_key, ip_value] + del q, ip_key, ip_value + audio_x = pay_attention(qkv_list, k_lens =audio_context_lens) #audio_context_lens + audio_x = audio_x.flatten(2) + return audio_x + + +class FantasyTalkingAudioConditionModel(nn.Module): + def __init__(self, wan_dit, audio_in_dim: int, audio_proj_dim: int): + super().__init__() + + self.audio_in_dim = audio_in_dim + self.audio_proj_dim = audio_proj_dim + + def split_audio_sequence(self, audio_proj_length, num_frames=81): + """ + Map the audio feature sequence to corresponding latent frame slices. + + Args: + audio_proj_length (int): The total length of the audio feature sequence + (e.g., 173 in audio_proj[1, 173, 768]). + num_frames (int): The number of video frames in the training data (default: 81). + + Returns: + list: A list of [start_idx, end_idx] pairs. Each pair represents the index range + (within the audio feature sequence) corresponding to a latent frame. + """ + # Average number of tokens per original video frame + tokens_per_frame = audio_proj_length / num_frames + + # Each latent frame covers 4 video frames, and we want the center + tokens_per_latent_frame = tokens_per_frame * 4 + half_tokens = int(tokens_per_latent_frame / 2) + + pos_indices = [] + for i in range(int((num_frames - 1) / 4) + 1): + if i == 0: + pos_indices.append(0) + else: + start_token = tokens_per_frame * ((i - 1) * 4 + 1) + end_token = tokens_per_frame * (i * 4 + 1) + center_token = int((start_token + end_token) / 2) - 1 + pos_indices.append(center_token) + + # Build index ranges centered around each position + pos_idx_ranges = [[idx - half_tokens, idx + half_tokens] for idx in pos_indices] + + # Adjust the first range to avoid negative start index + pos_idx_ranges[0] = [ + -(half_tokens * 2 - pos_idx_ranges[1][0]), + pos_idx_ranges[1][0], + ] + + return pos_idx_ranges + + def split_tensor_with_padding(self, input_tensor, pos_idx_ranges, expand_length=0): + """ + Split the input tensor into subsequences based on index ranges, and apply right-side zero-padding + if the range exceeds the input boundaries. + + Args: + input_tensor (Tensor): Input audio tensor of shape [1, L, 768]. + pos_idx_ranges (list): A list of index ranges, e.g. [[-7, 1], [1, 9], ..., [165, 173]]. + expand_length (int): Number of tokens to expand on both sides of each subsequence. + + Returns: + sub_sequences (Tensor): A tensor of shape [1, F, L, 768], where L is the length after padding. + Each element is a padded subsequence. + k_lens (Tensor): A tensor of shape [F], representing the actual (unpadded) length of each subsequence. + Useful for ignoring padding tokens in attention masks. + """ + pos_idx_ranges = [ + [idx[0] - expand_length, idx[1] + expand_length] for idx in pos_idx_ranges + ] + sub_sequences = [] + seq_len = input_tensor.size(1) # 173 + max_valid_idx = seq_len - 1 # 172 + k_lens_list = [] + for start, end in pos_idx_ranges: + # Calculate the fill amount + pad_front = max(-start, 0) + pad_back = max(end - max_valid_idx, 0) + + # Calculate the start and end indices of the valid part + valid_start = max(start, 0) + valid_end = min(end, max_valid_idx) + + # Extract the valid part + if valid_start <= valid_end: + valid_part = input_tensor[:, valid_start : valid_end + 1, :] + else: + valid_part = input_tensor.new_zeros((1, 0, input_tensor.size(2))) + + # In the sequence dimension (the 1st dimension) perform padding + padded_subseq = F.pad( + valid_part, + (0, 0, 0, pad_back + pad_front, 0, 0), + mode="constant", + value=0, + ) + k_lens_list.append(padded_subseq.size(-2) - pad_back - pad_front) + + sub_sequences.append(padded_subseq) + return torch.stack(sub_sequences, dim=1), torch.tensor( + k_lens_list, dtype=torch.long + ) diff --git a/fantasytalking/utils.py b/fantasytalking/utils.py new file mode 100644 index 0000000..e044934 --- /dev/null +++ b/fantasytalking/utils.py @@ -0,0 +1,52 @@ +# Copyright Alibaba Inc. All Rights Reserved. + +import imageio +import librosa +import numpy as np +import torch +from PIL import Image +from tqdm import tqdm + + +def resize_image_by_longest_edge(image_path, target_size): + image = Image.open(image_path).convert("RGB") + width, height = image.size + scale = target_size / max(width, height) + new_size = (int(width * scale), int(height * scale)) + return image.resize(new_size, Image.LANCZOS) + + +def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None): + writer = imageio.get_writer( + save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params + ) + for frame in tqdm(frames, desc="Saving video"): + frame = np.array(frame) + writer.append_data(frame) + writer.close() + + +def get_audio_features(wav2vec, audio_processor, audio_path, fps, num_frames): + sr = 16000 + audio_input, sample_rate = librosa.load(audio_path, sr=sr) # ιζ ·ηδΈΊ 16kHz + + start_time = 0 + # end_time = (0 + (num_frames - 1) * 1) / fps + end_time = num_frames / fps + + start_sample = int(start_time * sr) + end_sample = int(end_time * sr) + + try: + audio_segment = audio_input[start_sample:end_sample] + except: + audio_segment = audio_input + + input_values = audio_processor( + audio_segment, sampling_rate=sample_rate, return_tensors="pt" + ).input_values.to("cuda") + + with torch.no_grad(): + fea = wav2vec(input_values).last_hidden_state + + return fea diff --git a/requirements.txt b/requirements.txt index d2fb5c8..f94ee4b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,7 +16,7 @@ gradio==5.23.0 numpy>=1.23.5,<2 einops moviepy==1.0.3 -mmgp==3.4.1 +mmgp==3.4.2 peft==0.14.0 mutagen pydantic==2.10.6 @@ -28,4 +28,5 @@ timm segment-anything omegaconf hydra-core +librosa # rembg==2.0.65 diff --git a/wan/configs/__init__.py b/wan/configs/__init__.py index e3f539b..d9a98bb 100644 --- a/wan/configs/__init__.py +++ b/wan/configs/__init__.py @@ -44,6 +44,8 @@ SUPPORTED_SIZES = { VACE_SIZE_CONFIGS = { '480*832': (480, 832), '832*480': (832, 480), + '720*1280': (720, 1280), + '1280*720': (1280, 720), } VACE_MAX_AREA_CONFIGS = { diff --git a/wan/diffusion_forcing.py b/wan/diffusion_forcing.py index b8abc10..af9627b 100644 --- a/wan/diffusion_forcing.py +++ b/wan/diffusion_forcing.py @@ -56,16 +56,18 @@ class DTT2V: vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype, device=self.device) - logging.info(f"Creating WanModel from {model_filename}") + logging.info(f"Creating WanModel from {model_filename[-1]}") from mmgp import offload # model_filename = "model.safetensors" - self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) #, forcedConfigPath="config.json") + # model_filename = "c:/temp/diffusion_pytorch_model-00001-of-00006.safetensors" + self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) # , forcedConfigPath="c:/temp/config _df720.json") # offload.load_model_data(self.model, "recam.ckpt") # self.model.cpu() - self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype, True) + # dtype = torch.float16 + self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype) offload.change_dtype(self.model, dtype, True) # offload.save_model(self.model, "sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors", config_file_path="config.json") - # offload.save_model(self.model, "sky_reels2_diffusion_forcing_720p_14B_quanto_xbf16_int8.safetensors", do_quantize= True, config_file_path="config.json") + # offload.save_model(self.model, "sky_reels2_diffusion_forcing_720p_14B_quanto_mbf16_int8.safetensors", do_quantize= True, config_file_path="c:/temp/config _df720.json") # offload.save_model(self.model, "rtfp16_int8.safetensors", do_quantize= "config.json") self.model.eval().requires_grad_(False) @@ -200,6 +202,9 @@ class DTT2V: fps: int = 24, VAE_tile_size = 0, joint_pass = False, + slg_layers = None, + slg_start = 0.0, + slg_end = 1.0, callback = None, ): self._interrupt = False @@ -211,6 +216,7 @@ class DTT2V: if ar_step == 0: causal_block_size = 1 + causal_attention = False i2v_extra_kwrags = {} prefix_video = None @@ -252,31 +258,33 @@ class DTT2V: prefix_video = output_video.to(self.device) else: causal_block_size = 1 + causal_attention = False ar_step = 0 prefix_video = image prefix_video = torch.tensor(prefix_video).unsqueeze(1) # .to(image_embeds.dtype).unsqueeze(1) if prefix_video.dtype == torch.uint8: prefix_video = (prefix_video.float() / (255.0 / 2.0)) - 1.0 prefix_video = prefix_video.to(self.device) - prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)] - predix_video_latent_length = prefix_video[0].shape[1] + prefix_video = self.vae.encode(prefix_video.unsqueeze(0))[0] # [(c, f, h, w)] + predix_video_latent_length = prefix_video.shape[1] truncate_len = predix_video_latent_length % causal_block_size if truncate_len != 0: if truncate_len == predix_video_latent_length: causal_block_size = 1 + causal_attention = False + ar_step = 0 else: print("the length of prefix video is truncated for the casual block size alignment.") predix_video_latent_length -= truncate_len - prefix_video[0] = prefix_video[0][:, : predix_video_latent_length] + prefix_video = prefix_video[:, : predix_video_latent_length] base_num_frames_iter = latent_length latent_shape = [16, base_num_frames_iter, latent_height, latent_width] latents = self.prepare_latents( latent_shape, dtype=torch.float32, device=self.device, generator=generator ) - latents = [latents] if prefix_video is not None: - latents[0][:, :predix_video_latent_length] = prefix_video[0].to(torch.float32) + latents[:, :predix_video_latent_length] = prefix_video.to(torch.float32) step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( base_num_frames_iter, init_timesteps, @@ -298,6 +306,8 @@ class DTT2V: if callback != None: callback(-1, None, True, override_num_inference_steps = updated_num_steps) if self.model.enable_teacache: + x_count = 2 if self.do_classifier_free_guidance else 1 + self.model.previous_residual = [None] * x_count time_steps_comb = [] self.model.num_steps = updated_num_steps for i, timestep_i in enumerate(step_matrix): @@ -309,7 +319,7 @@ class DTT2V: self.model.compute_teacache_threshold(self.model.teacache_start_step, time_steps_comb, self.model.teacache_multiplier) del time_steps_comb from mmgp import offload - freqs = get_rotary_pos_embed(latents[0].shape[1 :], enable_RIFLEx= False) + freqs = get_rotary_pos_embed(latents.shape[1 :], enable_RIFLEx= False) kwrags = { "freqs" :freqs, "fps" : fps_embeds, @@ -320,27 +330,27 @@ class DTT2V: } kwrags.update(i2v_extra_kwrags) - for i, timestep_i in enumerate(tqdm(step_matrix)): + kwrags["slg_layers"] = slg_layers if int(slg_start * updated_num_steps) <= i < int(slg_end * updated_num_steps) else None + offload.set_step_no_for_lora(self.model, i) update_mask_i = step_update_mask[i] valid_interval_start, valid_interval_end = valid_interval[i] timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() - latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()] + latent_model_input = latents[:, valid_interval_start:valid_interval_end, :, :].clone() if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: noise_factor = 0.001 * addnoise_condition timestep_for_noised_condition = addnoise_condition - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = ( - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] + latent_model_input[:, valid_interval_start:predix_video_latent_length] = ( + latent_model_input[:, valid_interval_start:predix_video_latent_length] * (1.0 - noise_factor) + torch.randn_like( - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] + latent_model_input[:, valid_interval_start:predix_video_latent_length] ) * noise_factor ) timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition kwrags.update({ - "x" : torch.stack([latent_model_input[0]]), "t" : timestep, "current_step" : i, }) @@ -349,6 +359,7 @@ class DTT2V: if True: if not self.do_classifier_free_guidance: noise_pred = self.model( + x=[latent_model_input], context=[prompt_embeds], **kwrags, )[0] @@ -358,6 +369,7 @@ class DTT2V: else: if joint_pass: noise_pred_cond, noise_pred_uncond = self.model( + x=[latent_model_input, latent_model_input], context= [prompt_embeds, negative_prompt_embeds], **kwrags, ) @@ -365,12 +377,16 @@ class DTT2V: return None else: noise_pred_cond = self.model( + x=[latent_model_input], + x_id=0, context=[prompt_embeds], **kwrags, )[0] if self._interrupt: return None noise_pred_uncond = self.model( + x=[latent_model_input], + x_id=1, context=[negative_prompt_embeds], **kwrags, )[0] @@ -380,18 +396,18 @@ class DTT2V: del noise_pred_cond, noise_pred_uncond for idx in range(valid_interval_start, valid_interval_end): if update_mask_i[idx].item(): - latents[0][:, idx] = sample_schedulers[idx].step( + latents[:, idx] = sample_schedulers[idx].step( noise_pred[:, idx - valid_interval_start], timestep_i[idx], - latents[0][:, idx], + latents[:, idx], return_dict=False, generator=generator, )[0] sample_schedulers_counter[idx] += 1 if callback is not None: - callback(i, latents[0].squeeze(0), False) + callback(i, latents.squeeze(0), False) - x0 = latents[0].unsqueeze(0) + x0 = latents.unsqueeze(0) videos = [self.vae.decode(x0, tile_size= VAE_tile_size)[0]] output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w return output_video diff --git a/wan/image2video.py b/wan/image2video.py index 9486b32..996d5b5 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -8,7 +8,7 @@ import sys import types from contextlib import contextmanager from functools import partial - +import json import numpy as np import torch import torch.cuda.amp as amp @@ -84,13 +84,29 @@ class WanI2V: config.clip_checkpoint), tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer)) - logging.info(f"Creating WanModel from {model_filename}") + logging.info(f"Creating WanModel from {model_filename[-1]}") from mmgp import offload - self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) - self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype, True) - offload.change_dtype(self.model, dtype, True) - # offload.save_model(self.model, "i2v_720p_fp16.safetensors",do_quantize=True) + # fantasy = torch.load("c:/temp/fantasy.ckpt") + # proj_model = fantasy["proj_model"] + # audio_processor = fantasy["audio_processor"] + # offload.safetensors2.torch_write_file(proj_model, "proj_model.safetensors") + # offload.safetensors2.torch_write_file(audio_processor, "audio_processor.safetensors") + # for k,v in audio_processor.items(): + # audio_processor[k] = v.to(torch.bfloat16) + # with open("fantasy_config.json", "r", encoding="utf-8") as reader: + # config_text = reader.read() + # config_json = json.loads(config_text) + # offload.safetensors2.torch_write_file(audio_processor, "audio_processor_bf16.safetensors", config=config_json) + # model_filename = [model_filename, "audio_processor_bf16.safetensors"] + # model_filename = "c:/temp/i2v480p/diffusion_pytorch_model-00001-of-00007.safetensors" + # dtype = torch.float16 + self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) #, forcedConfigPath= "c:/temp/i2v720p/config.json") + self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype) + # offload.change_dtype(self.model, dtype, True) + # offload.save_model(self.model, "wan2.1_image2video_720p_14B_mbf16.safetensors", config_file_path="c:/temp/i2v720p/config.json") + # offload.save_model(self.model, "wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors",do_quantize=True, config_file_path="c:/temp/i2v720p/config.json") + # offload.save_model(self.model, "wan2.1_image2video_720p_14B_quanto_mfp16_int8.safetensors",do_quantize=True, config_file_path="c:/temp/i2v720p/config.json") # offload.save_model(self.model, "wan2.1_Fun_InP_1.3B_bf16_bis.safetensors") self.model.eval().requires_grad_(False) @@ -102,6 +118,8 @@ class WanI2V: input_prompt, img, img2 = None, + height =720, + width = 1280, max_area=720 * 1280, frame_num=81, shift=5.0, @@ -119,7 +137,11 @@ class WanI2V: slg_end = 1.0, cfg_star_switch = True, cfg_zero_step = 5, - add_frames_for_end_image = True + add_frames_for_end_image = True, + audio_scale=None, + audio_cfg_scale=None, + audio_proj=None, + audio_context_lens=None, ): r""" Generates video frames from input image and text prompt using diffusion process. @@ -167,17 +189,25 @@ class WanI2V: frame_num +=1 lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2) + h, w = img.shape[1:] - aspect_ratio = h / w + # aspect_ratio = h / w + + scale1 = min(height / h, width / w) + scale2 = min(height / h, width / w) + scale = max(scale1, scale2) + new_height = int(h * scale) + new_width = int(w * scale) + lat_h = round( - np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] // + new_height // self.vae_stride[1] // self.patch_size[1] * self.patch_size[1]) lat_w = round( - np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] // + new_width // self.vae_stride[2] // self.patch_size[2] * self.patch_size[2]) h = lat_h * self.vae_stride[1] w = lat_w * self.vae_stride[2] - + clip_image_size = self.clip.model.image_size img_interpolated = resize_lanczos(img, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device) #, self.dtype img = resize_lanczos(img, clip_image_size, clip_image_size) @@ -271,98 +301,101 @@ class WanI2V: # sample videos latent = noise - batch_size = latent.shape[0] + batch_size = 1 freqs = get_rotary_pos_embed(latent.shape[1:], enable_RIFLEx= enable_RIFLEx) - arg_c = { - 'context': [context], - 'clip_fea': clip_context, - 'y': [y], - 'freqs' : freqs, - 'pipeline' : self, - 'callback' : callback - } + kwargs = { 'clip_fea': clip_context, 'y': y, 'freqs' : freqs, 'pipeline' : self, 'callback' : callback } - arg_null = { - 'context': [context_null], - 'clip_fea': clip_context, - 'y': [y], - 'freqs' : freqs, - 'pipeline' : self, - 'callback' : callback - } - - arg_both= { - 'context': [context, context_null], - 'clip_fea': clip_context, - 'y': [y], - 'freqs' : freqs, - 'pipeline' : self, - 'callback' : callback - } + if audio_proj != None: + kwargs.update({ + "audio_proj": audio_proj.to(self.dtype), + "audio_context_lens": audio_context_lens, + }) if self.model.enable_teacache: + self.model.previous_residual = [None] * (3 if audio_cfg_scale !=None else 2) self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier) # self.model.to(self.device) if callback != None: callback(-1, None, True) - + latent = latent.to(self.device) for i, t in enumerate(tqdm(timesteps)): offload.set_step_no_for_lora(self.model, i) - slg_layers_local = None - if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps): - slg_layers_local = slg_layers - - latent_model_input = [latent.to(self.device)] + kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None + latent_model_input = latent timestep = [t] timestep = torch.stack(timestep).to(self.device) + kwargs.update({ + 't' :timestep, + 'current_step' :i, + }) + + if joint_pass: - noise_pred_cond, noise_pred_uncond = self.model( - latent_model_input, t=timestep, current_step=i, slg_layers=slg_layers_local, **arg_both) + if audio_proj == None: + noise_pred_cond, noise_pred_uncond = self.model( + [latent_model_input, latent_model_input], + context=[context, context_null], + **kwargs) + else: + noise_pred_cond, noise_pred_noaudio, noise_pred_uncond = self.model( + [latent_model_input, latent_model_input, latent_model_input], + context=[context, context, context_null], + audio_scale = [audio_scale, None, None ], + **kwargs) + if self._interrupt: return None else: noise_pred_cond = self.model( - latent_model_input, - t=timestep, - current_step=i, - is_uncond=False, - **arg_c, + [latent_model_input], + context=[context], + audio_scale = None if audio_scale == None else [audio_scale], + x_id=0, + **kwargs, )[0] if self._interrupt: - return None + return None + + if audio_proj != None: + noise_pred_noaudio = self.model( + [latent_model_input], + x_id=1, + context=[context], + **kwargs, + )[0] + if self._interrupt: + return None + noise_pred_uncond = self.model( - latent_model_input, - t=timestep, - current_step=i, - is_uncond=True, - slg_layers=slg_layers_local, - **arg_null, + [latent_model_input], + x_id=1 if audio_scale == None else 2, + context=[context_null], + **kwargs, )[0] if self._interrupt: return None del latent_model_input # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/ - noise_pred_text = noise_pred_cond if cfg_star_switch: - positive_flat = noise_pred_text.view(batch_size, -1) + positive_flat = noise_pred_cond.view(batch_size, -1) negative_flat = noise_pred_uncond.view(batch_size, -1) alpha = optimized_scale(positive_flat,negative_flat) alpha = alpha.view(batch_size, 1, 1, 1) - if (i <= cfg_zero_step): - noise_pred = noise_pred_text*0. # it would be faster not to compute noise_pred... + noise_pred = noise_pred_cond*0. # it would be faster not to compute noise_pred... else: noise_pred_uncond *= alpha - noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond) - - del noise_pred_uncond - + if audio_scale == None: + noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond) + else: + noise_pred = noise_pred_uncond + guide_scale * (noise_pred_noaudio - noise_pred_uncond) + audio_cfg_scale * (noise_pred_cond - noise_pred_noaudio) + noise_pred_uncond, noise_pred_noaudio = None, None temp_x0 = sample_scheduler.step( noise_pred.unsqueeze(0), t, @@ -376,9 +409,6 @@ class WanI2V: if callback is not None: callback(i, latent, False) - - # x0 = [latent.to(self.device, dtype=self.dtype)] - x0 = [latent] # x0 = [lat_y] diff --git a/wan/modules/attention.py b/wan/modules/attention.py index e70c098..868edbe 100644 --- a/wan/modules/attention.py +++ b/wan/modules/attention.py @@ -57,15 +57,15 @@ def sageattn_wrapper( ): q,k, v = qkv_list padding_length = q.shape[0] -attention_length - q = q[:attention_length, :, : ].unsqueeze(0) - k = k[:attention_length, :, : ].unsqueeze(0) - v = v[:attention_length, :, : ].unsqueeze(0) + q = q[:attention_length, :, : ] + k = k[:attention_length, :, : ] + v = v[:attention_length, :, : ] if True: qkv_list = [q,k,v] del q, k ,v - o = alt_sageattn(qkv_list, tensor_layout="NHD").squeeze(0) + o = alt_sageattn(qkv_list, tensor_layout="NHD") else: - o = sageattn(q, k, v, tensor_layout="NHD").squeeze(0) + o = sageattn(q, k, v, tensor_layout="NHD") del q, k ,v qkv_list.clear() @@ -107,14 +107,14 @@ def sdpa_wrapper( attention_length ): q,k, v = qkv_list - padding_length = q.shape[0] -attention_length - q = q[:attention_length, :].transpose(0,1).unsqueeze(0) - k = k[:attention_length, :].transpose(0,1).unsqueeze(0) - v = v[:attention_length, :].transpose(0,1).unsqueeze(0) + padding_length = q.shape[1] -attention_length + q = q[:attention_length, :].transpose(1,2) + k = k[:attention_length, :].transpose(1,2) + v = v[:attention_length, :].transpose(1,2) o = F.scaled_dot_product_attention( q, k, v, attn_mask=None, is_causal=False - ).squeeze(0).transpose(0,1) + ).transpose(1,2) del q, k ,v qkv_list.clear() @@ -159,36 +159,72 @@ def pay_attention( deterministic=False, version=None, force_attention= None, - cross_attn= False + cross_attn= False, + k_lens = None ): attn = offload.shared_state["_attention"] if force_attention== None else force_attention q,k,v = qkv_list qkv_list.clear() - # params b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype - assert b==1 - q = q.squeeze(0) - k = k.squeeze(0) - v = v.squeeze(0) - q = q.to(v.dtype) k = k.to(v.dtype) - - # if q_scale is not None: - # q = q * q_scale - + if b > 0 and k_lens != None and attn in ("sage2", "sdpa"): + # Poor's man var len attention + chunk_sizes = [] + k_sizes = [] + current_size = k_lens[0] + current_count= 1 + for k_len in k_lens[1:]: + if k_len == current_size: + current_count += 1 + else: + chunk_sizes.append(current_count) + k_sizes.append(current_size) + current_count = 1 + current_size = k_len + chunk_sizes.append(current_count) + k_sizes.append(k_len) + if len(chunk_sizes) > 1 or k_lens[0] != k.shape[1]: + q_chunks =torch.split(q, chunk_sizes) + k_chunks =torch.split(k, chunk_sizes) + v_chunks =torch.split(v, chunk_sizes) + q, k, v = None, None, None + k_chunks = [ u[:, :sz] for u, sz in zip(k_chunks, k_sizes)] + v_chunks = [ u[:, :sz] for u, sz in zip(v_chunks, k_sizes)] + o = [] + for sub_q, sub_k, sub_v in zip(q_chunks, k_chunks, v_chunks): + qkv_list = [sub_q, sub_k, sub_v] + sub_q, sub_k, sub_v = None, None, None + o.append( pay_attention(qkv_list) ) + q_chunks, k_chunks, v_chunks = None, None, None + o = torch.cat(o, dim = 0) + return o if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE: warnings.warn( 'Flash attention 3 is not available, use flash attention 2 instead.' ) if attn=="sage" or attn=="flash": - cu_seqlens_q = torch.tensor([0, lq], dtype=torch.int32, device="cuda") - cu_seqlens_k = torch.tensor([0, lk], dtype=torch.int32, device="cuda") + if b != 1 : + if k_lens == None: + k_lens = torch.tensor( [lk] * b, dtype=torch.int32).to(device=q.device, non_blocking=True) + k = torch.cat([u[:v] for u, v in zip(k, k_lens)]) + v = torch.cat([u[:v] for u, v in zip(v, k_lens)]) + q = q.reshape(-1, *q.shape[-2:]) + q_lens = torch.tensor([lq] * b, dtype=torch.int32).to(device=q.device, non_blocking=True) + cu_seqlens_q=torch.cat([k_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32) + cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32) + else: + cu_seqlens_q = torch.tensor([0, lq], dtype=torch.int32, device="cuda") + cu_seqlens_k = torch.tensor([0, lk], dtype=torch.int32, device="cuda") + q = q.squeeze(0) + k = k.squeeze(0) + v = v.squeeze(0) + # apply attention if attn=="sage": @@ -207,7 +243,7 @@ def pay_attention( qkv_list = [q,k,v] del q,k,v - x = sageattn_wrapper(qkv_list, lq).unsqueeze(0) + x = sageattn_wrapper(qkv_list, lq) #.unsqueeze(0) # else: # layer = offload.shared_state["layer"] # embed_sizes = offload.shared_state["embed_sizes"] @@ -267,8 +303,8 @@ def pay_attention( elif attn=="sdpa": qkv_list = [q, k, v] - del q, k , v - x = sdpa_wrapper( qkv_list, lq).unsqueeze(0) + del q ,k ,v + x = sdpa_wrapper( qkv_list, lq) #.unsqueeze(0) elif attn=="flash" and version == 3: # Note: dropout_p, window_size are not supported in FA3 now. x = flash_attn_interface.flash_attn_varlen_func( @@ -302,59 +338,11 @@ def pay_attention( # output elif attn=="xformers": - x = memory_efficient_attention( - q.unsqueeze(0), - k.unsqueeze(0), - v.unsqueeze(0), - ) #.unsqueeze(0) + from xformers.ops.fmha.attn_bias import BlockDiagonalPaddedKeysMask + if b != 1 and k_lens != None: + attn_mask = BlockDiagonalPaddedKeysMask.from_seqlens([lq] * b , lk, list(k_lens) ) + x = memory_efficient_attention(q, k, v, attn_bias= attn_mask ) + else: + x = memory_efficient_attention(q, k, v ) - return x.type(out_dtype) - - -def attention( - q, - k, - v, - q_lens=None, - k_lens=None, - dropout_p=0., - softmax_scale=None, - q_scale=None, - causal=False, - window_size=(-1, -1), - deterministic=False, - dtype=torch.bfloat16, - fa_version=None, -): - if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE: - return pay_attention( - q=q, - k=k, - v=v, - q_lens=q_lens, - k_lens=k_lens, - dropout_p=dropout_p, - softmax_scale=softmax_scale, - q_scale=q_scale, - causal=causal, - window_size=window_size, - deterministic=deterministic, - dtype=dtype, - version=fa_version, - ) - else: - if q_lens is not None or k_lens is not None: - warnings.warn( - 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.' - ) - attn_mask = None - - q = q.transpose(1, 2).to(dtype) - k = k.transpose(1, 2).to(dtype) - v = v.transpose(1, 2).to(dtype) - - out = torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p) - - out = out.transpose(1, 2).contiguous() - return out + return x.type(out_dtype) \ No newline at end of file diff --git a/wan/modules/model.py b/wan/modules/model.py index 3e65a01..1fb03e1 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -197,9 +197,9 @@ class WanSelfAttention(nn.Module): del q,k q,k = apply_rotary_emb(qklist, freqs, head_first=False) - qkv_list = [q,k,v] - del q,k,v if block_mask == None: + qkv_list = [q,k,v] + del q,k,v x = pay_attention( qkv_list, window_size=self.window_size) @@ -212,6 +212,7 @@ class WanSelfAttention(nn.Module): .transpose(1, 2) .contiguous() ) + del q,k,v # if not self._flag_ar_attention: # q = rope_apply(q, grid_sizes, freqs) @@ -241,7 +242,7 @@ class WanSelfAttention(nn.Module): class WanT2VCrossAttention(WanSelfAttention): - def forward(self, xlist, context): + def forward(self, xlist, context, grid_sizes, *args, **kwargs): r""" Args: x(Tensor): Shape [B, L1, C] @@ -262,6 +263,7 @@ class WanT2VCrossAttention(WanSelfAttention): v = self.v(context).view(b, -1, n, d) # compute attention + v = v.contiguous().clone() qvl_list=[q, k, v] del q, k, v x = pay_attention(qvl_list, cross_attn= True) @@ -287,7 +289,7 @@ class WanI2VCrossAttention(WanSelfAttention): # self.alpha = nn.Parameter(torch.zeros((1, ))) self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() - def forward(self, xlist, context): + def forward(self, xlist, context, grid_sizes, audio_proj, audio_scale, audio_context_lens ): r""" Args: x(Tensor): Shape [B, L1, C] @@ -310,6 +312,8 @@ class WanI2VCrossAttention(WanSelfAttention): del x self.norm_q(q) q= q.view(b, -1, n, d) + if audio_scale != None: + audio_x = self.processor(q, audio_proj, grid_sizes[0], audio_context_lens) k = self.k(context) self.norm_k(k) k = k.view(b, -1, n, d) @@ -334,6 +338,8 @@ class WanI2VCrossAttention(WanSelfAttention): img_x = img_x.flatten(2) x += img_x del img_x + if audio_scale != None: + x.add_(audio_x, alpha= audio_scale) x = self.o(x) return x @@ -398,7 +404,10 @@ class WanAttentionBlock(nn.Module): hints= None, context_scale=1.0, cam_emb= None, - block_mask = None + block_mask = None, + audio_proj= None, + audio_context_lens= None, + audio_scale=None, ): r""" Args: @@ -433,7 +442,7 @@ class WanAttentionBlock(nn.Module): if cam_emb != None: cam_emb = self.cam_encoder(cam_emb) cam_emb = cam_emb.repeat(1, 2, 1) - cam_emb = cam_emb.unsqueeze(2).unsqueeze(3).repeat(1, 1, grid_sizes[0][1], grid_sizes[0][2], 1) + cam_emb = cam_emb.unsqueeze(2).unsqueeze(3).repeat(1, 1, grid_sizes[1], grid_sizes[2], 1) cam_emb = rearrange(cam_emb, 'b f h w d -> b (f h w) d') x_mod += cam_emb @@ -453,7 +462,7 @@ class WanAttentionBlock(nn.Module): y = y.to(attention_dtype) ylist= [y] del y - x += self.cross_attn(ylist, context).to(dtype) + x += self.cross_attn(ylist, context, grid_sizes, audio_proj, audio_scale, audio_context_lens).to(dtype) y = self.norm2(x) @@ -610,6 +619,7 @@ class WanModel(ModelMixin, ConfigMixin): eps=1e-6, recammaster = False, inject_sample_info = False, + fantasytalking_dim = 0, ): r""" Initialize the diffusion model backbone. @@ -742,43 +752,48 @@ class WanModel(ModelMixin, ConfigMixin): block.projector.weight = nn.Parameter(torch.eye(dim)) block.projector.bias = nn.Parameter(torch.zeros(dim)) + if fantasytalking_dim > 0: + from fantasytalking.model import WanCrossAttentionProcessor + for block in self.blocks: + block.cross_attn.processor = WanCrossAttentionProcessor(fantasytalking_dim, dim) + + + def lock_layers_dtypes(self, hybrid_dtype = None, dtype = torch.float32): + layer_list = [self.head, self.head.head, self.patch_embedding] + target_dype= dtype + + layer_list2 = [ self.time_embedding, self.time_embedding[0], self.time_embedding[2], + self.time_projection, self.time_projection[1]] #, self.text_embedding, self.text_embedding[0], self.text_embedding[2] ] + + for block in self.blocks: + layer_list2 += [block.norm3] - def lock_layers_dtypes(self, dtype = torch.float32, force = False): - count = 0 - layer_list = [self.head, self.head.head, self.patch_embedding, self.time_embedding, self.time_embedding[0], self.time_embedding[2], - self.time_projection, self.time_projection[1]] #, self.text_embedding, self.text_embedding[0], self.text_embedding[2] ] if hasattr(self, "fps_embedding"): - layer_list += [self.fps_embedding, self.fps_projection, self.fps_projection[0], self.fps_projection[2]] + layer_list2 += [self.fps_embedding, self.fps_projection, self.fps_projection[0], self.fps_projection[2]] if hasattr(self, "vace_patch_embedding"): - layer_list += [self.vace_patch_embedding] - layer_list += [self.vace_blocks[0].before_proj] + layer_list2 += [self.vace_patch_embedding] + layer_list2 += [self.vace_blocks[0].before_proj] for block in self.vace_blocks: - layer_list += [block.after_proj, block.norm3] + layer_list2 += [block.after_proj, block.norm3] + + target_dype2 = hybrid_dtype if hybrid_dtype != None else dtype # cam master if hasattr(self.blocks[0], "projector"): for block in self.blocks: - layer_list += [block.projector] + layer_list2 += [block.projector] - for block in self.blocks: - layer_list += [block.norm3] - for layer in layer_list: - if hasattr(layer, "weight"): - if layer.weight.dtype == dtype : - count += 1 - elif force: - if hasattr(layer, "weight"): - layer.weight.data = layer.weight.data.to(dtype) + for current_layer_list, current_dtype in zip([layer_list, layer_list2], [target_dype, target_dype2]): + for layer in current_layer_list: + layer._lock_dtype = dtype + + if hasattr(layer, "weight") and layer.weight.dtype != current_dtype : + layer.weight.data = layer.weight.data.to(current_dtype) if hasattr(layer, "bias"): - layer.bias.data = layer.bias.data.to(dtype) - count += 1 + layer.bias.data = layer.bias.data.to(current_dtype) - layer._lock_dtype = dtype - - - if count > 0: - self._lock_dtype = dtype + self._lock_dtype = dtype def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0): @@ -788,7 +803,7 @@ class WanModel(ModelMixin, ConfigMixin): t = torch.stack([t]) time_emb = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype) ) # b, dim e_list.append(time_emb) - + best_deltas = None best_threshold = 0.01 best_diff = 1000 best_signed_diff = 1000 @@ -798,12 +813,16 @@ class WanModel(ModelMixin, ConfigMixin): accumulated_rel_l1_distance =0 nb_steps = 0 diff = 1000 + deltas = [] for i, t in enumerate(timesteps): skip = False - if not (i<=start_step or i== len(timesteps)): - accumulated_rel_l1_distance += abs(rescale_func(((e_list[i]-e_list[i-1]).abs().mean() / e_list[i-1].abs().mean()).cpu().item())) + if not (i<=start_step or i== len(timesteps)-1): + delta = abs(rescale_func(((e_list[i]-e_list[i-1]).abs().mean() / e_list[i-1].abs().mean()).cpu().item())) + # deltas.append(delta) + accumulated_rel_l1_distance += delta if accumulated_rel_l1_distance < threshold: skip = True + # deltas.append("SKIP") else: accumulated_rel_l1_distance = 0 if not skip: @@ -812,6 +831,7 @@ class WanModel(ModelMixin, ConfigMixin): diff = abs(signed_diff) if diff < best_diff: best_threshold = threshold + best_deltas = deltas best_diff = diff best_signed_diff = signed_diff elif diff > best_diff: @@ -819,6 +839,7 @@ class WanModel(ModelMixin, ConfigMixin): threshold += 0.01 self.rel_l1_thresh = best_threshold print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}") + # print(f"deltas:{best_deltas}") return best_threshold @@ -834,7 +855,7 @@ class WanModel(ModelMixin, ConfigMixin): freqs = None, pipeline = None, current_step = 0, - is_uncond=False, + x_id= 0, max_steps = 0, slg_layers=None, callback = None, @@ -842,10 +863,13 @@ class WanModel(ModelMixin, ConfigMixin): fps = None, causal_block_size = 1, causal_attention = False, - x_neg = None + audio_proj=None, + audio_context_lens=None, + audio_scale=None, + ): - # dtype = self.blocks[0].self_attn.q.weight.dtype - dtype = self.patch_embedding.weight.dtype + # patch_dtype = self.patch_embedding.weight.dtype + modulation_dtype = self.time_projection[1].weight.dtype if self.model_type == 'i2v': assert clip_fea is not None and y is not None @@ -854,20 +878,32 @@ class WanModel(ModelMixin, ConfigMixin): if torch.is_tensor(freqs) and freqs.device != device: freqs = freqs.to(device) - if y is not None: - x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] - # embeddings - x = [self.patch_embedding(u.unsqueeze(0)).to(dtype) for u in x] - if x_neg !=None: - x_neg = [self.patch_embedding(u.unsqueeze(0)).to(dtype) for u in x_neg] + x_list = x + joint_pass = len(x_list) > 1 + is_source_x = [ x.data_ptr() == x_list[0].data_ptr() and i > 0 for i, x in enumerate(x_list) ] + last_x_idx = 0 + for i, (is_source, x) in enumerate(zip(is_source_x, x_list)): + if is_source: + x_list[i] = x_list[0].clone() + last_x_idx = i + else: + # image source + if y is not None: + x = torch.cat([x, y], dim=0) + # embeddings + x = self.patch_embedding(x.unsqueeze(0)).to(modulation_dtype) + grid_sizes = x.shape[2:] + x = x.flatten(2).transpose(1, 2) + x_list[i] = x + x, y = None, None - grid_sizes = [ list(u.shape[2:]) for u in x] - embed_sizes = grid_sizes[0] - if causal_attention : #causal_block_size > 0: - frame_num = embed_sizes[0] - height = embed_sizes[1] - width = embed_sizes[2] + + block_mask = None + if causal_attention and causal_block_size > 0 and False: # NEVER WORKED + frame_num = grid_sizes[0] + height = grid_sizes[1] + width = grid_sizes[2] block_num = frame_num // causal_block_size range_tensor = torch.arange(block_num).view(-1, 1) range_tensor = range_tensor.repeat(1, causal_block_size).flatten() @@ -878,30 +914,21 @@ class WanModel(ModelMixin, ConfigMixin): block_mask = causal_mask.unsqueeze(0).unsqueeze(0) del causal_mask - offload.shared_state["embed_sizes"] = embed_sizes + offload.shared_state["embed_sizes"] = grid_sizes offload.shared_state["step_no"] = current_step offload.shared_state["max_steps"] = max_steps - x = [u.flatten(2).transpose(1, 2) for u in x] - x = x[0] - if x_neg !=None: - x_neg = [u.flatten(2).transpose(1, 2) for u in x_neg] - x_neg = x_neg[0] + _flag_df = t.dim() == 2 - if t.dim() == 2: - b, f = t.shape - _flag_df = True - else: - _flag_df = False e = self.time_embedding( - sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype) # self.patch_embedding.weight.dtype) + sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(modulation_dtype) # self.patch_embedding.weight.dtype) ) # b, dim e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype) if self.inject_sample_info: fps = torch.tensor(fps, dtype=torch.long, device=device) - fps_emb = self.fps_embedding(fps).to(dtype) # float() + fps_emb = self.fps_embedding(fps).to(e.dtype) if _flag_df: e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1) else: @@ -913,30 +940,28 @@ class WanModel(ModelMixin, ConfigMixin): if clip_fea is not None: context_clip = self.img_emb(clip_fea) # bs x 257 x dim context = [ torch.cat( [context_clip, u ], dim=1 ) for u in context ] - - joint_pass = len(context) > 0 - x_list = [x] - if joint_pass: - if x_neg == None: - x_list += [x.clone() for i in range(len(context) - 1) ] - else: - x_list += [x.clone() for i in range(len(context) - 2) ] + [x_neg] - is_uncond = False - del x + context_list = context + if audio_scale != None: + audio_scale_list = audio_scale + else: + audio_scale_list = [None] * len(x_list) # arguments kwargs = dict( grid_sizes=grid_sizes, freqs=freqs, - cam_emb = cam_emb + cam_emb = cam_emb, + block_mask = block_mask, + audio_proj=audio_proj, + audio_context_lens=audio_context_lens, ) if vace_context == None: hints_list = [None ] *len(x_list) else: - # embeddings + # Vace embeddings c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] c = [u.flatten(2).transpose(1, 2) for u in c] c = c[0] @@ -947,7 +972,7 @@ class WanModel(ModelMixin, ConfigMixin): should_calc = True if self.enable_teacache: - if is_uncond: + if x_id != 0: should_calc = self.should_calc else: if current_step <= self.teacache_start_step or current_step == self.num_steps-1: @@ -955,11 +980,12 @@ class WanModel(ModelMixin, ConfigMixin): self.accumulated_rel_l1_distance = 0 else: rescale_func = np.poly1d(self.coefficients) - self.accumulated_rel_l1_distance += abs(rescale_func(((e-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())) + delta = abs(rescale_func(((e-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())) + self.accumulated_rel_l1_distance += delta if self.accumulated_rel_l1_distance < self.rel_l1_thresh: should_calc = False self.teacache_skipped_steps += 1 - # print(f"Teacache Skipped Step:{self.teacache_skipped_steps}/{current_step}" ) + # print(f"Teacache Skipped Step no {current_step} ({self.teacache_skipped_steps}/{current_step}), delta={delta}" ) else: should_calc = True self.accumulated_rel_l1_distance = 0 @@ -967,15 +993,23 @@ class WanModel(ModelMixin, ConfigMixin): self.should_calc = should_calc if not should_calc: - for i, x in enumerate(x_list): - x += self.previous_residual_uncond if i==1 or is_uncond else self.previous_residual_cond + if joint_pass: + for i, x in enumerate(x_list): + x += self.previous_residual[i] + else: + x = x_list[0] + x += self.previous_residual[x_id] + x = None else: if self.enable_teacache: - if joint_pass or is_uncond: - self.previous_residual_uncond = None - if joint_pass or not is_uncond: - self.previous_residual_cond = None - ori_hidden_states = x_list[0].clone() + if joint_pass: + self.previous_residual = [ None ] * len(self.previous_residual) + else: + self.previous_residual[x_id] = None + ori_hidden_states = [ None ] * len(x_list) + ori_hidden_states[0] = x_list[0].clone() + for i in range(1, len(x_list)): + ori_hidden_states[i] = ori_hidden_states[0] if is_source_x[i] else x_list[i].clone() for block_idx, block in enumerate(self.blocks): offload.shared_state["layer"] = block_idx @@ -984,29 +1018,30 @@ class WanModel(ModelMixin, ConfigMixin): if pipeline._interrupt: return [None] * len(x_list) - if slg_layers is not None and block_idx in slg_layers: - if is_uncond and not joint_pass: + if (x_id != 0 or joint_pass) and slg_layers is not None and block_idx in slg_layers: + if not joint_pass: continue x_list[0] = block(x_list[0], context = context_list[0], e= e0, **kwargs) - else: - for i, (x, context, hints) in enumerate(zip(x_list, context_list, hints_list)): - x_list[i] = block(x, context = context, hints= hints, e= e0, **kwargs) + for i, (x, context, hints, audio_scale) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list)): + x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, e= e0, **kwargs) del x del context, hints if self.enable_teacache: if joint_pass: - self.previous_residual_cond = torch.sub(x_list[0], ori_hidden_states) - self.previous_residual_uncond = ori_hidden_states - torch.sub(x_list[1], ori_hidden_states, out=self.previous_residual_uncond) + for i, (x, ori, is_source) in enumerate(zip(x_list, ori_hidden_states, is_source_x)) : + if i == 0 or is_source and i != last_x_idx : + self.previous_residual[i] = torch.sub(x, ori) + else: + self.previous_residual[i] = ori + torch.sub(x, ori, out=self.previous_residual[i]) + ori_hidden_states[i] = None + x , ori = None, None else: - residual = ori_hidden_states # just to have a readable code - torch.sub(x_list[0], ori_hidden_states, out=residual) - if i==1 or is_uncond: - self.previous_residual_uncond = residual - else: - self.previous_residual_cond = residual + residual = ori_hidden_states[0] # just to have a readable code + torch.sub(x_list[0], ori_hidden_states[0], out=residual) + self.previous_residual[x_id] = residual residual, ori_hidden_states = None, None for i, x in enumerate(x_list): @@ -1037,10 +1072,10 @@ class WanModel(ModelMixin, ConfigMixin): c = self.out_dim out = [] - for u, v in zip(x, grid_sizes): - u = u[:math.prod(v)].view(*v, *self.patch_size, c) + for u in x: + u = u[:math.prod(grid_sizes)].view(*grid_sizes, *self.patch_size, c) u = torch.einsum('fhwpqrc->cfphqwr', u) - u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) + u = u.reshape(c, *[i * j for i, j in zip(grid_sizes, self.patch_size)]) out.append(u) return out diff --git a/wan/modules/sage2_core.py b/wan/modules/sage2_core.py index e023a28..646c3c9 100644 --- a/wan/modules/sage2_core.py +++ b/wan/modules/sage2_core.py @@ -140,7 +140,7 @@ def sageattn( elif arch == "sm90": return sageattn_qk_int8_pv_fp8_cuda_sm90(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp32") elif arch == "sm120": - return sageattn_qk_int8_pv_fp8_cuda(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, qk_quant_gran="per_warp", sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32") # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120. + return sageattn_qk_int8_pv_fp8_cuda(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, qk_quant_gran="per_warp", sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32", smooth_v= True) # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120. else: raise ValueError(f"Unsupported CUDA architecture: {arch}") diff --git a/wan/text2video.py b/wan/text2video.py index 4f89787..e85aaca 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -78,15 +78,16 @@ class WanT2V: vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype, device=self.device) - logging.info(f"Creating WanModel from {model_filename}") + logging.info(f"Creating WanModel from {model_filename[-1]}") from mmgp import offload # model_filename + self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False ) #, forcedConfigPath= "e:/vace_config.json") # offload.load_model_data(self.model, "e:/vace.safetensors") # offload.load_model_data(self.model, "c:/temp/Phantom-Wan-1.3B.pth") # self.model.to(torch.bfloat16) # self.model.cpu() - self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype, True) + self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype) offload.change_dtype(self.model, dtype, True) # offload.save_model(self.model, "mvace.safetensors", config_file_path="e:/vace_config.json") # offload.save_model(self.model, "phantom_1.3B.safetensors") @@ -95,7 +96,7 @@ class WanT2V: self.sample_neg_prompt = config.sample_neg_prompt - if "Vace" in model_filename: + if "Vace" in model_filename[-1]: self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]), min_area=480*832, max_area=480*832, @@ -107,7 +108,7 @@ class WanT2V: self.adapt_vace_model() - def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0): + def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0, overlapped_latents = 0, overlap_noise = 0): if ref_images is None: ref_images = [None] * len(frames) else: @@ -119,6 +120,11 @@ class WanT2V: inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)] reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)] inactive = self.vae.encode(inactive, tile_size = tile_size) + # inactive = [ t * (1.0 - noise_factor) + torch.randn_like(t ) * noise_factor for t in inactive] + # if overlapped_latents > 0: + # for t in inactive: + # t[:, :overlapped_latents ] = t[:, :overlapped_latents ] * (1.0 - noise_factor) + torch.randn_like(t[:, :overlapped_latents ] ) * noise_factor + reactive = self.vae.encode(reactive, tile_size = tile_size) latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)] @@ -288,7 +294,10 @@ class WanT2V: slg_end = 1.0, cfg_star_switch = True, cfg_zero_step = 5, - ): + overlapped_latents = 0, + overlap_noise = 0, + vace = False + ): r""" Generates video frames from text prompt using diffusion process. @@ -343,20 +352,20 @@ class WanT2V: size = (source_video.shape[2], source_video.shape[1]) source_video = source_video.to(dtype=self.dtype , device=self.device) source_video = source_video.permute(3, 0, 1, 2).div_(127.5).sub_(1.) - source_latents = self.vae.encode([source_video]) #.to(dtype=self.dtype, device=self.device) + source_latents = self.vae.encode([source_video])[0] #.to(dtype=self.dtype, device=self.device) del source_video # Process target camera (recammaster) from wan.utils.cammmaster_tools import get_camera_embedding cam_emb = get_camera_embedding(target_camera) cam_emb = cam_emb.to(dtype=self.dtype, device=self.device) - if input_frames != None: + if vace : # vace context encode input_frames = [u.to(self.device) for u in input_frames] input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images] input_masks = [u.to(self.device) for u in input_masks] - z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size) + z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents, overlap_noise = overlap_noise ) m0 = self.vace_encode_masks(input_masks, input_ref_images) z = self.vace_latent(z0, m0) @@ -365,10 +374,10 @@ class WanT2V: else: if input_ref_images != None: # Phantom Ref images phantom = True - input_ref_images = [self.get_vae_latents(input_ref_images, self.device)] - input_ref_images_neg = [torch.zeros_like(input_ref_images[0])] + input_ref_images = self.get_vae_latents(input_ref_images, self.device) + input_ref_images_neg = torch.zeros_like(input_ref_images) F = frame_num - target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1 + (input_ref_images[0].shape[1] if input_ref_images != None else 0), + target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1 + (input_ref_images.shape[1] if input_ref_images != None else 0), size[1] // self.vae_stride[1], size[0] // self.vae_stride[2]) @@ -405,37 +414,48 @@ class WanT2V: raise NotImplementedError("Unsupported solver.") # sample videos - latents = noise + latents = noise[0] del noise - batch_size =len(latents) + batch_size = 1 if target_camera != None: - shape = list(latents[0].shape[1:]) + shape = list(latents.shape[1:]) shape[0] *= 2 freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False) else: - freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx) + freqs = get_rotary_pos_embed(latents.shape[1:], enable_RIFLEx= enable_RIFLEx) kwargs = {'freqs': freqs, 'pipeline': self, 'callback': callback} if target_camera != None: kwargs.update({'cam_emb': cam_emb}) - if input_frames != None: + if vace: + ref_images_count = len(input_ref_images[0]) if input_ref_images != None else 0 kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale}) + if overlapped_latents > 0: + z_reactive = [ zz[0:16, ref_images_count:overlapped_latents + ref_images_count].clone() for zz in z] if self.model.enable_teacache: + x_count = 3 if phantom else 2 + self.model.previous_residual = [None] * x_count self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier) if callback != None: callback(-1, None, True) for i, t in enumerate(tqdm(timesteps)): + if vace and overlapped_latents > 0 : + # noise_factor = overlap_noise *(i/(len(timesteps)-1)) / 1000 + noise_factor = overlap_noise / 1000 # * (999-t) / 999 + # noise_factor = overlap_noise / 1000 # * t / 999 + for zz, zz_r in zip(z, z_reactive): + zz[0:16, ref_images_count:overlapped_latents + ref_images_count] = zz_r * (1.0 - noise_factor) + torch.randn_like(zz_r ) * noise_factor + if target_camera != None: - latent_model_input = [torch.cat([u,v], dim=1) for u,v in zip(latents,source_latents )] + latent_model_input = torch.cat([latents, source_latents], dim=1) else: latent_model_input = latents - slg_layers_local = None - if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps): - slg_layers_local = slg_layers + kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None + timestep = [t] offload.set_step_no_for_lora(self.model, i) timestep = torch.stack(timestep) @@ -444,38 +464,38 @@ class WanT2V: if joint_pass: if phantom: pos_it, pos_i, neg = self.model( - [torch.cat([latent[:,:-ref_latent.shape[1]], ref_latent], dim=1) for latent, ref_latent in zip(latent_model_input, input_ref_images)], - x_neg = [torch.cat([latent[:,:-ref_latent_neg.shape[1]], ref_latent_neg], dim=1) for latent, ref_latent_neg in zip(latent_model_input, input_ref_images_neg)], + [ torch.cat([latent_model_input[:,:-input_ref_images.shape[1]], input_ref_images], dim=1) ] * 2 + + [ torch.cat([latent_model_input[:,:-input_ref_images_neg.shape[1]], input_ref_images_neg], dim=1)], context = [context, context_null, context_null], **kwargs) else: noise_pred_cond, noise_pred_uncond = self.model( - latent_model_input, slg_layers=slg_layers_local, context = [context, context_null], **kwargs) + [latent_model_input, latent_model_input], context = [context, context_null], **kwargs) if self._interrupt: return None else: if phantom: pos_it = self.model( - [torch.cat([latent[:,:-ref_latent.shape[1]], ref_latent], dim=1) for latent, ref_latent in zip(latent_model_input, input_ref_images)], context = [context], **kwargs + [ torch.cat([latent_model_input[:,:-input_ref_images.shape[1]], input_ref_images], dim=1) ], x_id = 0, context = [context], **kwargs )[0] if self._interrupt: return None pos_i = self.model( - [torch.cat([latent[:,:-ref_latent.shape[1]], ref_latent], dim=1) for latent, ref_latent in zip(latent_model_input, input_ref_images)], context = [context_null],**kwargs + [ torch.cat([latent_model_input[:,:-input_ref_images.shape[1]], input_ref_images], dim=1) ], x_id = 1, context = [context_null],**kwargs )[0] if self._interrupt: return None neg = self.model( - [torch.cat([latent[:,:-ref_latent_neg.shape[1]], ref_latent_neg], dim=1) for latent, ref_latent_neg in zip(latent_model_input, input_ref_images_neg)], context = [context_null], **kwargs + [ torch.cat([latent_model_input[:,:-input_ref_images_neg.shape[1]], input_ref_images_neg], dim=1) ], x_id = 2, context = [context_null], **kwargs )[0] if self._interrupt: return None else: noise_pred_cond = self.model( - latent_model_input, is_uncond = False, context = [context], **kwargs)[0] + [latent_model_input], x_id = 0, context = [context], **kwargs)[0] if self._interrupt: return None noise_pred_uncond = self.model( - latent_model_input, is_uncond = True, slg_layers=slg_layers_local,context = [context_null], **kwargs)[0] + [latent_model_input], x_id = 1, context = [context_null], **kwargs)[0] if self._interrupt: return None @@ -505,21 +525,21 @@ class WanT2V: temp_x0 = sample_scheduler.step( noise_pred[:, :target_shape[1]].unsqueeze(0), t, - latents[0].unsqueeze(0), + latents.unsqueeze(0), return_dict=False, generator=seed_g)[0] - latents = [temp_x0.squeeze(0)] + latents = temp_x0.squeeze(0) del temp_x0 if callback is not None: - callback(i, latents[0], False) + callback(i, latents, False) - x0 = latents + x0 = [latents] if input_frames == None: if phantom: # phantom post processing - x0 = [x0_[:,:-input_ref_images[0].shape[1]] for x0_ in x0] + x0 = [x0_[:,:-input_ref_images.shape[1]] for x0_ in x0] videos = self.vae.decode(x0, VAE_tile_size) else: # vace post processing diff --git a/wgp.py b/wgp.py index ee4ab15..1519b17 100644 --- a/wgp.py +++ b/wgp.py @@ -40,7 +40,7 @@ global_queue_ref = [] AUTOSAVE_FILENAME = "queue.zip" PROMPT_VARS_MAX = 10 -target_mmgp_version = "3.4.1" +target_mmgp_version = "3.4.2" from importlib.metadata import version mmgp_version = version("mmgp") if mmgp_version != target_mmgp_version: @@ -49,32 +49,30 @@ if mmgp_version != target_mmgp_version: lock = threading.Lock() current_task_id = None task_id = 0 -# progress_tracker = {} -# tracker_lock = threading.Lock() -# def download_ffmpeg(): -# if os.name != 'nt': return -# exes = ['ffmpeg.exe', 'ffprobe.exe', 'ffplay.exe'] -# if all(os.path.exists(e) for e in exes): return -# api_url = 'https://api.github.com/repos/GyanD/codexffmpeg/releases/latest' -# r = requests.get(api_url, headers={'Accept': 'application/vnd.github+json'}) -# assets = r.json().get('assets', []) -# zip_asset = next((a for a in assets if 'essentials_build.zip' in a['name']), None) -# if not zip_asset: return -# zip_url = zip_asset['browser_download_url'] -# zip_name = zip_asset['name'] -# with requests.get(zip_url, stream=True) as resp: -# total = int(resp.headers.get('Content-Length', 0)) -# with open(zip_name, 'wb') as f, tqdm(total=total, unit='B', unit_scale=True) as pbar: -# for chunk in resp.iter_content(chunk_size=8192): -# f.write(chunk) -# pbar.update(len(chunk)) -# with zipfile.ZipFile(zip_name) as z: -# for f in z.namelist(): -# if f.endswith(tuple(exes)) and '/bin/' in f: -# z.extract(f) -# os.rename(f, os.path.basename(f)) -# os.remove(zip_name) +def download_ffmpeg(): + if os.name != 'nt': return + exes = ['ffmpeg.exe', 'ffprobe.exe', 'ffplay.exe'] + if all(os.path.exists(e) for e in exes): return + api_url = 'https://api.github.com/repos/GyanD/codexffmpeg/releases/latest' + r = requests.get(api_url, headers={'Accept': 'application/vnd.github+json'}) + assets = r.json().get('assets', []) + zip_asset = next((a for a in assets if 'essentials_build.zip' in a['name']), None) + if not zip_asset: return + zip_url = zip_asset['browser_download_url'] + zip_name = zip_asset['name'] + with requests.get(zip_url, stream=True) as resp: + total = int(resp.headers.get('Content-Length', 0)) + with open(zip_name, 'wb') as f, tqdm(total=total, unit='B', unit_scale=True) as pbar: + for chunk in resp.iter_content(chunk_size=8192): + f.write(chunk) + pbar.update(len(chunk)) + with zipfile.ZipFile(zip_name) as z: + for f in z.namelist(): + if f.endswith(tuple(exes)) and '/bin/' in f: + z.extract(f) + os.rename(f, os.path.basename(f)) + os.remove(zip_name) def format_time(seconds): if seconds < 60: @@ -168,14 +166,14 @@ def process_prompt_and_add_tasks(state, model_choice): resolution = inputs["resolution"] width, height = resolution.split("x") width, height = int(width), int(height) - if test_class_i2v(model_filename): - if "480p" in model_filename and not "Fun" in model_filename and width * height > 848*480: - gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P") - return - resolution = str(width) + "*" + str(height) - if resolution not in ['720*1280', '1280*720', '480*832', '832*480']: - gr.Info(f"Resolution {resolution} not supported by image 2 video") - return + # if test_class_i2v(model_filename): + # if "480p" in model_filename and not "Fun" in model_filename and width * height > 848*480: + # gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P") + # return + # resolution = str(width) + "*" + str(height) + # if resolution not in ['720*1280', '1280*720', '480*832', '832*480']: + # gr.Info(f"Resolution {resolution} not supported by image 2 video") + # return if "1.3B" in model_filename and width * height > 848*480 and any( model in model_filename for model in ["image2video", "text2video"] ): gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P") @@ -533,7 +531,7 @@ def save_queue_action(state): task_id_s = task.get('id', f"task_{task_index}") image_keys = ["image_start", "image_end", "image_refs"] - video_keys = ["video_guide", "video_mask", "video_source"] + video_keys = ["video_guide", "video_mask", "video_source", "audio_guide"] for key in image_keys: images_pil = params_copy.get(key) @@ -707,7 +705,7 @@ def load_queue_action(filepath, state, evt:gr.EventData): params['state'] = state image_keys = ["image_start", "image_end", "image_refs"] - video_keys = ["video_guide", "video_mask", "video_source"] + video_keys = ["video_guide", "video_mask", "video_source", "audio_guide"] loaded_pil_images = {} loaded_video_paths = {} @@ -925,7 +923,7 @@ def autosave_queue(): task_id_s = task.get('id', f"task_{task_index}") image_keys = ["image_start", "image_end", "image_refs"] - video_keys = ["video_guide", "video_mask", "video_source"] + video_keys = ["video_guide", "video_mask", "video_source", "audio_guide"] for key in image_keys: images_pil = params_copy.get(key) @@ -1418,32 +1416,35 @@ else: text = reader.read() server_config = json.loads(text) -# for src_path, tgt_path in zip( ["ckpts/sky_reels2_diffusion_forcing_14B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors"], ["ckpts/sky_reels2_diffusion_forcing_540p_14B_quanto_int8.safetensors", "ckpts/sky_reels2_diffusion_forcing_540p_14B_bf16.safetensors"] ): -# if Path(src_path).is_file(): -# shutil.move(src_path, tgt_path) ) -# for path in ["ckpts/wan2.1_Vace_1.3B_preview_mbf16.safetensors", "sky_reels2_diffusion_forcing_1.3B_bf16.safetensors"]: -# if Path(path).is_file(): -# os.remove(path) +# Deprecated models +for path in ["wan2.1_Vace_1.3B_preview_bf16.safetensors", "sky_reels2_diffusion_forcing_1.3B_bf16.safetensors","sky_reels2_diffusion_forcing_720p_14B_bf16.safetensors", +"sky_reels2_diffusion_forcing_720p_14B_quanto_int8.safetensors", "sky_reels2_diffusion_forcing_720p_14B_quanto_fp16_int8.safetensors", "wan2.1_image2video_480p_14B_bf16.safetensors", "wan2.1_image2video_480p_14B_quanto_int8.safetensors", +"wan2.1_image2video_720p_14B_quanto_int8.safetensors", "wan2.1_image2video_720p_14B_quanto_fp16_int8.safetensors", "wan2.1_image2video_720p_14B_bf16.safetensors" +]: + if Path(os.path.join("ckpts" , path)).is_file(): + os.remove( os.path.join("ckpts" , path)) -path= "ckpts/sky_reels2_diffusion_forcing_1.3B_bf16.safetensors" -if os.path.isfile(path) and os.path.getsize(path) > 4000000000: - os.remove(path) transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", "ckpts/wan2.1_Vace_1.3B_preview_mbf16.safetensors", "ckpts/wan2.1_recammaster_1.3B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_bf16.safetensors", - "ckpts/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors", "ckpts/sky_reels2_diffusion_forcing_720p_14B_bf16.safetensors","ckpts/sky_reels2_diffusion_forcing_720p_14B_quanto_int8.safetensors", + "ckpts/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors", "ckpts/sky_reels2_diffusion_forcing_720p_14B_mbf16.safetensors","ckpts/sky_reels2_diffusion_forcing_720p_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2_1_phantom_1.3B_mbf16.safetensors"] -transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", - "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", - "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_bf16.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_quanto_int8.safetensors"] +transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_mbf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_mbf16.safetensors", + "ckpts/wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", + "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_bf16.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_quanto_int8.safetensors", + "ckpts/wan2.1_fantasy_speaking_14B_bf16.safetensors"] transformer_choices = transformer_choices_t2v + transformer_choices_i2v - -model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "fun_inp", "recam_1.3B", "flf2v_720p", "sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "phantom_1.3B"] +def get_dependent_models(model_filename, quantization ): + if "fantasy" in model_filename: + return [get_model_filename("i2v_720p", quantization)] + else: + return [] +model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "fun_inp", "recam_1.3B", "flf2v_720p", "sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "phantom_1.3B", "fantasy"] model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B", "i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B", "recam_1.3B": "recammaster_1.3B", "flf2v_720p" : "FLF2V_720p", "sky_df_1.3B" : "sky_reels2_diffusion_forcing_1.3B", "sky_df_14B" : "sky_reels2_diffusion_forcing_14B", "sky_df_720p_14B" : "sky_reels2_diffusion_forcing_720p_14B", - "phantom_1.3B" : "phantom_1.3B", } + "phantom_1.3B" : "phantom_1.3B", "fantasy" : "fantasy" } def get_model_type(model_filename): @@ -1453,7 +1454,7 @@ def get_model_type(model_filename): raise Exception("Unknown model:" + model_filename) def test_class_i2v(model_filename): - return "image2video" in model_filename or "Fun_InP" in model_filename or "FLF2V" in model_filename + return "image2video" in model_filename or "Fun_InP" in model_filename or "FLF2V" in model_filename or "fantasy" in model_filename def get_model_name(model_filename, description_container = [""]): if "Fun" in model_filename: @@ -1491,6 +1492,10 @@ def get_model_name(model_filename, description_container = [""]): model_name = "Wan2.1 Phantom" model_name += " 14B" if "14B" in model_filename else " 1.3B" description = "The Phantom model is specialized to transfer people or objects of your choice into a generated Video. It produces very nices results when used at 720p." + elif "fantasy" in model_filename: + model_name = "Wan2.1 Fantasy Speaking 720p" + model_name += " 14B" if "14B" in model_filename else " 1.3B" + description = "The Fantasy Speaking model corresponds to the original Wan image 2 video model combined with the Fantasy Speaking extension to process an audio Input." else: model_name = "Wan2.1 text2video" model_name += " 14B" if "14B" in model_filename else " 1.3B" @@ -1536,6 +1541,7 @@ def get_default_settings(filename): "repeat_generation": 1, "multi_images_gen_type": 0, "guidance_scale": 5.0, + "audio_guidance_scale": 5.0, "flow_shift": get_default_flow(filename, i2v), "negative_prompt": "", "activated_loras": [], @@ -1719,8 +1725,9 @@ def download_models(transformer_filename, text_encoder_filename): from huggingface_hub import hf_hub_download, snapshot_download repoId = "DeepBeepMeep/Wan2.1" - sourceFolderList = ["xlm-roberta-large", "pose", "depth", "mask", "", ] - fileList = [ [], [],[], ["sam_vit_h_4b8939_fp16.safetensors"], ["Wan2.1_VAE.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "flownet.pkl" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ] + sourceFolderList = ["xlm-roberta-large", "pose", "depth", "mask", "wav2vec", "" ] + fileList = [ [], [],[], ["sam_vit_h_4b8939_fp16.safetensors"], ["config.json", "feature_extractor_config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"], + ["Wan2.1_VAE.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "flownet.pkl" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ] targetRoot = "ckpts/" for sourceFolder, files in zip(sourceFolderList,fileList ): if len(files)==0: @@ -1834,12 +1841,13 @@ def setup_loras(model_filename, transformer, lora_dir, lora_preselected_preset, return loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset -def load_t2v_model(model_filename, value, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False): +def load_t2v_model(model_filename, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False): cfg = WAN_CONFIGS['t2v-14B'] + filename = model_filename[-1] # cfg = WAN_CONFIGS['t2v-1.3B'] - print(f"Loading '{model_filename}' model...") - if get_model_type(model_filename) in ("sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"): + print(f"Loading '{filename}' model...") + if get_model_type(filename) in ("sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"): model_factory = wan.DTT2V else: model_factory = wan.WanT2V @@ -1859,9 +1867,10 @@ def load_t2v_model(model_filename, value, quantizeTransformer = False, dtype = t return wan_model, pipe -def load_i2v_model(model_filename, value, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False): +def load_i2v_model(model_filename, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False): - print(f"Loading '{model_filename}' model...") + filename = model_filename[-1] + print(f"Loading '{filename}' model...") cfg = WAN_CONFIGS['i2v-14B'] wan_model = wan.WanI2V( @@ -1883,7 +1892,6 @@ def load_i2v_model(model_filename, value, quantizeTransformer = False, dtype = t def load_models(model_filename): global transformer_filename - transformer_filename = model_filename perc_reserved_mem_max = args.perc_reserved_mem_max major, minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None) @@ -1892,19 +1900,27 @@ def load_models(model_filename): default_dtype = torch.float16 else: default_dtype = torch.float16 if args.fp16 else torch.bfloat16 - if default_dtype == torch.float16 : - if "quanto" in model_filename: - model_filename = model_filename.replace("quanto_int8", "quanto_fp16_int8") - download_models(model_filename, text_encoder_filename) + model_filelist = get_dependent_models(model_filename, quantization= transformer_quantization) + [model_filename] + updated_model_filename = [] + for filename in model_filelist: + if default_dtype == torch.float16 : + if "quanto_int8" in filename: + filename = filename.replace("quanto_int8", "quanto_fp16_int8") + elif "quanto_mbf16_int8": + filename = filename.replace("quanto_mbf16_int8", "quanto_mfp16_int8") + updated_model_filename.append(filename) + download_models(filename, text_encoder_filename) + model_filelist = updated_model_filename VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float mixed_precision_transformer = server_config.get("mixed_precision","0") == "1" - if test_class_i2v(model_filename): - res720P = "720p" in model_filename - wan_model, pipe = load_i2v_model(model_filename, "720P" if res720P else "480P", quantizeTransformer = quantizeTransformer, dtype = default_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer) + transformer_filename = None + new_transformer_filename = model_filelist[-1] + if test_class_i2v(new_transformer_filename): + wan_model, pipe = load_i2v_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = default_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer) else: - wan_model, pipe = load_t2v_model(model_filename, "", quantizeTransformer = quantizeTransformer, dtype = default_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer) - wan_model._model_file_name = model_filename - kwargs = { "extraModelsToQuantize": None} + wan_model, pipe = load_t2v_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = default_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer) + wan_model._model_file_name = new_transformer_filename + kwargs = { "extraModelsToQuantize": None} if profile == 2 or profile == 4: kwargs["budgets"] = { "transformer" : 100 if preload == 0 else preload, "text_encoder" : 100, "*" : 1000 } # if profile == 4: @@ -1914,7 +1930,7 @@ def load_models(model_filename): offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = default_dtype, **kwargs) if len(args.gpu) > 0: torch.set_default_device(args.gpu) - + transformer_filename = new_transformer_filename return wan_model, offloadobj, pipe["transformer"] if not "P" in preload_model_policy: @@ -2033,13 +2049,7 @@ def apply_changes( state, preload_model_policy = server_config["preload_model_policy"] transformer_quantization = server_config["transformer_quantization"] transformer_types = server_config["transformer_types"] - model_filename = state["model_filename"] - model_transformer_type = get_model_type(model_filename) - if not model_transformer_type in transformer_types: - model_transformer_type = transformer_types[0] if len(transformer_types) > 0 else model_types[0] - model_filename = get_model_filename(model_transformer_type, transformer_quantization) - state["model_filename"] = model_filename if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list"] for change in changes ): model_choice = gr.Dropdown() else: @@ -2249,7 +2259,9 @@ def preprocess_video(process_type, height, width, video_in, max_frames, start_fr frame_height, frame_width, _ = frames_list[0].shape if fit_canvas : - scale = min(height / frame_height, width / frame_width) + scale1 = min(height / frame_height, width / frame_width) + scale2 = min(height / frame_width, width / frame_height) + scale = max(scale1, scale2) else: scale = ((height * width ) / (frame_height * frame_width))**(1/2) @@ -2356,6 +2368,7 @@ def generate_video( seed, num_inference_steps, guidance_scale, + audio_guidance_scale, flow_shift, embedded_guidance_scale, repeat_generation, @@ -2375,8 +2388,10 @@ def generate_video( video_guide, keep_frames_video_guide, video_mask, + audio_guide, sliding_window_size, sliding_window_overlap, + sliding_window_overlap_noise, sliding_window_discard_last_frames, remove_background_image_ref, temporal_upsampling, @@ -2508,6 +2523,15 @@ def generate_video( # VAE Tiling device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576 + diffusion_forcing = "diffusion_forcing" in model_filename + vace = "Vace" in model_filename + if diffusion_forcing: + fps = 24 + elif audio_guide != None: + fps = 23 + else: + fps = 16 + joint_pass = boost ==1 #and profile != 1 and profile != 3 # TeaCache trans.enable_teacache = tea_cache_setting > 0 @@ -2517,12 +2541,10 @@ def generate_video( trans.teacache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100) if image2video: - if '480p' in model_filename: - trans.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01] - elif '720p' in model_filename: + if '720p' in model_filename: trans.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683] else: - raise gr.Error("Teacache not supported for this model") + trans.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01] else: if '1.3B' in model_filename: trans.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01] @@ -2535,6 +2557,18 @@ def generate_video( if "recam" in model_filename: source_video = preprocess_video("", width=width, height=height,video_in=video_source, max_frames= video_length, start_frame = 0, fit_canvas= True) target_camera = model_mode + + audio_proj_split = None + audio_scale = None + audio_context_lens = None + if audio_guide != None: + from fantasytalking.infer import parse_audio + import librosa + duration = librosa.get_duration(path=audio_guide) + video_length = min(int(fps * duration // 4) * 4 + 5, video_length) + audio_proj_split, audio_context_lens = parse_audio(audio_guide, num_frames= video_length, fps= fps, device= processing_device ) + audio_scale = 1.0 + import random if seed == None or seed <0: seed = random.randint(0, 999999999) @@ -2551,11 +2585,11 @@ def generate_video( extra_generation = 0 initial_total_windows = 0 max_frames_to_generate = video_length - diffusion_forcing = "diffusion_forcing" in model_filename - vace = "Vace" in model_filename phantom = "phantom" in model_filename if diffusion_forcing or vace: reuse_frames = min(sliding_window_size - 4, sliding_window_overlap) + else: + reuse_frames = 0 if diffusion_forcing and source_video != None: video_length += sliding_window_overlap sliding_window = ("Vace" in model_filename or diffusion_forcing) and video_length > sliding_window_size @@ -2571,10 +2605,8 @@ def generate_video( initial_total_windows = 1 first_window_video_length = video_length - fps = 24 if diffusion_forcing else 16 gen["sliding_window"] = sliding_window - while not abort: extra_generation += gen.get("extra_orders",0) gen["extra_orders"] = 0 @@ -2594,6 +2626,7 @@ def generate_video( guide_start_frame = 0 video_length = first_window_video_length gen["extra_windows"] = 0 + start_time = time.time() while not abort: if sliding_window: prompt = prompts[window_no] if window_no < len(prompts) else prompts[-1] @@ -2642,7 +2675,7 @@ def generate_video( if preprocess_type != None : send_cmd("progress", progress_args) - video_guide_copy = preprocess_video(preprocess_type, width=width, height=height,video_in=video_guide, max_frames= video_length if window_no == 1 else video_length - reuse_frames, start_frame = guide_start_frame, target_fps = fps) + video_guide_copy = preprocess_video(preprocess_type, width=width, height=height,video_in=video_guide, max_frames= video_length if window_no == 1 else video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = True, target_fps = fps) keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, max_frames_to_generate) if len(error) > 0: raise gr.Error(f"invalid keep frames {keep_frames_video_guide}") @@ -2678,8 +2711,7 @@ def generate_video( trans.teacache_counter = 0 trans.num_steps = num_inference_steps trans.teacache_skipped_steps = 0 - trans.previous_residual_uncond = None - trans.previous_residual_cond = None + trans.previous_residual = None if image2video: samples = wan_model.generate( @@ -2687,7 +2719,9 @@ def generate_video( image_start, image_end if image_end != None else None, frame_num=(video_length // 4)* 4 + 1, - max_area=MAX_AREA_CONFIGS[resolution_reformated], + # max_area=MAX_AREA_CONFIGS[resolution_reformated], + height = height, + width = width, shift=flow_shift, sampling_steps=num_inference_steps, guide_scale=guidance_scale, @@ -2702,7 +2736,11 @@ def generate_video( slg_end = slg_end_perc/100, cfg_star_switch = cfg_star_switch, cfg_zero_step = cfg_zero_step, - add_frames_for_end_image = "image2video" in model_filename + add_frames_for_end_image = "image2video" in model_filename, + audio_cfg_scale= audio_guidance_scale, + audio_proj= audio_proj_split, + audio_scale= audio_scale, + audio_context_lens= audio_context_lens ) elif diffusion_forcing: samples = wan_model.generate( @@ -2720,14 +2758,17 @@ def generate_video( callback= callback, VAE_tile_size = VAE_tile_size, joint_pass = joint_pass, - addnoise_condition = 20, + slg_layers = slg_layers, + slg_start = slg_start_perc/100, + slg_end = slg_end_perc/100, + addnoise_condition = sliding_window_overlap_noise, ar_step = model_mode, #5 causal_block_size = 5, causal_attention = True, fps = fps, ) else: - samples = wan_model.generate( + samples = wan_model.generate( prompt, input_frames = src_video, input_ref_images= src_ref_images, @@ -2751,6 +2792,9 @@ def generate_video( slg_end = slg_end_perc/100, cfg_star_switch = cfg_star_switch, cfg_zero_step = cfg_zero_step, + overlapped_latents = 0 if reuse_frames == 0 or window_no == 1 else ((reuse_frames - 1) // 4 + 1), + overlap_noise = sliding_window_overlap_noise, + vace = vace ) except Exception as e: if temp_filename!= None and os.path.isfile(temp_filename): @@ -2782,11 +2826,11 @@ def generate_video( print('\n'.join(tb)) send_cmd("error", new_error) return + finally: + trans.previous_residual = None if trans.enable_teacache: - print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{num_inference_steps}" ) - trans.previous_residual_uncond = None - trans.previous_residual_cond = None + print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{trans.num_steps}" ) if samples != None: samples = samples.to("cpu") @@ -2810,14 +2854,27 @@ def generate_video( if discard_last_frames > 0: sample = sample[: , :-discard_last_frames] guide_start_frame -= discard_last_frames - pre_video_guide = sample[:, -reuse_frames:] + if reuse_frames == 0: + pre_video_guide = sample[:,9999 :] + else: + # noise_factor = 200/ 1000 + # pre_video_guide = sample[:, -reuse_frames:] * (1.0 - noise_factor) + torch.randn_like(sample[:, -reuse_frames:] ) * noise_factor + pre_video_guide = sample[:, -reuse_frames:] + + if prefix_video != None: - sample = torch.cat([ prefix_video[:, :-reuse_frames], sample], dim = 1) + if reuse_frames == 0: + sample = torch.cat([ prefix_video[:, :], sample], dim = 1) + else: + sample = torch.cat([ prefix_video[:, :-reuse_frames], sample], dim = 1) prefix_video = None if sliding_window and window_no > 1: - sample = sample[: , reuse_frames:] - guide_start_frame -= reuse_frames + if reuse_frames == 0: + sample = sample[: , :] + else: + sample = sample[: , reuse_frames:] + guide_start_frame -= reuse_frames time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss") if os.name == 'nt': file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50]).strip()}.mp4" @@ -2875,18 +2932,23 @@ def generate_video( sample = torch.cat([frames_already_processed, sample], dim=1) frames_already_processed = sample - cache_video( - tensor=sample[None], - save_file=video_path, - fps=fps, - nrow=1, - normalize=True, - value_range=(-1, 1)) + if audio_guide == None: + cache_video( tensor=sample[None], save_file=video_path, fps=fps, nrow=1, normalize=True, value_range=(-1, 1)) + else: + save_path_tmp = video_path[:-4] + "_tmp.mp4" + cache_video( tensor=sample[None], save_file=save_path_tmp, fps=fps, nrow=1, normalize=True, value_range=(-1, 1)) + final_command = [ "ffmpeg", "-y", "-i", save_path_tmp, "-i", audio_guide, "-c:v", "libx264", "-c:a", "aac", "-shortest", "-loglevel", "warning", "-nostats", video_path, ] + import subprocess + subprocess.run(final_command, check=True) + os.remove(save_path_tmp) + + end_time = time.time() inputs = get_function_arguments(generate_video, locals()) inputs.pop("send_cmd") + inputs.pop("task_id") configs = prepare_inputs_dict("metadata", inputs) - + configs["generation_time"] = round(end_time-start_time) metadata_choice = server_config.get("metadata_type","metadata") if metadata_choice == "json": with open(video_path.replace('.mp4', '.json'), 'w') as f: @@ -3113,7 +3175,7 @@ def get_latest_status(state): prompt_no = gen["prompt_no"] prompts_max = gen.get("prompts_max",0) total_generation = gen.get("total_generation", 1) - repeat_no = gen["repeat_no"] + repeat_no = gen.get("repeat_no",0) total_generation += gen.get("extra_orders", 0) total_windows = gen.get("total_windows", 0) total_windows += gen.get("extra_windows", 0) @@ -3456,7 +3518,7 @@ def prepare_inputs_dict(target, inputs ): if target == "state": return inputs - unsaved_params = ["image_start", "image_end", "image_refs", "video_guide", "video_source", "video_mask"] + unsaved_params = ["image_start", "image_end", "image_refs", "video_guide", "video_source", "video_mask", "audio_guide", "embedded_guidance_scale"] for k in unsaved_params: inputs.pop(k) @@ -3484,10 +3546,14 @@ def prepare_inputs_dict(target, inputs ): inputs.pop(k) if not "Vace" in model_filename or "diffusion_forcing" in model_filename: - unsaved_params = [ "sliding_window_size", "sliding_window_overlap", "sliding_window_discard_last_frames"] + unsaved_params = [ "sliding_window_size", "sliding_window_overlap", "sliding_window_overlap_noise", "sliding_window_discard_last_frames"] for k in unsaved_params: inputs.pop(k) + if not "fantasy" in model_filename: + inputs.pop("audio_guidance_scale") + + if target == "metadata": inputs = {k: v for k,v in inputs.items() if v != None } @@ -3511,6 +3577,7 @@ def save_inputs( seed, num_inference_steps, guidance_scale, + audio_guidance_scale, flow_shift, embedded_guidance_scale, repeat_generation, @@ -3530,8 +3597,10 @@ def save_inputs( video_guide, keep_frames_video_guide, video_mask, + audio_guide, sliding_window_size, sliding_window_overlap, + sliding_window_overlap_noise, sliding_window_discard_last_frames, remove_background_image_ref, temporal_upsampling, @@ -3834,6 +3903,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non recammaster = "recam" in model_filename vace = "Vace" in model_filename phantom = "phantom" in model_filename + fantasy = "fantasy" in model_filename with gr.Column(visible= test_class_i2v(model_filename) or diffusion_forcing or recammaster) as image_prompt_column: if diffusion_forcing: image_prompt_type_value= ui_defaults.get("image_prompt_type","S") @@ -3939,7 +4009,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non video_mask = gr.Video(label= "Video Mask (for Inpainting or Outpaing, white pixels = Mask)", visible= "M" in video_prompt_type_value, value= ui_defaults.get("video_mask", None)) - + audio_guide = gr.Audio(value= ui_defaults.get("audio_guide", None), type="filepath", label="Voice to follow", show_download_button= True, visible= fantasy ) advanced_prompt = advanced_ui prompt_vars=[] @@ -3972,12 +4042,12 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False) wizard_variables_var = gr.Text(wizard_variables, visible = False) with gr.Row(): - if test_class_i2v(model_filename): + if test_class_i2v(model_filename) and False: resolution = gr.Dropdown( choices=[ # 720p - ("720p", "1280x720"), - ("480p", "832x480"), + ("720p (same amount of pixels)", "1280x720"), + ("480p (same amount of pixels)", "832x480"), ], value=ui_defaults.get("resolution","480p"), label="Resolution (video will have the same height / width ratio than the original image)" @@ -3989,19 +4059,21 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ("1280x720 (16:9, 720p)", "1280x720"), ("720x1280 (9:16, 720p)", "720x1280"), ("1024x1024 (4:3, 720p)", "1024x024"), - # ("832x1104 (3:4, 720p)", "832x1104"), - # ("960x960 (1:1, 720p)", "960x960"), + ("832x1104 (3:4, 720p)", "832x1104"), + ("1104x832 (3:4, 720p)", "1104x832"), + ("960x960 (1:1, 720p)", "960x960"), # 480p ("960x544 (16:9, 540p)", "960x544"), ("544x960 (16:9, 540p)", "544x960"), ("832x480 (16:9, 480p)", "832x480"), ("480x832 (9:16, 480p)", "480x832"), - # ("832x624 (4:3, 540p)", "832x624"), - # ("624x832 (3:4, 540p)", "624x832"), - # ("720x720 (1:1, 540p)", "720x720"), + ("832x624 (4:3, 480p)", "832x624"), + ("624x832 (3:4, 480p)", "624x832"), + ("720x720 (1:1, 480p)", "720x720"), + ("512x512 (1:1, 480p)", "512x512"), ], value=ui_defaults.get("resolution","832x480"), - label="Resolution" + label="Max Resolution (as it maybe less depending on video width / height ratio)" if test_class_i2v(model_filename) else "Resolution" ) with gr.Row(): if recammaster: @@ -4010,6 +4082,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non video_length = gr.Slider(17, 737, value=ui_defaults.get("video_length", 97), step=20, label="Number of frames (24 = 1s)", interactive= True) elif vace: video_length = gr.Slider(17, 737, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True) + elif fantasy: + video_length = gr.Slider(5, 233, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (23 = 1s)", interactive= True) else: video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True) with gr.Row(): @@ -4029,10 +4103,11 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non choices=[ ("Generate every combination of images and texts", 0), ("Match images and text prompts", 1), - ], visible= True, label= "Multiple Images as Texts Prompts" + ], visible= test_class_i2v(model_filename), label= "Multiple Images as Texts Prompts" ) with gr.Row(): guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance_scale",5), step=0.5, label="Guidance Scale", visible=True) + audio_guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("audio_guidance_scale",5), step=0.5, label="Audio Guidance", visible=fantasy) embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=False) flow_shift = gr.Slider(0.0, 25.0, value=ui_defaults.get("flow_shift",3), step=0.1, label="Shift Scale") with gr.Row(): @@ -4099,7 +4174,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non with gr.Tab("Quality"): with gr.Row(): - gr.Markdown("Experimental: Skip Layer Guidance, should improve video quality") + gr.Markdown("Skip Layer Guidance (improves video quality)") with gr.Row(): slg_switch = gr.Dropdown( choices=[ @@ -4148,11 +4223,13 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non if diffusion_forcing: sliding_window_size = gr.Slider(37, 137, value=ui_defaults.get("sliding_window_size", 97), step=20, label="Sliding Window Size (recommended to keep it at 97)") sliding_window_overlap = gr.Slider(17, 97, value=ui_defaults.get("sliding_window_overlap",17), step=20, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)") - sliding_window_discard_last_frames = gr.Slider(0, 12, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, visible = False) + sliding_window_overlap_noise = gr.Slider(0, 100, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect") + sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, visible = False) else: sliding_window_size = gr.Slider(5, 137, value=ui_defaults.get("sliding_window_size", 81), step=4, label="Sliding Window Size") - sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",17), step=4, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)") - sliding_window_discard_last_frames = gr.Slider(0, 12, value=ui_defaults.get("sliding_window_discard_last_frames", 4), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True) + sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",5), step=4, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)") + sliding_window_overlap_noise = gr.Slider(0, 100, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect") + sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 8), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True) with gr.Tab("Miscellaneous", visible= not "recam" in model_filename): @@ -4167,8 +4244,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non label="RIFLEx positional embedding to generate long video" ) - with gr.Row(): - save_settings_btn = gr.Button("Set Settings as Default", visible = not args.lock_config) + with gr.Row(): + save_settings_btn = gr.Button("Set Settings as Default", visible = not args.lock_config) if not update_form: with gr.Column(): @@ -5035,7 +5112,7 @@ def create_demo(): theme = gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md") with gr.Blocks(css=css, theme=theme, title= "Wan2GP") as main: - gr.Markdown("