Added support for fantasyspeaking model

This commit is contained in:
DeepBeepMeep 2025-05-04 00:10:40 +02:00
parent 4ecc866c7b
commit bc9121ffc6
13 changed files with 857 additions and 440 deletions

View File

@ -10,6 +10,7 @@
## 🔥 Latest News!! ## 🔥 Latest News!!
* May 5 2025: 👋 Wan 2.1GP v4.5: FantasySpeaking model, you can animate a talking head using a voice track. This works not only on people but also on objects. Also better seamless transitions between Vace sliding windows for very long videos (see recommended settings). New high quality processing features (mixed 16/32 bits calculation and 32 bitsVAE)
* April 27 2025: 👋 Wan 2.1GP v4.4: Phantom model support, very good model to transfer people or objects into video, works quite well at 720p and with the number of steps > 30 * April 27 2025: 👋 Wan 2.1GP v4.4: Phantom model support, very good model to transfer people or objects into video, works quite well at 720p and with the number of steps > 30
* April 25 2025: 👋 Wan 2.1GP v4.3: Added preview mode and support for Sky Reels v2 Diffusion Forcing for high quality "infinite length videos" (see Window Sliding section below).Note that Skyreel uses causal attention that is only supported by Sdpa attention so even if chose an other type of attention, some of the processes will use Sdpa attention. * April 25 2025: 👋 Wan 2.1GP v4.3: Added preview mode and support for Sky Reels v2 Diffusion Forcing for high quality "infinite length videos" (see Window Sliding section below).Note that Skyreel uses causal attention that is only supported by Sdpa attention so even if chose an other type of attention, some of the processes will use Sdpa attention.
@ -303,7 +304,13 @@ Vace provides on its github (https://github.com/ali-vilab/VACE/tree/main/vace/gr
There is also a guide that describes the various combination of hints (https://github.com/ali-vilab/VACE/blob/main/UserGuide.md).Good luck ! There is also a guide that describes the various combination of hints (https://github.com/ali-vilab/VACE/blob/main/UserGuide.md).Good luck !
It seems you will get better results if you turn on "Skip Layer Guidance" with its default configuration It seems you will get better results with Vace if you turn on "Skip Layer Guidance" with its default configuration.
Other recommended setttings for Vace:
- Use a long prompt description especially for the people / objects that are in the background and not in reference images. This will ensure consistency between the windows.
- Set a medium size overlap window: long enough to give the model a sense of the motion but short enough so any overlapped blurred frames do no turn the rest of the video into a blurred video
- Truncate at least the last 4 frames of the each generated window as Vace last frames tends to be blurry
### VACE and Sky Reels v2 Diffusion Forcing Slidig Window ### VACE and Sky Reels v2 Diffusion Forcing Slidig Window
With this mode (that works for the moment only with Vace and Sky Reels v2) you can merge mutiple Videos to form a very long video (up to 1 min). With this mode (that works for the moment only with Vace and Sky Reels v2) you can merge mutiple Videos to form a very long video (up to 1 min).

27
fantasytalking/infer.py Normal file
View File

@ -0,0 +1,27 @@
# Copyright Alibaba Inc. All Rights Reserved.
from transformers import Wav2Vec2Model, Wav2Vec2Processor
from .model import FantasyTalkingAudioConditionModel
from .utils import get_audio_features
def parse_audio(audio_path, num_frames, fps = 23, device = "cuda"):
fantasytalking = FantasyTalkingAudioConditionModel(None, 768, 2048).to(device)
from mmgp import offload
from accelerate import init_empty_weights
from fantasytalking.model import AudioProjModel
with init_empty_weights():
proj_model = AudioProjModel( 768, 2048)
offload.load_model_data(proj_model, "ckpts/fantasy_proj_model.safetensors")
proj_model.to(device).eval().requires_grad_(False)
wav2vec_model_dir = "ckpts/wav2vec"
wav2vec_processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_dir)
wav2vec = Wav2Vec2Model.from_pretrained(wav2vec_model_dir).to(device).eval().requires_grad_(False)
audio_wav2vec_fea = get_audio_features( wav2vec, wav2vec_processor, audio_path, fps, num_frames )
audio_proj_fea = proj_model(audio_wav2vec_fea)
pos_idx_ranges = fantasytalking.split_audio_sequence( audio_proj_fea.size(1), num_frames=num_frames )
audio_proj_split, audio_context_lens = fantasytalking.split_tensor_with_padding( audio_proj_fea, pos_idx_ranges, expand_length=4 ) # [b,21,9+8,768]
return audio_proj_split, audio_context_lens

162
fantasytalking/model.py Normal file
View File

@ -0,0 +1,162 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from wan.modules.attention import pay_attention
class AudioProjModel(nn.Module):
def __init__(self, audio_in_dim=1024, cross_attention_dim=1024):
super().__init__()
self.cross_attention_dim = cross_attention_dim
self.proj = torch.nn.Linear(audio_in_dim, cross_attention_dim, bias=False)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
def forward(self, audio_embeds):
context_tokens = self.proj(audio_embeds)
context_tokens = self.norm(context_tokens)
return context_tokens # [B,L,C]
class WanCrossAttentionProcessor(nn.Module):
def __init__(self, context_dim, hidden_dim):
super().__init__()
self.context_dim = context_dim
self.hidden_dim = hidden_dim
self.k_proj = nn.Linear(context_dim, hidden_dim, bias=False)
self.v_proj = nn.Linear(context_dim, hidden_dim, bias=False)
nn.init.zeros_(self.k_proj.weight)
nn.init.zeros_(self.v_proj.weight)
def __call__(
self,
q: torch.Tensor,
audio_proj: torch.Tensor,
latents_num_frames: int = 21,
audio_context_lens = None
) -> torch.Tensor:
"""
audio_proj: [B, 21, L3, C]
audio_context_lens: [B*21].
"""
b, l, n, d = q.shape
if len(audio_proj.shape) == 4:
audio_q = q.view(b * latents_num_frames, -1, n, d) # [b, 21, l1, n, d]
ip_key = self.k_proj(audio_proj).view(b * latents_num_frames, -1, n, d)
ip_value = self.v_proj(audio_proj).view(b * latents_num_frames, -1, n, d)
qkv_list = [audio_q, ip_key, ip_value]
del q, audio_q, ip_key, ip_value
audio_x = pay_attention(qkv_list, k_lens =audio_context_lens) #audio_context_lens
audio_x = audio_x.view(b, l, n, d)
audio_x = audio_x.flatten(2)
elif len(audio_proj.shape) == 3:
ip_key = self.k_proj(audio_proj).view(b, -1, n, d)
ip_value = self.v_proj(audio_proj).view(b, -1, n, d)
qkv_list = [q, ip_key, ip_value]
del q, ip_key, ip_value
audio_x = pay_attention(qkv_list, k_lens =audio_context_lens) #audio_context_lens
audio_x = audio_x.flatten(2)
return audio_x
class FantasyTalkingAudioConditionModel(nn.Module):
def __init__(self, wan_dit, audio_in_dim: int, audio_proj_dim: int):
super().__init__()
self.audio_in_dim = audio_in_dim
self.audio_proj_dim = audio_proj_dim
def split_audio_sequence(self, audio_proj_length, num_frames=81):
"""
Map the audio feature sequence to corresponding latent frame slices.
Args:
audio_proj_length (int): The total length of the audio feature sequence
(e.g., 173 in audio_proj[1, 173, 768]).
num_frames (int): The number of video frames in the training data (default: 81).
Returns:
list: A list of [start_idx, end_idx] pairs. Each pair represents the index range
(within the audio feature sequence) corresponding to a latent frame.
"""
# Average number of tokens per original video frame
tokens_per_frame = audio_proj_length / num_frames
# Each latent frame covers 4 video frames, and we want the center
tokens_per_latent_frame = tokens_per_frame * 4
half_tokens = int(tokens_per_latent_frame / 2)
pos_indices = []
for i in range(int((num_frames - 1) / 4) + 1):
if i == 0:
pos_indices.append(0)
else:
start_token = tokens_per_frame * ((i - 1) * 4 + 1)
end_token = tokens_per_frame * (i * 4 + 1)
center_token = int((start_token + end_token) / 2) - 1
pos_indices.append(center_token)
# Build index ranges centered around each position
pos_idx_ranges = [[idx - half_tokens, idx + half_tokens] for idx in pos_indices]
# Adjust the first range to avoid negative start index
pos_idx_ranges[0] = [
-(half_tokens * 2 - pos_idx_ranges[1][0]),
pos_idx_ranges[1][0],
]
return pos_idx_ranges
def split_tensor_with_padding(self, input_tensor, pos_idx_ranges, expand_length=0):
"""
Split the input tensor into subsequences based on index ranges, and apply right-side zero-padding
if the range exceeds the input boundaries.
Args:
input_tensor (Tensor): Input audio tensor of shape [1, L, 768].
pos_idx_ranges (list): A list of index ranges, e.g. [[-7, 1], [1, 9], ..., [165, 173]].
expand_length (int): Number of tokens to expand on both sides of each subsequence.
Returns:
sub_sequences (Tensor): A tensor of shape [1, F, L, 768], where L is the length after padding.
Each element is a padded subsequence.
k_lens (Tensor): A tensor of shape [F], representing the actual (unpadded) length of each subsequence.
Useful for ignoring padding tokens in attention masks.
"""
pos_idx_ranges = [
[idx[0] - expand_length, idx[1] + expand_length] for idx in pos_idx_ranges
]
sub_sequences = []
seq_len = input_tensor.size(1) # 173
max_valid_idx = seq_len - 1 # 172
k_lens_list = []
for start, end in pos_idx_ranges:
# Calculate the fill amount
pad_front = max(-start, 0)
pad_back = max(end - max_valid_idx, 0)
# Calculate the start and end indices of the valid part
valid_start = max(start, 0)
valid_end = min(end, max_valid_idx)
# Extract the valid part
if valid_start <= valid_end:
valid_part = input_tensor[:, valid_start : valid_end + 1, :]
else:
valid_part = input_tensor.new_zeros((1, 0, input_tensor.size(2)))
# In the sequence dimension (the 1st dimension) perform padding
padded_subseq = F.pad(
valid_part,
(0, 0, 0, pad_back + pad_front, 0, 0),
mode="constant",
value=0,
)
k_lens_list.append(padded_subseq.size(-2) - pad_back - pad_front)
sub_sequences.append(padded_subseq)
return torch.stack(sub_sequences, dim=1), torch.tensor(
k_lens_list, dtype=torch.long
)

52
fantasytalking/utils.py Normal file
View File

@ -0,0 +1,52 @@
# Copyright Alibaba Inc. All Rights Reserved.
import imageio
import librosa
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
def resize_image_by_longest_edge(image_path, target_size):
image = Image.open(image_path).convert("RGB")
width, height = image.size
scale = target_size / max(width, height)
new_size = (int(width * scale), int(height * scale))
return image.resize(new_size, Image.LANCZOS)
def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
writer = imageio.get_writer(
save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params
)
for frame in tqdm(frames, desc="Saving video"):
frame = np.array(frame)
writer.append_data(frame)
writer.close()
def get_audio_features(wav2vec, audio_processor, audio_path, fps, num_frames):
sr = 16000
audio_input, sample_rate = librosa.load(audio_path, sr=sr) # 采样率为 16kHz
start_time = 0
# end_time = (0 + (num_frames - 1) * 1) / fps
end_time = num_frames / fps
start_sample = int(start_time * sr)
end_sample = int(end_time * sr)
try:
audio_segment = audio_input[start_sample:end_sample]
except:
audio_segment = audio_input
input_values = audio_processor(
audio_segment, sampling_rate=sample_rate, return_tensors="pt"
).input_values.to("cuda")
with torch.no_grad():
fea = wav2vec(input_values).last_hidden_state
return fea

View File

@ -16,7 +16,7 @@ gradio==5.23.0
numpy>=1.23.5,<2 numpy>=1.23.5,<2
einops einops
moviepy==1.0.3 moviepy==1.0.3
mmgp==3.4.1 mmgp==3.4.2
peft==0.14.0 peft==0.14.0
mutagen mutagen
pydantic==2.10.6 pydantic==2.10.6
@ -28,4 +28,5 @@ timm
segment-anything segment-anything
omegaconf omegaconf
hydra-core hydra-core
librosa
# rembg==2.0.65 # rembg==2.0.65

View File

@ -44,6 +44,8 @@ SUPPORTED_SIZES = {
VACE_SIZE_CONFIGS = { VACE_SIZE_CONFIGS = {
'480*832': (480, 832), '480*832': (480, 832),
'832*480': (832, 480), '832*480': (832, 480),
'720*1280': (720, 1280),
'1280*720': (1280, 720),
} }
VACE_MAX_AREA_CONFIGS = { VACE_MAX_AREA_CONFIGS = {

View File

@ -56,16 +56,18 @@ class DTT2V:
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype, vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype,
device=self.device) device=self.device)
logging.info(f"Creating WanModel from {model_filename}") logging.info(f"Creating WanModel from {model_filename[-1]}")
from mmgp import offload from mmgp import offload
# model_filename = "model.safetensors" # model_filename = "model.safetensors"
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) #, forcedConfigPath="config.json") # model_filename = "c:/temp/diffusion_pytorch_model-00001-of-00006.safetensors"
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) # , forcedConfigPath="c:/temp/config _df720.json")
# offload.load_model_data(self.model, "recam.ckpt") # offload.load_model_data(self.model, "recam.ckpt")
# self.model.cpu() # self.model.cpu()
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype, True) # dtype = torch.float16
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
offload.change_dtype(self.model, dtype, True) offload.change_dtype(self.model, dtype, True)
# offload.save_model(self.model, "sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors", config_file_path="config.json") # offload.save_model(self.model, "sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors", config_file_path="config.json")
# offload.save_model(self.model, "sky_reels2_diffusion_forcing_720p_14B_quanto_xbf16_int8.safetensors", do_quantize= True, config_file_path="config.json") # offload.save_model(self.model, "sky_reels2_diffusion_forcing_720p_14B_quanto_mbf16_int8.safetensors", do_quantize= True, config_file_path="c:/temp/config _df720.json")
# offload.save_model(self.model, "rtfp16_int8.safetensors", do_quantize= "config.json") # offload.save_model(self.model, "rtfp16_int8.safetensors", do_quantize= "config.json")
self.model.eval().requires_grad_(False) self.model.eval().requires_grad_(False)
@ -200,6 +202,9 @@ class DTT2V:
fps: int = 24, fps: int = 24,
VAE_tile_size = 0, VAE_tile_size = 0,
joint_pass = False, joint_pass = False,
slg_layers = None,
slg_start = 0.0,
slg_end = 1.0,
callback = None, callback = None,
): ):
self._interrupt = False self._interrupt = False
@ -211,6 +216,7 @@ class DTT2V:
if ar_step == 0: if ar_step == 0:
causal_block_size = 1 causal_block_size = 1
causal_attention = False
i2v_extra_kwrags = {} i2v_extra_kwrags = {}
prefix_video = None prefix_video = None
@ -252,31 +258,33 @@ class DTT2V:
prefix_video = output_video.to(self.device) prefix_video = output_video.to(self.device)
else: else:
causal_block_size = 1 causal_block_size = 1
causal_attention = False
ar_step = 0 ar_step = 0
prefix_video = image prefix_video = image
prefix_video = torch.tensor(prefix_video).unsqueeze(1) # .to(image_embeds.dtype).unsqueeze(1) prefix_video = torch.tensor(prefix_video).unsqueeze(1) # .to(image_embeds.dtype).unsqueeze(1)
if prefix_video.dtype == torch.uint8: if prefix_video.dtype == torch.uint8:
prefix_video = (prefix_video.float() / (255.0 / 2.0)) - 1.0 prefix_video = (prefix_video.float() / (255.0 / 2.0)) - 1.0
prefix_video = prefix_video.to(self.device) prefix_video = prefix_video.to(self.device)
prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)] prefix_video = self.vae.encode(prefix_video.unsqueeze(0))[0] # [(c, f, h, w)]
predix_video_latent_length = prefix_video[0].shape[1] predix_video_latent_length = prefix_video.shape[1]
truncate_len = predix_video_latent_length % causal_block_size truncate_len = predix_video_latent_length % causal_block_size
if truncate_len != 0: if truncate_len != 0:
if truncate_len == predix_video_latent_length: if truncate_len == predix_video_latent_length:
causal_block_size = 1 causal_block_size = 1
causal_attention = False
ar_step = 0
else: else:
print("the length of prefix video is truncated for the casual block size alignment.") print("the length of prefix video is truncated for the casual block size alignment.")
predix_video_latent_length -= truncate_len predix_video_latent_length -= truncate_len
prefix_video[0] = prefix_video[0][:, : predix_video_latent_length] prefix_video = prefix_video[:, : predix_video_latent_length]
base_num_frames_iter = latent_length base_num_frames_iter = latent_length
latent_shape = [16, base_num_frames_iter, latent_height, latent_width] latent_shape = [16, base_num_frames_iter, latent_height, latent_width]
latents = self.prepare_latents( latents = self.prepare_latents(
latent_shape, dtype=torch.float32, device=self.device, generator=generator latent_shape, dtype=torch.float32, device=self.device, generator=generator
) )
latents = [latents]
if prefix_video is not None: if prefix_video is not None:
latents[0][:, :predix_video_latent_length] = prefix_video[0].to(torch.float32) latents[:, :predix_video_latent_length] = prefix_video.to(torch.float32)
step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
base_num_frames_iter, base_num_frames_iter,
init_timesteps, init_timesteps,
@ -298,6 +306,8 @@ class DTT2V:
if callback != None: if callback != None:
callback(-1, None, True, override_num_inference_steps = updated_num_steps) callback(-1, None, True, override_num_inference_steps = updated_num_steps)
if self.model.enable_teacache: if self.model.enable_teacache:
x_count = 2 if self.do_classifier_free_guidance else 1
self.model.previous_residual = [None] * x_count
time_steps_comb = [] time_steps_comb = []
self.model.num_steps = updated_num_steps self.model.num_steps = updated_num_steps
for i, timestep_i in enumerate(step_matrix): for i, timestep_i in enumerate(step_matrix):
@ -309,7 +319,7 @@ class DTT2V:
self.model.compute_teacache_threshold(self.model.teacache_start_step, time_steps_comb, self.model.teacache_multiplier) self.model.compute_teacache_threshold(self.model.teacache_start_step, time_steps_comb, self.model.teacache_multiplier)
del time_steps_comb del time_steps_comb
from mmgp import offload from mmgp import offload
freqs = get_rotary_pos_embed(latents[0].shape[1 :], enable_RIFLEx= False) freqs = get_rotary_pos_embed(latents.shape[1 :], enable_RIFLEx= False)
kwrags = { kwrags = {
"freqs" :freqs, "freqs" :freqs,
"fps" : fps_embeds, "fps" : fps_embeds,
@ -320,27 +330,27 @@ class DTT2V:
} }
kwrags.update(i2v_extra_kwrags) kwrags.update(i2v_extra_kwrags)
for i, timestep_i in enumerate(tqdm(step_matrix)): for i, timestep_i in enumerate(tqdm(step_matrix)):
kwrags["slg_layers"] = slg_layers if int(slg_start * updated_num_steps) <= i < int(slg_end * updated_num_steps) else None
offload.set_step_no_for_lora(self.model, i) offload.set_step_no_for_lora(self.model, i)
update_mask_i = step_update_mask[i] update_mask_i = step_update_mask[i]
valid_interval_start, valid_interval_end = valid_interval[i] valid_interval_start, valid_interval_end = valid_interval[i]
timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()] latent_model_input = latents[:, valid_interval_start:valid_interval_end, :, :].clone()
if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length:
noise_factor = 0.001 * addnoise_condition noise_factor = 0.001 * addnoise_condition
timestep_for_noised_condition = addnoise_condition timestep_for_noised_condition = addnoise_condition
latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = ( latent_model_input[:, valid_interval_start:predix_video_latent_length] = (
latent_model_input[0][:, valid_interval_start:predix_video_latent_length] latent_model_input[:, valid_interval_start:predix_video_latent_length]
* (1.0 - noise_factor) * (1.0 - noise_factor)
+ torch.randn_like( + torch.randn_like(
latent_model_input[0][:, valid_interval_start:predix_video_latent_length] latent_model_input[:, valid_interval_start:predix_video_latent_length]
) )
* noise_factor * noise_factor
) )
timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
kwrags.update({ kwrags.update({
"x" : torch.stack([latent_model_input[0]]),
"t" : timestep, "t" : timestep,
"current_step" : i, "current_step" : i,
}) })
@ -349,6 +359,7 @@ class DTT2V:
if True: if True:
if not self.do_classifier_free_guidance: if not self.do_classifier_free_guidance:
noise_pred = self.model( noise_pred = self.model(
x=[latent_model_input],
context=[prompt_embeds], context=[prompt_embeds],
**kwrags, **kwrags,
)[0] )[0]
@ -358,6 +369,7 @@ class DTT2V:
else: else:
if joint_pass: if joint_pass:
noise_pred_cond, noise_pred_uncond = self.model( noise_pred_cond, noise_pred_uncond = self.model(
x=[latent_model_input, latent_model_input],
context= [prompt_embeds, negative_prompt_embeds], context= [prompt_embeds, negative_prompt_embeds],
**kwrags, **kwrags,
) )
@ -365,12 +377,16 @@ class DTT2V:
return None return None
else: else:
noise_pred_cond = self.model( noise_pred_cond = self.model(
x=[latent_model_input],
x_id=0,
context=[prompt_embeds], context=[prompt_embeds],
**kwrags, **kwrags,
)[0] )[0]
if self._interrupt: if self._interrupt:
return None return None
noise_pred_uncond = self.model( noise_pred_uncond = self.model(
x=[latent_model_input],
x_id=1,
context=[negative_prompt_embeds], context=[negative_prompt_embeds],
**kwrags, **kwrags,
)[0] )[0]
@ -380,18 +396,18 @@ class DTT2V:
del noise_pred_cond, noise_pred_uncond del noise_pred_cond, noise_pred_uncond
for idx in range(valid_interval_start, valid_interval_end): for idx in range(valid_interval_start, valid_interval_end):
if update_mask_i[idx].item(): if update_mask_i[idx].item():
latents[0][:, idx] = sample_schedulers[idx].step( latents[:, idx] = sample_schedulers[idx].step(
noise_pred[:, idx - valid_interval_start], noise_pred[:, idx - valid_interval_start],
timestep_i[idx], timestep_i[idx],
latents[0][:, idx], latents[:, idx],
return_dict=False, return_dict=False,
generator=generator, generator=generator,
)[0] )[0]
sample_schedulers_counter[idx] += 1 sample_schedulers_counter[idx] += 1
if callback is not None: if callback is not None:
callback(i, latents[0].squeeze(0), False) callback(i, latents.squeeze(0), False)
x0 = latents[0].unsqueeze(0) x0 = latents.unsqueeze(0)
videos = [self.vae.decode(x0, tile_size= VAE_tile_size)[0]] videos = [self.vae.decode(x0, tile_size= VAE_tile_size)[0]]
output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w
return output_video return output_video

View File

@ -8,7 +8,7 @@ import sys
import types import types
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
import json
import numpy as np import numpy as np
import torch import torch
import torch.cuda.amp as amp import torch.cuda.amp as amp
@ -84,13 +84,29 @@ class WanI2V:
config.clip_checkpoint), config.clip_checkpoint),
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer)) tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
logging.info(f"Creating WanModel from {model_filename}") logging.info(f"Creating WanModel from {model_filename[-1]}")
from mmgp import offload from mmgp import offload
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) # fantasy = torch.load("c:/temp/fantasy.ckpt")
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype, True) # proj_model = fantasy["proj_model"]
offload.change_dtype(self.model, dtype, True) # audio_processor = fantasy["audio_processor"]
# offload.save_model(self.model, "i2v_720p_fp16.safetensors",do_quantize=True) # offload.safetensors2.torch_write_file(proj_model, "proj_model.safetensors")
# offload.safetensors2.torch_write_file(audio_processor, "audio_processor.safetensors")
# for k,v in audio_processor.items():
# audio_processor[k] = v.to(torch.bfloat16)
# with open("fantasy_config.json", "r", encoding="utf-8") as reader:
# config_text = reader.read()
# config_json = json.loads(config_text)
# offload.safetensors2.torch_write_file(audio_processor, "audio_processor_bf16.safetensors", config=config_json)
# model_filename = [model_filename, "audio_processor_bf16.safetensors"]
# model_filename = "c:/temp/i2v480p/diffusion_pytorch_model-00001-of-00007.safetensors"
# dtype = torch.float16
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) #, forcedConfigPath= "c:/temp/i2v720p/config.json")
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
# offload.change_dtype(self.model, dtype, True)
# offload.save_model(self.model, "wan2.1_image2video_720p_14B_mbf16.safetensors", config_file_path="c:/temp/i2v720p/config.json")
# offload.save_model(self.model, "wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors",do_quantize=True, config_file_path="c:/temp/i2v720p/config.json")
# offload.save_model(self.model, "wan2.1_image2video_720p_14B_quanto_mfp16_int8.safetensors",do_quantize=True, config_file_path="c:/temp/i2v720p/config.json")
# offload.save_model(self.model, "wan2.1_Fun_InP_1.3B_bf16_bis.safetensors") # offload.save_model(self.model, "wan2.1_Fun_InP_1.3B_bf16_bis.safetensors")
self.model.eval().requires_grad_(False) self.model.eval().requires_grad_(False)
@ -102,6 +118,8 @@ class WanI2V:
input_prompt, input_prompt,
img, img,
img2 = None, img2 = None,
height =720,
width = 1280,
max_area=720 * 1280, max_area=720 * 1280,
frame_num=81, frame_num=81,
shift=5.0, shift=5.0,
@ -119,7 +137,11 @@ class WanI2V:
slg_end = 1.0, slg_end = 1.0,
cfg_star_switch = True, cfg_star_switch = True,
cfg_zero_step = 5, cfg_zero_step = 5,
add_frames_for_end_image = True add_frames_for_end_image = True,
audio_scale=None,
audio_cfg_scale=None,
audio_proj=None,
audio_context_lens=None,
): ):
r""" r"""
Generates video frames from input image and text prompt using diffusion process. Generates video frames from input image and text prompt using diffusion process.
@ -167,13 +189,21 @@ class WanI2V:
frame_num +=1 frame_num +=1
lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2) lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2)
h, w = img.shape[1:] h, w = img.shape[1:]
aspect_ratio = h / w # aspect_ratio = h / w
scale1 = min(height / h, width / w)
scale2 = min(height / h, width / w)
scale = max(scale1, scale2)
new_height = int(h * scale)
new_width = int(w * scale)
lat_h = round( lat_h = round(
np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] // new_height // self.vae_stride[1] //
self.patch_size[1] * self.patch_size[1]) self.patch_size[1] * self.patch_size[1])
lat_w = round( lat_w = round(
np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] // new_width // self.vae_stride[2] //
self.patch_size[2] * self.patch_size[2]) self.patch_size[2] * self.patch_size[2])
h = lat_h * self.vae_stride[1] h = lat_h * self.vae_stride[1]
w = lat_w * self.vae_stride[2] w = lat_w * self.vae_stride[2]
@ -271,98 +301,101 @@ class WanI2V:
# sample videos # sample videos
latent = noise latent = noise
batch_size = latent.shape[0] batch_size = 1
freqs = get_rotary_pos_embed(latent.shape[1:], enable_RIFLEx= enable_RIFLEx) freqs = get_rotary_pos_embed(latent.shape[1:], enable_RIFLEx= enable_RIFLEx)
arg_c = { kwargs = { 'clip_fea': clip_context, 'y': y, 'freqs' : freqs, 'pipeline' : self, 'callback' : callback }
'context': [context],
'clip_fea': clip_context,
'y': [y],
'freqs' : freqs,
'pipeline' : self,
'callback' : callback
}
arg_null = { if audio_proj != None:
'context': [context_null], kwargs.update({
'clip_fea': clip_context, "audio_proj": audio_proj.to(self.dtype),
'y': [y], "audio_context_lens": audio_context_lens,
'freqs' : freqs, })
'pipeline' : self,
'callback' : callback
}
arg_both= {
'context': [context, context_null],
'clip_fea': clip_context,
'y': [y],
'freqs' : freqs,
'pipeline' : self,
'callback' : callback
}
if self.model.enable_teacache: if self.model.enable_teacache:
self.model.previous_residual = [None] * (3 if audio_cfg_scale !=None else 2)
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier) self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
# self.model.to(self.device) # self.model.to(self.device)
if callback != None: if callback != None:
callback(-1, None, True) callback(-1, None, True)
latent = latent.to(self.device)
for i, t in enumerate(tqdm(timesteps)): for i, t in enumerate(tqdm(timesteps)):
offload.set_step_no_for_lora(self.model, i) offload.set_step_no_for_lora(self.model, i)
slg_layers_local = None kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None
if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps): latent_model_input = latent
slg_layers_local = slg_layers
latent_model_input = [latent.to(self.device)]
timestep = [t] timestep = [t]
timestep = torch.stack(timestep).to(self.device) timestep = torch.stack(timestep).to(self.device)
kwargs.update({
't' :timestep,
'current_step' :i,
})
if joint_pass: if joint_pass:
if audio_proj == None:
noise_pred_cond, noise_pred_uncond = self.model( noise_pred_cond, noise_pred_uncond = self.model(
latent_model_input, t=timestep, current_step=i, slg_layers=slg_layers_local, **arg_both) [latent_model_input, latent_model_input],
context=[context, context_null],
**kwargs)
else:
noise_pred_cond, noise_pred_noaudio, noise_pred_uncond = self.model(
[latent_model_input, latent_model_input, latent_model_input],
context=[context, context, context_null],
audio_scale = [audio_scale, None, None ],
**kwargs)
if self._interrupt: if self._interrupt:
return None return None
else: else:
noise_pred_cond = self.model( noise_pred_cond = self.model(
latent_model_input, [latent_model_input],
t=timestep, context=[context],
current_step=i, audio_scale = None if audio_scale == None else [audio_scale],
is_uncond=False, x_id=0,
**arg_c, **kwargs,
)[0] )[0]
if self._interrupt: if self._interrupt:
return None return None
if audio_proj != None:
noise_pred_noaudio = self.model(
[latent_model_input],
x_id=1,
context=[context],
**kwargs,
)[0]
if self._interrupt:
return None
noise_pred_uncond = self.model( noise_pred_uncond = self.model(
latent_model_input, [latent_model_input],
t=timestep, x_id=1 if audio_scale == None else 2,
current_step=i, context=[context_null],
is_uncond=True, **kwargs,
slg_layers=slg_layers_local,
**arg_null,
)[0] )[0]
if self._interrupt: if self._interrupt:
return None return None
del latent_model_input del latent_model_input
# CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/ # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
noise_pred_text = noise_pred_cond
if cfg_star_switch: if cfg_star_switch:
positive_flat = noise_pred_text.view(batch_size, -1) positive_flat = noise_pred_cond.view(batch_size, -1)
negative_flat = noise_pred_uncond.view(batch_size, -1) negative_flat = noise_pred_uncond.view(batch_size, -1)
alpha = optimized_scale(positive_flat,negative_flat) alpha = optimized_scale(positive_flat,negative_flat)
alpha = alpha.view(batch_size, 1, 1, 1) alpha = alpha.view(batch_size, 1, 1, 1)
if (i <= cfg_zero_step): if (i <= cfg_zero_step):
noise_pred = noise_pred_text*0. # it would be faster not to compute noise_pred... noise_pred = noise_pred_cond*0. # it would be faster not to compute noise_pred...
else: else:
noise_pred_uncond *= alpha noise_pred_uncond *= alpha
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond) if audio_scale == None:
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond)
del noise_pred_uncond else:
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_noaudio - noise_pred_uncond) + audio_cfg_scale * (noise_pred_cond - noise_pred_noaudio)
noise_pred_uncond, noise_pred_noaudio = None, None
temp_x0 = sample_scheduler.step( temp_x0 = sample_scheduler.step(
noise_pred.unsqueeze(0), noise_pred.unsqueeze(0),
t, t,
@ -376,9 +409,6 @@ class WanI2V:
if callback is not None: if callback is not None:
callback(i, latent, False) callback(i, latent, False)
# x0 = [latent.to(self.device, dtype=self.dtype)]
x0 = [latent] x0 = [latent]
# x0 = [lat_y] # x0 = [lat_y]

View File

@ -57,15 +57,15 @@ def sageattn_wrapper(
): ):
q,k, v = qkv_list q,k, v = qkv_list
padding_length = q.shape[0] -attention_length padding_length = q.shape[0] -attention_length
q = q[:attention_length, :, : ].unsqueeze(0) q = q[:attention_length, :, : ]
k = k[:attention_length, :, : ].unsqueeze(0) k = k[:attention_length, :, : ]
v = v[:attention_length, :, : ].unsqueeze(0) v = v[:attention_length, :, : ]
if True: if True:
qkv_list = [q,k,v] qkv_list = [q,k,v]
del q, k ,v del q, k ,v
o = alt_sageattn(qkv_list, tensor_layout="NHD").squeeze(0) o = alt_sageattn(qkv_list, tensor_layout="NHD")
else: else:
o = sageattn(q, k, v, tensor_layout="NHD").squeeze(0) o = sageattn(q, k, v, tensor_layout="NHD")
del q, k ,v del q, k ,v
qkv_list.clear() qkv_list.clear()
@ -107,14 +107,14 @@ def sdpa_wrapper(
attention_length attention_length
): ):
q,k, v = qkv_list q,k, v = qkv_list
padding_length = q.shape[0] -attention_length padding_length = q.shape[1] -attention_length
q = q[:attention_length, :].transpose(0,1).unsqueeze(0) q = q[:attention_length, :].transpose(1,2)
k = k[:attention_length, :].transpose(0,1).unsqueeze(0) k = k[:attention_length, :].transpose(1,2)
v = v[:attention_length, :].transpose(0,1).unsqueeze(0) v = v[:attention_length, :].transpose(1,2)
o = F.scaled_dot_product_attention( o = F.scaled_dot_product_attention(
q, k, v, attn_mask=None, is_causal=False q, k, v, attn_mask=None, is_causal=False
).squeeze(0).transpose(0,1) ).transpose(1,2)
del q, k ,v del q, k ,v
qkv_list.clear() qkv_list.clear()
@ -159,36 +159,72 @@ def pay_attention(
deterministic=False, deterministic=False,
version=None, version=None,
force_attention= None, force_attention= None,
cross_attn= False cross_attn= False,
k_lens = None
): ):
attn = offload.shared_state["_attention"] if force_attention== None else force_attention attn = offload.shared_state["_attention"] if force_attention== None else force_attention
q,k,v = qkv_list q,k,v = qkv_list
qkv_list.clear() qkv_list.clear()
# params # params
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
assert b==1
q = q.squeeze(0)
k = k.squeeze(0)
v = v.squeeze(0)
q = q.to(v.dtype) q = q.to(v.dtype)
k = k.to(v.dtype) k = k.to(v.dtype)
if b > 0 and k_lens != None and attn in ("sage2", "sdpa"):
# if q_scale is not None: # Poor's man var len attention
# q = q * q_scale chunk_sizes = []
k_sizes = []
current_size = k_lens[0]
current_count= 1
for k_len in k_lens[1:]:
if k_len == current_size:
current_count += 1
else:
chunk_sizes.append(current_count)
k_sizes.append(current_size)
current_count = 1
current_size = k_len
chunk_sizes.append(current_count)
k_sizes.append(k_len)
if len(chunk_sizes) > 1 or k_lens[0] != k.shape[1]:
q_chunks =torch.split(q, chunk_sizes)
k_chunks =torch.split(k, chunk_sizes)
v_chunks =torch.split(v, chunk_sizes)
q, k, v = None, None, None
k_chunks = [ u[:, :sz] for u, sz in zip(k_chunks, k_sizes)]
v_chunks = [ u[:, :sz] for u, sz in zip(v_chunks, k_sizes)]
o = []
for sub_q, sub_k, sub_v in zip(q_chunks, k_chunks, v_chunks):
qkv_list = [sub_q, sub_k, sub_v]
sub_q, sub_k, sub_v = None, None, None
o.append( pay_attention(qkv_list) )
q_chunks, k_chunks, v_chunks = None, None, None
o = torch.cat(o, dim = 0)
return o
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE: if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
warnings.warn( warnings.warn(
'Flash attention 3 is not available, use flash attention 2 instead.' 'Flash attention 3 is not available, use flash attention 2 instead.'
) )
if attn=="sage" or attn=="flash": if attn=="sage" or attn=="flash":
if b != 1 :
if k_lens == None:
k_lens = torch.tensor( [lk] * b, dtype=torch.int32).to(device=q.device, non_blocking=True)
k = torch.cat([u[:v] for u, v in zip(k, k_lens)])
v = torch.cat([u[:v] for u, v in zip(v, k_lens)])
q = q.reshape(-1, *q.shape[-2:])
q_lens = torch.tensor([lq] * b, dtype=torch.int32).to(device=q.device, non_blocking=True)
cu_seqlens_q=torch.cat([k_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32)
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32)
else:
cu_seqlens_q = torch.tensor([0, lq], dtype=torch.int32, device="cuda") cu_seqlens_q = torch.tensor([0, lq], dtype=torch.int32, device="cuda")
cu_seqlens_k = torch.tensor([0, lk], dtype=torch.int32, device="cuda") cu_seqlens_k = torch.tensor([0, lk], dtype=torch.int32, device="cuda")
q = q.squeeze(0)
k = k.squeeze(0)
v = v.squeeze(0)
# apply attention # apply attention
if attn=="sage": if attn=="sage":
@ -207,7 +243,7 @@ def pay_attention(
qkv_list = [q,k,v] qkv_list = [q,k,v]
del q,k,v del q,k,v
x = sageattn_wrapper(qkv_list, lq).unsqueeze(0) x = sageattn_wrapper(qkv_list, lq) #.unsqueeze(0)
# else: # else:
# layer = offload.shared_state["layer"] # layer = offload.shared_state["layer"]
# embed_sizes = offload.shared_state["embed_sizes"] # embed_sizes = offload.shared_state["embed_sizes"]
@ -267,8 +303,8 @@ def pay_attention(
elif attn=="sdpa": elif attn=="sdpa":
qkv_list = [q, k, v] qkv_list = [q, k, v]
del q, k , v del q ,k ,v
x = sdpa_wrapper( qkv_list, lq).unsqueeze(0) x = sdpa_wrapper( qkv_list, lq) #.unsqueeze(0)
elif attn=="flash" and version == 3: elif attn=="flash" and version == 3:
# Note: dropout_p, window_size are not supported in FA3 now. # Note: dropout_p, window_size are not supported in FA3 now.
x = flash_attn_interface.flash_attn_varlen_func( x = flash_attn_interface.flash_attn_varlen_func(
@ -302,59 +338,11 @@ def pay_attention(
# output # output
elif attn=="xformers": elif attn=="xformers":
x = memory_efficient_attention( from xformers.ops.fmha.attn_bias import BlockDiagonalPaddedKeysMask
q.unsqueeze(0), if b != 1 and k_lens != None:
k.unsqueeze(0), attn_mask = BlockDiagonalPaddedKeysMask.from_seqlens([lq] * b , lk, list(k_lens) )
v.unsqueeze(0), x = memory_efficient_attention(q, k, v, attn_bias= attn_mask )
) #.unsqueeze(0) else:
x = memory_efficient_attention(q, k, v )
return x.type(out_dtype) return x.type(out_dtype)
def attention(
q,
k,
v,
q_lens=None,
k_lens=None,
dropout_p=0.,
softmax_scale=None,
q_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
dtype=torch.bfloat16,
fa_version=None,
):
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
return pay_attention(
q=q,
k=k,
v=v,
q_lens=q_lens,
k_lens=k_lens,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
q_scale=q_scale,
causal=causal,
window_size=window_size,
deterministic=deterministic,
dtype=dtype,
version=fa_version,
)
else:
if q_lens is not None or k_lens is not None:
warnings.warn(
'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
)
attn_mask = None
q = q.transpose(1, 2).to(dtype)
k = k.transpose(1, 2).to(dtype)
v = v.transpose(1, 2).to(dtype)
out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
out = out.transpose(1, 2).contiguous()
return out

View File

@ -197,9 +197,9 @@ class WanSelfAttention(nn.Module):
del q,k del q,k
q,k = apply_rotary_emb(qklist, freqs, head_first=False) q,k = apply_rotary_emb(qklist, freqs, head_first=False)
if block_mask == None:
qkv_list = [q,k,v] qkv_list = [q,k,v]
del q,k,v del q,k,v
if block_mask == None:
x = pay_attention( x = pay_attention(
qkv_list, qkv_list,
window_size=self.window_size) window_size=self.window_size)
@ -212,6 +212,7 @@ class WanSelfAttention(nn.Module):
.transpose(1, 2) .transpose(1, 2)
.contiguous() .contiguous()
) )
del q,k,v
# if not self._flag_ar_attention: # if not self._flag_ar_attention:
# q = rope_apply(q, grid_sizes, freqs) # q = rope_apply(q, grid_sizes, freqs)
@ -241,7 +242,7 @@ class WanSelfAttention(nn.Module):
class WanT2VCrossAttention(WanSelfAttention): class WanT2VCrossAttention(WanSelfAttention):
def forward(self, xlist, context): def forward(self, xlist, context, grid_sizes, *args, **kwargs):
r""" r"""
Args: Args:
x(Tensor): Shape [B, L1, C] x(Tensor): Shape [B, L1, C]
@ -262,6 +263,7 @@ class WanT2VCrossAttention(WanSelfAttention):
v = self.v(context).view(b, -1, n, d) v = self.v(context).view(b, -1, n, d)
# compute attention # compute attention
v = v.contiguous().clone()
qvl_list=[q, k, v] qvl_list=[q, k, v]
del q, k, v del q, k, v
x = pay_attention(qvl_list, cross_attn= True) x = pay_attention(qvl_list, cross_attn= True)
@ -287,7 +289,7 @@ class WanI2VCrossAttention(WanSelfAttention):
# self.alpha = nn.Parameter(torch.zeros((1, ))) # self.alpha = nn.Parameter(torch.zeros((1, )))
self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
def forward(self, xlist, context): def forward(self, xlist, context, grid_sizes, audio_proj, audio_scale, audio_context_lens ):
r""" r"""
Args: Args:
x(Tensor): Shape [B, L1, C] x(Tensor): Shape [B, L1, C]
@ -310,6 +312,8 @@ class WanI2VCrossAttention(WanSelfAttention):
del x del x
self.norm_q(q) self.norm_q(q)
q= q.view(b, -1, n, d) q= q.view(b, -1, n, d)
if audio_scale != None:
audio_x = self.processor(q, audio_proj, grid_sizes[0], audio_context_lens)
k = self.k(context) k = self.k(context)
self.norm_k(k) self.norm_k(k)
k = k.view(b, -1, n, d) k = k.view(b, -1, n, d)
@ -334,6 +338,8 @@ class WanI2VCrossAttention(WanSelfAttention):
img_x = img_x.flatten(2) img_x = img_x.flatten(2)
x += img_x x += img_x
del img_x del img_x
if audio_scale != None:
x.add_(audio_x, alpha= audio_scale)
x = self.o(x) x = self.o(x)
return x return x
@ -398,7 +404,10 @@ class WanAttentionBlock(nn.Module):
hints= None, hints= None,
context_scale=1.0, context_scale=1.0,
cam_emb= None, cam_emb= None,
block_mask = None block_mask = None,
audio_proj= None,
audio_context_lens= None,
audio_scale=None,
): ):
r""" r"""
Args: Args:
@ -433,7 +442,7 @@ class WanAttentionBlock(nn.Module):
if cam_emb != None: if cam_emb != None:
cam_emb = self.cam_encoder(cam_emb) cam_emb = self.cam_encoder(cam_emb)
cam_emb = cam_emb.repeat(1, 2, 1) cam_emb = cam_emb.repeat(1, 2, 1)
cam_emb = cam_emb.unsqueeze(2).unsqueeze(3).repeat(1, 1, grid_sizes[0][1], grid_sizes[0][2], 1) cam_emb = cam_emb.unsqueeze(2).unsqueeze(3).repeat(1, 1, grid_sizes[1], grid_sizes[2], 1)
cam_emb = rearrange(cam_emb, 'b f h w d -> b (f h w) d') cam_emb = rearrange(cam_emb, 'b f h w d -> b (f h w) d')
x_mod += cam_emb x_mod += cam_emb
@ -453,7 +462,7 @@ class WanAttentionBlock(nn.Module):
y = y.to(attention_dtype) y = y.to(attention_dtype)
ylist= [y] ylist= [y]
del y del y
x += self.cross_attn(ylist, context).to(dtype) x += self.cross_attn(ylist, context, grid_sizes, audio_proj, audio_scale, audio_context_lens).to(dtype)
y = self.norm2(x) y = self.norm2(x)
@ -610,6 +619,7 @@ class WanModel(ModelMixin, ConfigMixin):
eps=1e-6, eps=1e-6,
recammaster = False, recammaster = False,
inject_sample_info = False, inject_sample_info = False,
fantasytalking_dim = 0,
): ):
r""" r"""
Initialize the diffusion model backbone. Initialize the diffusion model backbone.
@ -742,42 +752,47 @@ class WanModel(ModelMixin, ConfigMixin):
block.projector.weight = nn.Parameter(torch.eye(dim)) block.projector.weight = nn.Parameter(torch.eye(dim))
block.projector.bias = nn.Parameter(torch.zeros(dim)) block.projector.bias = nn.Parameter(torch.zeros(dim))
if fantasytalking_dim > 0:
from fantasytalking.model import WanCrossAttentionProcessor
for block in self.blocks:
block.cross_attn.processor = WanCrossAttentionProcessor(fantasytalking_dim, dim)
def lock_layers_dtypes(self, dtype = torch.float32, force = False):
count = 0 def lock_layers_dtypes(self, hybrid_dtype = None, dtype = torch.float32):
layer_list = [self.head, self.head.head, self.patch_embedding, self.time_embedding, self.time_embedding[0], self.time_embedding[2], layer_list = [self.head, self.head.head, self.patch_embedding]
target_dype= dtype
layer_list2 = [ self.time_embedding, self.time_embedding[0], self.time_embedding[2],
self.time_projection, self.time_projection[1]] #, self.text_embedding, self.text_embedding[0], self.text_embedding[2] ] self.time_projection, self.time_projection[1]] #, self.text_embedding, self.text_embedding[0], self.text_embedding[2] ]
for block in self.blocks:
layer_list2 += [block.norm3]
if hasattr(self, "fps_embedding"): if hasattr(self, "fps_embedding"):
layer_list += [self.fps_embedding, self.fps_projection, self.fps_projection[0], self.fps_projection[2]] layer_list2 += [self.fps_embedding, self.fps_projection, self.fps_projection[0], self.fps_projection[2]]
if hasattr(self, "vace_patch_embedding"): if hasattr(self, "vace_patch_embedding"):
layer_list += [self.vace_patch_embedding] layer_list2 += [self.vace_patch_embedding]
layer_list += [self.vace_blocks[0].before_proj] layer_list2 += [self.vace_blocks[0].before_proj]
for block in self.vace_blocks: for block in self.vace_blocks:
layer_list += [block.after_proj, block.norm3] layer_list2 += [block.after_proj, block.norm3]
target_dype2 = hybrid_dtype if hybrid_dtype != None else dtype
# cam master # cam master
if hasattr(self.blocks[0], "projector"): if hasattr(self.blocks[0], "projector"):
for block in self.blocks: for block in self.blocks:
layer_list += [block.projector] layer_list2 += [block.projector]
for block in self.blocks:
layer_list += [block.norm3]
for layer in layer_list:
if hasattr(layer, "weight"):
if layer.weight.dtype == dtype :
count += 1
elif force:
if hasattr(layer, "weight"):
layer.weight.data = layer.weight.data.to(dtype)
if hasattr(layer, "bias"):
layer.bias.data = layer.bias.data.to(dtype)
count += 1
for current_layer_list, current_dtype in zip([layer_list, layer_list2], [target_dype, target_dype2]):
for layer in current_layer_list:
layer._lock_dtype = dtype layer._lock_dtype = dtype
if hasattr(layer, "weight") and layer.weight.dtype != current_dtype :
layer.weight.data = layer.weight.data.to(current_dtype)
if hasattr(layer, "bias"):
layer.bias.data = layer.bias.data.to(current_dtype)
if count > 0:
self._lock_dtype = dtype self._lock_dtype = dtype
@ -788,7 +803,7 @@ class WanModel(ModelMixin, ConfigMixin):
t = torch.stack([t]) t = torch.stack([t])
time_emb = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype) ) # b, dim time_emb = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype) ) # b, dim
e_list.append(time_emb) e_list.append(time_emb)
best_deltas = None
best_threshold = 0.01 best_threshold = 0.01
best_diff = 1000 best_diff = 1000
best_signed_diff = 1000 best_signed_diff = 1000
@ -798,12 +813,16 @@ class WanModel(ModelMixin, ConfigMixin):
accumulated_rel_l1_distance =0 accumulated_rel_l1_distance =0
nb_steps = 0 nb_steps = 0
diff = 1000 diff = 1000
deltas = []
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
skip = False skip = False
if not (i<=start_step or i== len(timesteps)): if not (i<=start_step or i== len(timesteps)-1):
accumulated_rel_l1_distance += abs(rescale_func(((e_list[i]-e_list[i-1]).abs().mean() / e_list[i-1].abs().mean()).cpu().item())) delta = abs(rescale_func(((e_list[i]-e_list[i-1]).abs().mean() / e_list[i-1].abs().mean()).cpu().item()))
# deltas.append(delta)
accumulated_rel_l1_distance += delta
if accumulated_rel_l1_distance < threshold: if accumulated_rel_l1_distance < threshold:
skip = True skip = True
# deltas.append("SKIP")
else: else:
accumulated_rel_l1_distance = 0 accumulated_rel_l1_distance = 0
if not skip: if not skip:
@ -812,6 +831,7 @@ class WanModel(ModelMixin, ConfigMixin):
diff = abs(signed_diff) diff = abs(signed_diff)
if diff < best_diff: if diff < best_diff:
best_threshold = threshold best_threshold = threshold
best_deltas = deltas
best_diff = diff best_diff = diff
best_signed_diff = signed_diff best_signed_diff = signed_diff
elif diff > best_diff: elif diff > best_diff:
@ -819,6 +839,7 @@ class WanModel(ModelMixin, ConfigMixin):
threshold += 0.01 threshold += 0.01
self.rel_l1_thresh = best_threshold self.rel_l1_thresh = best_threshold
print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}") print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}")
# print(f"deltas:{best_deltas}")
return best_threshold return best_threshold
@ -834,7 +855,7 @@ class WanModel(ModelMixin, ConfigMixin):
freqs = None, freqs = None,
pipeline = None, pipeline = None,
current_step = 0, current_step = 0,
is_uncond=False, x_id= 0,
max_steps = 0, max_steps = 0,
slg_layers=None, slg_layers=None,
callback = None, callback = None,
@ -842,10 +863,13 @@ class WanModel(ModelMixin, ConfigMixin):
fps = None, fps = None,
causal_block_size = 1, causal_block_size = 1,
causal_attention = False, causal_attention = False,
x_neg = None audio_proj=None,
audio_context_lens=None,
audio_scale=None,
): ):
# dtype = self.blocks[0].self_attn.q.weight.dtype # patch_dtype = self.patch_embedding.weight.dtype
dtype = self.patch_embedding.weight.dtype modulation_dtype = self.time_projection[1].weight.dtype
if self.model_type == 'i2v': if self.model_type == 'i2v':
assert clip_fea is not None and y is not None assert clip_fea is not None and y is not None
@ -854,20 +878,32 @@ class WanModel(ModelMixin, ConfigMixin):
if torch.is_tensor(freqs) and freqs.device != device: if torch.is_tensor(freqs) and freqs.device != device:
freqs = freqs.to(device) freqs = freqs.to(device)
x_list = x
joint_pass = len(x_list) > 1
is_source_x = [ x.data_ptr() == x_list[0].data_ptr() and i > 0 for i, x in enumerate(x_list) ]
last_x_idx = 0
for i, (is_source, x) in enumerate(zip(is_source_x, x_list)):
if is_source:
x_list[i] = x_list[0].clone()
last_x_idx = i
else:
# image source
if y is not None: if y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] x = torch.cat([x, y], dim=0)
# embeddings # embeddings
x = [self.patch_embedding(u.unsqueeze(0)).to(dtype) for u in x] x = self.patch_embedding(x.unsqueeze(0)).to(modulation_dtype)
if x_neg !=None: grid_sizes = x.shape[2:]
x_neg = [self.patch_embedding(u.unsqueeze(0)).to(dtype) for u in x_neg] x = x.flatten(2).transpose(1, 2)
x_list[i] = x
x, y = None, None
grid_sizes = [ list(u.shape[2:]) for u in x]
embed_sizes = grid_sizes[0] block_mask = None
if causal_attention : #causal_block_size > 0: if causal_attention and causal_block_size > 0 and False: # NEVER WORKED
frame_num = embed_sizes[0] frame_num = grid_sizes[0]
height = embed_sizes[1] height = grid_sizes[1]
width = embed_sizes[2] width = grid_sizes[2]
block_num = frame_num // causal_block_size block_num = frame_num // causal_block_size
range_tensor = torch.arange(block_num).view(-1, 1) range_tensor = torch.arange(block_num).view(-1, 1)
range_tensor = range_tensor.repeat(1, causal_block_size).flatten() range_tensor = range_tensor.repeat(1, causal_block_size).flatten()
@ -878,30 +914,21 @@ class WanModel(ModelMixin, ConfigMixin):
block_mask = causal_mask.unsqueeze(0).unsqueeze(0) block_mask = causal_mask.unsqueeze(0).unsqueeze(0)
del causal_mask del causal_mask
offload.shared_state["embed_sizes"] = embed_sizes offload.shared_state["embed_sizes"] = grid_sizes
offload.shared_state["step_no"] = current_step offload.shared_state["step_no"] = current_step
offload.shared_state["max_steps"] = max_steps offload.shared_state["max_steps"] = max_steps
x = [u.flatten(2).transpose(1, 2) for u in x] _flag_df = t.dim() == 2
x = x[0]
if x_neg !=None:
x_neg = [u.flatten(2).transpose(1, 2) for u in x_neg]
x_neg = x_neg[0]
if t.dim() == 2:
b, f = t.shape
_flag_df = True
else:
_flag_df = False
e = self.time_embedding( e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype) # self.patch_embedding.weight.dtype) sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(modulation_dtype) # self.patch_embedding.weight.dtype)
) # b, dim ) # b, dim
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype) e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype)
if self.inject_sample_info: if self.inject_sample_info:
fps = torch.tensor(fps, dtype=torch.long, device=device) fps = torch.tensor(fps, dtype=torch.long, device=device)
fps_emb = self.fps_embedding(fps).to(dtype) # float() fps_emb = self.fps_embedding(fps).to(e.dtype)
if _flag_df: if _flag_df:
e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1) e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1)
else: else:
@ -914,29 +941,27 @@ class WanModel(ModelMixin, ConfigMixin):
context_clip = self.img_emb(clip_fea) # bs x 257 x dim context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = [ torch.cat( [context_clip, u ], dim=1 ) for u in context ] context = [ torch.cat( [context_clip, u ], dim=1 ) for u in context ]
joint_pass = len(context) > 0
x_list = [x]
if joint_pass:
if x_neg == None:
x_list += [x.clone() for i in range(len(context) - 1) ]
else:
x_list += [x.clone() for i in range(len(context) - 2) ] + [x_neg]
is_uncond = False
del x
context_list = context context_list = context
if audio_scale != None:
audio_scale_list = audio_scale
else:
audio_scale_list = [None] * len(x_list)
# arguments # arguments
kwargs = dict( kwargs = dict(
grid_sizes=grid_sizes, grid_sizes=grid_sizes,
freqs=freqs, freqs=freqs,
cam_emb = cam_emb cam_emb = cam_emb,
block_mask = block_mask,
audio_proj=audio_proj,
audio_context_lens=audio_context_lens,
) )
if vace_context == None: if vace_context == None:
hints_list = [None ] *len(x_list) hints_list = [None ] *len(x_list)
else: else:
# embeddings # Vace embeddings
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c] c = [u.flatten(2).transpose(1, 2) for u in c]
c = c[0] c = c[0]
@ -947,7 +972,7 @@ class WanModel(ModelMixin, ConfigMixin):
should_calc = True should_calc = True
if self.enable_teacache: if self.enable_teacache:
if is_uncond: if x_id != 0:
should_calc = self.should_calc should_calc = self.should_calc
else: else:
if current_step <= self.teacache_start_step or current_step == self.num_steps-1: if current_step <= self.teacache_start_step or current_step == self.num_steps-1:
@ -955,11 +980,12 @@ class WanModel(ModelMixin, ConfigMixin):
self.accumulated_rel_l1_distance = 0 self.accumulated_rel_l1_distance = 0
else: else:
rescale_func = np.poly1d(self.coefficients) rescale_func = np.poly1d(self.coefficients)
self.accumulated_rel_l1_distance += abs(rescale_func(((e-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())) delta = abs(rescale_func(((e-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()))
self.accumulated_rel_l1_distance += delta
if self.accumulated_rel_l1_distance < self.rel_l1_thresh: if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False should_calc = False
self.teacache_skipped_steps += 1 self.teacache_skipped_steps += 1
# print(f"Teacache Skipped Step:{self.teacache_skipped_steps}/{current_step}" ) # print(f"Teacache Skipped Step no {current_step} ({self.teacache_skipped_steps}/{current_step}), delta={delta}" )
else: else:
should_calc = True should_calc = True
self.accumulated_rel_l1_distance = 0 self.accumulated_rel_l1_distance = 0
@ -967,15 +993,23 @@ class WanModel(ModelMixin, ConfigMixin):
self.should_calc = should_calc self.should_calc = should_calc
if not should_calc: if not should_calc:
if joint_pass:
for i, x in enumerate(x_list): for i, x in enumerate(x_list):
x += self.previous_residual_uncond if i==1 or is_uncond else self.previous_residual_cond x += self.previous_residual[i]
else:
x = x_list[0]
x += self.previous_residual[x_id]
x = None
else: else:
if self.enable_teacache: if self.enable_teacache:
if joint_pass or is_uncond: if joint_pass:
self.previous_residual_uncond = None self.previous_residual = [ None ] * len(self.previous_residual)
if joint_pass or not is_uncond: else:
self.previous_residual_cond = None self.previous_residual[x_id] = None
ori_hidden_states = x_list[0].clone() ori_hidden_states = [ None ] * len(x_list)
ori_hidden_states[0] = x_list[0].clone()
for i in range(1, len(x_list)):
ori_hidden_states[i] = ori_hidden_states[0] if is_source_x[i] else x_list[i].clone()
for block_idx, block in enumerate(self.blocks): for block_idx, block in enumerate(self.blocks):
offload.shared_state["layer"] = block_idx offload.shared_state["layer"] = block_idx
@ -984,29 +1018,30 @@ class WanModel(ModelMixin, ConfigMixin):
if pipeline._interrupt: if pipeline._interrupt:
return [None] * len(x_list) return [None] * len(x_list)
if slg_layers is not None and block_idx in slg_layers: if (x_id != 0 or joint_pass) and slg_layers is not None and block_idx in slg_layers:
if is_uncond and not joint_pass: if not joint_pass:
continue continue
x_list[0] = block(x_list[0], context = context_list[0], e= e0, **kwargs) x_list[0] = block(x_list[0], context = context_list[0], e= e0, **kwargs)
else: else:
for i, (x, context, hints) in enumerate(zip(x_list, context_list, hints_list)): for i, (x, context, hints, audio_scale) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list)):
x_list[i] = block(x, context = context, hints= hints, e= e0, **kwargs) x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, e= e0, **kwargs)
del x del x
del context, hints del context, hints
if self.enable_teacache: if self.enable_teacache:
if joint_pass: if joint_pass:
self.previous_residual_cond = torch.sub(x_list[0], ori_hidden_states) for i, (x, ori, is_source) in enumerate(zip(x_list, ori_hidden_states, is_source_x)) :
self.previous_residual_uncond = ori_hidden_states if i == 0 or is_source and i != last_x_idx :
torch.sub(x_list[1], ori_hidden_states, out=self.previous_residual_uncond) self.previous_residual[i] = torch.sub(x, ori)
else: else:
residual = ori_hidden_states # just to have a readable code self.previous_residual[i] = ori
torch.sub(x_list[0], ori_hidden_states, out=residual) torch.sub(x, ori, out=self.previous_residual[i])
if i==1 or is_uncond: ori_hidden_states[i] = None
self.previous_residual_uncond = residual x , ori = None, None
else: else:
self.previous_residual_cond = residual residual = ori_hidden_states[0] # just to have a readable code
torch.sub(x_list[0], ori_hidden_states[0], out=residual)
self.previous_residual[x_id] = residual
residual, ori_hidden_states = None, None residual, ori_hidden_states = None, None
for i, x in enumerate(x_list): for i, x in enumerate(x_list):
@ -1037,10 +1072,10 @@ class WanModel(ModelMixin, ConfigMixin):
c = self.out_dim c = self.out_dim
out = [] out = []
for u, v in zip(x, grid_sizes): for u in x:
u = u[:math.prod(v)].view(*v, *self.patch_size, c) u = u[:math.prod(grid_sizes)].view(*grid_sizes, *self.patch_size, c)
u = torch.einsum('fhwpqrc->cfphqwr', u) u = torch.einsum('fhwpqrc->cfphqwr', u)
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) u = u.reshape(c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
out.append(u) out.append(u)
return out return out

View File

@ -140,7 +140,7 @@ def sageattn(
elif arch == "sm90": elif arch == "sm90":
return sageattn_qk_int8_pv_fp8_cuda_sm90(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp32") return sageattn_qk_int8_pv_fp8_cuda_sm90(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp32")
elif arch == "sm120": elif arch == "sm120":
return sageattn_qk_int8_pv_fp8_cuda(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, qk_quant_gran="per_warp", sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32") # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120. return sageattn_qk_int8_pv_fp8_cuda(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, qk_quant_gran="per_warp", sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32", smooth_v= True) # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120.
else: else:
raise ValueError(f"Unsupported CUDA architecture: {arch}") raise ValueError(f"Unsupported CUDA architecture: {arch}")

View File

@ -78,15 +78,16 @@ class WanT2V:
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype, vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype,
device=self.device) device=self.device)
logging.info(f"Creating WanModel from {model_filename}") logging.info(f"Creating WanModel from {model_filename[-1]}")
from mmgp import offload from mmgp import offload
# model_filename # model_filename
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False ) #, forcedConfigPath= "e:/vace_config.json") self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False ) #, forcedConfigPath= "e:/vace_config.json")
# offload.load_model_data(self.model, "e:/vace.safetensors") # offload.load_model_data(self.model, "e:/vace.safetensors")
# offload.load_model_data(self.model, "c:/temp/Phantom-Wan-1.3B.pth") # offload.load_model_data(self.model, "c:/temp/Phantom-Wan-1.3B.pth")
# self.model.to(torch.bfloat16) # self.model.to(torch.bfloat16)
# self.model.cpu() # self.model.cpu()
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype, True) self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
offload.change_dtype(self.model, dtype, True) offload.change_dtype(self.model, dtype, True)
# offload.save_model(self.model, "mvace.safetensors", config_file_path="e:/vace_config.json") # offload.save_model(self.model, "mvace.safetensors", config_file_path="e:/vace_config.json")
# offload.save_model(self.model, "phantom_1.3B.safetensors") # offload.save_model(self.model, "phantom_1.3B.safetensors")
@ -95,7 +96,7 @@ class WanT2V:
self.sample_neg_prompt = config.sample_neg_prompt self.sample_neg_prompt = config.sample_neg_prompt
if "Vace" in model_filename: if "Vace" in model_filename[-1]:
self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]), self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]),
min_area=480*832, min_area=480*832,
max_area=480*832, max_area=480*832,
@ -107,7 +108,7 @@ class WanT2V:
self.adapt_vace_model() self.adapt_vace_model()
def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0): def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0, overlapped_latents = 0, overlap_noise = 0):
if ref_images is None: if ref_images is None:
ref_images = [None] * len(frames) ref_images = [None] * len(frames)
else: else:
@ -119,6 +120,11 @@ class WanT2V:
inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)] inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)] reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
inactive = self.vae.encode(inactive, tile_size = tile_size) inactive = self.vae.encode(inactive, tile_size = tile_size)
# inactive = [ t * (1.0 - noise_factor) + torch.randn_like(t ) * noise_factor for t in inactive]
# if overlapped_latents > 0:
# for t in inactive:
# t[:, :overlapped_latents ] = t[:, :overlapped_latents ] * (1.0 - noise_factor) + torch.randn_like(t[:, :overlapped_latents ] ) * noise_factor
reactive = self.vae.encode(reactive, tile_size = tile_size) reactive = self.vae.encode(reactive, tile_size = tile_size)
latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)] latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
@ -288,6 +294,9 @@ class WanT2V:
slg_end = 1.0, slg_end = 1.0,
cfg_star_switch = True, cfg_star_switch = True,
cfg_zero_step = 5, cfg_zero_step = 5,
overlapped_latents = 0,
overlap_noise = 0,
vace = False
): ):
r""" r"""
Generates video frames from text prompt using diffusion process. Generates video frames from text prompt using diffusion process.
@ -343,20 +352,20 @@ class WanT2V:
size = (source_video.shape[2], source_video.shape[1]) size = (source_video.shape[2], source_video.shape[1])
source_video = source_video.to(dtype=self.dtype , device=self.device) source_video = source_video.to(dtype=self.dtype , device=self.device)
source_video = source_video.permute(3, 0, 1, 2).div_(127.5).sub_(1.) source_video = source_video.permute(3, 0, 1, 2).div_(127.5).sub_(1.)
source_latents = self.vae.encode([source_video]) #.to(dtype=self.dtype, device=self.device) source_latents = self.vae.encode([source_video])[0] #.to(dtype=self.dtype, device=self.device)
del source_video del source_video
# Process target camera (recammaster) # Process target camera (recammaster)
from wan.utils.cammmaster_tools import get_camera_embedding from wan.utils.cammmaster_tools import get_camera_embedding
cam_emb = get_camera_embedding(target_camera) cam_emb = get_camera_embedding(target_camera)
cam_emb = cam_emb.to(dtype=self.dtype, device=self.device) cam_emb = cam_emb.to(dtype=self.dtype, device=self.device)
if input_frames != None: if vace :
# vace context encode # vace context encode
input_frames = [u.to(self.device) for u in input_frames] input_frames = [u.to(self.device) for u in input_frames]
input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images] input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images]
input_masks = [u.to(self.device) for u in input_masks] input_masks = [u.to(self.device) for u in input_masks]
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size) z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents, overlap_noise = overlap_noise )
m0 = self.vace_encode_masks(input_masks, input_ref_images) m0 = self.vace_encode_masks(input_masks, input_ref_images)
z = self.vace_latent(z0, m0) z = self.vace_latent(z0, m0)
@ -365,10 +374,10 @@ class WanT2V:
else: else:
if input_ref_images != None: # Phantom Ref images if input_ref_images != None: # Phantom Ref images
phantom = True phantom = True
input_ref_images = [self.get_vae_latents(input_ref_images, self.device)] input_ref_images = self.get_vae_latents(input_ref_images, self.device)
input_ref_images_neg = [torch.zeros_like(input_ref_images[0])] input_ref_images_neg = torch.zeros_like(input_ref_images)
F = frame_num F = frame_num
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1 + (input_ref_images[0].shape[1] if input_ref_images != None else 0), target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1 + (input_ref_images.shape[1] if input_ref_images != None else 0),
size[1] // self.vae_stride[1], size[1] // self.vae_stride[1],
size[0] // self.vae_stride[2]) size[0] // self.vae_stride[2])
@ -405,37 +414,48 @@ class WanT2V:
raise NotImplementedError("Unsupported solver.") raise NotImplementedError("Unsupported solver.")
# sample videos # sample videos
latents = noise latents = noise[0]
del noise del noise
batch_size =len(latents) batch_size = 1
if target_camera != None: if target_camera != None:
shape = list(latents[0].shape[1:]) shape = list(latents.shape[1:])
shape[0] *= 2 shape[0] *= 2
freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False) freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
else: else:
freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx) freqs = get_rotary_pos_embed(latents.shape[1:], enable_RIFLEx= enable_RIFLEx)
kwargs = {'freqs': freqs, 'pipeline': self, 'callback': callback} kwargs = {'freqs': freqs, 'pipeline': self, 'callback': callback}
if target_camera != None: if target_camera != None:
kwargs.update({'cam_emb': cam_emb}) kwargs.update({'cam_emb': cam_emb})
if input_frames != None: if vace:
ref_images_count = len(input_ref_images[0]) if input_ref_images != None else 0
kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale}) kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale})
if overlapped_latents > 0:
z_reactive = [ zz[0:16, ref_images_count:overlapped_latents + ref_images_count].clone() for zz in z]
if self.model.enable_teacache: if self.model.enable_teacache:
x_count = 3 if phantom else 2
self.model.previous_residual = [None] * x_count
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier) self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
if callback != None: if callback != None:
callback(-1, None, True) callback(-1, None, True)
for i, t in enumerate(tqdm(timesteps)): for i, t in enumerate(tqdm(timesteps)):
if vace and overlapped_latents > 0 :
# noise_factor = overlap_noise *(i/(len(timesteps)-1)) / 1000
noise_factor = overlap_noise / 1000 # * (999-t) / 999
# noise_factor = overlap_noise / 1000 # * t / 999
for zz, zz_r in zip(z, z_reactive):
zz[0:16, ref_images_count:overlapped_latents + ref_images_count] = zz_r * (1.0 - noise_factor) + torch.randn_like(zz_r ) * noise_factor
if target_camera != None: if target_camera != None:
latent_model_input = [torch.cat([u,v], dim=1) for u,v in zip(latents,source_latents )] latent_model_input = torch.cat([latents, source_latents], dim=1)
else: else:
latent_model_input = latents latent_model_input = latents
slg_layers_local = None kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None
if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps):
slg_layers_local = slg_layers
timestep = [t] timestep = [t]
offload.set_step_no_for_lora(self.model, i) offload.set_step_no_for_lora(self.model, i)
timestep = torch.stack(timestep) timestep = torch.stack(timestep)
@ -444,38 +464,38 @@ class WanT2V:
if joint_pass: if joint_pass:
if phantom: if phantom:
pos_it, pos_i, neg = self.model( pos_it, pos_i, neg = self.model(
[torch.cat([latent[:,:-ref_latent.shape[1]], ref_latent], dim=1) for latent, ref_latent in zip(latent_model_input, input_ref_images)], [ torch.cat([latent_model_input[:,:-input_ref_images.shape[1]], input_ref_images], dim=1) ] * 2 +
x_neg = [torch.cat([latent[:,:-ref_latent_neg.shape[1]], ref_latent_neg], dim=1) for latent, ref_latent_neg in zip(latent_model_input, input_ref_images_neg)], [ torch.cat([latent_model_input[:,:-input_ref_images_neg.shape[1]], input_ref_images_neg], dim=1)],
context = [context, context_null, context_null], **kwargs) context = [context, context_null, context_null], **kwargs)
else: else:
noise_pred_cond, noise_pred_uncond = self.model( noise_pred_cond, noise_pred_uncond = self.model(
latent_model_input, slg_layers=slg_layers_local, context = [context, context_null], **kwargs) [latent_model_input, latent_model_input], context = [context, context_null], **kwargs)
if self._interrupt: if self._interrupt:
return None return None
else: else:
if phantom: if phantom:
pos_it = self.model( pos_it = self.model(
[torch.cat([latent[:,:-ref_latent.shape[1]], ref_latent], dim=1) for latent, ref_latent in zip(latent_model_input, input_ref_images)], context = [context], **kwargs [ torch.cat([latent_model_input[:,:-input_ref_images.shape[1]], input_ref_images], dim=1) ], x_id = 0, context = [context], **kwargs
)[0] )[0]
if self._interrupt: if self._interrupt:
return None return None
pos_i = self.model( pos_i = self.model(
[torch.cat([latent[:,:-ref_latent.shape[1]], ref_latent], dim=1) for latent, ref_latent in zip(latent_model_input, input_ref_images)], context = [context_null],**kwargs [ torch.cat([latent_model_input[:,:-input_ref_images.shape[1]], input_ref_images], dim=1) ], x_id = 1, context = [context_null],**kwargs
)[0] )[0]
if self._interrupt: if self._interrupt:
return None return None
neg = self.model( neg = self.model(
[torch.cat([latent[:,:-ref_latent_neg.shape[1]], ref_latent_neg], dim=1) for latent, ref_latent_neg in zip(latent_model_input, input_ref_images_neg)], context = [context_null], **kwargs [ torch.cat([latent_model_input[:,:-input_ref_images_neg.shape[1]], input_ref_images_neg], dim=1) ], x_id = 2, context = [context_null], **kwargs
)[0] )[0]
if self._interrupt: if self._interrupt:
return None return None
else: else:
noise_pred_cond = self.model( noise_pred_cond = self.model(
latent_model_input, is_uncond = False, context = [context], **kwargs)[0] [latent_model_input], x_id = 0, context = [context], **kwargs)[0]
if self._interrupt: if self._interrupt:
return None return None
noise_pred_uncond = self.model( noise_pred_uncond = self.model(
latent_model_input, is_uncond = True, slg_layers=slg_layers_local,context = [context_null], **kwargs)[0] [latent_model_input], x_id = 1, context = [context_null], **kwargs)[0]
if self._interrupt: if self._interrupt:
return None return None
@ -505,21 +525,21 @@ class WanT2V:
temp_x0 = sample_scheduler.step( temp_x0 = sample_scheduler.step(
noise_pred[:, :target_shape[1]].unsqueeze(0), noise_pred[:, :target_shape[1]].unsqueeze(0),
t, t,
latents[0].unsqueeze(0), latents.unsqueeze(0),
return_dict=False, return_dict=False,
generator=seed_g)[0] generator=seed_g)[0]
latents = [temp_x0.squeeze(0)] latents = temp_x0.squeeze(0)
del temp_x0 del temp_x0
if callback is not None: if callback is not None:
callback(i, latents[0], False) callback(i, latents, False)
x0 = latents x0 = [latents]
if input_frames == None: if input_frames == None:
if phantom: if phantom:
# phantom post processing # phantom post processing
x0 = [x0_[:,:-input_ref_images[0].shape[1]] for x0_ in x0] x0 = [x0_[:,:-input_ref_images.shape[1]] for x0_ in x0]
videos = self.vae.decode(x0, VAE_tile_size) videos = self.vae.decode(x0, VAE_tile_size)
else: else:
# vace post processing # vace post processing

325
wgp.py
View File

@ -40,7 +40,7 @@ global_queue_ref = []
AUTOSAVE_FILENAME = "queue.zip" AUTOSAVE_FILENAME = "queue.zip"
PROMPT_VARS_MAX = 10 PROMPT_VARS_MAX = 10
target_mmgp_version = "3.4.1" target_mmgp_version = "3.4.2"
from importlib.metadata import version from importlib.metadata import version
mmgp_version = version("mmgp") mmgp_version = version("mmgp")
if mmgp_version != target_mmgp_version: if mmgp_version != target_mmgp_version:
@ -49,32 +49,30 @@ if mmgp_version != target_mmgp_version:
lock = threading.Lock() lock = threading.Lock()
current_task_id = None current_task_id = None
task_id = 0 task_id = 0
# progress_tracker = {}
# tracker_lock = threading.Lock()
# def download_ffmpeg(): def download_ffmpeg():
# if os.name != 'nt': return if os.name != 'nt': return
# exes = ['ffmpeg.exe', 'ffprobe.exe', 'ffplay.exe'] exes = ['ffmpeg.exe', 'ffprobe.exe', 'ffplay.exe']
# if all(os.path.exists(e) for e in exes): return if all(os.path.exists(e) for e in exes): return
# api_url = 'https://api.github.com/repos/GyanD/codexffmpeg/releases/latest' api_url = 'https://api.github.com/repos/GyanD/codexffmpeg/releases/latest'
# r = requests.get(api_url, headers={'Accept': 'application/vnd.github+json'}) r = requests.get(api_url, headers={'Accept': 'application/vnd.github+json'})
# assets = r.json().get('assets', []) assets = r.json().get('assets', [])
# zip_asset = next((a for a in assets if 'essentials_build.zip' in a['name']), None) zip_asset = next((a for a in assets if 'essentials_build.zip' in a['name']), None)
# if not zip_asset: return if not zip_asset: return
# zip_url = zip_asset['browser_download_url'] zip_url = zip_asset['browser_download_url']
# zip_name = zip_asset['name'] zip_name = zip_asset['name']
# with requests.get(zip_url, stream=True) as resp: with requests.get(zip_url, stream=True) as resp:
# total = int(resp.headers.get('Content-Length', 0)) total = int(resp.headers.get('Content-Length', 0))
# with open(zip_name, 'wb') as f, tqdm(total=total, unit='B', unit_scale=True) as pbar: with open(zip_name, 'wb') as f, tqdm(total=total, unit='B', unit_scale=True) as pbar:
# for chunk in resp.iter_content(chunk_size=8192): for chunk in resp.iter_content(chunk_size=8192):
# f.write(chunk) f.write(chunk)
# pbar.update(len(chunk)) pbar.update(len(chunk))
# with zipfile.ZipFile(zip_name) as z: with zipfile.ZipFile(zip_name) as z:
# for f in z.namelist(): for f in z.namelist():
# if f.endswith(tuple(exes)) and '/bin/' in f: if f.endswith(tuple(exes)) and '/bin/' in f:
# z.extract(f) z.extract(f)
# os.rename(f, os.path.basename(f)) os.rename(f, os.path.basename(f))
# os.remove(zip_name) os.remove(zip_name)
def format_time(seconds): def format_time(seconds):
if seconds < 60: if seconds < 60:
@ -168,14 +166,14 @@ def process_prompt_and_add_tasks(state, model_choice):
resolution = inputs["resolution"] resolution = inputs["resolution"]
width, height = resolution.split("x") width, height = resolution.split("x")
width, height = int(width), int(height) width, height = int(width), int(height)
if test_class_i2v(model_filename): # if test_class_i2v(model_filename):
if "480p" in model_filename and not "Fun" in model_filename and width * height > 848*480: # if "480p" in model_filename and not "Fun" in model_filename and width * height > 848*480:
gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P") # gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P")
return # return
resolution = str(width) + "*" + str(height) # resolution = str(width) + "*" + str(height)
if resolution not in ['720*1280', '1280*720', '480*832', '832*480']: # if resolution not in ['720*1280', '1280*720', '480*832', '832*480']:
gr.Info(f"Resolution {resolution} not supported by image 2 video") # gr.Info(f"Resolution {resolution} not supported by image 2 video")
return # return
if "1.3B" in model_filename and width * height > 848*480 and any( model in model_filename for model in ["image2video", "text2video"] ): if "1.3B" in model_filename and width * height > 848*480 and any( model in model_filename for model in ["image2video", "text2video"] ):
gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P") gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P")
@ -533,7 +531,7 @@ def save_queue_action(state):
task_id_s = task.get('id', f"task_{task_index}") task_id_s = task.get('id', f"task_{task_index}")
image_keys = ["image_start", "image_end", "image_refs"] image_keys = ["image_start", "image_end", "image_refs"]
video_keys = ["video_guide", "video_mask", "video_source"] video_keys = ["video_guide", "video_mask", "video_source", "audio_guide"]
for key in image_keys: for key in image_keys:
images_pil = params_copy.get(key) images_pil = params_copy.get(key)
@ -707,7 +705,7 @@ def load_queue_action(filepath, state, evt:gr.EventData):
params['state'] = state params['state'] = state
image_keys = ["image_start", "image_end", "image_refs"] image_keys = ["image_start", "image_end", "image_refs"]
video_keys = ["video_guide", "video_mask", "video_source"] video_keys = ["video_guide", "video_mask", "video_source", "audio_guide"]
loaded_pil_images = {} loaded_pil_images = {}
loaded_video_paths = {} loaded_video_paths = {}
@ -925,7 +923,7 @@ def autosave_queue():
task_id_s = task.get('id', f"task_{task_index}") task_id_s = task.get('id', f"task_{task_index}")
image_keys = ["image_start", "image_end", "image_refs"] image_keys = ["image_start", "image_end", "image_refs"]
video_keys = ["video_guide", "video_mask", "video_source"] video_keys = ["video_guide", "video_mask", "video_source", "audio_guide"]
for key in image_keys: for key in image_keys:
images_pil = params_copy.get(key) images_pil = params_copy.get(key)
@ -1418,32 +1416,35 @@ else:
text = reader.read() text = reader.read()
server_config = json.loads(text) server_config = json.loads(text)
# for src_path, tgt_path in zip( ["ckpts/sky_reels2_diffusion_forcing_14B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors"], ["ckpts/sky_reels2_diffusion_forcing_540p_14B_quanto_int8.safetensors", "ckpts/sky_reels2_diffusion_forcing_540p_14B_bf16.safetensors"] ): # Deprecated models
# if Path(src_path).is_file(): for path in ["wan2.1_Vace_1.3B_preview_bf16.safetensors", "sky_reels2_diffusion_forcing_1.3B_bf16.safetensors","sky_reels2_diffusion_forcing_720p_14B_bf16.safetensors",
# shutil.move(src_path, tgt_path) ) "sky_reels2_diffusion_forcing_720p_14B_quanto_int8.safetensors", "sky_reels2_diffusion_forcing_720p_14B_quanto_fp16_int8.safetensors", "wan2.1_image2video_480p_14B_bf16.safetensors", "wan2.1_image2video_480p_14B_quanto_int8.safetensors",
# for path in ["ckpts/wan2.1_Vace_1.3B_preview_mbf16.safetensors", "sky_reels2_diffusion_forcing_1.3B_bf16.safetensors"]: "wan2.1_image2video_720p_14B_quanto_int8.safetensors", "wan2.1_image2video_720p_14B_quanto_fp16_int8.safetensors", "wan2.1_image2video_720p_14B_bf16.safetensors"
# if Path(path).is_file(): ]:
# os.remove(path) if Path(os.path.join("ckpts" , path)).is_file():
os.remove( os.path.join("ckpts" , path))
path= "ckpts/sky_reels2_diffusion_forcing_1.3B_bf16.safetensors"
if os.path.isfile(path) and os.path.getsize(path) > 4000000000:
os.remove(path)
transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", "ckpts/wan2.1_Vace_1.3B_preview_mbf16.safetensors", transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", "ckpts/wan2.1_Vace_1.3B_preview_mbf16.safetensors",
"ckpts/wan2.1_recammaster_1.3B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_bf16.safetensors", "ckpts/wan2.1_recammaster_1.3B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_bf16.safetensors",
"ckpts/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors", "ckpts/sky_reels2_diffusion_forcing_720p_14B_bf16.safetensors","ckpts/sky_reels2_diffusion_forcing_720p_14B_quanto_int8.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors", "ckpts/sky_reels2_diffusion_forcing_720p_14B_mbf16.safetensors","ckpts/sky_reels2_diffusion_forcing_720p_14B_quanto_mbf16_int8.safetensors",
"ckpts/wan2_1_phantom_1.3B_mbf16.safetensors"] "ckpts/wan2_1_phantom_1.3B_mbf16.safetensors"]
transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_mbf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_mbf16.safetensors",
"ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors",
"ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_bf16.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_quanto_int8.safetensors"] "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_bf16.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_quanto_int8.safetensors",
"ckpts/wan2.1_fantasy_speaking_14B_bf16.safetensors"]
transformer_choices = transformer_choices_t2v + transformer_choices_i2v transformer_choices = transformer_choices_t2v + transformer_choices_i2v
def get_dependent_models(model_filename, quantization ):
model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "fun_inp", "recam_1.3B", "flf2v_720p", "sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "phantom_1.3B"] if "fantasy" in model_filename:
return [get_model_filename("i2v_720p", quantization)]
else:
return []
model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "fun_inp", "recam_1.3B", "flf2v_720p", "sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "phantom_1.3B", "fantasy"]
model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B", model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B",
"i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B", "recam_1.3B": "recammaster_1.3B", "i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B", "recam_1.3B": "recammaster_1.3B",
"flf2v_720p" : "FLF2V_720p", "sky_df_1.3B" : "sky_reels2_diffusion_forcing_1.3B", "sky_df_14B" : "sky_reels2_diffusion_forcing_14B", "flf2v_720p" : "FLF2V_720p", "sky_df_1.3B" : "sky_reels2_diffusion_forcing_1.3B", "sky_df_14B" : "sky_reels2_diffusion_forcing_14B",
"sky_df_720p_14B" : "sky_reels2_diffusion_forcing_720p_14B", "sky_df_720p_14B" : "sky_reels2_diffusion_forcing_720p_14B",
"phantom_1.3B" : "phantom_1.3B", } "phantom_1.3B" : "phantom_1.3B", "fantasy" : "fantasy" }
def get_model_type(model_filename): def get_model_type(model_filename):
@ -1453,7 +1454,7 @@ def get_model_type(model_filename):
raise Exception("Unknown model:" + model_filename) raise Exception("Unknown model:" + model_filename)
def test_class_i2v(model_filename): def test_class_i2v(model_filename):
return "image2video" in model_filename or "Fun_InP" in model_filename or "FLF2V" in model_filename return "image2video" in model_filename or "Fun_InP" in model_filename or "FLF2V" in model_filename or "fantasy" in model_filename
def get_model_name(model_filename, description_container = [""]): def get_model_name(model_filename, description_container = [""]):
if "Fun" in model_filename: if "Fun" in model_filename:
@ -1491,6 +1492,10 @@ def get_model_name(model_filename, description_container = [""]):
model_name = "Wan2.1 Phantom" model_name = "Wan2.1 Phantom"
model_name += " 14B" if "14B" in model_filename else " 1.3B" model_name += " 14B" if "14B" in model_filename else " 1.3B"
description = "The Phantom model is specialized to transfer people or objects of your choice into a generated Video. It produces very nices results when used at 720p." description = "The Phantom model is specialized to transfer people or objects of your choice into a generated Video. It produces very nices results when used at 720p."
elif "fantasy" in model_filename:
model_name = "Wan2.1 Fantasy Speaking 720p"
model_name += " 14B" if "14B" in model_filename else " 1.3B"
description = "The Fantasy Speaking model corresponds to the original Wan image 2 video model combined with the Fantasy Speaking extension to process an audio Input."
else: else:
model_name = "Wan2.1 text2video" model_name = "Wan2.1 text2video"
model_name += " 14B" if "14B" in model_filename else " 1.3B" model_name += " 14B" if "14B" in model_filename else " 1.3B"
@ -1536,6 +1541,7 @@ def get_default_settings(filename):
"repeat_generation": 1, "repeat_generation": 1,
"multi_images_gen_type": 0, "multi_images_gen_type": 0,
"guidance_scale": 5.0, "guidance_scale": 5.0,
"audio_guidance_scale": 5.0,
"flow_shift": get_default_flow(filename, i2v), "flow_shift": get_default_flow(filename, i2v),
"negative_prompt": "", "negative_prompt": "",
"activated_loras": [], "activated_loras": [],
@ -1719,8 +1725,9 @@ def download_models(transformer_filename, text_encoder_filename):
from huggingface_hub import hf_hub_download, snapshot_download from huggingface_hub import hf_hub_download, snapshot_download
repoId = "DeepBeepMeep/Wan2.1" repoId = "DeepBeepMeep/Wan2.1"
sourceFolderList = ["xlm-roberta-large", "pose", "depth", "mask", "", ] sourceFolderList = ["xlm-roberta-large", "pose", "depth", "mask", "wav2vec", "" ]
fileList = [ [], [],[], ["sam_vit_h_4b8939_fp16.safetensors"], ["Wan2.1_VAE.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "flownet.pkl" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ] fileList = [ [], [],[], ["sam_vit_h_4b8939_fp16.safetensors"], ["config.json", "feature_extractor_config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"],
["Wan2.1_VAE.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "flownet.pkl" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ]
targetRoot = "ckpts/" targetRoot = "ckpts/"
for sourceFolder, files in zip(sourceFolderList,fileList ): for sourceFolder, files in zip(sourceFolderList,fileList ):
if len(files)==0: if len(files)==0:
@ -1834,12 +1841,13 @@ def setup_loras(model_filename, transformer, lora_dir, lora_preselected_preset,
return loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset return loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset
def load_t2v_model(model_filename, value, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False): def load_t2v_model(model_filename, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False):
cfg = WAN_CONFIGS['t2v-14B'] cfg = WAN_CONFIGS['t2v-14B']
filename = model_filename[-1]
# cfg = WAN_CONFIGS['t2v-1.3B'] # cfg = WAN_CONFIGS['t2v-1.3B']
print(f"Loading '{model_filename}' model...") print(f"Loading '{filename}' model...")
if get_model_type(model_filename) in ("sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"): if get_model_type(filename) in ("sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"):
model_factory = wan.DTT2V model_factory = wan.DTT2V
else: else:
model_factory = wan.WanT2V model_factory = wan.WanT2V
@ -1859,9 +1867,10 @@ def load_t2v_model(model_filename, value, quantizeTransformer = False, dtype = t
return wan_model, pipe return wan_model, pipe
def load_i2v_model(model_filename, value, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False): def load_i2v_model(model_filename, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False):
print(f"Loading '{model_filename}' model...") filename = model_filename[-1]
print(f"Loading '{filename}' model...")
cfg = WAN_CONFIGS['i2v-14B'] cfg = WAN_CONFIGS['i2v-14B']
wan_model = wan.WanI2V( wan_model = wan.WanI2V(
@ -1883,7 +1892,6 @@ def load_i2v_model(model_filename, value, quantizeTransformer = False, dtype = t
def load_models(model_filename): def load_models(model_filename):
global transformer_filename global transformer_filename
transformer_filename = model_filename
perc_reserved_mem_max = args.perc_reserved_mem_max perc_reserved_mem_max = args.perc_reserved_mem_max
major, minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None) major, minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None)
@ -1892,18 +1900,26 @@ def load_models(model_filename):
default_dtype = torch.float16 default_dtype = torch.float16
else: else:
default_dtype = torch.float16 if args.fp16 else torch.bfloat16 default_dtype = torch.float16 if args.fp16 else torch.bfloat16
model_filelist = get_dependent_models(model_filename, quantization= transformer_quantization) + [model_filename]
updated_model_filename = []
for filename in model_filelist:
if default_dtype == torch.float16 : if default_dtype == torch.float16 :
if "quanto" in model_filename: if "quanto_int8" in filename:
model_filename = model_filename.replace("quanto_int8", "quanto_fp16_int8") filename = filename.replace("quanto_int8", "quanto_fp16_int8")
download_models(model_filename, text_encoder_filename) elif "quanto_mbf16_int8":
filename = filename.replace("quanto_mbf16_int8", "quanto_mfp16_int8")
updated_model_filename.append(filename)
download_models(filename, text_encoder_filename)
model_filelist = updated_model_filename
VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float
mixed_precision_transformer = server_config.get("mixed_precision","0") == "1" mixed_precision_transformer = server_config.get("mixed_precision","0") == "1"
if test_class_i2v(model_filename): transformer_filename = None
res720P = "720p" in model_filename new_transformer_filename = model_filelist[-1]
wan_model, pipe = load_i2v_model(model_filename, "720P" if res720P else "480P", quantizeTransformer = quantizeTransformer, dtype = default_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer) if test_class_i2v(new_transformer_filename):
wan_model, pipe = load_i2v_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = default_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
else: else:
wan_model, pipe = load_t2v_model(model_filename, "", quantizeTransformer = quantizeTransformer, dtype = default_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer) wan_model, pipe = load_t2v_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = default_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
wan_model._model_file_name = model_filename wan_model._model_file_name = new_transformer_filename
kwargs = { "extraModelsToQuantize": None} kwargs = { "extraModelsToQuantize": None}
if profile == 2 or profile == 4: if profile == 2 or profile == 4:
kwargs["budgets"] = { "transformer" : 100 if preload == 0 else preload, "text_encoder" : 100, "*" : 1000 } kwargs["budgets"] = { "transformer" : 100 if preload == 0 else preload, "text_encoder" : 100, "*" : 1000 }
@ -1914,7 +1930,7 @@ def load_models(model_filename):
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = default_dtype, **kwargs) offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = default_dtype, **kwargs)
if len(args.gpu) > 0: if len(args.gpu) > 0:
torch.set_default_device(args.gpu) torch.set_default_device(args.gpu)
transformer_filename = new_transformer_filename
return wan_model, offloadobj, pipe["transformer"] return wan_model, offloadobj, pipe["transformer"]
if not "P" in preload_model_policy: if not "P" in preload_model_policy:
@ -2033,13 +2049,7 @@ def apply_changes( state,
preload_model_policy = server_config["preload_model_policy"] preload_model_policy = server_config["preload_model_policy"]
transformer_quantization = server_config["transformer_quantization"] transformer_quantization = server_config["transformer_quantization"]
transformer_types = server_config["transformer_types"] transformer_types = server_config["transformer_types"]
model_filename = state["model_filename"]
model_transformer_type = get_model_type(model_filename)
if not model_transformer_type in transformer_types:
model_transformer_type = transformer_types[0] if len(transformer_types) > 0 else model_types[0]
model_filename = get_model_filename(model_transformer_type, transformer_quantization)
state["model_filename"] = model_filename
if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list"] for change in changes ): if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list"] for change in changes ):
model_choice = gr.Dropdown() model_choice = gr.Dropdown()
else: else:
@ -2249,7 +2259,9 @@ def preprocess_video(process_type, height, width, video_in, max_frames, start_fr
frame_height, frame_width, _ = frames_list[0].shape frame_height, frame_width, _ = frames_list[0].shape
if fit_canvas : if fit_canvas :
scale = min(height / frame_height, width / frame_width) scale1 = min(height / frame_height, width / frame_width)
scale2 = min(height / frame_width, width / frame_height)
scale = max(scale1, scale2)
else: else:
scale = ((height * width ) / (frame_height * frame_width))**(1/2) scale = ((height * width ) / (frame_height * frame_width))**(1/2)
@ -2356,6 +2368,7 @@ def generate_video(
seed, seed,
num_inference_steps, num_inference_steps,
guidance_scale, guidance_scale,
audio_guidance_scale,
flow_shift, flow_shift,
embedded_guidance_scale, embedded_guidance_scale,
repeat_generation, repeat_generation,
@ -2375,8 +2388,10 @@ def generate_video(
video_guide, video_guide,
keep_frames_video_guide, keep_frames_video_guide,
video_mask, video_mask,
audio_guide,
sliding_window_size, sliding_window_size,
sliding_window_overlap, sliding_window_overlap,
sliding_window_overlap_noise,
sliding_window_discard_last_frames, sliding_window_discard_last_frames,
remove_background_image_ref, remove_background_image_ref,
temporal_upsampling, temporal_upsampling,
@ -2508,6 +2523,15 @@ def generate_video(
# VAE Tiling # VAE Tiling
device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576 device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576
diffusion_forcing = "diffusion_forcing" in model_filename
vace = "Vace" in model_filename
if diffusion_forcing:
fps = 24
elif audio_guide != None:
fps = 23
else:
fps = 16
joint_pass = boost ==1 #and profile != 1 and profile != 3 joint_pass = boost ==1 #and profile != 1 and profile != 3
# TeaCache # TeaCache
trans.enable_teacache = tea_cache_setting > 0 trans.enable_teacache = tea_cache_setting > 0
@ -2517,12 +2541,10 @@ def generate_video(
trans.teacache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100) trans.teacache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
if image2video: if image2video:
if '480p' in model_filename: if '720p' in model_filename:
trans.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
elif '720p' in model_filename:
trans.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683] trans.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
else: else:
raise gr.Error("Teacache not supported for this model") trans.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
else: else:
if '1.3B' in model_filename: if '1.3B' in model_filename:
trans.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01] trans.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
@ -2535,6 +2557,18 @@ def generate_video(
if "recam" in model_filename: if "recam" in model_filename:
source_video = preprocess_video("", width=width, height=height,video_in=video_source, max_frames= video_length, start_frame = 0, fit_canvas= True) source_video = preprocess_video("", width=width, height=height,video_in=video_source, max_frames= video_length, start_frame = 0, fit_canvas= True)
target_camera = model_mode target_camera = model_mode
audio_proj_split = None
audio_scale = None
audio_context_lens = None
if audio_guide != None:
from fantasytalking.infer import parse_audio
import librosa
duration = librosa.get_duration(path=audio_guide)
video_length = min(int(fps * duration // 4) * 4 + 5, video_length)
audio_proj_split, audio_context_lens = parse_audio(audio_guide, num_frames= video_length, fps= fps, device= processing_device )
audio_scale = 1.0
import random import random
if seed == None or seed <0: if seed == None or seed <0:
seed = random.randint(0, 999999999) seed = random.randint(0, 999999999)
@ -2551,11 +2585,11 @@ def generate_video(
extra_generation = 0 extra_generation = 0
initial_total_windows = 0 initial_total_windows = 0
max_frames_to_generate = video_length max_frames_to_generate = video_length
diffusion_forcing = "diffusion_forcing" in model_filename
vace = "Vace" in model_filename
phantom = "phantom" in model_filename phantom = "phantom" in model_filename
if diffusion_forcing or vace: if diffusion_forcing or vace:
reuse_frames = min(sliding_window_size - 4, sliding_window_overlap) reuse_frames = min(sliding_window_size - 4, sliding_window_overlap)
else:
reuse_frames = 0
if diffusion_forcing and source_video != None: if diffusion_forcing and source_video != None:
video_length += sliding_window_overlap video_length += sliding_window_overlap
sliding_window = ("Vace" in model_filename or diffusion_forcing) and video_length > sliding_window_size sliding_window = ("Vace" in model_filename or diffusion_forcing) and video_length > sliding_window_size
@ -2571,10 +2605,8 @@ def generate_video(
initial_total_windows = 1 initial_total_windows = 1
first_window_video_length = video_length first_window_video_length = video_length
fps = 24 if diffusion_forcing else 16
gen["sliding_window"] = sliding_window gen["sliding_window"] = sliding_window
while not abort: while not abort:
extra_generation += gen.get("extra_orders",0) extra_generation += gen.get("extra_orders",0)
gen["extra_orders"] = 0 gen["extra_orders"] = 0
@ -2594,6 +2626,7 @@ def generate_video(
guide_start_frame = 0 guide_start_frame = 0
video_length = first_window_video_length video_length = first_window_video_length
gen["extra_windows"] = 0 gen["extra_windows"] = 0
start_time = time.time()
while not abort: while not abort:
if sliding_window: if sliding_window:
prompt = prompts[window_no] if window_no < len(prompts) else prompts[-1] prompt = prompts[window_no] if window_no < len(prompts) else prompts[-1]
@ -2642,7 +2675,7 @@ def generate_video(
if preprocess_type != None : if preprocess_type != None :
send_cmd("progress", progress_args) send_cmd("progress", progress_args)
video_guide_copy = preprocess_video(preprocess_type, width=width, height=height,video_in=video_guide, max_frames= video_length if window_no == 1 else video_length - reuse_frames, start_frame = guide_start_frame, target_fps = fps) video_guide_copy = preprocess_video(preprocess_type, width=width, height=height,video_in=video_guide, max_frames= video_length if window_no == 1 else video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = True, target_fps = fps)
keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, max_frames_to_generate) keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, max_frames_to_generate)
if len(error) > 0: if len(error) > 0:
raise gr.Error(f"invalid keep frames {keep_frames_video_guide}") raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
@ -2678,8 +2711,7 @@ def generate_video(
trans.teacache_counter = 0 trans.teacache_counter = 0
trans.num_steps = num_inference_steps trans.num_steps = num_inference_steps
trans.teacache_skipped_steps = 0 trans.teacache_skipped_steps = 0
trans.previous_residual_uncond = None trans.previous_residual = None
trans.previous_residual_cond = None
if image2video: if image2video:
samples = wan_model.generate( samples = wan_model.generate(
@ -2687,7 +2719,9 @@ def generate_video(
image_start, image_start,
image_end if image_end != None else None, image_end if image_end != None else None,
frame_num=(video_length // 4)* 4 + 1, frame_num=(video_length // 4)* 4 + 1,
max_area=MAX_AREA_CONFIGS[resolution_reformated], # max_area=MAX_AREA_CONFIGS[resolution_reformated],
height = height,
width = width,
shift=flow_shift, shift=flow_shift,
sampling_steps=num_inference_steps, sampling_steps=num_inference_steps,
guide_scale=guidance_scale, guide_scale=guidance_scale,
@ -2702,7 +2736,11 @@ def generate_video(
slg_end = slg_end_perc/100, slg_end = slg_end_perc/100,
cfg_star_switch = cfg_star_switch, cfg_star_switch = cfg_star_switch,
cfg_zero_step = cfg_zero_step, cfg_zero_step = cfg_zero_step,
add_frames_for_end_image = "image2video" in model_filename add_frames_for_end_image = "image2video" in model_filename,
audio_cfg_scale= audio_guidance_scale,
audio_proj= audio_proj_split,
audio_scale= audio_scale,
audio_context_lens= audio_context_lens
) )
elif diffusion_forcing: elif diffusion_forcing:
samples = wan_model.generate( samples = wan_model.generate(
@ -2720,7 +2758,10 @@ def generate_video(
callback= callback, callback= callback,
VAE_tile_size = VAE_tile_size, VAE_tile_size = VAE_tile_size,
joint_pass = joint_pass, joint_pass = joint_pass,
addnoise_condition = 20, slg_layers = slg_layers,
slg_start = slg_start_perc/100,
slg_end = slg_end_perc/100,
addnoise_condition = sliding_window_overlap_noise,
ar_step = model_mode, #5 ar_step = model_mode, #5
causal_block_size = 5, causal_block_size = 5,
causal_attention = True, causal_attention = True,
@ -2751,6 +2792,9 @@ def generate_video(
slg_end = slg_end_perc/100, slg_end = slg_end_perc/100,
cfg_star_switch = cfg_star_switch, cfg_star_switch = cfg_star_switch,
cfg_zero_step = cfg_zero_step, cfg_zero_step = cfg_zero_step,
overlapped_latents = 0 if reuse_frames == 0 or window_no == 1 else ((reuse_frames - 1) // 4 + 1),
overlap_noise = sliding_window_overlap_noise,
vace = vace
) )
except Exception as e: except Exception as e:
if temp_filename!= None and os.path.isfile(temp_filename): if temp_filename!= None and os.path.isfile(temp_filename):
@ -2782,11 +2826,11 @@ def generate_video(
print('\n'.join(tb)) print('\n'.join(tb))
send_cmd("error", new_error) send_cmd("error", new_error)
return return
finally:
trans.previous_residual = None
if trans.enable_teacache: if trans.enable_teacache:
print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{num_inference_steps}" ) print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{trans.num_steps}" )
trans.previous_residual_uncond = None
trans.previous_residual_cond = None
if samples != None: if samples != None:
samples = samples.to("cpu") samples = samples.to("cpu")
@ -2810,14 +2854,27 @@ def generate_video(
if discard_last_frames > 0: if discard_last_frames > 0:
sample = sample[: , :-discard_last_frames] sample = sample[: , :-discard_last_frames]
guide_start_frame -= discard_last_frames guide_start_frame -= discard_last_frames
if reuse_frames == 0:
pre_video_guide = sample[:,9999 :]
else:
# noise_factor = 200/ 1000
# pre_video_guide = sample[:, -reuse_frames:] * (1.0 - noise_factor) + torch.randn_like(sample[:, -reuse_frames:] ) * noise_factor
pre_video_guide = sample[:, -reuse_frames:] pre_video_guide = sample[:, -reuse_frames:]
if prefix_video != None: if prefix_video != None:
if reuse_frames == 0:
sample = torch.cat([ prefix_video[:, :], sample], dim = 1)
else:
sample = torch.cat([ prefix_video[:, :-reuse_frames], sample], dim = 1) sample = torch.cat([ prefix_video[:, :-reuse_frames], sample], dim = 1)
prefix_video = None prefix_video = None
if sliding_window and window_no > 1: if sliding_window and window_no > 1:
if reuse_frames == 0:
sample = sample[: , :]
else:
sample = sample[: , reuse_frames:] sample = sample[: , reuse_frames:]
guide_start_frame -= reuse_frames
guide_start_frame -= reuse_frames
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss") time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
if os.name == 'nt': if os.name == 'nt':
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50]).strip()}.mp4" file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50]).strip()}.mp4"
@ -2875,18 +2932,23 @@ def generate_video(
sample = torch.cat([frames_already_processed, sample], dim=1) sample = torch.cat([frames_already_processed, sample], dim=1)
frames_already_processed = sample frames_already_processed = sample
cache_video( if audio_guide == None:
tensor=sample[None], cache_video( tensor=sample[None], save_file=video_path, fps=fps, nrow=1, normalize=True, value_range=(-1, 1))
save_file=video_path, else:
fps=fps, save_path_tmp = video_path[:-4] + "_tmp.mp4"
nrow=1, cache_video( tensor=sample[None], save_file=save_path_tmp, fps=fps, nrow=1, normalize=True, value_range=(-1, 1))
normalize=True, final_command = [ "ffmpeg", "-y", "-i", save_path_tmp, "-i", audio_guide, "-c:v", "libx264", "-c:a", "aac", "-shortest", "-loglevel", "warning", "-nostats", video_path, ]
value_range=(-1, 1)) import subprocess
subprocess.run(final_command, check=True)
os.remove(save_path_tmp)
end_time = time.time()
inputs = get_function_arguments(generate_video, locals()) inputs = get_function_arguments(generate_video, locals())
inputs.pop("send_cmd") inputs.pop("send_cmd")
inputs.pop("task_id")
configs = prepare_inputs_dict("metadata", inputs) configs = prepare_inputs_dict("metadata", inputs)
configs["generation_time"] = round(end_time-start_time)
metadata_choice = server_config.get("metadata_type","metadata") metadata_choice = server_config.get("metadata_type","metadata")
if metadata_choice == "json": if metadata_choice == "json":
with open(video_path.replace('.mp4', '.json'), 'w') as f: with open(video_path.replace('.mp4', '.json'), 'w') as f:
@ -3113,7 +3175,7 @@ def get_latest_status(state):
prompt_no = gen["prompt_no"] prompt_no = gen["prompt_no"]
prompts_max = gen.get("prompts_max",0) prompts_max = gen.get("prompts_max",0)
total_generation = gen.get("total_generation", 1) total_generation = gen.get("total_generation", 1)
repeat_no = gen["repeat_no"] repeat_no = gen.get("repeat_no",0)
total_generation += gen.get("extra_orders", 0) total_generation += gen.get("extra_orders", 0)
total_windows = gen.get("total_windows", 0) total_windows = gen.get("total_windows", 0)
total_windows += gen.get("extra_windows", 0) total_windows += gen.get("extra_windows", 0)
@ -3456,7 +3518,7 @@ def prepare_inputs_dict(target, inputs ):
if target == "state": if target == "state":
return inputs return inputs
unsaved_params = ["image_start", "image_end", "image_refs", "video_guide", "video_source", "video_mask"] unsaved_params = ["image_start", "image_end", "image_refs", "video_guide", "video_source", "video_mask", "audio_guide", "embedded_guidance_scale"]
for k in unsaved_params: for k in unsaved_params:
inputs.pop(k) inputs.pop(k)
@ -3484,10 +3546,14 @@ def prepare_inputs_dict(target, inputs ):
inputs.pop(k) inputs.pop(k)
if not "Vace" in model_filename or "diffusion_forcing" in model_filename: if not "Vace" in model_filename or "diffusion_forcing" in model_filename:
unsaved_params = [ "sliding_window_size", "sliding_window_overlap", "sliding_window_discard_last_frames"] unsaved_params = [ "sliding_window_size", "sliding_window_overlap", "sliding_window_overlap_noise", "sliding_window_discard_last_frames"]
for k in unsaved_params: for k in unsaved_params:
inputs.pop(k) inputs.pop(k)
if not "fantasy" in model_filename:
inputs.pop("audio_guidance_scale")
if target == "metadata": if target == "metadata":
inputs = {k: v for k,v in inputs.items() if v != None } inputs = {k: v for k,v in inputs.items() if v != None }
@ -3511,6 +3577,7 @@ def save_inputs(
seed, seed,
num_inference_steps, num_inference_steps,
guidance_scale, guidance_scale,
audio_guidance_scale,
flow_shift, flow_shift,
embedded_guidance_scale, embedded_guidance_scale,
repeat_generation, repeat_generation,
@ -3530,8 +3597,10 @@ def save_inputs(
video_guide, video_guide,
keep_frames_video_guide, keep_frames_video_guide,
video_mask, video_mask,
audio_guide,
sliding_window_size, sliding_window_size,
sliding_window_overlap, sliding_window_overlap,
sliding_window_overlap_noise,
sliding_window_discard_last_frames, sliding_window_discard_last_frames,
remove_background_image_ref, remove_background_image_ref,
temporal_upsampling, temporal_upsampling,
@ -3834,6 +3903,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
recammaster = "recam" in model_filename recammaster = "recam" in model_filename
vace = "Vace" in model_filename vace = "Vace" in model_filename
phantom = "phantom" in model_filename phantom = "phantom" in model_filename
fantasy = "fantasy" in model_filename
with gr.Column(visible= test_class_i2v(model_filename) or diffusion_forcing or recammaster) as image_prompt_column: with gr.Column(visible= test_class_i2v(model_filename) or diffusion_forcing or recammaster) as image_prompt_column:
if diffusion_forcing: if diffusion_forcing:
image_prompt_type_value= ui_defaults.get("image_prompt_type","S") image_prompt_type_value= ui_defaults.get("image_prompt_type","S")
@ -3939,7 +4009,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
video_mask = gr.Video(label= "Video Mask (for Inpainting or Outpaing, white pixels = Mask)", visible= "M" in video_prompt_type_value, value= ui_defaults.get("video_mask", None)) video_mask = gr.Video(label= "Video Mask (for Inpainting or Outpaing, white pixels = Mask)", visible= "M" in video_prompt_type_value, value= ui_defaults.get("video_mask", None))
audio_guide = gr.Audio(value= ui_defaults.get("audio_guide", None), type="filepath", label="Voice to follow", show_download_button= True, visible= fantasy )
advanced_prompt = advanced_ui advanced_prompt = advanced_ui
prompt_vars=[] prompt_vars=[]
@ -3972,12 +4042,12 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False) wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False)
wizard_variables_var = gr.Text(wizard_variables, visible = False) wizard_variables_var = gr.Text(wizard_variables, visible = False)
with gr.Row(): with gr.Row():
if test_class_i2v(model_filename): if test_class_i2v(model_filename) and False:
resolution = gr.Dropdown( resolution = gr.Dropdown(
choices=[ choices=[
# 720p # 720p
("720p", "1280x720"), ("720p (same amount of pixels)", "1280x720"),
("480p", "832x480"), ("480p (same amount of pixels)", "832x480"),
], ],
value=ui_defaults.get("resolution","480p"), value=ui_defaults.get("resolution","480p"),
label="Resolution (video will have the same height / width ratio than the original image)" label="Resolution (video will have the same height / width ratio than the original image)"
@ -3989,19 +4059,21 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
("1280x720 (16:9, 720p)", "1280x720"), ("1280x720 (16:9, 720p)", "1280x720"),
("720x1280 (9:16, 720p)", "720x1280"), ("720x1280 (9:16, 720p)", "720x1280"),
("1024x1024 (4:3, 720p)", "1024x024"), ("1024x1024 (4:3, 720p)", "1024x024"),
# ("832x1104 (3:4, 720p)", "832x1104"), ("832x1104 (3:4, 720p)", "832x1104"),
# ("960x960 (1:1, 720p)", "960x960"), ("1104x832 (3:4, 720p)", "1104x832"),
("960x960 (1:1, 720p)", "960x960"),
# 480p # 480p
("960x544 (16:9, 540p)", "960x544"), ("960x544 (16:9, 540p)", "960x544"),
("544x960 (16:9, 540p)", "544x960"), ("544x960 (16:9, 540p)", "544x960"),
("832x480 (16:9, 480p)", "832x480"), ("832x480 (16:9, 480p)", "832x480"),
("480x832 (9:16, 480p)", "480x832"), ("480x832 (9:16, 480p)", "480x832"),
# ("832x624 (4:3, 540p)", "832x624"), ("832x624 (4:3, 480p)", "832x624"),
# ("624x832 (3:4, 540p)", "624x832"), ("624x832 (3:4, 480p)", "624x832"),
# ("720x720 (1:1, 540p)", "720x720"), ("720x720 (1:1, 480p)", "720x720"),
("512x512 (1:1, 480p)", "512x512"),
], ],
value=ui_defaults.get("resolution","832x480"), value=ui_defaults.get("resolution","832x480"),
label="Resolution" label="Max Resolution (as it maybe less depending on video width / height ratio)" if test_class_i2v(model_filename) else "Resolution"
) )
with gr.Row(): with gr.Row():
if recammaster: if recammaster:
@ -4010,6 +4082,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
video_length = gr.Slider(17, 737, value=ui_defaults.get("video_length", 97), step=20, label="Number of frames (24 = 1s)", interactive= True) video_length = gr.Slider(17, 737, value=ui_defaults.get("video_length", 97), step=20, label="Number of frames (24 = 1s)", interactive= True)
elif vace: elif vace:
video_length = gr.Slider(17, 737, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True) video_length = gr.Slider(17, 737, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True)
elif fantasy:
video_length = gr.Slider(5, 233, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (23 = 1s)", interactive= True)
else: else:
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True) video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True)
with gr.Row(): with gr.Row():
@ -4029,10 +4103,11 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
choices=[ choices=[
("Generate every combination of images and texts", 0), ("Generate every combination of images and texts", 0),
("Match images and text prompts", 1), ("Match images and text prompts", 1),
], visible= True, label= "Multiple Images as Texts Prompts" ], visible= test_class_i2v(model_filename), label= "Multiple Images as Texts Prompts"
) )
with gr.Row(): with gr.Row():
guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance_scale",5), step=0.5, label="Guidance Scale", visible=True) guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance_scale",5), step=0.5, label="Guidance Scale", visible=True)
audio_guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("audio_guidance_scale",5), step=0.5, label="Audio Guidance", visible=fantasy)
embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=False) embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=False)
flow_shift = gr.Slider(0.0, 25.0, value=ui_defaults.get("flow_shift",3), step=0.1, label="Shift Scale") flow_shift = gr.Slider(0.0, 25.0, value=ui_defaults.get("flow_shift",3), step=0.1, label="Shift Scale")
with gr.Row(): with gr.Row():
@ -4099,7 +4174,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
with gr.Tab("Quality"): with gr.Tab("Quality"):
with gr.Row(): with gr.Row():
gr.Markdown("<B>Experimental: Skip Layer Guidance, should improve video quality</B>") gr.Markdown("<B>Skip Layer Guidance (improves video quality)</B>")
with gr.Row(): with gr.Row():
slg_switch = gr.Dropdown( slg_switch = gr.Dropdown(
choices=[ choices=[
@ -4148,11 +4223,13 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
if diffusion_forcing: if diffusion_forcing:
sliding_window_size = gr.Slider(37, 137, value=ui_defaults.get("sliding_window_size", 97), step=20, label="Sliding Window Size (recommended to keep it at 97)") sliding_window_size = gr.Slider(37, 137, value=ui_defaults.get("sliding_window_size", 97), step=20, label="Sliding Window Size (recommended to keep it at 97)")
sliding_window_overlap = gr.Slider(17, 97, value=ui_defaults.get("sliding_window_overlap",17), step=20, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)") sliding_window_overlap = gr.Slider(17, 97, value=ui_defaults.get("sliding_window_overlap",17), step=20, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
sliding_window_discard_last_frames = gr.Slider(0, 12, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, visible = False) sliding_window_overlap_noise = gr.Slider(0, 100, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect")
sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, visible = False)
else: else:
sliding_window_size = gr.Slider(5, 137, value=ui_defaults.get("sliding_window_size", 81), step=4, label="Sliding Window Size") sliding_window_size = gr.Slider(5, 137, value=ui_defaults.get("sliding_window_size", 81), step=4, label="Sliding Window Size")
sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",17), step=4, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)") sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",5), step=4, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
sliding_window_discard_last_frames = gr.Slider(0, 12, value=ui_defaults.get("sliding_window_discard_last_frames", 4), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True) sliding_window_overlap_noise = gr.Slider(0, 100, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect")
sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 8), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True)
with gr.Tab("Miscellaneous", visible= not "recam" in model_filename): with gr.Tab("Miscellaneous", visible= not "recam" in model_filename):
@ -5035,7 +5112,7 @@ def create_demo():
theme = gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md") theme = gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md")
with gr.Blocks(css=css, theme=theme, title= "Wan2GP") as main: with gr.Blocks(css=css, theme=theme, title= "Wan2GP") as main:
gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.4 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>") gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.5 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
global model_list global model_list
tab_state = gr.State({ "tab_no":0 }) tab_state = gr.State({ "tab_no":0 })
@ -5076,7 +5153,7 @@ def create_demo():
if __name__ == "__main__": if __name__ == "__main__":
atexit.register(autosave_queue) atexit.register(autosave_queue)
# download_ffmpeg() download_ffmpeg()
# threading.Thread(target=runner, daemon=True).start() # threading.Thread(target=runner, daemon=True).start()
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
server_port = int(args.server_port) server_port = int(args.server_port)