Added support for fantasyspeaking model

This commit is contained in:
DeepBeepMeep 2025-05-04 00:10:40 +02:00
parent 4ecc866c7b
commit bc9121ffc6
13 changed files with 857 additions and 440 deletions

View File

@ -10,6 +10,7 @@
## 🔥 Latest News!!
* May 5 2025: 👋 Wan 2.1GP v4.5: FantasySpeaking model, you can animate a talking head using a voice track. This works not only on people but also on objects. Also better seamless transitions between Vace sliding windows for very long videos (see recommended settings). New high quality processing features (mixed 16/32 bits calculation and 32 bitsVAE)
* April 27 2025: 👋 Wan 2.1GP v4.4: Phantom model support, very good model to transfer people or objects into video, works quite well at 720p and with the number of steps > 30
* April 25 2025: 👋 Wan 2.1GP v4.3: Added preview mode and support for Sky Reels v2 Diffusion Forcing for high quality "infinite length videos" (see Window Sliding section below).Note that Skyreel uses causal attention that is only supported by Sdpa attention so even if chose an other type of attention, some of the processes will use Sdpa attention.
@ -303,7 +304,13 @@ Vace provides on its github (https://github.com/ali-vilab/VACE/tree/main/vace/gr
There is also a guide that describes the various combination of hints (https://github.com/ali-vilab/VACE/blob/main/UserGuide.md).Good luck !
It seems you will get better results if you turn on "Skip Layer Guidance" with its default configuration
It seems you will get better results with Vace if you turn on "Skip Layer Guidance" with its default configuration.
Other recommended setttings for Vace:
- Use a long prompt description especially for the people / objects that are in the background and not in reference images. This will ensure consistency between the windows.
- Set a medium size overlap window: long enough to give the model a sense of the motion but short enough so any overlapped blurred frames do no turn the rest of the video into a blurred video
- Truncate at least the last 4 frames of the each generated window as Vace last frames tends to be blurry
### VACE and Sky Reels v2 Diffusion Forcing Slidig Window
With this mode (that works for the moment only with Vace and Sky Reels v2) you can merge mutiple Videos to form a very long video (up to 1 min).

27
fantasytalking/infer.py Normal file
View File

@ -0,0 +1,27 @@
# Copyright Alibaba Inc. All Rights Reserved.
from transformers import Wav2Vec2Model, Wav2Vec2Processor
from .model import FantasyTalkingAudioConditionModel
from .utils import get_audio_features
def parse_audio(audio_path, num_frames, fps = 23, device = "cuda"):
fantasytalking = FantasyTalkingAudioConditionModel(None, 768, 2048).to(device)
from mmgp import offload
from accelerate import init_empty_weights
from fantasytalking.model import AudioProjModel
with init_empty_weights():
proj_model = AudioProjModel( 768, 2048)
offload.load_model_data(proj_model, "ckpts/fantasy_proj_model.safetensors")
proj_model.to(device).eval().requires_grad_(False)
wav2vec_model_dir = "ckpts/wav2vec"
wav2vec_processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_dir)
wav2vec = Wav2Vec2Model.from_pretrained(wav2vec_model_dir).to(device).eval().requires_grad_(False)
audio_wav2vec_fea = get_audio_features( wav2vec, wav2vec_processor, audio_path, fps, num_frames )
audio_proj_fea = proj_model(audio_wav2vec_fea)
pos_idx_ranges = fantasytalking.split_audio_sequence( audio_proj_fea.size(1), num_frames=num_frames )
audio_proj_split, audio_context_lens = fantasytalking.split_tensor_with_padding( audio_proj_fea, pos_idx_ranges, expand_length=4 ) # [b,21,9+8,768]
return audio_proj_split, audio_context_lens

162
fantasytalking/model.py Normal file
View File

@ -0,0 +1,162 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from wan.modules.attention import pay_attention
class AudioProjModel(nn.Module):
def __init__(self, audio_in_dim=1024, cross_attention_dim=1024):
super().__init__()
self.cross_attention_dim = cross_attention_dim
self.proj = torch.nn.Linear(audio_in_dim, cross_attention_dim, bias=False)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
def forward(self, audio_embeds):
context_tokens = self.proj(audio_embeds)
context_tokens = self.norm(context_tokens)
return context_tokens # [B,L,C]
class WanCrossAttentionProcessor(nn.Module):
def __init__(self, context_dim, hidden_dim):
super().__init__()
self.context_dim = context_dim
self.hidden_dim = hidden_dim
self.k_proj = nn.Linear(context_dim, hidden_dim, bias=False)
self.v_proj = nn.Linear(context_dim, hidden_dim, bias=False)
nn.init.zeros_(self.k_proj.weight)
nn.init.zeros_(self.v_proj.weight)
def __call__(
self,
q: torch.Tensor,
audio_proj: torch.Tensor,
latents_num_frames: int = 21,
audio_context_lens = None
) -> torch.Tensor:
"""
audio_proj: [B, 21, L3, C]
audio_context_lens: [B*21].
"""
b, l, n, d = q.shape
if len(audio_proj.shape) == 4:
audio_q = q.view(b * latents_num_frames, -1, n, d) # [b, 21, l1, n, d]
ip_key = self.k_proj(audio_proj).view(b * latents_num_frames, -1, n, d)
ip_value = self.v_proj(audio_proj).view(b * latents_num_frames, -1, n, d)
qkv_list = [audio_q, ip_key, ip_value]
del q, audio_q, ip_key, ip_value
audio_x = pay_attention(qkv_list, k_lens =audio_context_lens) #audio_context_lens
audio_x = audio_x.view(b, l, n, d)
audio_x = audio_x.flatten(2)
elif len(audio_proj.shape) == 3:
ip_key = self.k_proj(audio_proj).view(b, -1, n, d)
ip_value = self.v_proj(audio_proj).view(b, -1, n, d)
qkv_list = [q, ip_key, ip_value]
del q, ip_key, ip_value
audio_x = pay_attention(qkv_list, k_lens =audio_context_lens) #audio_context_lens
audio_x = audio_x.flatten(2)
return audio_x
class FantasyTalkingAudioConditionModel(nn.Module):
def __init__(self, wan_dit, audio_in_dim: int, audio_proj_dim: int):
super().__init__()
self.audio_in_dim = audio_in_dim
self.audio_proj_dim = audio_proj_dim
def split_audio_sequence(self, audio_proj_length, num_frames=81):
"""
Map the audio feature sequence to corresponding latent frame slices.
Args:
audio_proj_length (int): The total length of the audio feature sequence
(e.g., 173 in audio_proj[1, 173, 768]).
num_frames (int): The number of video frames in the training data (default: 81).
Returns:
list: A list of [start_idx, end_idx] pairs. Each pair represents the index range
(within the audio feature sequence) corresponding to a latent frame.
"""
# Average number of tokens per original video frame
tokens_per_frame = audio_proj_length / num_frames
# Each latent frame covers 4 video frames, and we want the center
tokens_per_latent_frame = tokens_per_frame * 4
half_tokens = int(tokens_per_latent_frame / 2)
pos_indices = []
for i in range(int((num_frames - 1) / 4) + 1):
if i == 0:
pos_indices.append(0)
else:
start_token = tokens_per_frame * ((i - 1) * 4 + 1)
end_token = tokens_per_frame * (i * 4 + 1)
center_token = int((start_token + end_token) / 2) - 1
pos_indices.append(center_token)
# Build index ranges centered around each position
pos_idx_ranges = [[idx - half_tokens, idx + half_tokens] for idx in pos_indices]
# Adjust the first range to avoid negative start index
pos_idx_ranges[0] = [
-(half_tokens * 2 - pos_idx_ranges[1][0]),
pos_idx_ranges[1][0],
]
return pos_idx_ranges
def split_tensor_with_padding(self, input_tensor, pos_idx_ranges, expand_length=0):
"""
Split the input tensor into subsequences based on index ranges, and apply right-side zero-padding
if the range exceeds the input boundaries.
Args:
input_tensor (Tensor): Input audio tensor of shape [1, L, 768].
pos_idx_ranges (list): A list of index ranges, e.g. [[-7, 1], [1, 9], ..., [165, 173]].
expand_length (int): Number of tokens to expand on both sides of each subsequence.
Returns:
sub_sequences (Tensor): A tensor of shape [1, F, L, 768], where L is the length after padding.
Each element is a padded subsequence.
k_lens (Tensor): A tensor of shape [F], representing the actual (unpadded) length of each subsequence.
Useful for ignoring padding tokens in attention masks.
"""
pos_idx_ranges = [
[idx[0] - expand_length, idx[1] + expand_length] for idx in pos_idx_ranges
]
sub_sequences = []
seq_len = input_tensor.size(1) # 173
max_valid_idx = seq_len - 1 # 172
k_lens_list = []
for start, end in pos_idx_ranges:
# Calculate the fill amount
pad_front = max(-start, 0)
pad_back = max(end - max_valid_idx, 0)
# Calculate the start and end indices of the valid part
valid_start = max(start, 0)
valid_end = min(end, max_valid_idx)
# Extract the valid part
if valid_start <= valid_end:
valid_part = input_tensor[:, valid_start : valid_end + 1, :]
else:
valid_part = input_tensor.new_zeros((1, 0, input_tensor.size(2)))
# In the sequence dimension (the 1st dimension) perform padding
padded_subseq = F.pad(
valid_part,
(0, 0, 0, pad_back + pad_front, 0, 0),
mode="constant",
value=0,
)
k_lens_list.append(padded_subseq.size(-2) - pad_back - pad_front)
sub_sequences.append(padded_subseq)
return torch.stack(sub_sequences, dim=1), torch.tensor(
k_lens_list, dtype=torch.long
)

52
fantasytalking/utils.py Normal file
View File

@ -0,0 +1,52 @@
# Copyright Alibaba Inc. All Rights Reserved.
import imageio
import librosa
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
def resize_image_by_longest_edge(image_path, target_size):
image = Image.open(image_path).convert("RGB")
width, height = image.size
scale = target_size / max(width, height)
new_size = (int(width * scale), int(height * scale))
return image.resize(new_size, Image.LANCZOS)
def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
writer = imageio.get_writer(
save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params
)
for frame in tqdm(frames, desc="Saving video"):
frame = np.array(frame)
writer.append_data(frame)
writer.close()
def get_audio_features(wav2vec, audio_processor, audio_path, fps, num_frames):
sr = 16000
audio_input, sample_rate = librosa.load(audio_path, sr=sr) # 采样率为 16kHz
start_time = 0
# end_time = (0 + (num_frames - 1) * 1) / fps
end_time = num_frames / fps
start_sample = int(start_time * sr)
end_sample = int(end_time * sr)
try:
audio_segment = audio_input[start_sample:end_sample]
except:
audio_segment = audio_input
input_values = audio_processor(
audio_segment, sampling_rate=sample_rate, return_tensors="pt"
).input_values.to("cuda")
with torch.no_grad():
fea = wav2vec(input_values).last_hidden_state
return fea

View File

@ -16,7 +16,7 @@ gradio==5.23.0
numpy>=1.23.5,<2
einops
moviepy==1.0.3
mmgp==3.4.1
mmgp==3.4.2
peft==0.14.0
mutagen
pydantic==2.10.6
@ -28,4 +28,5 @@ timm
segment-anything
omegaconf
hydra-core
librosa
# rembg==2.0.65

View File

@ -44,6 +44,8 @@ SUPPORTED_SIZES = {
VACE_SIZE_CONFIGS = {
'480*832': (480, 832),
'832*480': (832, 480),
'720*1280': (720, 1280),
'1280*720': (1280, 720),
}
VACE_MAX_AREA_CONFIGS = {

View File

@ -56,16 +56,18 @@ class DTT2V:
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype,
device=self.device)
logging.info(f"Creating WanModel from {model_filename}")
logging.info(f"Creating WanModel from {model_filename[-1]}")
from mmgp import offload
# model_filename = "model.safetensors"
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) #, forcedConfigPath="config.json")
# model_filename = "c:/temp/diffusion_pytorch_model-00001-of-00006.safetensors"
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) # , forcedConfigPath="c:/temp/config _df720.json")
# offload.load_model_data(self.model, "recam.ckpt")
# self.model.cpu()
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype, True)
# dtype = torch.float16
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
offload.change_dtype(self.model, dtype, True)
# offload.save_model(self.model, "sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors", config_file_path="config.json")
# offload.save_model(self.model, "sky_reels2_diffusion_forcing_720p_14B_quanto_xbf16_int8.safetensors", do_quantize= True, config_file_path="config.json")
# offload.save_model(self.model, "sky_reels2_diffusion_forcing_720p_14B_quanto_mbf16_int8.safetensors", do_quantize= True, config_file_path="c:/temp/config _df720.json")
# offload.save_model(self.model, "rtfp16_int8.safetensors", do_quantize= "config.json")
self.model.eval().requires_grad_(False)
@ -200,6 +202,9 @@ class DTT2V:
fps: int = 24,
VAE_tile_size = 0,
joint_pass = False,
slg_layers = None,
slg_start = 0.0,
slg_end = 1.0,
callback = None,
):
self._interrupt = False
@ -211,6 +216,7 @@ class DTT2V:
if ar_step == 0:
causal_block_size = 1
causal_attention = False
i2v_extra_kwrags = {}
prefix_video = None
@ -252,31 +258,33 @@ class DTT2V:
prefix_video = output_video.to(self.device)
else:
causal_block_size = 1
causal_attention = False
ar_step = 0
prefix_video = image
prefix_video = torch.tensor(prefix_video).unsqueeze(1) # .to(image_embeds.dtype).unsqueeze(1)
if prefix_video.dtype == torch.uint8:
prefix_video = (prefix_video.float() / (255.0 / 2.0)) - 1.0
prefix_video = prefix_video.to(self.device)
prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)]
predix_video_latent_length = prefix_video[0].shape[1]
prefix_video = self.vae.encode(prefix_video.unsqueeze(0))[0] # [(c, f, h, w)]
predix_video_latent_length = prefix_video.shape[1]
truncate_len = predix_video_latent_length % causal_block_size
if truncate_len != 0:
if truncate_len == predix_video_latent_length:
causal_block_size = 1
causal_attention = False
ar_step = 0
else:
print("the length of prefix video is truncated for the casual block size alignment.")
predix_video_latent_length -= truncate_len
prefix_video[0] = prefix_video[0][:, : predix_video_latent_length]
prefix_video = prefix_video[:, : predix_video_latent_length]
base_num_frames_iter = latent_length
latent_shape = [16, base_num_frames_iter, latent_height, latent_width]
latents = self.prepare_latents(
latent_shape, dtype=torch.float32, device=self.device, generator=generator
)
latents = [latents]
if prefix_video is not None:
latents[0][:, :predix_video_latent_length] = prefix_video[0].to(torch.float32)
latents[:, :predix_video_latent_length] = prefix_video.to(torch.float32)
step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
base_num_frames_iter,
init_timesteps,
@ -298,6 +306,8 @@ class DTT2V:
if callback != None:
callback(-1, None, True, override_num_inference_steps = updated_num_steps)
if self.model.enable_teacache:
x_count = 2 if self.do_classifier_free_guidance else 1
self.model.previous_residual = [None] * x_count
time_steps_comb = []
self.model.num_steps = updated_num_steps
for i, timestep_i in enumerate(step_matrix):
@ -309,7 +319,7 @@ class DTT2V:
self.model.compute_teacache_threshold(self.model.teacache_start_step, time_steps_comb, self.model.teacache_multiplier)
del time_steps_comb
from mmgp import offload
freqs = get_rotary_pos_embed(latents[0].shape[1 :], enable_RIFLEx= False)
freqs = get_rotary_pos_embed(latents.shape[1 :], enable_RIFLEx= False)
kwrags = {
"freqs" :freqs,
"fps" : fps_embeds,
@ -320,27 +330,27 @@ class DTT2V:
}
kwrags.update(i2v_extra_kwrags)
for i, timestep_i in enumerate(tqdm(step_matrix)):
kwrags["slg_layers"] = slg_layers if int(slg_start * updated_num_steps) <= i < int(slg_end * updated_num_steps) else None
offload.set_step_no_for_lora(self.model, i)
update_mask_i = step_update_mask[i]
valid_interval_start, valid_interval_end = valid_interval[i]
timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()]
latent_model_input = latents[:, valid_interval_start:valid_interval_end, :, :].clone()
if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length:
noise_factor = 0.001 * addnoise_condition
timestep_for_noised_condition = addnoise_condition
latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = (
latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
latent_model_input[:, valid_interval_start:predix_video_latent_length] = (
latent_model_input[:, valid_interval_start:predix_video_latent_length]
* (1.0 - noise_factor)
+ torch.randn_like(
latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
latent_model_input[:, valid_interval_start:predix_video_latent_length]
)
* noise_factor
)
timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
kwrags.update({
"x" : torch.stack([latent_model_input[0]]),
"t" : timestep,
"current_step" : i,
})
@ -349,6 +359,7 @@ class DTT2V:
if True:
if not self.do_classifier_free_guidance:
noise_pred = self.model(
x=[latent_model_input],
context=[prompt_embeds],
**kwrags,
)[0]
@ -358,6 +369,7 @@ class DTT2V:
else:
if joint_pass:
noise_pred_cond, noise_pred_uncond = self.model(
x=[latent_model_input, latent_model_input],
context= [prompt_embeds, negative_prompt_embeds],
**kwrags,
)
@ -365,12 +377,16 @@ class DTT2V:
return None
else:
noise_pred_cond = self.model(
x=[latent_model_input],
x_id=0,
context=[prompt_embeds],
**kwrags,
)[0]
if self._interrupt:
return None
noise_pred_uncond = self.model(
x=[latent_model_input],
x_id=1,
context=[negative_prompt_embeds],
**kwrags,
)[0]
@ -380,18 +396,18 @@ class DTT2V:
del noise_pred_cond, noise_pred_uncond
for idx in range(valid_interval_start, valid_interval_end):
if update_mask_i[idx].item():
latents[0][:, idx] = sample_schedulers[idx].step(
latents[:, idx] = sample_schedulers[idx].step(
noise_pred[:, idx - valid_interval_start],
timestep_i[idx],
latents[0][:, idx],
latents[:, idx],
return_dict=False,
generator=generator,
)[0]
sample_schedulers_counter[idx] += 1
if callback is not None:
callback(i, latents[0].squeeze(0), False)
callback(i, latents.squeeze(0), False)
x0 = latents[0].unsqueeze(0)
x0 = latents.unsqueeze(0)
videos = [self.vae.decode(x0, tile_size= VAE_tile_size)[0]]
output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w
return output_video

View File

@ -8,7 +8,7 @@ import sys
import types
from contextlib import contextmanager
from functools import partial
import json
import numpy as np
import torch
import torch.cuda.amp as amp
@ -84,13 +84,29 @@ class WanI2V:
config.clip_checkpoint),
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
logging.info(f"Creating WanModel from {model_filename}")
logging.info(f"Creating WanModel from {model_filename[-1]}")
from mmgp import offload
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False)
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype, True)
offload.change_dtype(self.model, dtype, True)
# offload.save_model(self.model, "i2v_720p_fp16.safetensors",do_quantize=True)
# fantasy = torch.load("c:/temp/fantasy.ckpt")
# proj_model = fantasy["proj_model"]
# audio_processor = fantasy["audio_processor"]
# offload.safetensors2.torch_write_file(proj_model, "proj_model.safetensors")
# offload.safetensors2.torch_write_file(audio_processor, "audio_processor.safetensors")
# for k,v in audio_processor.items():
# audio_processor[k] = v.to(torch.bfloat16)
# with open("fantasy_config.json", "r", encoding="utf-8") as reader:
# config_text = reader.read()
# config_json = json.loads(config_text)
# offload.safetensors2.torch_write_file(audio_processor, "audio_processor_bf16.safetensors", config=config_json)
# model_filename = [model_filename, "audio_processor_bf16.safetensors"]
# model_filename = "c:/temp/i2v480p/diffusion_pytorch_model-00001-of-00007.safetensors"
# dtype = torch.float16
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) #, forcedConfigPath= "c:/temp/i2v720p/config.json")
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
# offload.change_dtype(self.model, dtype, True)
# offload.save_model(self.model, "wan2.1_image2video_720p_14B_mbf16.safetensors", config_file_path="c:/temp/i2v720p/config.json")
# offload.save_model(self.model, "wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors",do_quantize=True, config_file_path="c:/temp/i2v720p/config.json")
# offload.save_model(self.model, "wan2.1_image2video_720p_14B_quanto_mfp16_int8.safetensors",do_quantize=True, config_file_path="c:/temp/i2v720p/config.json")
# offload.save_model(self.model, "wan2.1_Fun_InP_1.3B_bf16_bis.safetensors")
self.model.eval().requires_grad_(False)
@ -102,6 +118,8 @@ class WanI2V:
input_prompt,
img,
img2 = None,
height =720,
width = 1280,
max_area=720 * 1280,
frame_num=81,
shift=5.0,
@ -119,7 +137,11 @@ class WanI2V:
slg_end = 1.0,
cfg_star_switch = True,
cfg_zero_step = 5,
add_frames_for_end_image = True
add_frames_for_end_image = True,
audio_scale=None,
audio_cfg_scale=None,
audio_proj=None,
audio_context_lens=None,
):
r"""
Generates video frames from input image and text prompt using diffusion process.
@ -167,13 +189,21 @@ class WanI2V:
frame_num +=1
lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2)
h, w = img.shape[1:]
aspect_ratio = h / w
# aspect_ratio = h / w
scale1 = min(height / h, width / w)
scale2 = min(height / h, width / w)
scale = max(scale1, scale2)
new_height = int(h * scale)
new_width = int(w * scale)
lat_h = round(
np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
new_height // self.vae_stride[1] //
self.patch_size[1] * self.patch_size[1])
lat_w = round(
np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
new_width // self.vae_stride[2] //
self.patch_size[2] * self.patch_size[2])
h = lat_h * self.vae_stride[1]
w = lat_w * self.vae_stride[2]
@ -271,98 +301,101 @@ class WanI2V:
# sample videos
latent = noise
batch_size = latent.shape[0]
batch_size = 1
freqs = get_rotary_pos_embed(latent.shape[1:], enable_RIFLEx= enable_RIFLEx)
arg_c = {
'context': [context],
'clip_fea': clip_context,
'y': [y],
'freqs' : freqs,
'pipeline' : self,
'callback' : callback
}
kwargs = { 'clip_fea': clip_context, 'y': y, 'freqs' : freqs, 'pipeline' : self, 'callback' : callback }
arg_null = {
'context': [context_null],
'clip_fea': clip_context,
'y': [y],
'freqs' : freqs,
'pipeline' : self,
'callback' : callback
}
arg_both= {
'context': [context, context_null],
'clip_fea': clip_context,
'y': [y],
'freqs' : freqs,
'pipeline' : self,
'callback' : callback
}
if audio_proj != None:
kwargs.update({
"audio_proj": audio_proj.to(self.dtype),
"audio_context_lens": audio_context_lens,
})
if self.model.enable_teacache:
self.model.previous_residual = [None] * (3 if audio_cfg_scale !=None else 2)
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
# self.model.to(self.device)
if callback != None:
callback(-1, None, True)
latent = latent.to(self.device)
for i, t in enumerate(tqdm(timesteps)):
offload.set_step_no_for_lora(self.model, i)
slg_layers_local = None
if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps):
slg_layers_local = slg_layers
latent_model_input = [latent.to(self.device)]
kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None
latent_model_input = latent
timestep = [t]
timestep = torch.stack(timestep).to(self.device)
kwargs.update({
't' :timestep,
'current_step' :i,
})
if joint_pass:
if audio_proj == None:
noise_pred_cond, noise_pred_uncond = self.model(
latent_model_input, t=timestep, current_step=i, slg_layers=slg_layers_local, **arg_both)
[latent_model_input, latent_model_input],
context=[context, context_null],
**kwargs)
else:
noise_pred_cond, noise_pred_noaudio, noise_pred_uncond = self.model(
[latent_model_input, latent_model_input, latent_model_input],
context=[context, context, context_null],
audio_scale = [audio_scale, None, None ],
**kwargs)
if self._interrupt:
return None
else:
noise_pred_cond = self.model(
latent_model_input,
t=timestep,
current_step=i,
is_uncond=False,
**arg_c,
[latent_model_input],
context=[context],
audio_scale = None if audio_scale == None else [audio_scale],
x_id=0,
**kwargs,
)[0]
if self._interrupt:
return None
if audio_proj != None:
noise_pred_noaudio = self.model(
[latent_model_input],
x_id=1,
context=[context],
**kwargs,
)[0]
if self._interrupt:
return None
noise_pred_uncond = self.model(
latent_model_input,
t=timestep,
current_step=i,
is_uncond=True,
slg_layers=slg_layers_local,
**arg_null,
[latent_model_input],
x_id=1 if audio_scale == None else 2,
context=[context_null],
**kwargs,
)[0]
if self._interrupt:
return None
del latent_model_input
# CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
noise_pred_text = noise_pred_cond
if cfg_star_switch:
positive_flat = noise_pred_text.view(batch_size, -1)
positive_flat = noise_pred_cond.view(batch_size, -1)
negative_flat = noise_pred_uncond.view(batch_size, -1)
alpha = optimized_scale(positive_flat,negative_flat)
alpha = alpha.view(batch_size, 1, 1, 1)
if (i <= cfg_zero_step):
noise_pred = noise_pred_text*0. # it would be faster not to compute noise_pred...
noise_pred = noise_pred_cond*0. # it would be faster not to compute noise_pred...
else:
noise_pred_uncond *= alpha
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond)
del noise_pred_uncond
if audio_scale == None:
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond)
else:
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_noaudio - noise_pred_uncond) + audio_cfg_scale * (noise_pred_cond - noise_pred_noaudio)
noise_pred_uncond, noise_pred_noaudio = None, None
temp_x0 = sample_scheduler.step(
noise_pred.unsqueeze(0),
t,
@ -376,9 +409,6 @@ class WanI2V:
if callback is not None:
callback(i, latent, False)
# x0 = [latent.to(self.device, dtype=self.dtype)]
x0 = [latent]
# x0 = [lat_y]

View File

@ -57,15 +57,15 @@ def sageattn_wrapper(
):
q,k, v = qkv_list
padding_length = q.shape[0] -attention_length
q = q[:attention_length, :, : ].unsqueeze(0)
k = k[:attention_length, :, : ].unsqueeze(0)
v = v[:attention_length, :, : ].unsqueeze(0)
q = q[:attention_length, :, : ]
k = k[:attention_length, :, : ]
v = v[:attention_length, :, : ]
if True:
qkv_list = [q,k,v]
del q, k ,v
o = alt_sageattn(qkv_list, tensor_layout="NHD").squeeze(0)
o = alt_sageattn(qkv_list, tensor_layout="NHD")
else:
o = sageattn(q, k, v, tensor_layout="NHD").squeeze(0)
o = sageattn(q, k, v, tensor_layout="NHD")
del q, k ,v
qkv_list.clear()
@ -107,14 +107,14 @@ def sdpa_wrapper(
attention_length
):
q,k, v = qkv_list
padding_length = q.shape[0] -attention_length
q = q[:attention_length, :].transpose(0,1).unsqueeze(0)
k = k[:attention_length, :].transpose(0,1).unsqueeze(0)
v = v[:attention_length, :].transpose(0,1).unsqueeze(0)
padding_length = q.shape[1] -attention_length
q = q[:attention_length, :].transpose(1,2)
k = k[:attention_length, :].transpose(1,2)
v = v[:attention_length, :].transpose(1,2)
o = F.scaled_dot_product_attention(
q, k, v, attn_mask=None, is_causal=False
).squeeze(0).transpose(0,1)
).transpose(1,2)
del q, k ,v
qkv_list.clear()
@ -159,36 +159,72 @@ def pay_attention(
deterministic=False,
version=None,
force_attention= None,
cross_attn= False
cross_attn= False,
k_lens = None
):
attn = offload.shared_state["_attention"] if force_attention== None else force_attention
q,k,v = qkv_list
qkv_list.clear()
# params
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
assert b==1
q = q.squeeze(0)
k = k.squeeze(0)
v = v.squeeze(0)
q = q.to(v.dtype)
k = k.to(v.dtype)
# if q_scale is not None:
# q = q * q_scale
if b > 0 and k_lens != None and attn in ("sage2", "sdpa"):
# Poor's man var len attention
chunk_sizes = []
k_sizes = []
current_size = k_lens[0]
current_count= 1
for k_len in k_lens[1:]:
if k_len == current_size:
current_count += 1
else:
chunk_sizes.append(current_count)
k_sizes.append(current_size)
current_count = 1
current_size = k_len
chunk_sizes.append(current_count)
k_sizes.append(k_len)
if len(chunk_sizes) > 1 or k_lens[0] != k.shape[1]:
q_chunks =torch.split(q, chunk_sizes)
k_chunks =torch.split(k, chunk_sizes)
v_chunks =torch.split(v, chunk_sizes)
q, k, v = None, None, None
k_chunks = [ u[:, :sz] for u, sz in zip(k_chunks, k_sizes)]
v_chunks = [ u[:, :sz] for u, sz in zip(v_chunks, k_sizes)]
o = []
for sub_q, sub_k, sub_v in zip(q_chunks, k_chunks, v_chunks):
qkv_list = [sub_q, sub_k, sub_v]
sub_q, sub_k, sub_v = None, None, None
o.append( pay_attention(qkv_list) )
q_chunks, k_chunks, v_chunks = None, None, None
o = torch.cat(o, dim = 0)
return o
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
warnings.warn(
'Flash attention 3 is not available, use flash attention 2 instead.'
)
if attn=="sage" or attn=="flash":
if b != 1 :
if k_lens == None:
k_lens = torch.tensor( [lk] * b, dtype=torch.int32).to(device=q.device, non_blocking=True)
k = torch.cat([u[:v] for u, v in zip(k, k_lens)])
v = torch.cat([u[:v] for u, v in zip(v, k_lens)])
q = q.reshape(-1, *q.shape[-2:])
q_lens = torch.tensor([lq] * b, dtype=torch.int32).to(device=q.device, non_blocking=True)
cu_seqlens_q=torch.cat([k_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32)
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32)
else:
cu_seqlens_q = torch.tensor([0, lq], dtype=torch.int32, device="cuda")
cu_seqlens_k = torch.tensor([0, lk], dtype=torch.int32, device="cuda")
q = q.squeeze(0)
k = k.squeeze(0)
v = v.squeeze(0)
# apply attention
if attn=="sage":
@ -207,7 +243,7 @@ def pay_attention(
qkv_list = [q,k,v]
del q,k,v
x = sageattn_wrapper(qkv_list, lq).unsqueeze(0)
x = sageattn_wrapper(qkv_list, lq) #.unsqueeze(0)
# else:
# layer = offload.shared_state["layer"]
# embed_sizes = offload.shared_state["embed_sizes"]
@ -267,8 +303,8 @@ def pay_attention(
elif attn=="sdpa":
qkv_list = [q, k, v]
del q, k , v
x = sdpa_wrapper( qkv_list, lq).unsqueeze(0)
del q ,k ,v
x = sdpa_wrapper( qkv_list, lq) #.unsqueeze(0)
elif attn=="flash" and version == 3:
# Note: dropout_p, window_size are not supported in FA3 now.
x = flash_attn_interface.flash_attn_varlen_func(
@ -302,59 +338,11 @@ def pay_attention(
# output
elif attn=="xformers":
x = memory_efficient_attention(
q.unsqueeze(0),
k.unsqueeze(0),
v.unsqueeze(0),
) #.unsqueeze(0)
from xformers.ops.fmha.attn_bias import BlockDiagonalPaddedKeysMask
if b != 1 and k_lens != None:
attn_mask = BlockDiagonalPaddedKeysMask.from_seqlens([lq] * b , lk, list(k_lens) )
x = memory_efficient_attention(q, k, v, attn_bias= attn_mask )
else:
x = memory_efficient_attention(q, k, v )
return x.type(out_dtype)
def attention(
q,
k,
v,
q_lens=None,
k_lens=None,
dropout_p=0.,
softmax_scale=None,
q_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
dtype=torch.bfloat16,
fa_version=None,
):
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
return pay_attention(
q=q,
k=k,
v=v,
q_lens=q_lens,
k_lens=k_lens,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
q_scale=q_scale,
causal=causal,
window_size=window_size,
deterministic=deterministic,
dtype=dtype,
version=fa_version,
)
else:
if q_lens is not None or k_lens is not None:
warnings.warn(
'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
)
attn_mask = None
q = q.transpose(1, 2).to(dtype)
k = k.transpose(1, 2).to(dtype)
v = v.transpose(1, 2).to(dtype)
out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
out = out.transpose(1, 2).contiguous()
return out

View File

@ -197,9 +197,9 @@ class WanSelfAttention(nn.Module):
del q,k
q,k = apply_rotary_emb(qklist, freqs, head_first=False)
if block_mask == None:
qkv_list = [q,k,v]
del q,k,v
if block_mask == None:
x = pay_attention(
qkv_list,
window_size=self.window_size)
@ -212,6 +212,7 @@ class WanSelfAttention(nn.Module):
.transpose(1, 2)
.contiguous()
)
del q,k,v
# if not self._flag_ar_attention:
# q = rope_apply(q, grid_sizes, freqs)
@ -241,7 +242,7 @@ class WanSelfAttention(nn.Module):
class WanT2VCrossAttention(WanSelfAttention):
def forward(self, xlist, context):
def forward(self, xlist, context, grid_sizes, *args, **kwargs):
r"""
Args:
x(Tensor): Shape [B, L1, C]
@ -262,6 +263,7 @@ class WanT2VCrossAttention(WanSelfAttention):
v = self.v(context).view(b, -1, n, d)
# compute attention
v = v.contiguous().clone()
qvl_list=[q, k, v]
del q, k, v
x = pay_attention(qvl_list, cross_attn= True)
@ -287,7 +289,7 @@ class WanI2VCrossAttention(WanSelfAttention):
# self.alpha = nn.Parameter(torch.zeros((1, )))
self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
def forward(self, xlist, context):
def forward(self, xlist, context, grid_sizes, audio_proj, audio_scale, audio_context_lens ):
r"""
Args:
x(Tensor): Shape [B, L1, C]
@ -310,6 +312,8 @@ class WanI2VCrossAttention(WanSelfAttention):
del x
self.norm_q(q)
q= q.view(b, -1, n, d)
if audio_scale != None:
audio_x = self.processor(q, audio_proj, grid_sizes[0], audio_context_lens)
k = self.k(context)
self.norm_k(k)
k = k.view(b, -1, n, d)
@ -334,6 +338,8 @@ class WanI2VCrossAttention(WanSelfAttention):
img_x = img_x.flatten(2)
x += img_x
del img_x
if audio_scale != None:
x.add_(audio_x, alpha= audio_scale)
x = self.o(x)
return x
@ -398,7 +404,10 @@ class WanAttentionBlock(nn.Module):
hints= None,
context_scale=1.0,
cam_emb= None,
block_mask = None
block_mask = None,
audio_proj= None,
audio_context_lens= None,
audio_scale=None,
):
r"""
Args:
@ -433,7 +442,7 @@ class WanAttentionBlock(nn.Module):
if cam_emb != None:
cam_emb = self.cam_encoder(cam_emb)
cam_emb = cam_emb.repeat(1, 2, 1)
cam_emb = cam_emb.unsqueeze(2).unsqueeze(3).repeat(1, 1, grid_sizes[0][1], grid_sizes[0][2], 1)
cam_emb = cam_emb.unsqueeze(2).unsqueeze(3).repeat(1, 1, grid_sizes[1], grid_sizes[2], 1)
cam_emb = rearrange(cam_emb, 'b f h w d -> b (f h w) d')
x_mod += cam_emb
@ -453,7 +462,7 @@ class WanAttentionBlock(nn.Module):
y = y.to(attention_dtype)
ylist= [y]
del y
x += self.cross_attn(ylist, context).to(dtype)
x += self.cross_attn(ylist, context, grid_sizes, audio_proj, audio_scale, audio_context_lens).to(dtype)
y = self.norm2(x)
@ -610,6 +619,7 @@ class WanModel(ModelMixin, ConfigMixin):
eps=1e-6,
recammaster = False,
inject_sample_info = False,
fantasytalking_dim = 0,
):
r"""
Initialize the diffusion model backbone.
@ -742,42 +752,47 @@ class WanModel(ModelMixin, ConfigMixin):
block.projector.weight = nn.Parameter(torch.eye(dim))
block.projector.bias = nn.Parameter(torch.zeros(dim))
if fantasytalking_dim > 0:
from fantasytalking.model import WanCrossAttentionProcessor
for block in self.blocks:
block.cross_attn.processor = WanCrossAttentionProcessor(fantasytalking_dim, dim)
def lock_layers_dtypes(self, dtype = torch.float32, force = False):
count = 0
layer_list = [self.head, self.head.head, self.patch_embedding, self.time_embedding, self.time_embedding[0], self.time_embedding[2],
def lock_layers_dtypes(self, hybrid_dtype = None, dtype = torch.float32):
layer_list = [self.head, self.head.head, self.patch_embedding]
target_dype= dtype
layer_list2 = [ self.time_embedding, self.time_embedding[0], self.time_embedding[2],
self.time_projection, self.time_projection[1]] #, self.text_embedding, self.text_embedding[0], self.text_embedding[2] ]
for block in self.blocks:
layer_list2 += [block.norm3]
if hasattr(self, "fps_embedding"):
layer_list += [self.fps_embedding, self.fps_projection, self.fps_projection[0], self.fps_projection[2]]
layer_list2 += [self.fps_embedding, self.fps_projection, self.fps_projection[0], self.fps_projection[2]]
if hasattr(self, "vace_patch_embedding"):
layer_list += [self.vace_patch_embedding]
layer_list += [self.vace_blocks[0].before_proj]
layer_list2 += [self.vace_patch_embedding]
layer_list2 += [self.vace_blocks[0].before_proj]
for block in self.vace_blocks:
layer_list += [block.after_proj, block.norm3]
layer_list2 += [block.after_proj, block.norm3]
target_dype2 = hybrid_dtype if hybrid_dtype != None else dtype
# cam master
if hasattr(self.blocks[0], "projector"):
for block in self.blocks:
layer_list += [block.projector]
for block in self.blocks:
layer_list += [block.norm3]
for layer in layer_list:
if hasattr(layer, "weight"):
if layer.weight.dtype == dtype :
count += 1
elif force:
if hasattr(layer, "weight"):
layer.weight.data = layer.weight.data.to(dtype)
if hasattr(layer, "bias"):
layer.bias.data = layer.bias.data.to(dtype)
count += 1
layer_list2 += [block.projector]
for current_layer_list, current_dtype in zip([layer_list, layer_list2], [target_dype, target_dype2]):
for layer in current_layer_list:
layer._lock_dtype = dtype
if hasattr(layer, "weight") and layer.weight.dtype != current_dtype :
layer.weight.data = layer.weight.data.to(current_dtype)
if hasattr(layer, "bias"):
layer.bias.data = layer.bias.data.to(current_dtype)
if count > 0:
self._lock_dtype = dtype
@ -788,7 +803,7 @@ class WanModel(ModelMixin, ConfigMixin):
t = torch.stack([t])
time_emb = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype) ) # b, dim
e_list.append(time_emb)
best_deltas = None
best_threshold = 0.01
best_diff = 1000
best_signed_diff = 1000
@ -798,12 +813,16 @@ class WanModel(ModelMixin, ConfigMixin):
accumulated_rel_l1_distance =0
nb_steps = 0
diff = 1000
deltas = []
for i, t in enumerate(timesteps):
skip = False
if not (i<=start_step or i== len(timesteps)):
accumulated_rel_l1_distance += abs(rescale_func(((e_list[i]-e_list[i-1]).abs().mean() / e_list[i-1].abs().mean()).cpu().item()))
if not (i<=start_step or i== len(timesteps)-1):
delta = abs(rescale_func(((e_list[i]-e_list[i-1]).abs().mean() / e_list[i-1].abs().mean()).cpu().item()))
# deltas.append(delta)
accumulated_rel_l1_distance += delta
if accumulated_rel_l1_distance < threshold:
skip = True
# deltas.append("SKIP")
else:
accumulated_rel_l1_distance = 0
if not skip:
@ -812,6 +831,7 @@ class WanModel(ModelMixin, ConfigMixin):
diff = abs(signed_diff)
if diff < best_diff:
best_threshold = threshold
best_deltas = deltas
best_diff = diff
best_signed_diff = signed_diff
elif diff > best_diff:
@ -819,6 +839,7 @@ class WanModel(ModelMixin, ConfigMixin):
threshold += 0.01
self.rel_l1_thresh = best_threshold
print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}")
# print(f"deltas:{best_deltas}")
return best_threshold
@ -834,7 +855,7 @@ class WanModel(ModelMixin, ConfigMixin):
freqs = None,
pipeline = None,
current_step = 0,
is_uncond=False,
x_id= 0,
max_steps = 0,
slg_layers=None,
callback = None,
@ -842,10 +863,13 @@ class WanModel(ModelMixin, ConfigMixin):
fps = None,
causal_block_size = 1,
causal_attention = False,
x_neg = None
audio_proj=None,
audio_context_lens=None,
audio_scale=None,
):
# dtype = self.blocks[0].self_attn.q.weight.dtype
dtype = self.patch_embedding.weight.dtype
# patch_dtype = self.patch_embedding.weight.dtype
modulation_dtype = self.time_projection[1].weight.dtype
if self.model_type == 'i2v':
assert clip_fea is not None and y is not None
@ -854,20 +878,32 @@ class WanModel(ModelMixin, ConfigMixin):
if torch.is_tensor(freqs) and freqs.device != device:
freqs = freqs.to(device)
x_list = x
joint_pass = len(x_list) > 1
is_source_x = [ x.data_ptr() == x_list[0].data_ptr() and i > 0 for i, x in enumerate(x_list) ]
last_x_idx = 0
for i, (is_source, x) in enumerate(zip(is_source_x, x_list)):
if is_source:
x_list[i] = x_list[0].clone()
last_x_idx = i
else:
# image source
if y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
x = torch.cat([x, y], dim=0)
# embeddings
x = [self.patch_embedding(u.unsqueeze(0)).to(dtype) for u in x]
if x_neg !=None:
x_neg = [self.patch_embedding(u.unsqueeze(0)).to(dtype) for u in x_neg]
x = self.patch_embedding(x.unsqueeze(0)).to(modulation_dtype)
grid_sizes = x.shape[2:]
x = x.flatten(2).transpose(1, 2)
x_list[i] = x
x, y = None, None
grid_sizes = [ list(u.shape[2:]) for u in x]
embed_sizes = grid_sizes[0]
if causal_attention : #causal_block_size > 0:
frame_num = embed_sizes[0]
height = embed_sizes[1]
width = embed_sizes[2]
block_mask = None
if causal_attention and causal_block_size > 0 and False: # NEVER WORKED
frame_num = grid_sizes[0]
height = grid_sizes[1]
width = grid_sizes[2]
block_num = frame_num // causal_block_size
range_tensor = torch.arange(block_num).view(-1, 1)
range_tensor = range_tensor.repeat(1, causal_block_size).flatten()
@ -878,30 +914,21 @@ class WanModel(ModelMixin, ConfigMixin):
block_mask = causal_mask.unsqueeze(0).unsqueeze(0)
del causal_mask
offload.shared_state["embed_sizes"] = embed_sizes
offload.shared_state["embed_sizes"] = grid_sizes
offload.shared_state["step_no"] = current_step
offload.shared_state["max_steps"] = max_steps
x = [u.flatten(2).transpose(1, 2) for u in x]
x = x[0]
if x_neg !=None:
x_neg = [u.flatten(2).transpose(1, 2) for u in x_neg]
x_neg = x_neg[0]
_flag_df = t.dim() == 2
if t.dim() == 2:
b, f = t.shape
_flag_df = True
else:
_flag_df = False
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype) # self.patch_embedding.weight.dtype)
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(modulation_dtype) # self.patch_embedding.weight.dtype)
) # b, dim
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype)
if self.inject_sample_info:
fps = torch.tensor(fps, dtype=torch.long, device=device)
fps_emb = self.fps_embedding(fps).to(dtype) # float()
fps_emb = self.fps_embedding(fps).to(e.dtype)
if _flag_df:
e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1)
else:
@ -914,29 +941,27 @@ class WanModel(ModelMixin, ConfigMixin):
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = [ torch.cat( [context_clip, u ], dim=1 ) for u in context ]
joint_pass = len(context) > 0
x_list = [x]
if joint_pass:
if x_neg == None:
x_list += [x.clone() for i in range(len(context) - 1) ]
else:
x_list += [x.clone() for i in range(len(context) - 2) ] + [x_neg]
is_uncond = False
del x
context_list = context
if audio_scale != None:
audio_scale_list = audio_scale
else:
audio_scale_list = [None] * len(x_list)
# arguments
kwargs = dict(
grid_sizes=grid_sizes,
freqs=freqs,
cam_emb = cam_emb
cam_emb = cam_emb,
block_mask = block_mask,
audio_proj=audio_proj,
audio_context_lens=audio_context_lens,
)
if vace_context == None:
hints_list = [None ] *len(x_list)
else:
# embeddings
# Vace embeddings
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c]
c = c[0]
@ -947,7 +972,7 @@ class WanModel(ModelMixin, ConfigMixin):
should_calc = True
if self.enable_teacache:
if is_uncond:
if x_id != 0:
should_calc = self.should_calc
else:
if current_step <= self.teacache_start_step or current_step == self.num_steps-1:
@ -955,11 +980,12 @@ class WanModel(ModelMixin, ConfigMixin):
self.accumulated_rel_l1_distance = 0
else:
rescale_func = np.poly1d(self.coefficients)
self.accumulated_rel_l1_distance += abs(rescale_func(((e-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()))
delta = abs(rescale_func(((e-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()))
self.accumulated_rel_l1_distance += delta
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
self.teacache_skipped_steps += 1
# print(f"Teacache Skipped Step:{self.teacache_skipped_steps}/{current_step}" )
# print(f"Teacache Skipped Step no {current_step} ({self.teacache_skipped_steps}/{current_step}), delta={delta}" )
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
@ -967,15 +993,23 @@ class WanModel(ModelMixin, ConfigMixin):
self.should_calc = should_calc
if not should_calc:
if joint_pass:
for i, x in enumerate(x_list):
x += self.previous_residual_uncond if i==1 or is_uncond else self.previous_residual_cond
x += self.previous_residual[i]
else:
x = x_list[0]
x += self.previous_residual[x_id]
x = None
else:
if self.enable_teacache:
if joint_pass or is_uncond:
self.previous_residual_uncond = None
if joint_pass or not is_uncond:
self.previous_residual_cond = None
ori_hidden_states = x_list[0].clone()
if joint_pass:
self.previous_residual = [ None ] * len(self.previous_residual)
else:
self.previous_residual[x_id] = None
ori_hidden_states = [ None ] * len(x_list)
ori_hidden_states[0] = x_list[0].clone()
for i in range(1, len(x_list)):
ori_hidden_states[i] = ori_hidden_states[0] if is_source_x[i] else x_list[i].clone()
for block_idx, block in enumerate(self.blocks):
offload.shared_state["layer"] = block_idx
@ -984,29 +1018,30 @@ class WanModel(ModelMixin, ConfigMixin):
if pipeline._interrupt:
return [None] * len(x_list)
if slg_layers is not None and block_idx in slg_layers:
if is_uncond and not joint_pass:
if (x_id != 0 or joint_pass) and slg_layers is not None and block_idx in slg_layers:
if not joint_pass:
continue
x_list[0] = block(x_list[0], context = context_list[0], e= e0, **kwargs)
else:
for i, (x, context, hints) in enumerate(zip(x_list, context_list, hints_list)):
x_list[i] = block(x, context = context, hints= hints, e= e0, **kwargs)
for i, (x, context, hints, audio_scale) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list)):
x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, e= e0, **kwargs)
del x
del context, hints
if self.enable_teacache:
if joint_pass:
self.previous_residual_cond = torch.sub(x_list[0], ori_hidden_states)
self.previous_residual_uncond = ori_hidden_states
torch.sub(x_list[1], ori_hidden_states, out=self.previous_residual_uncond)
for i, (x, ori, is_source) in enumerate(zip(x_list, ori_hidden_states, is_source_x)) :
if i == 0 or is_source and i != last_x_idx :
self.previous_residual[i] = torch.sub(x, ori)
else:
residual = ori_hidden_states # just to have a readable code
torch.sub(x_list[0], ori_hidden_states, out=residual)
if i==1 or is_uncond:
self.previous_residual_uncond = residual
self.previous_residual[i] = ori
torch.sub(x, ori, out=self.previous_residual[i])
ori_hidden_states[i] = None
x , ori = None, None
else:
self.previous_residual_cond = residual
residual = ori_hidden_states[0] # just to have a readable code
torch.sub(x_list[0], ori_hidden_states[0], out=residual)
self.previous_residual[x_id] = residual
residual, ori_hidden_states = None, None
for i, x in enumerate(x_list):
@ -1037,10 +1072,10 @@ class WanModel(ModelMixin, ConfigMixin):
c = self.out_dim
out = []
for u, v in zip(x, grid_sizes):
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
for u in x:
u = u[:math.prod(grid_sizes)].view(*grid_sizes, *self.patch_size, c)
u = torch.einsum('fhwpqrc->cfphqwr', u)
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
u = u.reshape(c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
out.append(u)
return out

View File

@ -140,7 +140,7 @@ def sageattn(
elif arch == "sm90":
return sageattn_qk_int8_pv_fp8_cuda_sm90(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp32")
elif arch == "sm120":
return sageattn_qk_int8_pv_fp8_cuda(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, qk_quant_gran="per_warp", sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32") # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120.
return sageattn_qk_int8_pv_fp8_cuda(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, qk_quant_gran="per_warp", sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32", smooth_v= True) # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120.
else:
raise ValueError(f"Unsupported CUDA architecture: {arch}")

View File

@ -78,15 +78,16 @@ class WanT2V:
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype,
device=self.device)
logging.info(f"Creating WanModel from {model_filename}")
logging.info(f"Creating WanModel from {model_filename[-1]}")
from mmgp import offload
# model_filename
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False ) #, forcedConfigPath= "e:/vace_config.json")
# offload.load_model_data(self.model, "e:/vace.safetensors")
# offload.load_model_data(self.model, "c:/temp/Phantom-Wan-1.3B.pth")
# self.model.to(torch.bfloat16)
# self.model.cpu()
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype, True)
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
offload.change_dtype(self.model, dtype, True)
# offload.save_model(self.model, "mvace.safetensors", config_file_path="e:/vace_config.json")
# offload.save_model(self.model, "phantom_1.3B.safetensors")
@ -95,7 +96,7 @@ class WanT2V:
self.sample_neg_prompt = config.sample_neg_prompt
if "Vace" in model_filename:
if "Vace" in model_filename[-1]:
self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]),
min_area=480*832,
max_area=480*832,
@ -107,7 +108,7 @@ class WanT2V:
self.adapt_vace_model()
def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0):
def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0, overlapped_latents = 0, overlap_noise = 0):
if ref_images is None:
ref_images = [None] * len(frames)
else:
@ -119,6 +120,11 @@ class WanT2V:
inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
inactive = self.vae.encode(inactive, tile_size = tile_size)
# inactive = [ t * (1.0 - noise_factor) + torch.randn_like(t ) * noise_factor for t in inactive]
# if overlapped_latents > 0:
# for t in inactive:
# t[:, :overlapped_latents ] = t[:, :overlapped_latents ] * (1.0 - noise_factor) + torch.randn_like(t[:, :overlapped_latents ] ) * noise_factor
reactive = self.vae.encode(reactive, tile_size = tile_size)
latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
@ -288,6 +294,9 @@ class WanT2V:
slg_end = 1.0,
cfg_star_switch = True,
cfg_zero_step = 5,
overlapped_latents = 0,
overlap_noise = 0,
vace = False
):
r"""
Generates video frames from text prompt using diffusion process.
@ -343,20 +352,20 @@ class WanT2V:
size = (source_video.shape[2], source_video.shape[1])
source_video = source_video.to(dtype=self.dtype , device=self.device)
source_video = source_video.permute(3, 0, 1, 2).div_(127.5).sub_(1.)
source_latents = self.vae.encode([source_video]) #.to(dtype=self.dtype, device=self.device)
source_latents = self.vae.encode([source_video])[0] #.to(dtype=self.dtype, device=self.device)
del source_video
# Process target camera (recammaster)
from wan.utils.cammmaster_tools import get_camera_embedding
cam_emb = get_camera_embedding(target_camera)
cam_emb = cam_emb.to(dtype=self.dtype, device=self.device)
if input_frames != None:
if vace :
# vace context encode
input_frames = [u.to(self.device) for u in input_frames]
input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images]
input_masks = [u.to(self.device) for u in input_masks]
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size)
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents, overlap_noise = overlap_noise )
m0 = self.vace_encode_masks(input_masks, input_ref_images)
z = self.vace_latent(z0, m0)
@ -365,10 +374,10 @@ class WanT2V:
else:
if input_ref_images != None: # Phantom Ref images
phantom = True
input_ref_images = [self.get_vae_latents(input_ref_images, self.device)]
input_ref_images_neg = [torch.zeros_like(input_ref_images[0])]
input_ref_images = self.get_vae_latents(input_ref_images, self.device)
input_ref_images_neg = torch.zeros_like(input_ref_images)
F = frame_num
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1 + (input_ref_images[0].shape[1] if input_ref_images != None else 0),
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1 + (input_ref_images.shape[1] if input_ref_images != None else 0),
size[1] // self.vae_stride[1],
size[0] // self.vae_stride[2])
@ -405,37 +414,48 @@ class WanT2V:
raise NotImplementedError("Unsupported solver.")
# sample videos
latents = noise
latents = noise[0]
del noise
batch_size =len(latents)
batch_size = 1
if target_camera != None:
shape = list(latents[0].shape[1:])
shape = list(latents.shape[1:])
shape[0] *= 2
freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
else:
freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
freqs = get_rotary_pos_embed(latents.shape[1:], enable_RIFLEx= enable_RIFLEx)
kwargs = {'freqs': freqs, 'pipeline': self, 'callback': callback}
if target_camera != None:
kwargs.update({'cam_emb': cam_emb})
if input_frames != None:
if vace:
ref_images_count = len(input_ref_images[0]) if input_ref_images != None else 0
kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale})
if overlapped_latents > 0:
z_reactive = [ zz[0:16, ref_images_count:overlapped_latents + ref_images_count].clone() for zz in z]
if self.model.enable_teacache:
x_count = 3 if phantom else 2
self.model.previous_residual = [None] * x_count
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
if callback != None:
callback(-1, None, True)
for i, t in enumerate(tqdm(timesteps)):
if vace and overlapped_latents > 0 :
# noise_factor = overlap_noise *(i/(len(timesteps)-1)) / 1000
noise_factor = overlap_noise / 1000 # * (999-t) / 999
# noise_factor = overlap_noise / 1000 # * t / 999
for zz, zz_r in zip(z, z_reactive):
zz[0:16, ref_images_count:overlapped_latents + ref_images_count] = zz_r * (1.0 - noise_factor) + torch.randn_like(zz_r ) * noise_factor
if target_camera != None:
latent_model_input = [torch.cat([u,v], dim=1) for u,v in zip(latents,source_latents )]
latent_model_input = torch.cat([latents, source_latents], dim=1)
else:
latent_model_input = latents
slg_layers_local = None
if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps):
slg_layers_local = slg_layers
kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None
timestep = [t]
offload.set_step_no_for_lora(self.model, i)
timestep = torch.stack(timestep)
@ -444,38 +464,38 @@ class WanT2V:
if joint_pass:
if phantom:
pos_it, pos_i, neg = self.model(
[torch.cat([latent[:,:-ref_latent.shape[1]], ref_latent], dim=1) for latent, ref_latent in zip(latent_model_input, input_ref_images)],
x_neg = [torch.cat([latent[:,:-ref_latent_neg.shape[1]], ref_latent_neg], dim=1) for latent, ref_latent_neg in zip(latent_model_input, input_ref_images_neg)],
[ torch.cat([latent_model_input[:,:-input_ref_images.shape[1]], input_ref_images], dim=1) ] * 2 +
[ torch.cat([latent_model_input[:,:-input_ref_images_neg.shape[1]], input_ref_images_neg], dim=1)],
context = [context, context_null, context_null], **kwargs)
else:
noise_pred_cond, noise_pred_uncond = self.model(
latent_model_input, slg_layers=slg_layers_local, context = [context, context_null], **kwargs)
[latent_model_input, latent_model_input], context = [context, context_null], **kwargs)
if self._interrupt:
return None
else:
if phantom:
pos_it = self.model(
[torch.cat([latent[:,:-ref_latent.shape[1]], ref_latent], dim=1) for latent, ref_latent in zip(latent_model_input, input_ref_images)], context = [context], **kwargs
[ torch.cat([latent_model_input[:,:-input_ref_images.shape[1]], input_ref_images], dim=1) ], x_id = 0, context = [context], **kwargs
)[0]
if self._interrupt:
return None
pos_i = self.model(
[torch.cat([latent[:,:-ref_latent.shape[1]], ref_latent], dim=1) for latent, ref_latent in zip(latent_model_input, input_ref_images)], context = [context_null],**kwargs
[ torch.cat([latent_model_input[:,:-input_ref_images.shape[1]], input_ref_images], dim=1) ], x_id = 1, context = [context_null],**kwargs
)[0]
if self._interrupt:
return None
neg = self.model(
[torch.cat([latent[:,:-ref_latent_neg.shape[1]], ref_latent_neg], dim=1) for latent, ref_latent_neg in zip(latent_model_input, input_ref_images_neg)], context = [context_null], **kwargs
[ torch.cat([latent_model_input[:,:-input_ref_images_neg.shape[1]], input_ref_images_neg], dim=1) ], x_id = 2, context = [context_null], **kwargs
)[0]
if self._interrupt:
return None
else:
noise_pred_cond = self.model(
latent_model_input, is_uncond = False, context = [context], **kwargs)[0]
[latent_model_input], x_id = 0, context = [context], **kwargs)[0]
if self._interrupt:
return None
noise_pred_uncond = self.model(
latent_model_input, is_uncond = True, slg_layers=slg_layers_local,context = [context_null], **kwargs)[0]
[latent_model_input], x_id = 1, context = [context_null], **kwargs)[0]
if self._interrupt:
return None
@ -505,21 +525,21 @@ class WanT2V:
temp_x0 = sample_scheduler.step(
noise_pred[:, :target_shape[1]].unsqueeze(0),
t,
latents[0].unsqueeze(0),
latents.unsqueeze(0),
return_dict=False,
generator=seed_g)[0]
latents = [temp_x0.squeeze(0)]
latents = temp_x0.squeeze(0)
del temp_x0
if callback is not None:
callback(i, latents[0], False)
callback(i, latents, False)
x0 = latents
x0 = [latents]
if input_frames == None:
if phantom:
# phantom post processing
x0 = [x0_[:,:-input_ref_images[0].shape[1]] for x0_ in x0]
x0 = [x0_[:,:-input_ref_images.shape[1]] for x0_ in x0]
videos = self.vae.decode(x0, VAE_tile_size)
else:
# vace post processing

325
wgp.py
View File

@ -40,7 +40,7 @@ global_queue_ref = []
AUTOSAVE_FILENAME = "queue.zip"
PROMPT_VARS_MAX = 10
target_mmgp_version = "3.4.1"
target_mmgp_version = "3.4.2"
from importlib.metadata import version
mmgp_version = version("mmgp")
if mmgp_version != target_mmgp_version:
@ -49,32 +49,30 @@ if mmgp_version != target_mmgp_version:
lock = threading.Lock()
current_task_id = None
task_id = 0
# progress_tracker = {}
# tracker_lock = threading.Lock()
# def download_ffmpeg():
# if os.name != 'nt': return
# exes = ['ffmpeg.exe', 'ffprobe.exe', 'ffplay.exe']
# if all(os.path.exists(e) for e in exes): return
# api_url = 'https://api.github.com/repos/GyanD/codexffmpeg/releases/latest'
# r = requests.get(api_url, headers={'Accept': 'application/vnd.github+json'})
# assets = r.json().get('assets', [])
# zip_asset = next((a for a in assets if 'essentials_build.zip' in a['name']), None)
# if not zip_asset: return
# zip_url = zip_asset['browser_download_url']
# zip_name = zip_asset['name']
# with requests.get(zip_url, stream=True) as resp:
# total = int(resp.headers.get('Content-Length', 0))
# with open(zip_name, 'wb') as f, tqdm(total=total, unit='B', unit_scale=True) as pbar:
# for chunk in resp.iter_content(chunk_size=8192):
# f.write(chunk)
# pbar.update(len(chunk))
# with zipfile.ZipFile(zip_name) as z:
# for f in z.namelist():
# if f.endswith(tuple(exes)) and '/bin/' in f:
# z.extract(f)
# os.rename(f, os.path.basename(f))
# os.remove(zip_name)
def download_ffmpeg():
if os.name != 'nt': return
exes = ['ffmpeg.exe', 'ffprobe.exe', 'ffplay.exe']
if all(os.path.exists(e) for e in exes): return
api_url = 'https://api.github.com/repos/GyanD/codexffmpeg/releases/latest'
r = requests.get(api_url, headers={'Accept': 'application/vnd.github+json'})
assets = r.json().get('assets', [])
zip_asset = next((a for a in assets if 'essentials_build.zip' in a['name']), None)
if not zip_asset: return
zip_url = zip_asset['browser_download_url']
zip_name = zip_asset['name']
with requests.get(zip_url, stream=True) as resp:
total = int(resp.headers.get('Content-Length', 0))
with open(zip_name, 'wb') as f, tqdm(total=total, unit='B', unit_scale=True) as pbar:
for chunk in resp.iter_content(chunk_size=8192):
f.write(chunk)
pbar.update(len(chunk))
with zipfile.ZipFile(zip_name) as z:
for f in z.namelist():
if f.endswith(tuple(exes)) and '/bin/' in f:
z.extract(f)
os.rename(f, os.path.basename(f))
os.remove(zip_name)
def format_time(seconds):
if seconds < 60:
@ -168,14 +166,14 @@ def process_prompt_and_add_tasks(state, model_choice):
resolution = inputs["resolution"]
width, height = resolution.split("x")
width, height = int(width), int(height)
if test_class_i2v(model_filename):
if "480p" in model_filename and not "Fun" in model_filename and width * height > 848*480:
gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P")
return
resolution = str(width) + "*" + str(height)
if resolution not in ['720*1280', '1280*720', '480*832', '832*480']:
gr.Info(f"Resolution {resolution} not supported by image 2 video")
return
# if test_class_i2v(model_filename):
# if "480p" in model_filename and not "Fun" in model_filename and width * height > 848*480:
# gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P")
# return
# resolution = str(width) + "*" + str(height)
# if resolution not in ['720*1280', '1280*720', '480*832', '832*480']:
# gr.Info(f"Resolution {resolution} not supported by image 2 video")
# return
if "1.3B" in model_filename and width * height > 848*480 and any( model in model_filename for model in ["image2video", "text2video"] ):
gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P")
@ -533,7 +531,7 @@ def save_queue_action(state):
task_id_s = task.get('id', f"task_{task_index}")
image_keys = ["image_start", "image_end", "image_refs"]
video_keys = ["video_guide", "video_mask", "video_source"]
video_keys = ["video_guide", "video_mask", "video_source", "audio_guide"]
for key in image_keys:
images_pil = params_copy.get(key)
@ -707,7 +705,7 @@ def load_queue_action(filepath, state, evt:gr.EventData):
params['state'] = state
image_keys = ["image_start", "image_end", "image_refs"]
video_keys = ["video_guide", "video_mask", "video_source"]
video_keys = ["video_guide", "video_mask", "video_source", "audio_guide"]
loaded_pil_images = {}
loaded_video_paths = {}
@ -925,7 +923,7 @@ def autosave_queue():
task_id_s = task.get('id', f"task_{task_index}")
image_keys = ["image_start", "image_end", "image_refs"]
video_keys = ["video_guide", "video_mask", "video_source"]
video_keys = ["video_guide", "video_mask", "video_source", "audio_guide"]
for key in image_keys:
images_pil = params_copy.get(key)
@ -1418,32 +1416,35 @@ else:
text = reader.read()
server_config = json.loads(text)
# for src_path, tgt_path in zip( ["ckpts/sky_reels2_diffusion_forcing_14B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors"], ["ckpts/sky_reels2_diffusion_forcing_540p_14B_quanto_int8.safetensors", "ckpts/sky_reels2_diffusion_forcing_540p_14B_bf16.safetensors"] ):
# if Path(src_path).is_file():
# shutil.move(src_path, tgt_path) )
# for path in ["ckpts/wan2.1_Vace_1.3B_preview_mbf16.safetensors", "sky_reels2_diffusion_forcing_1.3B_bf16.safetensors"]:
# if Path(path).is_file():
# os.remove(path)
# Deprecated models
for path in ["wan2.1_Vace_1.3B_preview_bf16.safetensors", "sky_reels2_diffusion_forcing_1.3B_bf16.safetensors","sky_reels2_diffusion_forcing_720p_14B_bf16.safetensors",
"sky_reels2_diffusion_forcing_720p_14B_quanto_int8.safetensors", "sky_reels2_diffusion_forcing_720p_14B_quanto_fp16_int8.safetensors", "wan2.1_image2video_480p_14B_bf16.safetensors", "wan2.1_image2video_480p_14B_quanto_int8.safetensors",
"wan2.1_image2video_720p_14B_quanto_int8.safetensors", "wan2.1_image2video_720p_14B_quanto_fp16_int8.safetensors", "wan2.1_image2video_720p_14B_bf16.safetensors"
]:
if Path(os.path.join("ckpts" , path)).is_file():
os.remove( os.path.join("ckpts" , path))
path= "ckpts/sky_reels2_diffusion_forcing_1.3B_bf16.safetensors"
if os.path.isfile(path) and os.path.getsize(path) > 4000000000:
os.remove(path)
transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", "ckpts/wan2.1_Vace_1.3B_preview_mbf16.safetensors",
"ckpts/wan2.1_recammaster_1.3B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_bf16.safetensors",
"ckpts/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors", "ckpts/sky_reels2_diffusion_forcing_720p_14B_bf16.safetensors","ckpts/sky_reels2_diffusion_forcing_720p_14B_quanto_int8.safetensors",
"ckpts/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors", "ckpts/sky_reels2_diffusion_forcing_720p_14B_mbf16.safetensors","ckpts/sky_reels2_diffusion_forcing_720p_14B_quanto_mbf16_int8.safetensors",
"ckpts/wan2_1_phantom_1.3B_mbf16.safetensors"]
transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors",
"ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors",
"ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_bf16.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_quanto_int8.safetensors"]
transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_mbf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_mbf16.safetensors",
"ckpts/wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors",
"ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_bf16.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_quanto_int8.safetensors",
"ckpts/wan2.1_fantasy_speaking_14B_bf16.safetensors"]
transformer_choices = transformer_choices_t2v + transformer_choices_i2v
model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "fun_inp", "recam_1.3B", "flf2v_720p", "sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "phantom_1.3B"]
def get_dependent_models(model_filename, quantization ):
if "fantasy" in model_filename:
return [get_model_filename("i2v_720p", quantization)]
else:
return []
model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "fun_inp", "recam_1.3B", "flf2v_720p", "sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "phantom_1.3B", "fantasy"]
model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B",
"i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B", "recam_1.3B": "recammaster_1.3B",
"flf2v_720p" : "FLF2V_720p", "sky_df_1.3B" : "sky_reels2_diffusion_forcing_1.3B", "sky_df_14B" : "sky_reels2_diffusion_forcing_14B",
"sky_df_720p_14B" : "sky_reels2_diffusion_forcing_720p_14B",
"phantom_1.3B" : "phantom_1.3B", }
"phantom_1.3B" : "phantom_1.3B", "fantasy" : "fantasy" }
def get_model_type(model_filename):
@ -1453,7 +1454,7 @@ def get_model_type(model_filename):
raise Exception("Unknown model:" + model_filename)
def test_class_i2v(model_filename):
return "image2video" in model_filename or "Fun_InP" in model_filename or "FLF2V" in model_filename
return "image2video" in model_filename or "Fun_InP" in model_filename or "FLF2V" in model_filename or "fantasy" in model_filename
def get_model_name(model_filename, description_container = [""]):
if "Fun" in model_filename:
@ -1491,6 +1492,10 @@ def get_model_name(model_filename, description_container = [""]):
model_name = "Wan2.1 Phantom"
model_name += " 14B" if "14B" in model_filename else " 1.3B"
description = "The Phantom model is specialized to transfer people or objects of your choice into a generated Video. It produces very nices results when used at 720p."
elif "fantasy" in model_filename:
model_name = "Wan2.1 Fantasy Speaking 720p"
model_name += " 14B" if "14B" in model_filename else " 1.3B"
description = "The Fantasy Speaking model corresponds to the original Wan image 2 video model combined with the Fantasy Speaking extension to process an audio Input."
else:
model_name = "Wan2.1 text2video"
model_name += " 14B" if "14B" in model_filename else " 1.3B"
@ -1536,6 +1541,7 @@ def get_default_settings(filename):
"repeat_generation": 1,
"multi_images_gen_type": 0,
"guidance_scale": 5.0,
"audio_guidance_scale": 5.0,
"flow_shift": get_default_flow(filename, i2v),
"negative_prompt": "",
"activated_loras": [],
@ -1719,8 +1725,9 @@ def download_models(transformer_filename, text_encoder_filename):
from huggingface_hub import hf_hub_download, snapshot_download
repoId = "DeepBeepMeep/Wan2.1"
sourceFolderList = ["xlm-roberta-large", "pose", "depth", "mask", "", ]
fileList = [ [], [],[], ["sam_vit_h_4b8939_fp16.safetensors"], ["Wan2.1_VAE.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "flownet.pkl" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ]
sourceFolderList = ["xlm-roberta-large", "pose", "depth", "mask", "wav2vec", "" ]
fileList = [ [], [],[], ["sam_vit_h_4b8939_fp16.safetensors"], ["config.json", "feature_extractor_config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"],
["Wan2.1_VAE.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "flownet.pkl" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ]
targetRoot = "ckpts/"
for sourceFolder, files in zip(sourceFolderList,fileList ):
if len(files)==0:
@ -1834,12 +1841,13 @@ def setup_loras(model_filename, transformer, lora_dir, lora_preselected_preset,
return loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset
def load_t2v_model(model_filename, value, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False):
def load_t2v_model(model_filename, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False):
cfg = WAN_CONFIGS['t2v-14B']
filename = model_filename[-1]
# cfg = WAN_CONFIGS['t2v-1.3B']
print(f"Loading '{model_filename}' model...")
if get_model_type(model_filename) in ("sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"):
print(f"Loading '{filename}' model...")
if get_model_type(filename) in ("sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"):
model_factory = wan.DTT2V
else:
model_factory = wan.WanT2V
@ -1859,9 +1867,10 @@ def load_t2v_model(model_filename, value, quantizeTransformer = False, dtype = t
return wan_model, pipe
def load_i2v_model(model_filename, value, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False):
def load_i2v_model(model_filename, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False):
print(f"Loading '{model_filename}' model...")
filename = model_filename[-1]
print(f"Loading '{filename}' model...")
cfg = WAN_CONFIGS['i2v-14B']
wan_model = wan.WanI2V(
@ -1883,7 +1892,6 @@ def load_i2v_model(model_filename, value, quantizeTransformer = False, dtype = t
def load_models(model_filename):
global transformer_filename
transformer_filename = model_filename
perc_reserved_mem_max = args.perc_reserved_mem_max
major, minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None)
@ -1892,18 +1900,26 @@ def load_models(model_filename):
default_dtype = torch.float16
else:
default_dtype = torch.float16 if args.fp16 else torch.bfloat16
model_filelist = get_dependent_models(model_filename, quantization= transformer_quantization) + [model_filename]
updated_model_filename = []
for filename in model_filelist:
if default_dtype == torch.float16 :
if "quanto" in model_filename:
model_filename = model_filename.replace("quanto_int8", "quanto_fp16_int8")
download_models(model_filename, text_encoder_filename)
if "quanto_int8" in filename:
filename = filename.replace("quanto_int8", "quanto_fp16_int8")
elif "quanto_mbf16_int8":
filename = filename.replace("quanto_mbf16_int8", "quanto_mfp16_int8")
updated_model_filename.append(filename)
download_models(filename, text_encoder_filename)
model_filelist = updated_model_filename
VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float
mixed_precision_transformer = server_config.get("mixed_precision","0") == "1"
if test_class_i2v(model_filename):
res720P = "720p" in model_filename
wan_model, pipe = load_i2v_model(model_filename, "720P" if res720P else "480P", quantizeTransformer = quantizeTransformer, dtype = default_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
transformer_filename = None
new_transformer_filename = model_filelist[-1]
if test_class_i2v(new_transformer_filename):
wan_model, pipe = load_i2v_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = default_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
else:
wan_model, pipe = load_t2v_model(model_filename, "", quantizeTransformer = quantizeTransformer, dtype = default_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
wan_model._model_file_name = model_filename
wan_model, pipe = load_t2v_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = default_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
wan_model._model_file_name = new_transformer_filename
kwargs = { "extraModelsToQuantize": None}
if profile == 2 or profile == 4:
kwargs["budgets"] = { "transformer" : 100 if preload == 0 else preload, "text_encoder" : 100, "*" : 1000 }
@ -1914,7 +1930,7 @@ def load_models(model_filename):
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = default_dtype, **kwargs)
if len(args.gpu) > 0:
torch.set_default_device(args.gpu)
transformer_filename = new_transformer_filename
return wan_model, offloadobj, pipe["transformer"]
if not "P" in preload_model_policy:
@ -2033,13 +2049,7 @@ def apply_changes( state,
preload_model_policy = server_config["preload_model_policy"]
transformer_quantization = server_config["transformer_quantization"]
transformer_types = server_config["transformer_types"]
model_filename = state["model_filename"]
model_transformer_type = get_model_type(model_filename)
if not model_transformer_type in transformer_types:
model_transformer_type = transformer_types[0] if len(transformer_types) > 0 else model_types[0]
model_filename = get_model_filename(model_transformer_type, transformer_quantization)
state["model_filename"] = model_filename
if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list"] for change in changes ):
model_choice = gr.Dropdown()
else:
@ -2249,7 +2259,9 @@ def preprocess_video(process_type, height, width, video_in, max_frames, start_fr
frame_height, frame_width, _ = frames_list[0].shape
if fit_canvas :
scale = min(height / frame_height, width / frame_width)
scale1 = min(height / frame_height, width / frame_width)
scale2 = min(height / frame_width, width / frame_height)
scale = max(scale1, scale2)
else:
scale = ((height * width ) / (frame_height * frame_width))**(1/2)
@ -2356,6 +2368,7 @@ def generate_video(
seed,
num_inference_steps,
guidance_scale,
audio_guidance_scale,
flow_shift,
embedded_guidance_scale,
repeat_generation,
@ -2375,8 +2388,10 @@ def generate_video(
video_guide,
keep_frames_video_guide,
video_mask,
audio_guide,
sliding_window_size,
sliding_window_overlap,
sliding_window_overlap_noise,
sliding_window_discard_last_frames,
remove_background_image_ref,
temporal_upsampling,
@ -2508,6 +2523,15 @@ def generate_video(
# VAE Tiling
device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576
diffusion_forcing = "diffusion_forcing" in model_filename
vace = "Vace" in model_filename
if diffusion_forcing:
fps = 24
elif audio_guide != None:
fps = 23
else:
fps = 16
joint_pass = boost ==1 #and profile != 1 and profile != 3
# TeaCache
trans.enable_teacache = tea_cache_setting > 0
@ -2517,12 +2541,10 @@ def generate_video(
trans.teacache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
if image2video:
if '480p' in model_filename:
trans.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
elif '720p' in model_filename:
if '720p' in model_filename:
trans.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
else:
raise gr.Error("Teacache not supported for this model")
trans.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
else:
if '1.3B' in model_filename:
trans.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
@ -2535,6 +2557,18 @@ def generate_video(
if "recam" in model_filename:
source_video = preprocess_video("", width=width, height=height,video_in=video_source, max_frames= video_length, start_frame = 0, fit_canvas= True)
target_camera = model_mode
audio_proj_split = None
audio_scale = None
audio_context_lens = None
if audio_guide != None:
from fantasytalking.infer import parse_audio
import librosa
duration = librosa.get_duration(path=audio_guide)
video_length = min(int(fps * duration // 4) * 4 + 5, video_length)
audio_proj_split, audio_context_lens = parse_audio(audio_guide, num_frames= video_length, fps= fps, device= processing_device )
audio_scale = 1.0
import random
if seed == None or seed <0:
seed = random.randint(0, 999999999)
@ -2551,11 +2585,11 @@ def generate_video(
extra_generation = 0
initial_total_windows = 0
max_frames_to_generate = video_length
diffusion_forcing = "diffusion_forcing" in model_filename
vace = "Vace" in model_filename
phantom = "phantom" in model_filename
if diffusion_forcing or vace:
reuse_frames = min(sliding_window_size - 4, sliding_window_overlap)
else:
reuse_frames = 0
if diffusion_forcing and source_video != None:
video_length += sliding_window_overlap
sliding_window = ("Vace" in model_filename or diffusion_forcing) and video_length > sliding_window_size
@ -2571,10 +2605,8 @@ def generate_video(
initial_total_windows = 1
first_window_video_length = video_length
fps = 24 if diffusion_forcing else 16
gen["sliding_window"] = sliding_window
while not abort:
extra_generation += gen.get("extra_orders",0)
gen["extra_orders"] = 0
@ -2594,6 +2626,7 @@ def generate_video(
guide_start_frame = 0
video_length = first_window_video_length
gen["extra_windows"] = 0
start_time = time.time()
while not abort:
if sliding_window:
prompt = prompts[window_no] if window_no < len(prompts) else prompts[-1]
@ -2642,7 +2675,7 @@ def generate_video(
if preprocess_type != None :
send_cmd("progress", progress_args)
video_guide_copy = preprocess_video(preprocess_type, width=width, height=height,video_in=video_guide, max_frames= video_length if window_no == 1 else video_length - reuse_frames, start_frame = guide_start_frame, target_fps = fps)
video_guide_copy = preprocess_video(preprocess_type, width=width, height=height,video_in=video_guide, max_frames= video_length if window_no == 1 else video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = True, target_fps = fps)
keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, max_frames_to_generate)
if len(error) > 0:
raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
@ -2678,8 +2711,7 @@ def generate_video(
trans.teacache_counter = 0
trans.num_steps = num_inference_steps
trans.teacache_skipped_steps = 0
trans.previous_residual_uncond = None
trans.previous_residual_cond = None
trans.previous_residual = None
if image2video:
samples = wan_model.generate(
@ -2687,7 +2719,9 @@ def generate_video(
image_start,
image_end if image_end != None else None,
frame_num=(video_length // 4)* 4 + 1,
max_area=MAX_AREA_CONFIGS[resolution_reformated],
# max_area=MAX_AREA_CONFIGS[resolution_reformated],
height = height,
width = width,
shift=flow_shift,
sampling_steps=num_inference_steps,
guide_scale=guidance_scale,
@ -2702,7 +2736,11 @@ def generate_video(
slg_end = slg_end_perc/100,
cfg_star_switch = cfg_star_switch,
cfg_zero_step = cfg_zero_step,
add_frames_for_end_image = "image2video" in model_filename
add_frames_for_end_image = "image2video" in model_filename,
audio_cfg_scale= audio_guidance_scale,
audio_proj= audio_proj_split,
audio_scale= audio_scale,
audio_context_lens= audio_context_lens
)
elif diffusion_forcing:
samples = wan_model.generate(
@ -2720,7 +2758,10 @@ def generate_video(
callback= callback,
VAE_tile_size = VAE_tile_size,
joint_pass = joint_pass,
addnoise_condition = 20,
slg_layers = slg_layers,
slg_start = slg_start_perc/100,
slg_end = slg_end_perc/100,
addnoise_condition = sliding_window_overlap_noise,
ar_step = model_mode, #5
causal_block_size = 5,
causal_attention = True,
@ -2751,6 +2792,9 @@ def generate_video(
slg_end = slg_end_perc/100,
cfg_star_switch = cfg_star_switch,
cfg_zero_step = cfg_zero_step,
overlapped_latents = 0 if reuse_frames == 0 or window_no == 1 else ((reuse_frames - 1) // 4 + 1),
overlap_noise = sliding_window_overlap_noise,
vace = vace
)
except Exception as e:
if temp_filename!= None and os.path.isfile(temp_filename):
@ -2782,11 +2826,11 @@ def generate_video(
print('\n'.join(tb))
send_cmd("error", new_error)
return
finally:
trans.previous_residual = None
if trans.enable_teacache:
print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{num_inference_steps}" )
trans.previous_residual_uncond = None
trans.previous_residual_cond = None
print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{trans.num_steps}" )
if samples != None:
samples = samples.to("cpu")
@ -2810,14 +2854,27 @@ def generate_video(
if discard_last_frames > 0:
sample = sample[: , :-discard_last_frames]
guide_start_frame -= discard_last_frames
if reuse_frames == 0:
pre_video_guide = sample[:,9999 :]
else:
# noise_factor = 200/ 1000
# pre_video_guide = sample[:, -reuse_frames:] * (1.0 - noise_factor) + torch.randn_like(sample[:, -reuse_frames:] ) * noise_factor
pre_video_guide = sample[:, -reuse_frames:]
if prefix_video != None:
if reuse_frames == 0:
sample = torch.cat([ prefix_video[:, :], sample], dim = 1)
else:
sample = torch.cat([ prefix_video[:, :-reuse_frames], sample], dim = 1)
prefix_video = None
if sliding_window and window_no > 1:
if reuse_frames == 0:
sample = sample[: , :]
else:
sample = sample[: , reuse_frames:]
guide_start_frame -= reuse_frames
guide_start_frame -= reuse_frames
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
if os.name == 'nt':
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50]).strip()}.mp4"
@ -2875,18 +2932,23 @@ def generate_video(
sample = torch.cat([frames_already_processed, sample], dim=1)
frames_already_processed = sample
cache_video(
tensor=sample[None],
save_file=video_path,
fps=fps,
nrow=1,
normalize=True,
value_range=(-1, 1))
if audio_guide == None:
cache_video( tensor=sample[None], save_file=video_path, fps=fps, nrow=1, normalize=True, value_range=(-1, 1))
else:
save_path_tmp = video_path[:-4] + "_tmp.mp4"
cache_video( tensor=sample[None], save_file=save_path_tmp, fps=fps, nrow=1, normalize=True, value_range=(-1, 1))
final_command = [ "ffmpeg", "-y", "-i", save_path_tmp, "-i", audio_guide, "-c:v", "libx264", "-c:a", "aac", "-shortest", "-loglevel", "warning", "-nostats", video_path, ]
import subprocess
subprocess.run(final_command, check=True)
os.remove(save_path_tmp)
end_time = time.time()
inputs = get_function_arguments(generate_video, locals())
inputs.pop("send_cmd")
inputs.pop("task_id")
configs = prepare_inputs_dict("metadata", inputs)
configs["generation_time"] = round(end_time-start_time)
metadata_choice = server_config.get("metadata_type","metadata")
if metadata_choice == "json":
with open(video_path.replace('.mp4', '.json'), 'w') as f:
@ -3113,7 +3175,7 @@ def get_latest_status(state):
prompt_no = gen["prompt_no"]
prompts_max = gen.get("prompts_max",0)
total_generation = gen.get("total_generation", 1)
repeat_no = gen["repeat_no"]
repeat_no = gen.get("repeat_no",0)
total_generation += gen.get("extra_orders", 0)
total_windows = gen.get("total_windows", 0)
total_windows += gen.get("extra_windows", 0)
@ -3456,7 +3518,7 @@ def prepare_inputs_dict(target, inputs ):
if target == "state":
return inputs
unsaved_params = ["image_start", "image_end", "image_refs", "video_guide", "video_source", "video_mask"]
unsaved_params = ["image_start", "image_end", "image_refs", "video_guide", "video_source", "video_mask", "audio_guide", "embedded_guidance_scale"]
for k in unsaved_params:
inputs.pop(k)
@ -3484,10 +3546,14 @@ def prepare_inputs_dict(target, inputs ):
inputs.pop(k)
if not "Vace" in model_filename or "diffusion_forcing" in model_filename:
unsaved_params = [ "sliding_window_size", "sliding_window_overlap", "sliding_window_discard_last_frames"]
unsaved_params = [ "sliding_window_size", "sliding_window_overlap", "sliding_window_overlap_noise", "sliding_window_discard_last_frames"]
for k in unsaved_params:
inputs.pop(k)
if not "fantasy" in model_filename:
inputs.pop("audio_guidance_scale")
if target == "metadata":
inputs = {k: v for k,v in inputs.items() if v != None }
@ -3511,6 +3577,7 @@ def save_inputs(
seed,
num_inference_steps,
guidance_scale,
audio_guidance_scale,
flow_shift,
embedded_guidance_scale,
repeat_generation,
@ -3530,8 +3597,10 @@ def save_inputs(
video_guide,
keep_frames_video_guide,
video_mask,
audio_guide,
sliding_window_size,
sliding_window_overlap,
sliding_window_overlap_noise,
sliding_window_discard_last_frames,
remove_background_image_ref,
temporal_upsampling,
@ -3834,6 +3903,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
recammaster = "recam" in model_filename
vace = "Vace" in model_filename
phantom = "phantom" in model_filename
fantasy = "fantasy" in model_filename
with gr.Column(visible= test_class_i2v(model_filename) or diffusion_forcing or recammaster) as image_prompt_column:
if diffusion_forcing:
image_prompt_type_value= ui_defaults.get("image_prompt_type","S")
@ -3939,7 +4009,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
video_mask = gr.Video(label= "Video Mask (for Inpainting or Outpaing, white pixels = Mask)", visible= "M" in video_prompt_type_value, value= ui_defaults.get("video_mask", None))
audio_guide = gr.Audio(value= ui_defaults.get("audio_guide", None), type="filepath", label="Voice to follow", show_download_button= True, visible= fantasy )
advanced_prompt = advanced_ui
prompt_vars=[]
@ -3972,12 +4042,12 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False)
wizard_variables_var = gr.Text(wizard_variables, visible = False)
with gr.Row():
if test_class_i2v(model_filename):
if test_class_i2v(model_filename) and False:
resolution = gr.Dropdown(
choices=[
# 720p
("720p", "1280x720"),
("480p", "832x480"),
("720p (same amount of pixels)", "1280x720"),
("480p (same amount of pixels)", "832x480"),
],
value=ui_defaults.get("resolution","480p"),
label="Resolution (video will have the same height / width ratio than the original image)"
@ -3989,19 +4059,21 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
("1280x720 (16:9, 720p)", "1280x720"),
("720x1280 (9:16, 720p)", "720x1280"),
("1024x1024 (4:3, 720p)", "1024x024"),
# ("832x1104 (3:4, 720p)", "832x1104"),
# ("960x960 (1:1, 720p)", "960x960"),
("832x1104 (3:4, 720p)", "832x1104"),
("1104x832 (3:4, 720p)", "1104x832"),
("960x960 (1:1, 720p)", "960x960"),
# 480p
("960x544 (16:9, 540p)", "960x544"),
("544x960 (16:9, 540p)", "544x960"),
("832x480 (16:9, 480p)", "832x480"),
("480x832 (9:16, 480p)", "480x832"),
# ("832x624 (4:3, 540p)", "832x624"),
# ("624x832 (3:4, 540p)", "624x832"),
# ("720x720 (1:1, 540p)", "720x720"),
("832x624 (4:3, 480p)", "832x624"),
("624x832 (3:4, 480p)", "624x832"),
("720x720 (1:1, 480p)", "720x720"),
("512x512 (1:1, 480p)", "512x512"),
],
value=ui_defaults.get("resolution","832x480"),
label="Resolution"
label="Max Resolution (as it maybe less depending on video width / height ratio)" if test_class_i2v(model_filename) else "Resolution"
)
with gr.Row():
if recammaster:
@ -4010,6 +4082,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
video_length = gr.Slider(17, 737, value=ui_defaults.get("video_length", 97), step=20, label="Number of frames (24 = 1s)", interactive= True)
elif vace:
video_length = gr.Slider(17, 737, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True)
elif fantasy:
video_length = gr.Slider(5, 233, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (23 = 1s)", interactive= True)
else:
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True)
with gr.Row():
@ -4029,10 +4103,11 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
choices=[
("Generate every combination of images and texts", 0),
("Match images and text prompts", 1),
], visible= True, label= "Multiple Images as Texts Prompts"
], visible= test_class_i2v(model_filename), label= "Multiple Images as Texts Prompts"
)
with gr.Row():
guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance_scale",5), step=0.5, label="Guidance Scale", visible=True)
audio_guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("audio_guidance_scale",5), step=0.5, label="Audio Guidance", visible=fantasy)
embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=False)
flow_shift = gr.Slider(0.0, 25.0, value=ui_defaults.get("flow_shift",3), step=0.1, label="Shift Scale")
with gr.Row():
@ -4099,7 +4174,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
with gr.Tab("Quality"):
with gr.Row():
gr.Markdown("<B>Experimental: Skip Layer Guidance, should improve video quality</B>")
gr.Markdown("<B>Skip Layer Guidance (improves video quality)</B>")
with gr.Row():
slg_switch = gr.Dropdown(
choices=[
@ -4148,11 +4223,13 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
if diffusion_forcing:
sliding_window_size = gr.Slider(37, 137, value=ui_defaults.get("sliding_window_size", 97), step=20, label="Sliding Window Size (recommended to keep it at 97)")
sliding_window_overlap = gr.Slider(17, 97, value=ui_defaults.get("sliding_window_overlap",17), step=20, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
sliding_window_discard_last_frames = gr.Slider(0, 12, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, visible = False)
sliding_window_overlap_noise = gr.Slider(0, 100, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect")
sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, visible = False)
else:
sliding_window_size = gr.Slider(5, 137, value=ui_defaults.get("sliding_window_size", 81), step=4, label="Sliding Window Size")
sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",17), step=4, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
sliding_window_discard_last_frames = gr.Slider(0, 12, value=ui_defaults.get("sliding_window_discard_last_frames", 4), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True)
sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",5), step=4, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
sliding_window_overlap_noise = gr.Slider(0, 100, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect")
sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 8), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True)
with gr.Tab("Miscellaneous", visible= not "recam" in model_filename):
@ -5035,7 +5112,7 @@ def create_demo():
theme = gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md")
with gr.Blocks(css=css, theme=theme, title= "Wan2GP") as main:
gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.4 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.5 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
global model_list
tab_state = gr.State({ "tab_no":0 })
@ -5076,7 +5153,7 @@ def create_demo():
if __name__ == "__main__":
atexit.register(autosave_queue)
# download_ffmpeg()
download_ffmpeg()
# threading.Thread(target=runner, daemon=True).start()
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
server_port = int(args.server_port)