mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 22:26:36 +00:00
Added support for fantasyspeaking model
This commit is contained in:
parent
4ecc866c7b
commit
bc9121ffc6
@ -10,6 +10,7 @@
|
|||||||
|
|
||||||
|
|
||||||
## 🔥 Latest News!!
|
## 🔥 Latest News!!
|
||||||
|
* May 5 2025: 👋 Wan 2.1GP v4.5: FantasySpeaking model, you can animate a talking head using a voice track. This works not only on people but also on objects. Also better seamless transitions between Vace sliding windows for very long videos (see recommended settings). New high quality processing features (mixed 16/32 bits calculation and 32 bitsVAE)
|
||||||
* April 27 2025: 👋 Wan 2.1GP v4.4: Phantom model support, very good model to transfer people or objects into video, works quite well at 720p and with the number of steps > 30
|
* April 27 2025: 👋 Wan 2.1GP v4.4: Phantom model support, very good model to transfer people or objects into video, works quite well at 720p and with the number of steps > 30
|
||||||
* April 25 2025: 👋 Wan 2.1GP v4.3: Added preview mode and support for Sky Reels v2 Diffusion Forcing for high quality "infinite length videos" (see Window Sliding section below).Note that Skyreel uses causal attention that is only supported by Sdpa attention so even if chose an other type of attention, some of the processes will use Sdpa attention.
|
* April 25 2025: 👋 Wan 2.1GP v4.3: Added preview mode and support for Sky Reels v2 Diffusion Forcing for high quality "infinite length videos" (see Window Sliding section below).Note that Skyreel uses causal attention that is only supported by Sdpa attention so even if chose an other type of attention, some of the processes will use Sdpa attention.
|
||||||
|
|
||||||
@ -303,7 +304,13 @@ Vace provides on its github (https://github.com/ali-vilab/VACE/tree/main/vace/gr
|
|||||||
|
|
||||||
There is also a guide that describes the various combination of hints (https://github.com/ali-vilab/VACE/blob/main/UserGuide.md).Good luck !
|
There is also a guide that describes the various combination of hints (https://github.com/ali-vilab/VACE/blob/main/UserGuide.md).Good luck !
|
||||||
|
|
||||||
It seems you will get better results if you turn on "Skip Layer Guidance" with its default configuration
|
It seems you will get better results with Vace if you turn on "Skip Layer Guidance" with its default configuration.
|
||||||
|
|
||||||
|
Other recommended setttings for Vace:
|
||||||
|
- Use a long prompt description especially for the people / objects that are in the background and not in reference images. This will ensure consistency between the windows.
|
||||||
|
- Set a medium size overlap window: long enough to give the model a sense of the motion but short enough so any overlapped blurred frames do no turn the rest of the video into a blurred video
|
||||||
|
- Truncate at least the last 4 frames of the each generated window as Vace last frames tends to be blurry
|
||||||
|
|
||||||
|
|
||||||
### VACE and Sky Reels v2 Diffusion Forcing Slidig Window
|
### VACE and Sky Reels v2 Diffusion Forcing Slidig Window
|
||||||
With this mode (that works for the moment only with Vace and Sky Reels v2) you can merge mutiple Videos to form a very long video (up to 1 min).
|
With this mode (that works for the moment only with Vace and Sky Reels v2) you can merge mutiple Videos to form a very long video (up to 1 min).
|
||||||
|
|||||||
27
fantasytalking/infer.py
Normal file
27
fantasytalking/infer.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
# Copyright Alibaba Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
from transformers import Wav2Vec2Model, Wav2Vec2Processor
|
||||||
|
|
||||||
|
from .model import FantasyTalkingAudioConditionModel
|
||||||
|
from .utils import get_audio_features
|
||||||
|
|
||||||
|
|
||||||
|
def parse_audio(audio_path, num_frames, fps = 23, device = "cuda"):
|
||||||
|
fantasytalking = FantasyTalkingAudioConditionModel(None, 768, 2048).to(device)
|
||||||
|
from mmgp import offload
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from fantasytalking.model import AudioProjModel
|
||||||
|
with init_empty_weights():
|
||||||
|
proj_model = AudioProjModel( 768, 2048)
|
||||||
|
offload.load_model_data(proj_model, "ckpts/fantasy_proj_model.safetensors")
|
||||||
|
proj_model.to(device).eval().requires_grad_(False)
|
||||||
|
|
||||||
|
wav2vec_model_dir = "ckpts/wav2vec"
|
||||||
|
wav2vec_processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_dir)
|
||||||
|
wav2vec = Wav2Vec2Model.from_pretrained(wav2vec_model_dir).to(device).eval().requires_grad_(False)
|
||||||
|
audio_wav2vec_fea = get_audio_features( wav2vec, wav2vec_processor, audio_path, fps, num_frames )
|
||||||
|
|
||||||
|
audio_proj_fea = proj_model(audio_wav2vec_fea)
|
||||||
|
pos_idx_ranges = fantasytalking.split_audio_sequence( audio_proj_fea.size(1), num_frames=num_frames )
|
||||||
|
audio_proj_split, audio_context_lens = fantasytalking.split_tensor_with_padding( audio_proj_fea, pos_idx_ranges, expand_length=4 ) # [b,21,9+8,768]
|
||||||
|
return audio_proj_split, audio_context_lens
|
||||||
162
fantasytalking/model.py
Normal file
162
fantasytalking/model.py
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from wan.modules.attention import pay_attention
|
||||||
|
|
||||||
|
|
||||||
|
class AudioProjModel(nn.Module):
|
||||||
|
def __init__(self, audio_in_dim=1024, cross_attention_dim=1024):
|
||||||
|
super().__init__()
|
||||||
|
self.cross_attention_dim = cross_attention_dim
|
||||||
|
self.proj = torch.nn.Linear(audio_in_dim, cross_attention_dim, bias=False)
|
||||||
|
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
||||||
|
|
||||||
|
def forward(self, audio_embeds):
|
||||||
|
context_tokens = self.proj(audio_embeds)
|
||||||
|
context_tokens = self.norm(context_tokens)
|
||||||
|
return context_tokens # [B,L,C]
|
||||||
|
|
||||||
|
class WanCrossAttentionProcessor(nn.Module):
|
||||||
|
def __init__(self, context_dim, hidden_dim):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.context_dim = context_dim
|
||||||
|
self.hidden_dim = hidden_dim
|
||||||
|
|
||||||
|
self.k_proj = nn.Linear(context_dim, hidden_dim, bias=False)
|
||||||
|
self.v_proj = nn.Linear(context_dim, hidden_dim, bias=False)
|
||||||
|
|
||||||
|
nn.init.zeros_(self.k_proj.weight)
|
||||||
|
nn.init.zeros_(self.v_proj.weight)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
audio_proj: torch.Tensor,
|
||||||
|
latents_num_frames: int = 21,
|
||||||
|
audio_context_lens = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
audio_proj: [B, 21, L3, C]
|
||||||
|
audio_context_lens: [B*21].
|
||||||
|
"""
|
||||||
|
b, l, n, d = q.shape
|
||||||
|
|
||||||
|
if len(audio_proj.shape) == 4:
|
||||||
|
audio_q = q.view(b * latents_num_frames, -1, n, d) # [b, 21, l1, n, d]
|
||||||
|
ip_key = self.k_proj(audio_proj).view(b * latents_num_frames, -1, n, d)
|
||||||
|
ip_value = self.v_proj(audio_proj).view(b * latents_num_frames, -1, n, d)
|
||||||
|
qkv_list = [audio_q, ip_key, ip_value]
|
||||||
|
del q, audio_q, ip_key, ip_value
|
||||||
|
audio_x = pay_attention(qkv_list, k_lens =audio_context_lens) #audio_context_lens
|
||||||
|
audio_x = audio_x.view(b, l, n, d)
|
||||||
|
audio_x = audio_x.flatten(2)
|
||||||
|
elif len(audio_proj.shape) == 3:
|
||||||
|
ip_key = self.k_proj(audio_proj).view(b, -1, n, d)
|
||||||
|
ip_value = self.v_proj(audio_proj).view(b, -1, n, d)
|
||||||
|
qkv_list = [q, ip_key, ip_value]
|
||||||
|
del q, ip_key, ip_value
|
||||||
|
audio_x = pay_attention(qkv_list, k_lens =audio_context_lens) #audio_context_lens
|
||||||
|
audio_x = audio_x.flatten(2)
|
||||||
|
return audio_x
|
||||||
|
|
||||||
|
|
||||||
|
class FantasyTalkingAudioConditionModel(nn.Module):
|
||||||
|
def __init__(self, wan_dit, audio_in_dim: int, audio_proj_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.audio_in_dim = audio_in_dim
|
||||||
|
self.audio_proj_dim = audio_proj_dim
|
||||||
|
|
||||||
|
def split_audio_sequence(self, audio_proj_length, num_frames=81):
|
||||||
|
"""
|
||||||
|
Map the audio feature sequence to corresponding latent frame slices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_proj_length (int): The total length of the audio feature sequence
|
||||||
|
(e.g., 173 in audio_proj[1, 173, 768]).
|
||||||
|
num_frames (int): The number of video frames in the training data (default: 81).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: A list of [start_idx, end_idx] pairs. Each pair represents the index range
|
||||||
|
(within the audio feature sequence) corresponding to a latent frame.
|
||||||
|
"""
|
||||||
|
# Average number of tokens per original video frame
|
||||||
|
tokens_per_frame = audio_proj_length / num_frames
|
||||||
|
|
||||||
|
# Each latent frame covers 4 video frames, and we want the center
|
||||||
|
tokens_per_latent_frame = tokens_per_frame * 4
|
||||||
|
half_tokens = int(tokens_per_latent_frame / 2)
|
||||||
|
|
||||||
|
pos_indices = []
|
||||||
|
for i in range(int((num_frames - 1) / 4) + 1):
|
||||||
|
if i == 0:
|
||||||
|
pos_indices.append(0)
|
||||||
|
else:
|
||||||
|
start_token = tokens_per_frame * ((i - 1) * 4 + 1)
|
||||||
|
end_token = tokens_per_frame * (i * 4 + 1)
|
||||||
|
center_token = int((start_token + end_token) / 2) - 1
|
||||||
|
pos_indices.append(center_token)
|
||||||
|
|
||||||
|
# Build index ranges centered around each position
|
||||||
|
pos_idx_ranges = [[idx - half_tokens, idx + half_tokens] for idx in pos_indices]
|
||||||
|
|
||||||
|
# Adjust the first range to avoid negative start index
|
||||||
|
pos_idx_ranges[0] = [
|
||||||
|
-(half_tokens * 2 - pos_idx_ranges[1][0]),
|
||||||
|
pos_idx_ranges[1][0],
|
||||||
|
]
|
||||||
|
|
||||||
|
return pos_idx_ranges
|
||||||
|
|
||||||
|
def split_tensor_with_padding(self, input_tensor, pos_idx_ranges, expand_length=0):
|
||||||
|
"""
|
||||||
|
Split the input tensor into subsequences based on index ranges, and apply right-side zero-padding
|
||||||
|
if the range exceeds the input boundaries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_tensor (Tensor): Input audio tensor of shape [1, L, 768].
|
||||||
|
pos_idx_ranges (list): A list of index ranges, e.g. [[-7, 1], [1, 9], ..., [165, 173]].
|
||||||
|
expand_length (int): Number of tokens to expand on both sides of each subsequence.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
sub_sequences (Tensor): A tensor of shape [1, F, L, 768], where L is the length after padding.
|
||||||
|
Each element is a padded subsequence.
|
||||||
|
k_lens (Tensor): A tensor of shape [F], representing the actual (unpadded) length of each subsequence.
|
||||||
|
Useful for ignoring padding tokens in attention masks.
|
||||||
|
"""
|
||||||
|
pos_idx_ranges = [
|
||||||
|
[idx[0] - expand_length, idx[1] + expand_length] for idx in pos_idx_ranges
|
||||||
|
]
|
||||||
|
sub_sequences = []
|
||||||
|
seq_len = input_tensor.size(1) # 173
|
||||||
|
max_valid_idx = seq_len - 1 # 172
|
||||||
|
k_lens_list = []
|
||||||
|
for start, end in pos_idx_ranges:
|
||||||
|
# Calculate the fill amount
|
||||||
|
pad_front = max(-start, 0)
|
||||||
|
pad_back = max(end - max_valid_idx, 0)
|
||||||
|
|
||||||
|
# Calculate the start and end indices of the valid part
|
||||||
|
valid_start = max(start, 0)
|
||||||
|
valid_end = min(end, max_valid_idx)
|
||||||
|
|
||||||
|
# Extract the valid part
|
||||||
|
if valid_start <= valid_end:
|
||||||
|
valid_part = input_tensor[:, valid_start : valid_end + 1, :]
|
||||||
|
else:
|
||||||
|
valid_part = input_tensor.new_zeros((1, 0, input_tensor.size(2)))
|
||||||
|
|
||||||
|
# In the sequence dimension (the 1st dimension) perform padding
|
||||||
|
padded_subseq = F.pad(
|
||||||
|
valid_part,
|
||||||
|
(0, 0, 0, pad_back + pad_front, 0, 0),
|
||||||
|
mode="constant",
|
||||||
|
value=0,
|
||||||
|
)
|
||||||
|
k_lens_list.append(padded_subseq.size(-2) - pad_back - pad_front)
|
||||||
|
|
||||||
|
sub_sequences.append(padded_subseq)
|
||||||
|
return torch.stack(sub_sequences, dim=1), torch.tensor(
|
||||||
|
k_lens_list, dtype=torch.long
|
||||||
|
)
|
||||||
52
fantasytalking/utils.py
Normal file
52
fantasytalking/utils.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
# Copyright Alibaba Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
import imageio
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def resize_image_by_longest_edge(image_path, target_size):
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
width, height = image.size
|
||||||
|
scale = target_size / max(width, height)
|
||||||
|
new_size = (int(width * scale), int(height * scale))
|
||||||
|
return image.resize(new_size, Image.LANCZOS)
|
||||||
|
|
||||||
|
|
||||||
|
def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
|
||||||
|
writer = imageio.get_writer(
|
||||||
|
save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params
|
||||||
|
)
|
||||||
|
for frame in tqdm(frames, desc="Saving video"):
|
||||||
|
frame = np.array(frame)
|
||||||
|
writer.append_data(frame)
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
|
||||||
|
def get_audio_features(wav2vec, audio_processor, audio_path, fps, num_frames):
|
||||||
|
sr = 16000
|
||||||
|
audio_input, sample_rate = librosa.load(audio_path, sr=sr) # 采样率为 16kHz
|
||||||
|
|
||||||
|
start_time = 0
|
||||||
|
# end_time = (0 + (num_frames - 1) * 1) / fps
|
||||||
|
end_time = num_frames / fps
|
||||||
|
|
||||||
|
start_sample = int(start_time * sr)
|
||||||
|
end_sample = int(end_time * sr)
|
||||||
|
|
||||||
|
try:
|
||||||
|
audio_segment = audio_input[start_sample:end_sample]
|
||||||
|
except:
|
||||||
|
audio_segment = audio_input
|
||||||
|
|
||||||
|
input_values = audio_processor(
|
||||||
|
audio_segment, sampling_rate=sample_rate, return_tensors="pt"
|
||||||
|
).input_values.to("cuda")
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
fea = wav2vec(input_values).last_hidden_state
|
||||||
|
|
||||||
|
return fea
|
||||||
@ -16,7 +16,7 @@ gradio==5.23.0
|
|||||||
numpy>=1.23.5,<2
|
numpy>=1.23.5,<2
|
||||||
einops
|
einops
|
||||||
moviepy==1.0.3
|
moviepy==1.0.3
|
||||||
mmgp==3.4.1
|
mmgp==3.4.2
|
||||||
peft==0.14.0
|
peft==0.14.0
|
||||||
mutagen
|
mutagen
|
||||||
pydantic==2.10.6
|
pydantic==2.10.6
|
||||||
@ -28,4 +28,5 @@ timm
|
|||||||
segment-anything
|
segment-anything
|
||||||
omegaconf
|
omegaconf
|
||||||
hydra-core
|
hydra-core
|
||||||
|
librosa
|
||||||
# rembg==2.0.65
|
# rembg==2.0.65
|
||||||
|
|||||||
@ -44,6 +44,8 @@ SUPPORTED_SIZES = {
|
|||||||
VACE_SIZE_CONFIGS = {
|
VACE_SIZE_CONFIGS = {
|
||||||
'480*832': (480, 832),
|
'480*832': (480, 832),
|
||||||
'832*480': (832, 480),
|
'832*480': (832, 480),
|
||||||
|
'720*1280': (720, 1280),
|
||||||
|
'1280*720': (1280, 720),
|
||||||
}
|
}
|
||||||
|
|
||||||
VACE_MAX_AREA_CONFIGS = {
|
VACE_MAX_AREA_CONFIGS = {
|
||||||
|
|||||||
@ -56,16 +56,18 @@ class DTT2V:
|
|||||||
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype,
|
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
|
|
||||||
logging.info(f"Creating WanModel from {model_filename}")
|
logging.info(f"Creating WanModel from {model_filename[-1]}")
|
||||||
from mmgp import offload
|
from mmgp import offload
|
||||||
# model_filename = "model.safetensors"
|
# model_filename = "model.safetensors"
|
||||||
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) #, forcedConfigPath="config.json")
|
# model_filename = "c:/temp/diffusion_pytorch_model-00001-of-00006.safetensors"
|
||||||
|
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) # , forcedConfigPath="c:/temp/config _df720.json")
|
||||||
# offload.load_model_data(self.model, "recam.ckpt")
|
# offload.load_model_data(self.model, "recam.ckpt")
|
||||||
# self.model.cpu()
|
# self.model.cpu()
|
||||||
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype, True)
|
# dtype = torch.float16
|
||||||
|
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
|
||||||
offload.change_dtype(self.model, dtype, True)
|
offload.change_dtype(self.model, dtype, True)
|
||||||
# offload.save_model(self.model, "sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors", config_file_path="config.json")
|
# offload.save_model(self.model, "sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors", config_file_path="config.json")
|
||||||
# offload.save_model(self.model, "sky_reels2_diffusion_forcing_720p_14B_quanto_xbf16_int8.safetensors", do_quantize= True, config_file_path="config.json")
|
# offload.save_model(self.model, "sky_reels2_diffusion_forcing_720p_14B_quanto_mbf16_int8.safetensors", do_quantize= True, config_file_path="c:/temp/config _df720.json")
|
||||||
# offload.save_model(self.model, "rtfp16_int8.safetensors", do_quantize= "config.json")
|
# offload.save_model(self.model, "rtfp16_int8.safetensors", do_quantize= "config.json")
|
||||||
|
|
||||||
self.model.eval().requires_grad_(False)
|
self.model.eval().requires_grad_(False)
|
||||||
@ -200,6 +202,9 @@ class DTT2V:
|
|||||||
fps: int = 24,
|
fps: int = 24,
|
||||||
VAE_tile_size = 0,
|
VAE_tile_size = 0,
|
||||||
joint_pass = False,
|
joint_pass = False,
|
||||||
|
slg_layers = None,
|
||||||
|
slg_start = 0.0,
|
||||||
|
slg_end = 1.0,
|
||||||
callback = None,
|
callback = None,
|
||||||
):
|
):
|
||||||
self._interrupt = False
|
self._interrupt = False
|
||||||
@ -211,6 +216,7 @@ class DTT2V:
|
|||||||
|
|
||||||
if ar_step == 0:
|
if ar_step == 0:
|
||||||
causal_block_size = 1
|
causal_block_size = 1
|
||||||
|
causal_attention = False
|
||||||
|
|
||||||
i2v_extra_kwrags = {}
|
i2v_extra_kwrags = {}
|
||||||
prefix_video = None
|
prefix_video = None
|
||||||
@ -252,31 +258,33 @@ class DTT2V:
|
|||||||
prefix_video = output_video.to(self.device)
|
prefix_video = output_video.to(self.device)
|
||||||
else:
|
else:
|
||||||
causal_block_size = 1
|
causal_block_size = 1
|
||||||
|
causal_attention = False
|
||||||
ar_step = 0
|
ar_step = 0
|
||||||
prefix_video = image
|
prefix_video = image
|
||||||
prefix_video = torch.tensor(prefix_video).unsqueeze(1) # .to(image_embeds.dtype).unsqueeze(1)
|
prefix_video = torch.tensor(prefix_video).unsqueeze(1) # .to(image_embeds.dtype).unsqueeze(1)
|
||||||
if prefix_video.dtype == torch.uint8:
|
if prefix_video.dtype == torch.uint8:
|
||||||
prefix_video = (prefix_video.float() / (255.0 / 2.0)) - 1.0
|
prefix_video = (prefix_video.float() / (255.0 / 2.0)) - 1.0
|
||||||
prefix_video = prefix_video.to(self.device)
|
prefix_video = prefix_video.to(self.device)
|
||||||
prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)]
|
prefix_video = self.vae.encode(prefix_video.unsqueeze(0))[0] # [(c, f, h, w)]
|
||||||
predix_video_latent_length = prefix_video[0].shape[1]
|
predix_video_latent_length = prefix_video.shape[1]
|
||||||
truncate_len = predix_video_latent_length % causal_block_size
|
truncate_len = predix_video_latent_length % causal_block_size
|
||||||
if truncate_len != 0:
|
if truncate_len != 0:
|
||||||
if truncate_len == predix_video_latent_length:
|
if truncate_len == predix_video_latent_length:
|
||||||
causal_block_size = 1
|
causal_block_size = 1
|
||||||
|
causal_attention = False
|
||||||
|
ar_step = 0
|
||||||
else:
|
else:
|
||||||
print("the length of prefix video is truncated for the casual block size alignment.")
|
print("the length of prefix video is truncated for the casual block size alignment.")
|
||||||
predix_video_latent_length -= truncate_len
|
predix_video_latent_length -= truncate_len
|
||||||
prefix_video[0] = prefix_video[0][:, : predix_video_latent_length]
|
prefix_video = prefix_video[:, : predix_video_latent_length]
|
||||||
|
|
||||||
base_num_frames_iter = latent_length
|
base_num_frames_iter = latent_length
|
||||||
latent_shape = [16, base_num_frames_iter, latent_height, latent_width]
|
latent_shape = [16, base_num_frames_iter, latent_height, latent_width]
|
||||||
latents = self.prepare_latents(
|
latents = self.prepare_latents(
|
||||||
latent_shape, dtype=torch.float32, device=self.device, generator=generator
|
latent_shape, dtype=torch.float32, device=self.device, generator=generator
|
||||||
)
|
)
|
||||||
latents = [latents]
|
|
||||||
if prefix_video is not None:
|
if prefix_video is not None:
|
||||||
latents[0][:, :predix_video_latent_length] = prefix_video[0].to(torch.float32)
|
latents[:, :predix_video_latent_length] = prefix_video.to(torch.float32)
|
||||||
step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
|
step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
|
||||||
base_num_frames_iter,
|
base_num_frames_iter,
|
||||||
init_timesteps,
|
init_timesteps,
|
||||||
@ -298,6 +306,8 @@ class DTT2V:
|
|||||||
if callback != None:
|
if callback != None:
|
||||||
callback(-1, None, True, override_num_inference_steps = updated_num_steps)
|
callback(-1, None, True, override_num_inference_steps = updated_num_steps)
|
||||||
if self.model.enable_teacache:
|
if self.model.enable_teacache:
|
||||||
|
x_count = 2 if self.do_classifier_free_guidance else 1
|
||||||
|
self.model.previous_residual = [None] * x_count
|
||||||
time_steps_comb = []
|
time_steps_comb = []
|
||||||
self.model.num_steps = updated_num_steps
|
self.model.num_steps = updated_num_steps
|
||||||
for i, timestep_i in enumerate(step_matrix):
|
for i, timestep_i in enumerate(step_matrix):
|
||||||
@ -309,7 +319,7 @@ class DTT2V:
|
|||||||
self.model.compute_teacache_threshold(self.model.teacache_start_step, time_steps_comb, self.model.teacache_multiplier)
|
self.model.compute_teacache_threshold(self.model.teacache_start_step, time_steps_comb, self.model.teacache_multiplier)
|
||||||
del time_steps_comb
|
del time_steps_comb
|
||||||
from mmgp import offload
|
from mmgp import offload
|
||||||
freqs = get_rotary_pos_embed(latents[0].shape[1 :], enable_RIFLEx= False)
|
freqs = get_rotary_pos_embed(latents.shape[1 :], enable_RIFLEx= False)
|
||||||
kwrags = {
|
kwrags = {
|
||||||
"freqs" :freqs,
|
"freqs" :freqs,
|
||||||
"fps" : fps_embeds,
|
"fps" : fps_embeds,
|
||||||
@ -320,27 +330,27 @@ class DTT2V:
|
|||||||
}
|
}
|
||||||
kwrags.update(i2v_extra_kwrags)
|
kwrags.update(i2v_extra_kwrags)
|
||||||
|
|
||||||
|
|
||||||
for i, timestep_i in enumerate(tqdm(step_matrix)):
|
for i, timestep_i in enumerate(tqdm(step_matrix)):
|
||||||
|
kwrags["slg_layers"] = slg_layers if int(slg_start * updated_num_steps) <= i < int(slg_end * updated_num_steps) else None
|
||||||
|
|
||||||
offload.set_step_no_for_lora(self.model, i)
|
offload.set_step_no_for_lora(self.model, i)
|
||||||
update_mask_i = step_update_mask[i]
|
update_mask_i = step_update_mask[i]
|
||||||
valid_interval_start, valid_interval_end = valid_interval[i]
|
valid_interval_start, valid_interval_end = valid_interval[i]
|
||||||
timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
|
timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
|
||||||
latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()]
|
latent_model_input = latents[:, valid_interval_start:valid_interval_end, :, :].clone()
|
||||||
if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length:
|
if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length:
|
||||||
noise_factor = 0.001 * addnoise_condition
|
noise_factor = 0.001 * addnoise_condition
|
||||||
timestep_for_noised_condition = addnoise_condition
|
timestep_for_noised_condition = addnoise_condition
|
||||||
latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = (
|
latent_model_input[:, valid_interval_start:predix_video_latent_length] = (
|
||||||
latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
|
latent_model_input[:, valid_interval_start:predix_video_latent_length]
|
||||||
* (1.0 - noise_factor)
|
* (1.0 - noise_factor)
|
||||||
+ torch.randn_like(
|
+ torch.randn_like(
|
||||||
latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
|
latent_model_input[:, valid_interval_start:predix_video_latent_length]
|
||||||
)
|
)
|
||||||
* noise_factor
|
* noise_factor
|
||||||
)
|
)
|
||||||
timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
|
timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
|
||||||
kwrags.update({
|
kwrags.update({
|
||||||
"x" : torch.stack([latent_model_input[0]]),
|
|
||||||
"t" : timestep,
|
"t" : timestep,
|
||||||
"current_step" : i,
|
"current_step" : i,
|
||||||
})
|
})
|
||||||
@ -349,6 +359,7 @@ class DTT2V:
|
|||||||
if True:
|
if True:
|
||||||
if not self.do_classifier_free_guidance:
|
if not self.do_classifier_free_guidance:
|
||||||
noise_pred = self.model(
|
noise_pred = self.model(
|
||||||
|
x=[latent_model_input],
|
||||||
context=[prompt_embeds],
|
context=[prompt_embeds],
|
||||||
**kwrags,
|
**kwrags,
|
||||||
)[0]
|
)[0]
|
||||||
@ -358,6 +369,7 @@ class DTT2V:
|
|||||||
else:
|
else:
|
||||||
if joint_pass:
|
if joint_pass:
|
||||||
noise_pred_cond, noise_pred_uncond = self.model(
|
noise_pred_cond, noise_pred_uncond = self.model(
|
||||||
|
x=[latent_model_input, latent_model_input],
|
||||||
context= [prompt_embeds, negative_prompt_embeds],
|
context= [prompt_embeds, negative_prompt_embeds],
|
||||||
**kwrags,
|
**kwrags,
|
||||||
)
|
)
|
||||||
@ -365,12 +377,16 @@ class DTT2V:
|
|||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
noise_pred_cond = self.model(
|
noise_pred_cond = self.model(
|
||||||
|
x=[latent_model_input],
|
||||||
|
x_id=0,
|
||||||
context=[prompt_embeds],
|
context=[prompt_embeds],
|
||||||
**kwrags,
|
**kwrags,
|
||||||
)[0]
|
)[0]
|
||||||
if self._interrupt:
|
if self._interrupt:
|
||||||
return None
|
return None
|
||||||
noise_pred_uncond = self.model(
|
noise_pred_uncond = self.model(
|
||||||
|
x=[latent_model_input],
|
||||||
|
x_id=1,
|
||||||
context=[negative_prompt_embeds],
|
context=[negative_prompt_embeds],
|
||||||
**kwrags,
|
**kwrags,
|
||||||
)[0]
|
)[0]
|
||||||
@ -380,18 +396,18 @@ class DTT2V:
|
|||||||
del noise_pred_cond, noise_pred_uncond
|
del noise_pred_cond, noise_pred_uncond
|
||||||
for idx in range(valid_interval_start, valid_interval_end):
|
for idx in range(valid_interval_start, valid_interval_end):
|
||||||
if update_mask_i[idx].item():
|
if update_mask_i[idx].item():
|
||||||
latents[0][:, idx] = sample_schedulers[idx].step(
|
latents[:, idx] = sample_schedulers[idx].step(
|
||||||
noise_pred[:, idx - valid_interval_start],
|
noise_pred[:, idx - valid_interval_start],
|
||||||
timestep_i[idx],
|
timestep_i[idx],
|
||||||
latents[0][:, idx],
|
latents[:, idx],
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
generator=generator,
|
generator=generator,
|
||||||
)[0]
|
)[0]
|
||||||
sample_schedulers_counter[idx] += 1
|
sample_schedulers_counter[idx] += 1
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback(i, latents[0].squeeze(0), False)
|
callback(i, latents.squeeze(0), False)
|
||||||
|
|
||||||
x0 = latents[0].unsqueeze(0)
|
x0 = latents.unsqueeze(0)
|
||||||
videos = [self.vae.decode(x0, tile_size= VAE_tile_size)[0]]
|
videos = [self.vae.decode(x0, tile_size= VAE_tile_size)[0]]
|
||||||
output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w
|
output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w
|
||||||
return output_video
|
return output_video
|
||||||
|
|||||||
@ -8,7 +8,7 @@ import sys
|
|||||||
import types
|
import types
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda.amp as amp
|
import torch.cuda.amp as amp
|
||||||
@ -84,13 +84,29 @@ class WanI2V:
|
|||||||
config.clip_checkpoint),
|
config.clip_checkpoint),
|
||||||
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
|
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
|
||||||
|
|
||||||
logging.info(f"Creating WanModel from {model_filename}")
|
logging.info(f"Creating WanModel from {model_filename[-1]}")
|
||||||
from mmgp import offload
|
from mmgp import offload
|
||||||
|
|
||||||
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False)
|
# fantasy = torch.load("c:/temp/fantasy.ckpt")
|
||||||
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype, True)
|
# proj_model = fantasy["proj_model"]
|
||||||
offload.change_dtype(self.model, dtype, True)
|
# audio_processor = fantasy["audio_processor"]
|
||||||
# offload.save_model(self.model, "i2v_720p_fp16.safetensors",do_quantize=True)
|
# offload.safetensors2.torch_write_file(proj_model, "proj_model.safetensors")
|
||||||
|
# offload.safetensors2.torch_write_file(audio_processor, "audio_processor.safetensors")
|
||||||
|
# for k,v in audio_processor.items():
|
||||||
|
# audio_processor[k] = v.to(torch.bfloat16)
|
||||||
|
# with open("fantasy_config.json", "r", encoding="utf-8") as reader:
|
||||||
|
# config_text = reader.read()
|
||||||
|
# config_json = json.loads(config_text)
|
||||||
|
# offload.safetensors2.torch_write_file(audio_processor, "audio_processor_bf16.safetensors", config=config_json)
|
||||||
|
# model_filename = [model_filename, "audio_processor_bf16.safetensors"]
|
||||||
|
# model_filename = "c:/temp/i2v480p/diffusion_pytorch_model-00001-of-00007.safetensors"
|
||||||
|
# dtype = torch.float16
|
||||||
|
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) #, forcedConfigPath= "c:/temp/i2v720p/config.json")
|
||||||
|
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
|
||||||
|
# offload.change_dtype(self.model, dtype, True)
|
||||||
|
# offload.save_model(self.model, "wan2.1_image2video_720p_14B_mbf16.safetensors", config_file_path="c:/temp/i2v720p/config.json")
|
||||||
|
# offload.save_model(self.model, "wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors",do_quantize=True, config_file_path="c:/temp/i2v720p/config.json")
|
||||||
|
# offload.save_model(self.model, "wan2.1_image2video_720p_14B_quanto_mfp16_int8.safetensors",do_quantize=True, config_file_path="c:/temp/i2v720p/config.json")
|
||||||
|
|
||||||
# offload.save_model(self.model, "wan2.1_Fun_InP_1.3B_bf16_bis.safetensors")
|
# offload.save_model(self.model, "wan2.1_Fun_InP_1.3B_bf16_bis.safetensors")
|
||||||
self.model.eval().requires_grad_(False)
|
self.model.eval().requires_grad_(False)
|
||||||
@ -102,6 +118,8 @@ class WanI2V:
|
|||||||
input_prompt,
|
input_prompt,
|
||||||
img,
|
img,
|
||||||
img2 = None,
|
img2 = None,
|
||||||
|
height =720,
|
||||||
|
width = 1280,
|
||||||
max_area=720 * 1280,
|
max_area=720 * 1280,
|
||||||
frame_num=81,
|
frame_num=81,
|
||||||
shift=5.0,
|
shift=5.0,
|
||||||
@ -119,7 +137,11 @@ class WanI2V:
|
|||||||
slg_end = 1.0,
|
slg_end = 1.0,
|
||||||
cfg_star_switch = True,
|
cfg_star_switch = True,
|
||||||
cfg_zero_step = 5,
|
cfg_zero_step = 5,
|
||||||
add_frames_for_end_image = True
|
add_frames_for_end_image = True,
|
||||||
|
audio_scale=None,
|
||||||
|
audio_cfg_scale=None,
|
||||||
|
audio_proj=None,
|
||||||
|
audio_context_lens=None,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Generates video frames from input image and text prompt using diffusion process.
|
Generates video frames from input image and text prompt using diffusion process.
|
||||||
@ -167,13 +189,21 @@ class WanI2V:
|
|||||||
frame_num +=1
|
frame_num +=1
|
||||||
lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2)
|
lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2)
|
||||||
|
|
||||||
|
|
||||||
h, w = img.shape[1:]
|
h, w = img.shape[1:]
|
||||||
aspect_ratio = h / w
|
# aspect_ratio = h / w
|
||||||
|
|
||||||
|
scale1 = min(height / h, width / w)
|
||||||
|
scale2 = min(height / h, width / w)
|
||||||
|
scale = max(scale1, scale2)
|
||||||
|
new_height = int(h * scale)
|
||||||
|
new_width = int(w * scale)
|
||||||
|
|
||||||
lat_h = round(
|
lat_h = round(
|
||||||
np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
|
new_height // self.vae_stride[1] //
|
||||||
self.patch_size[1] * self.patch_size[1])
|
self.patch_size[1] * self.patch_size[1])
|
||||||
lat_w = round(
|
lat_w = round(
|
||||||
np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
|
new_width // self.vae_stride[2] //
|
||||||
self.patch_size[2] * self.patch_size[2])
|
self.patch_size[2] * self.patch_size[2])
|
||||||
h = lat_h * self.vae_stride[1]
|
h = lat_h * self.vae_stride[1]
|
||||||
w = lat_w * self.vae_stride[2]
|
w = lat_w * self.vae_stride[2]
|
||||||
@ -271,98 +301,101 @@ class WanI2V:
|
|||||||
|
|
||||||
# sample videos
|
# sample videos
|
||||||
latent = noise
|
latent = noise
|
||||||
batch_size = latent.shape[0]
|
batch_size = 1
|
||||||
freqs = get_rotary_pos_embed(latent.shape[1:], enable_RIFLEx= enable_RIFLEx)
|
freqs = get_rotary_pos_embed(latent.shape[1:], enable_RIFLEx= enable_RIFLEx)
|
||||||
|
|
||||||
arg_c = {
|
kwargs = { 'clip_fea': clip_context, 'y': y, 'freqs' : freqs, 'pipeline' : self, 'callback' : callback }
|
||||||
'context': [context],
|
|
||||||
'clip_fea': clip_context,
|
|
||||||
'y': [y],
|
|
||||||
'freqs' : freqs,
|
|
||||||
'pipeline' : self,
|
|
||||||
'callback' : callback
|
|
||||||
}
|
|
||||||
|
|
||||||
arg_null = {
|
if audio_proj != None:
|
||||||
'context': [context_null],
|
kwargs.update({
|
||||||
'clip_fea': clip_context,
|
"audio_proj": audio_proj.to(self.dtype),
|
||||||
'y': [y],
|
"audio_context_lens": audio_context_lens,
|
||||||
'freqs' : freqs,
|
})
|
||||||
'pipeline' : self,
|
|
||||||
'callback' : callback
|
|
||||||
}
|
|
||||||
|
|
||||||
arg_both= {
|
|
||||||
'context': [context, context_null],
|
|
||||||
'clip_fea': clip_context,
|
|
||||||
'y': [y],
|
|
||||||
'freqs' : freqs,
|
|
||||||
'pipeline' : self,
|
|
||||||
'callback' : callback
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.model.enable_teacache:
|
if self.model.enable_teacache:
|
||||||
|
self.model.previous_residual = [None] * (3 if audio_cfg_scale !=None else 2)
|
||||||
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
||||||
|
|
||||||
# self.model.to(self.device)
|
# self.model.to(self.device)
|
||||||
if callback != None:
|
if callback != None:
|
||||||
callback(-1, None, True)
|
callback(-1, None, True)
|
||||||
|
latent = latent.to(self.device)
|
||||||
for i, t in enumerate(tqdm(timesteps)):
|
for i, t in enumerate(tqdm(timesteps)):
|
||||||
offload.set_step_no_for_lora(self.model, i)
|
offload.set_step_no_for_lora(self.model, i)
|
||||||
slg_layers_local = None
|
kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None
|
||||||
if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps):
|
latent_model_input = latent
|
||||||
slg_layers_local = slg_layers
|
|
||||||
|
|
||||||
latent_model_input = [latent.to(self.device)]
|
|
||||||
timestep = [t]
|
timestep = [t]
|
||||||
|
|
||||||
timestep = torch.stack(timestep).to(self.device)
|
timestep = torch.stack(timestep).to(self.device)
|
||||||
|
kwargs.update({
|
||||||
|
't' :timestep,
|
||||||
|
'current_step' :i,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
if joint_pass:
|
if joint_pass:
|
||||||
noise_pred_cond, noise_pred_uncond = self.model(
|
if audio_proj == None:
|
||||||
latent_model_input, t=timestep, current_step=i, slg_layers=slg_layers_local, **arg_both)
|
noise_pred_cond, noise_pred_uncond = self.model(
|
||||||
|
[latent_model_input, latent_model_input],
|
||||||
|
context=[context, context_null],
|
||||||
|
**kwargs)
|
||||||
|
else:
|
||||||
|
noise_pred_cond, noise_pred_noaudio, noise_pred_uncond = self.model(
|
||||||
|
[latent_model_input, latent_model_input, latent_model_input],
|
||||||
|
context=[context, context, context_null],
|
||||||
|
audio_scale = [audio_scale, None, None ],
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
if self._interrupt:
|
if self._interrupt:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
noise_pred_cond = self.model(
|
noise_pred_cond = self.model(
|
||||||
latent_model_input,
|
[latent_model_input],
|
||||||
t=timestep,
|
context=[context],
|
||||||
current_step=i,
|
audio_scale = None if audio_scale == None else [audio_scale],
|
||||||
is_uncond=False,
|
x_id=0,
|
||||||
**arg_c,
|
**kwargs,
|
||||||
)[0]
|
)[0]
|
||||||
if self._interrupt:
|
if self._interrupt:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if audio_proj != None:
|
||||||
|
noise_pred_noaudio = self.model(
|
||||||
|
[latent_model_input],
|
||||||
|
x_id=1,
|
||||||
|
context=[context],
|
||||||
|
**kwargs,
|
||||||
|
)[0]
|
||||||
|
if self._interrupt:
|
||||||
|
return None
|
||||||
|
|
||||||
noise_pred_uncond = self.model(
|
noise_pred_uncond = self.model(
|
||||||
latent_model_input,
|
[latent_model_input],
|
||||||
t=timestep,
|
x_id=1 if audio_scale == None else 2,
|
||||||
current_step=i,
|
context=[context_null],
|
||||||
is_uncond=True,
|
**kwargs,
|
||||||
slg_layers=slg_layers_local,
|
|
||||||
**arg_null,
|
|
||||||
)[0]
|
)[0]
|
||||||
if self._interrupt:
|
if self._interrupt:
|
||||||
return None
|
return None
|
||||||
del latent_model_input
|
del latent_model_input
|
||||||
|
|
||||||
# CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
|
# CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
|
||||||
noise_pred_text = noise_pred_cond
|
|
||||||
if cfg_star_switch:
|
if cfg_star_switch:
|
||||||
positive_flat = noise_pred_text.view(batch_size, -1)
|
positive_flat = noise_pred_cond.view(batch_size, -1)
|
||||||
negative_flat = noise_pred_uncond.view(batch_size, -1)
|
negative_flat = noise_pred_uncond.view(batch_size, -1)
|
||||||
|
|
||||||
alpha = optimized_scale(positive_flat,negative_flat)
|
alpha = optimized_scale(positive_flat,negative_flat)
|
||||||
alpha = alpha.view(batch_size, 1, 1, 1)
|
alpha = alpha.view(batch_size, 1, 1, 1)
|
||||||
|
|
||||||
|
|
||||||
if (i <= cfg_zero_step):
|
if (i <= cfg_zero_step):
|
||||||
noise_pred = noise_pred_text*0. # it would be faster not to compute noise_pred...
|
noise_pred = noise_pred_cond*0. # it would be faster not to compute noise_pred...
|
||||||
else:
|
else:
|
||||||
noise_pred_uncond *= alpha
|
noise_pred_uncond *= alpha
|
||||||
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond)
|
if audio_scale == None:
|
||||||
|
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond)
|
||||||
del noise_pred_uncond
|
else:
|
||||||
|
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_noaudio - noise_pred_uncond) + audio_cfg_scale * (noise_pred_cond - noise_pred_noaudio)
|
||||||
|
noise_pred_uncond, noise_pred_noaudio = None, None
|
||||||
temp_x0 = sample_scheduler.step(
|
temp_x0 = sample_scheduler.step(
|
||||||
noise_pred.unsqueeze(0),
|
noise_pred.unsqueeze(0),
|
||||||
t,
|
t,
|
||||||
@ -376,9 +409,6 @@ class WanI2V:
|
|||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback(i, latent, False)
|
callback(i, latent, False)
|
||||||
|
|
||||||
|
|
||||||
# x0 = [latent.to(self.device, dtype=self.dtype)]
|
|
||||||
|
|
||||||
x0 = [latent]
|
x0 = [latent]
|
||||||
|
|
||||||
# x0 = [lat_y]
|
# x0 = [lat_y]
|
||||||
|
|||||||
@ -57,15 +57,15 @@ def sageattn_wrapper(
|
|||||||
):
|
):
|
||||||
q,k, v = qkv_list
|
q,k, v = qkv_list
|
||||||
padding_length = q.shape[0] -attention_length
|
padding_length = q.shape[0] -attention_length
|
||||||
q = q[:attention_length, :, : ].unsqueeze(0)
|
q = q[:attention_length, :, : ]
|
||||||
k = k[:attention_length, :, : ].unsqueeze(0)
|
k = k[:attention_length, :, : ]
|
||||||
v = v[:attention_length, :, : ].unsqueeze(0)
|
v = v[:attention_length, :, : ]
|
||||||
if True:
|
if True:
|
||||||
qkv_list = [q,k,v]
|
qkv_list = [q,k,v]
|
||||||
del q, k ,v
|
del q, k ,v
|
||||||
o = alt_sageattn(qkv_list, tensor_layout="NHD").squeeze(0)
|
o = alt_sageattn(qkv_list, tensor_layout="NHD")
|
||||||
else:
|
else:
|
||||||
o = sageattn(q, k, v, tensor_layout="NHD").squeeze(0)
|
o = sageattn(q, k, v, tensor_layout="NHD")
|
||||||
del q, k ,v
|
del q, k ,v
|
||||||
|
|
||||||
qkv_list.clear()
|
qkv_list.clear()
|
||||||
@ -107,14 +107,14 @@ def sdpa_wrapper(
|
|||||||
attention_length
|
attention_length
|
||||||
):
|
):
|
||||||
q,k, v = qkv_list
|
q,k, v = qkv_list
|
||||||
padding_length = q.shape[0] -attention_length
|
padding_length = q.shape[1] -attention_length
|
||||||
q = q[:attention_length, :].transpose(0,1).unsqueeze(0)
|
q = q[:attention_length, :].transpose(1,2)
|
||||||
k = k[:attention_length, :].transpose(0,1).unsqueeze(0)
|
k = k[:attention_length, :].transpose(1,2)
|
||||||
v = v[:attention_length, :].transpose(0,1).unsqueeze(0)
|
v = v[:attention_length, :].transpose(1,2)
|
||||||
|
|
||||||
o = F.scaled_dot_product_attention(
|
o = F.scaled_dot_product_attention(
|
||||||
q, k, v, attn_mask=None, is_causal=False
|
q, k, v, attn_mask=None, is_causal=False
|
||||||
).squeeze(0).transpose(0,1)
|
).transpose(1,2)
|
||||||
del q, k ,v
|
del q, k ,v
|
||||||
qkv_list.clear()
|
qkv_list.clear()
|
||||||
|
|
||||||
@ -159,36 +159,72 @@ def pay_attention(
|
|||||||
deterministic=False,
|
deterministic=False,
|
||||||
version=None,
|
version=None,
|
||||||
force_attention= None,
|
force_attention= None,
|
||||||
cross_attn= False
|
cross_attn= False,
|
||||||
|
k_lens = None
|
||||||
):
|
):
|
||||||
|
|
||||||
attn = offload.shared_state["_attention"] if force_attention== None else force_attention
|
attn = offload.shared_state["_attention"] if force_attention== None else force_attention
|
||||||
q,k,v = qkv_list
|
q,k,v = qkv_list
|
||||||
qkv_list.clear()
|
qkv_list.clear()
|
||||||
|
|
||||||
|
|
||||||
# params
|
# params
|
||||||
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
|
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
|
||||||
assert b==1
|
|
||||||
q = q.squeeze(0)
|
|
||||||
k = k.squeeze(0)
|
|
||||||
v = v.squeeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
q = q.to(v.dtype)
|
q = q.to(v.dtype)
|
||||||
k = k.to(v.dtype)
|
k = k.to(v.dtype)
|
||||||
|
if b > 0 and k_lens != None and attn in ("sage2", "sdpa"):
|
||||||
# if q_scale is not None:
|
# Poor's man var len attention
|
||||||
# q = q * q_scale
|
chunk_sizes = []
|
||||||
|
k_sizes = []
|
||||||
|
current_size = k_lens[0]
|
||||||
|
current_count= 1
|
||||||
|
for k_len in k_lens[1:]:
|
||||||
|
if k_len == current_size:
|
||||||
|
current_count += 1
|
||||||
|
else:
|
||||||
|
chunk_sizes.append(current_count)
|
||||||
|
k_sizes.append(current_size)
|
||||||
|
current_count = 1
|
||||||
|
current_size = k_len
|
||||||
|
chunk_sizes.append(current_count)
|
||||||
|
k_sizes.append(k_len)
|
||||||
|
if len(chunk_sizes) > 1 or k_lens[0] != k.shape[1]:
|
||||||
|
q_chunks =torch.split(q, chunk_sizes)
|
||||||
|
k_chunks =torch.split(k, chunk_sizes)
|
||||||
|
v_chunks =torch.split(v, chunk_sizes)
|
||||||
|
q, k, v = None, None, None
|
||||||
|
k_chunks = [ u[:, :sz] for u, sz in zip(k_chunks, k_sizes)]
|
||||||
|
v_chunks = [ u[:, :sz] for u, sz in zip(v_chunks, k_sizes)]
|
||||||
|
o = []
|
||||||
|
for sub_q, sub_k, sub_v in zip(q_chunks, k_chunks, v_chunks):
|
||||||
|
qkv_list = [sub_q, sub_k, sub_v]
|
||||||
|
sub_q, sub_k, sub_v = None, None, None
|
||||||
|
o.append( pay_attention(qkv_list) )
|
||||||
|
q_chunks, k_chunks, v_chunks = None, None, None
|
||||||
|
o = torch.cat(o, dim = 0)
|
||||||
|
return o
|
||||||
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
|
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
'Flash attention 3 is not available, use flash attention 2 instead.'
|
'Flash attention 3 is not available, use flash attention 2 instead.'
|
||||||
)
|
)
|
||||||
|
|
||||||
if attn=="sage" or attn=="flash":
|
if attn=="sage" or attn=="flash":
|
||||||
cu_seqlens_q = torch.tensor([0, lq], dtype=torch.int32, device="cuda")
|
if b != 1 :
|
||||||
cu_seqlens_k = torch.tensor([0, lk], dtype=torch.int32, device="cuda")
|
if k_lens == None:
|
||||||
|
k_lens = torch.tensor( [lk] * b, dtype=torch.int32).to(device=q.device, non_blocking=True)
|
||||||
|
k = torch.cat([u[:v] for u, v in zip(k, k_lens)])
|
||||||
|
v = torch.cat([u[:v] for u, v in zip(v, k_lens)])
|
||||||
|
q = q.reshape(-1, *q.shape[-2:])
|
||||||
|
q_lens = torch.tensor([lq] * b, dtype=torch.int32).to(device=q.device, non_blocking=True)
|
||||||
|
cu_seqlens_q=torch.cat([k_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32)
|
||||||
|
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32)
|
||||||
|
else:
|
||||||
|
cu_seqlens_q = torch.tensor([0, lq], dtype=torch.int32, device="cuda")
|
||||||
|
cu_seqlens_k = torch.tensor([0, lk], dtype=torch.int32, device="cuda")
|
||||||
|
q = q.squeeze(0)
|
||||||
|
k = k.squeeze(0)
|
||||||
|
v = v.squeeze(0)
|
||||||
|
|
||||||
|
|
||||||
# apply attention
|
# apply attention
|
||||||
if attn=="sage":
|
if attn=="sage":
|
||||||
@ -207,7 +243,7 @@ def pay_attention(
|
|||||||
qkv_list = [q,k,v]
|
qkv_list = [q,k,v]
|
||||||
del q,k,v
|
del q,k,v
|
||||||
|
|
||||||
x = sageattn_wrapper(qkv_list, lq).unsqueeze(0)
|
x = sageattn_wrapper(qkv_list, lq) #.unsqueeze(0)
|
||||||
# else:
|
# else:
|
||||||
# layer = offload.shared_state["layer"]
|
# layer = offload.shared_state["layer"]
|
||||||
# embed_sizes = offload.shared_state["embed_sizes"]
|
# embed_sizes = offload.shared_state["embed_sizes"]
|
||||||
@ -267,8 +303,8 @@ def pay_attention(
|
|||||||
|
|
||||||
elif attn=="sdpa":
|
elif attn=="sdpa":
|
||||||
qkv_list = [q, k, v]
|
qkv_list = [q, k, v]
|
||||||
del q, k , v
|
del q ,k ,v
|
||||||
x = sdpa_wrapper( qkv_list, lq).unsqueeze(0)
|
x = sdpa_wrapper( qkv_list, lq) #.unsqueeze(0)
|
||||||
elif attn=="flash" and version == 3:
|
elif attn=="flash" and version == 3:
|
||||||
# Note: dropout_p, window_size are not supported in FA3 now.
|
# Note: dropout_p, window_size are not supported in FA3 now.
|
||||||
x = flash_attn_interface.flash_attn_varlen_func(
|
x = flash_attn_interface.flash_attn_varlen_func(
|
||||||
@ -302,59 +338,11 @@ def pay_attention(
|
|||||||
# output
|
# output
|
||||||
|
|
||||||
elif attn=="xformers":
|
elif attn=="xformers":
|
||||||
x = memory_efficient_attention(
|
from xformers.ops.fmha.attn_bias import BlockDiagonalPaddedKeysMask
|
||||||
q.unsqueeze(0),
|
if b != 1 and k_lens != None:
|
||||||
k.unsqueeze(0),
|
attn_mask = BlockDiagonalPaddedKeysMask.from_seqlens([lq] * b , lk, list(k_lens) )
|
||||||
v.unsqueeze(0),
|
x = memory_efficient_attention(q, k, v, attn_bias= attn_mask )
|
||||||
) #.unsqueeze(0)
|
else:
|
||||||
|
x = memory_efficient_attention(q, k, v )
|
||||||
|
|
||||||
return x.type(out_dtype)
|
return x.type(out_dtype)
|
||||||
|
|
||||||
|
|
||||||
def attention(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
q_lens=None,
|
|
||||||
k_lens=None,
|
|
||||||
dropout_p=0.,
|
|
||||||
softmax_scale=None,
|
|
||||||
q_scale=None,
|
|
||||||
causal=False,
|
|
||||||
window_size=(-1, -1),
|
|
||||||
deterministic=False,
|
|
||||||
dtype=torch.bfloat16,
|
|
||||||
fa_version=None,
|
|
||||||
):
|
|
||||||
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
|
|
||||||
return pay_attention(
|
|
||||||
q=q,
|
|
||||||
k=k,
|
|
||||||
v=v,
|
|
||||||
q_lens=q_lens,
|
|
||||||
k_lens=k_lens,
|
|
||||||
dropout_p=dropout_p,
|
|
||||||
softmax_scale=softmax_scale,
|
|
||||||
q_scale=q_scale,
|
|
||||||
causal=causal,
|
|
||||||
window_size=window_size,
|
|
||||||
deterministic=deterministic,
|
|
||||||
dtype=dtype,
|
|
||||||
version=fa_version,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if q_lens is not None or k_lens is not None:
|
|
||||||
warnings.warn(
|
|
||||||
'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
|
|
||||||
)
|
|
||||||
attn_mask = None
|
|
||||||
|
|
||||||
q = q.transpose(1, 2).to(dtype)
|
|
||||||
k = k.transpose(1, 2).to(dtype)
|
|
||||||
v = v.transpose(1, 2).to(dtype)
|
|
||||||
|
|
||||||
out = torch.nn.functional.scaled_dot_product_attention(
|
|
||||||
q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
|
|
||||||
|
|
||||||
out = out.transpose(1, 2).contiguous()
|
|
||||||
return out
|
|
||||||
|
|||||||
@ -197,9 +197,9 @@ class WanSelfAttention(nn.Module):
|
|||||||
del q,k
|
del q,k
|
||||||
|
|
||||||
q,k = apply_rotary_emb(qklist, freqs, head_first=False)
|
q,k = apply_rotary_emb(qklist, freqs, head_first=False)
|
||||||
qkv_list = [q,k,v]
|
|
||||||
del q,k,v
|
|
||||||
if block_mask == None:
|
if block_mask == None:
|
||||||
|
qkv_list = [q,k,v]
|
||||||
|
del q,k,v
|
||||||
x = pay_attention(
|
x = pay_attention(
|
||||||
qkv_list,
|
qkv_list,
|
||||||
window_size=self.window_size)
|
window_size=self.window_size)
|
||||||
@ -212,6 +212,7 @@ class WanSelfAttention(nn.Module):
|
|||||||
.transpose(1, 2)
|
.transpose(1, 2)
|
||||||
.contiguous()
|
.contiguous()
|
||||||
)
|
)
|
||||||
|
del q,k,v
|
||||||
|
|
||||||
# if not self._flag_ar_attention:
|
# if not self._flag_ar_attention:
|
||||||
# q = rope_apply(q, grid_sizes, freqs)
|
# q = rope_apply(q, grid_sizes, freqs)
|
||||||
@ -241,7 +242,7 @@ class WanSelfAttention(nn.Module):
|
|||||||
|
|
||||||
class WanT2VCrossAttention(WanSelfAttention):
|
class WanT2VCrossAttention(WanSelfAttention):
|
||||||
|
|
||||||
def forward(self, xlist, context):
|
def forward(self, xlist, context, grid_sizes, *args, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
x(Tensor): Shape [B, L1, C]
|
x(Tensor): Shape [B, L1, C]
|
||||||
@ -262,6 +263,7 @@ class WanT2VCrossAttention(WanSelfAttention):
|
|||||||
v = self.v(context).view(b, -1, n, d)
|
v = self.v(context).view(b, -1, n, d)
|
||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
|
v = v.contiguous().clone()
|
||||||
qvl_list=[q, k, v]
|
qvl_list=[q, k, v]
|
||||||
del q, k, v
|
del q, k, v
|
||||||
x = pay_attention(qvl_list, cross_attn= True)
|
x = pay_attention(qvl_list, cross_attn= True)
|
||||||
@ -287,7 +289,7 @@ class WanI2VCrossAttention(WanSelfAttention):
|
|||||||
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
||||||
self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
||||||
|
|
||||||
def forward(self, xlist, context):
|
def forward(self, xlist, context, grid_sizes, audio_proj, audio_scale, audio_context_lens ):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
x(Tensor): Shape [B, L1, C]
|
x(Tensor): Shape [B, L1, C]
|
||||||
@ -310,6 +312,8 @@ class WanI2VCrossAttention(WanSelfAttention):
|
|||||||
del x
|
del x
|
||||||
self.norm_q(q)
|
self.norm_q(q)
|
||||||
q= q.view(b, -1, n, d)
|
q= q.view(b, -1, n, d)
|
||||||
|
if audio_scale != None:
|
||||||
|
audio_x = self.processor(q, audio_proj, grid_sizes[0], audio_context_lens)
|
||||||
k = self.k(context)
|
k = self.k(context)
|
||||||
self.norm_k(k)
|
self.norm_k(k)
|
||||||
k = k.view(b, -1, n, d)
|
k = k.view(b, -1, n, d)
|
||||||
@ -334,6 +338,8 @@ class WanI2VCrossAttention(WanSelfAttention):
|
|||||||
img_x = img_x.flatten(2)
|
img_x = img_x.flatten(2)
|
||||||
x += img_x
|
x += img_x
|
||||||
del img_x
|
del img_x
|
||||||
|
if audio_scale != None:
|
||||||
|
x.add_(audio_x, alpha= audio_scale)
|
||||||
x = self.o(x)
|
x = self.o(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -398,7 +404,10 @@ class WanAttentionBlock(nn.Module):
|
|||||||
hints= None,
|
hints= None,
|
||||||
context_scale=1.0,
|
context_scale=1.0,
|
||||||
cam_emb= None,
|
cam_emb= None,
|
||||||
block_mask = None
|
block_mask = None,
|
||||||
|
audio_proj= None,
|
||||||
|
audio_context_lens= None,
|
||||||
|
audio_scale=None,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@ -433,7 +442,7 @@ class WanAttentionBlock(nn.Module):
|
|||||||
if cam_emb != None:
|
if cam_emb != None:
|
||||||
cam_emb = self.cam_encoder(cam_emb)
|
cam_emb = self.cam_encoder(cam_emb)
|
||||||
cam_emb = cam_emb.repeat(1, 2, 1)
|
cam_emb = cam_emb.repeat(1, 2, 1)
|
||||||
cam_emb = cam_emb.unsqueeze(2).unsqueeze(3).repeat(1, 1, grid_sizes[0][1], grid_sizes[0][2], 1)
|
cam_emb = cam_emb.unsqueeze(2).unsqueeze(3).repeat(1, 1, grid_sizes[1], grid_sizes[2], 1)
|
||||||
cam_emb = rearrange(cam_emb, 'b f h w d -> b (f h w) d')
|
cam_emb = rearrange(cam_emb, 'b f h w d -> b (f h w) d')
|
||||||
x_mod += cam_emb
|
x_mod += cam_emb
|
||||||
|
|
||||||
@ -453,7 +462,7 @@ class WanAttentionBlock(nn.Module):
|
|||||||
y = y.to(attention_dtype)
|
y = y.to(attention_dtype)
|
||||||
ylist= [y]
|
ylist= [y]
|
||||||
del y
|
del y
|
||||||
x += self.cross_attn(ylist, context).to(dtype)
|
x += self.cross_attn(ylist, context, grid_sizes, audio_proj, audio_scale, audio_context_lens).to(dtype)
|
||||||
|
|
||||||
y = self.norm2(x)
|
y = self.norm2(x)
|
||||||
|
|
||||||
@ -610,6 +619,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
eps=1e-6,
|
eps=1e-6,
|
||||||
recammaster = False,
|
recammaster = False,
|
||||||
inject_sample_info = False,
|
inject_sample_info = False,
|
||||||
|
fantasytalking_dim = 0,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Initialize the diffusion model backbone.
|
Initialize the diffusion model backbone.
|
||||||
@ -742,43 +752,48 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
block.projector.weight = nn.Parameter(torch.eye(dim))
|
block.projector.weight = nn.Parameter(torch.eye(dim))
|
||||||
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
||||||
|
|
||||||
|
if fantasytalking_dim > 0:
|
||||||
|
from fantasytalking.model import WanCrossAttentionProcessor
|
||||||
|
for block in self.blocks:
|
||||||
|
block.cross_attn.processor = WanCrossAttentionProcessor(fantasytalking_dim, dim)
|
||||||
|
|
||||||
|
|
||||||
|
def lock_layers_dtypes(self, hybrid_dtype = None, dtype = torch.float32):
|
||||||
|
layer_list = [self.head, self.head.head, self.patch_embedding]
|
||||||
|
target_dype= dtype
|
||||||
|
|
||||||
|
layer_list2 = [ self.time_embedding, self.time_embedding[0], self.time_embedding[2],
|
||||||
|
self.time_projection, self.time_projection[1]] #, self.text_embedding, self.text_embedding[0], self.text_embedding[2] ]
|
||||||
|
|
||||||
|
for block in self.blocks:
|
||||||
|
layer_list2 += [block.norm3]
|
||||||
|
|
||||||
def lock_layers_dtypes(self, dtype = torch.float32, force = False):
|
|
||||||
count = 0
|
|
||||||
layer_list = [self.head, self.head.head, self.patch_embedding, self.time_embedding, self.time_embedding[0], self.time_embedding[2],
|
|
||||||
self.time_projection, self.time_projection[1]] #, self.text_embedding, self.text_embedding[0], self.text_embedding[2] ]
|
|
||||||
if hasattr(self, "fps_embedding"):
|
if hasattr(self, "fps_embedding"):
|
||||||
layer_list += [self.fps_embedding, self.fps_projection, self.fps_projection[0], self.fps_projection[2]]
|
layer_list2 += [self.fps_embedding, self.fps_projection, self.fps_projection[0], self.fps_projection[2]]
|
||||||
|
|
||||||
if hasattr(self, "vace_patch_embedding"):
|
if hasattr(self, "vace_patch_embedding"):
|
||||||
layer_list += [self.vace_patch_embedding]
|
layer_list2 += [self.vace_patch_embedding]
|
||||||
layer_list += [self.vace_blocks[0].before_proj]
|
layer_list2 += [self.vace_blocks[0].before_proj]
|
||||||
for block in self.vace_blocks:
|
for block in self.vace_blocks:
|
||||||
layer_list += [block.after_proj, block.norm3]
|
layer_list2 += [block.after_proj, block.norm3]
|
||||||
|
|
||||||
|
target_dype2 = hybrid_dtype if hybrid_dtype != None else dtype
|
||||||
|
|
||||||
# cam master
|
# cam master
|
||||||
if hasattr(self.blocks[0], "projector"):
|
if hasattr(self.blocks[0], "projector"):
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
layer_list += [block.projector]
|
layer_list2 += [block.projector]
|
||||||
|
|
||||||
for block in self.blocks:
|
for current_layer_list, current_dtype in zip([layer_list, layer_list2], [target_dype, target_dype2]):
|
||||||
layer_list += [block.norm3]
|
for layer in current_layer_list:
|
||||||
for layer in layer_list:
|
layer._lock_dtype = dtype
|
||||||
if hasattr(layer, "weight"):
|
|
||||||
if layer.weight.dtype == dtype :
|
if hasattr(layer, "weight") and layer.weight.dtype != current_dtype :
|
||||||
count += 1
|
layer.weight.data = layer.weight.data.to(current_dtype)
|
||||||
elif force:
|
|
||||||
if hasattr(layer, "weight"):
|
|
||||||
layer.weight.data = layer.weight.data.to(dtype)
|
|
||||||
if hasattr(layer, "bias"):
|
if hasattr(layer, "bias"):
|
||||||
layer.bias.data = layer.bias.data.to(dtype)
|
layer.bias.data = layer.bias.data.to(current_dtype)
|
||||||
count += 1
|
|
||||||
|
|
||||||
layer._lock_dtype = dtype
|
self._lock_dtype = dtype
|
||||||
|
|
||||||
|
|
||||||
if count > 0:
|
|
||||||
self._lock_dtype = dtype
|
|
||||||
|
|
||||||
|
|
||||||
def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
|
def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
|
||||||
@ -788,7 +803,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
t = torch.stack([t])
|
t = torch.stack([t])
|
||||||
time_emb = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype) ) # b, dim
|
time_emb = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype) ) # b, dim
|
||||||
e_list.append(time_emb)
|
e_list.append(time_emb)
|
||||||
|
best_deltas = None
|
||||||
best_threshold = 0.01
|
best_threshold = 0.01
|
||||||
best_diff = 1000
|
best_diff = 1000
|
||||||
best_signed_diff = 1000
|
best_signed_diff = 1000
|
||||||
@ -798,12 +813,16 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
accumulated_rel_l1_distance =0
|
accumulated_rel_l1_distance =0
|
||||||
nb_steps = 0
|
nb_steps = 0
|
||||||
diff = 1000
|
diff = 1000
|
||||||
|
deltas = []
|
||||||
for i, t in enumerate(timesteps):
|
for i, t in enumerate(timesteps):
|
||||||
skip = False
|
skip = False
|
||||||
if not (i<=start_step or i== len(timesteps)):
|
if not (i<=start_step or i== len(timesteps)-1):
|
||||||
accumulated_rel_l1_distance += abs(rescale_func(((e_list[i]-e_list[i-1]).abs().mean() / e_list[i-1].abs().mean()).cpu().item()))
|
delta = abs(rescale_func(((e_list[i]-e_list[i-1]).abs().mean() / e_list[i-1].abs().mean()).cpu().item()))
|
||||||
|
# deltas.append(delta)
|
||||||
|
accumulated_rel_l1_distance += delta
|
||||||
if accumulated_rel_l1_distance < threshold:
|
if accumulated_rel_l1_distance < threshold:
|
||||||
skip = True
|
skip = True
|
||||||
|
# deltas.append("SKIP")
|
||||||
else:
|
else:
|
||||||
accumulated_rel_l1_distance = 0
|
accumulated_rel_l1_distance = 0
|
||||||
if not skip:
|
if not skip:
|
||||||
@ -812,6 +831,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
diff = abs(signed_diff)
|
diff = abs(signed_diff)
|
||||||
if diff < best_diff:
|
if diff < best_diff:
|
||||||
best_threshold = threshold
|
best_threshold = threshold
|
||||||
|
best_deltas = deltas
|
||||||
best_diff = diff
|
best_diff = diff
|
||||||
best_signed_diff = signed_diff
|
best_signed_diff = signed_diff
|
||||||
elif diff > best_diff:
|
elif diff > best_diff:
|
||||||
@ -819,6 +839,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
threshold += 0.01
|
threshold += 0.01
|
||||||
self.rel_l1_thresh = best_threshold
|
self.rel_l1_thresh = best_threshold
|
||||||
print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}")
|
print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}")
|
||||||
|
# print(f"deltas:{best_deltas}")
|
||||||
return best_threshold
|
return best_threshold
|
||||||
|
|
||||||
|
|
||||||
@ -834,7 +855,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
freqs = None,
|
freqs = None,
|
||||||
pipeline = None,
|
pipeline = None,
|
||||||
current_step = 0,
|
current_step = 0,
|
||||||
is_uncond=False,
|
x_id= 0,
|
||||||
max_steps = 0,
|
max_steps = 0,
|
||||||
slg_layers=None,
|
slg_layers=None,
|
||||||
callback = None,
|
callback = None,
|
||||||
@ -842,10 +863,13 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
fps = None,
|
fps = None,
|
||||||
causal_block_size = 1,
|
causal_block_size = 1,
|
||||||
causal_attention = False,
|
causal_attention = False,
|
||||||
x_neg = None
|
audio_proj=None,
|
||||||
|
audio_context_lens=None,
|
||||||
|
audio_scale=None,
|
||||||
|
|
||||||
):
|
):
|
||||||
# dtype = self.blocks[0].self_attn.q.weight.dtype
|
# patch_dtype = self.patch_embedding.weight.dtype
|
||||||
dtype = self.patch_embedding.weight.dtype
|
modulation_dtype = self.time_projection[1].weight.dtype
|
||||||
|
|
||||||
if self.model_type == 'i2v':
|
if self.model_type == 'i2v':
|
||||||
assert clip_fea is not None and y is not None
|
assert clip_fea is not None and y is not None
|
||||||
@ -854,20 +878,32 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
if torch.is_tensor(freqs) and freqs.device != device:
|
if torch.is_tensor(freqs) and freqs.device != device:
|
||||||
freqs = freqs.to(device)
|
freqs = freqs.to(device)
|
||||||
|
|
||||||
if y is not None:
|
|
||||||
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
|
||||||
|
|
||||||
# embeddings
|
x_list = x
|
||||||
x = [self.patch_embedding(u.unsqueeze(0)).to(dtype) for u in x]
|
joint_pass = len(x_list) > 1
|
||||||
if x_neg !=None:
|
is_source_x = [ x.data_ptr() == x_list[0].data_ptr() and i > 0 for i, x in enumerate(x_list) ]
|
||||||
x_neg = [self.patch_embedding(u.unsqueeze(0)).to(dtype) for u in x_neg]
|
last_x_idx = 0
|
||||||
|
for i, (is_source, x) in enumerate(zip(is_source_x, x_list)):
|
||||||
|
if is_source:
|
||||||
|
x_list[i] = x_list[0].clone()
|
||||||
|
last_x_idx = i
|
||||||
|
else:
|
||||||
|
# image source
|
||||||
|
if y is not None:
|
||||||
|
x = torch.cat([x, y], dim=0)
|
||||||
|
# embeddings
|
||||||
|
x = self.patch_embedding(x.unsqueeze(0)).to(modulation_dtype)
|
||||||
|
grid_sizes = x.shape[2:]
|
||||||
|
x = x.flatten(2).transpose(1, 2)
|
||||||
|
x_list[i] = x
|
||||||
|
x, y = None, None
|
||||||
|
|
||||||
grid_sizes = [ list(u.shape[2:]) for u in x]
|
|
||||||
embed_sizes = grid_sizes[0]
|
block_mask = None
|
||||||
if causal_attention : #causal_block_size > 0:
|
if causal_attention and causal_block_size > 0 and False: # NEVER WORKED
|
||||||
frame_num = embed_sizes[0]
|
frame_num = grid_sizes[0]
|
||||||
height = embed_sizes[1]
|
height = grid_sizes[1]
|
||||||
width = embed_sizes[2]
|
width = grid_sizes[2]
|
||||||
block_num = frame_num // causal_block_size
|
block_num = frame_num // causal_block_size
|
||||||
range_tensor = torch.arange(block_num).view(-1, 1)
|
range_tensor = torch.arange(block_num).view(-1, 1)
|
||||||
range_tensor = range_tensor.repeat(1, causal_block_size).flatten()
|
range_tensor = range_tensor.repeat(1, causal_block_size).flatten()
|
||||||
@ -878,30 +914,21 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
block_mask = causal_mask.unsqueeze(0).unsqueeze(0)
|
block_mask = causal_mask.unsqueeze(0).unsqueeze(0)
|
||||||
del causal_mask
|
del causal_mask
|
||||||
|
|
||||||
offload.shared_state["embed_sizes"] = embed_sizes
|
offload.shared_state["embed_sizes"] = grid_sizes
|
||||||
offload.shared_state["step_no"] = current_step
|
offload.shared_state["step_no"] = current_step
|
||||||
offload.shared_state["max_steps"] = max_steps
|
offload.shared_state["max_steps"] = max_steps
|
||||||
|
|
||||||
x = [u.flatten(2).transpose(1, 2) for u in x]
|
_flag_df = t.dim() == 2
|
||||||
x = x[0]
|
|
||||||
if x_neg !=None:
|
|
||||||
x_neg = [u.flatten(2).transpose(1, 2) for u in x_neg]
|
|
||||||
x_neg = x_neg[0]
|
|
||||||
|
|
||||||
if t.dim() == 2:
|
|
||||||
b, f = t.shape
|
|
||||||
_flag_df = True
|
|
||||||
else:
|
|
||||||
_flag_df = False
|
|
||||||
e = self.time_embedding(
|
e = self.time_embedding(
|
||||||
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype) # self.patch_embedding.weight.dtype)
|
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(modulation_dtype) # self.patch_embedding.weight.dtype)
|
||||||
) # b, dim
|
) # b, dim
|
||||||
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype)
|
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype)
|
||||||
|
|
||||||
if self.inject_sample_info:
|
if self.inject_sample_info:
|
||||||
fps = torch.tensor(fps, dtype=torch.long, device=device)
|
fps = torch.tensor(fps, dtype=torch.long, device=device)
|
||||||
|
|
||||||
fps_emb = self.fps_embedding(fps).to(dtype) # float()
|
fps_emb = self.fps_embedding(fps).to(e.dtype)
|
||||||
if _flag_df:
|
if _flag_df:
|
||||||
e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1)
|
e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1)
|
||||||
else:
|
else:
|
||||||
@ -914,29 +941,27 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||||
context = [ torch.cat( [context_clip, u ], dim=1 ) for u in context ]
|
context = [ torch.cat( [context_clip, u ], dim=1 ) for u in context ]
|
||||||
|
|
||||||
joint_pass = len(context) > 0
|
|
||||||
x_list = [x]
|
|
||||||
if joint_pass:
|
|
||||||
if x_neg == None:
|
|
||||||
x_list += [x.clone() for i in range(len(context) - 1) ]
|
|
||||||
else:
|
|
||||||
x_list += [x.clone() for i in range(len(context) - 2) ] + [x_neg]
|
|
||||||
is_uncond = False
|
|
||||||
del x
|
|
||||||
context_list = context
|
context_list = context
|
||||||
|
if audio_scale != None:
|
||||||
|
audio_scale_list = audio_scale
|
||||||
|
else:
|
||||||
|
audio_scale_list = [None] * len(x_list)
|
||||||
|
|
||||||
# arguments
|
# arguments
|
||||||
|
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
grid_sizes=grid_sizes,
|
grid_sizes=grid_sizes,
|
||||||
freqs=freqs,
|
freqs=freqs,
|
||||||
cam_emb = cam_emb
|
cam_emb = cam_emb,
|
||||||
|
block_mask = block_mask,
|
||||||
|
audio_proj=audio_proj,
|
||||||
|
audio_context_lens=audio_context_lens,
|
||||||
)
|
)
|
||||||
|
|
||||||
if vace_context == None:
|
if vace_context == None:
|
||||||
hints_list = [None ] *len(x_list)
|
hints_list = [None ] *len(x_list)
|
||||||
else:
|
else:
|
||||||
# embeddings
|
# Vace embeddings
|
||||||
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
||||||
c = [u.flatten(2).transpose(1, 2) for u in c]
|
c = [u.flatten(2).transpose(1, 2) for u in c]
|
||||||
c = c[0]
|
c = c[0]
|
||||||
@ -947,7 +972,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
|
|
||||||
should_calc = True
|
should_calc = True
|
||||||
if self.enable_teacache:
|
if self.enable_teacache:
|
||||||
if is_uncond:
|
if x_id != 0:
|
||||||
should_calc = self.should_calc
|
should_calc = self.should_calc
|
||||||
else:
|
else:
|
||||||
if current_step <= self.teacache_start_step or current_step == self.num_steps-1:
|
if current_step <= self.teacache_start_step or current_step == self.num_steps-1:
|
||||||
@ -955,11 +980,12 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
self.accumulated_rel_l1_distance = 0
|
self.accumulated_rel_l1_distance = 0
|
||||||
else:
|
else:
|
||||||
rescale_func = np.poly1d(self.coefficients)
|
rescale_func = np.poly1d(self.coefficients)
|
||||||
self.accumulated_rel_l1_distance += abs(rescale_func(((e-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()))
|
delta = abs(rescale_func(((e-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()))
|
||||||
|
self.accumulated_rel_l1_distance += delta
|
||||||
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
||||||
should_calc = False
|
should_calc = False
|
||||||
self.teacache_skipped_steps += 1
|
self.teacache_skipped_steps += 1
|
||||||
# print(f"Teacache Skipped Step:{self.teacache_skipped_steps}/{current_step}" )
|
# print(f"Teacache Skipped Step no {current_step} ({self.teacache_skipped_steps}/{current_step}), delta={delta}" )
|
||||||
else:
|
else:
|
||||||
should_calc = True
|
should_calc = True
|
||||||
self.accumulated_rel_l1_distance = 0
|
self.accumulated_rel_l1_distance = 0
|
||||||
@ -967,15 +993,23 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
self.should_calc = should_calc
|
self.should_calc = should_calc
|
||||||
|
|
||||||
if not should_calc:
|
if not should_calc:
|
||||||
for i, x in enumerate(x_list):
|
if joint_pass:
|
||||||
x += self.previous_residual_uncond if i==1 or is_uncond else self.previous_residual_cond
|
for i, x in enumerate(x_list):
|
||||||
|
x += self.previous_residual[i]
|
||||||
|
else:
|
||||||
|
x = x_list[0]
|
||||||
|
x += self.previous_residual[x_id]
|
||||||
|
x = None
|
||||||
else:
|
else:
|
||||||
if self.enable_teacache:
|
if self.enable_teacache:
|
||||||
if joint_pass or is_uncond:
|
if joint_pass:
|
||||||
self.previous_residual_uncond = None
|
self.previous_residual = [ None ] * len(self.previous_residual)
|
||||||
if joint_pass or not is_uncond:
|
else:
|
||||||
self.previous_residual_cond = None
|
self.previous_residual[x_id] = None
|
||||||
ori_hidden_states = x_list[0].clone()
|
ori_hidden_states = [ None ] * len(x_list)
|
||||||
|
ori_hidden_states[0] = x_list[0].clone()
|
||||||
|
for i in range(1, len(x_list)):
|
||||||
|
ori_hidden_states[i] = ori_hidden_states[0] if is_source_x[i] else x_list[i].clone()
|
||||||
|
|
||||||
for block_idx, block in enumerate(self.blocks):
|
for block_idx, block in enumerate(self.blocks):
|
||||||
offload.shared_state["layer"] = block_idx
|
offload.shared_state["layer"] = block_idx
|
||||||
@ -984,29 +1018,30 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
if pipeline._interrupt:
|
if pipeline._interrupt:
|
||||||
return [None] * len(x_list)
|
return [None] * len(x_list)
|
||||||
|
|
||||||
if slg_layers is not None and block_idx in slg_layers:
|
if (x_id != 0 or joint_pass) and slg_layers is not None and block_idx in slg_layers:
|
||||||
if is_uncond and not joint_pass:
|
if not joint_pass:
|
||||||
continue
|
continue
|
||||||
x_list[0] = block(x_list[0], context = context_list[0], e= e0, **kwargs)
|
x_list[0] = block(x_list[0], context = context_list[0], e= e0, **kwargs)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
for i, (x, context, hints) in enumerate(zip(x_list, context_list, hints_list)):
|
for i, (x, context, hints, audio_scale) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list)):
|
||||||
x_list[i] = block(x, context = context, hints= hints, e= e0, **kwargs)
|
x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, e= e0, **kwargs)
|
||||||
del x
|
del x
|
||||||
del context, hints
|
del context, hints
|
||||||
|
|
||||||
if self.enable_teacache:
|
if self.enable_teacache:
|
||||||
if joint_pass:
|
if joint_pass:
|
||||||
self.previous_residual_cond = torch.sub(x_list[0], ori_hidden_states)
|
for i, (x, ori, is_source) in enumerate(zip(x_list, ori_hidden_states, is_source_x)) :
|
||||||
self.previous_residual_uncond = ori_hidden_states
|
if i == 0 or is_source and i != last_x_idx :
|
||||||
torch.sub(x_list[1], ori_hidden_states, out=self.previous_residual_uncond)
|
self.previous_residual[i] = torch.sub(x, ori)
|
||||||
|
else:
|
||||||
|
self.previous_residual[i] = ori
|
||||||
|
torch.sub(x, ori, out=self.previous_residual[i])
|
||||||
|
ori_hidden_states[i] = None
|
||||||
|
x , ori = None, None
|
||||||
else:
|
else:
|
||||||
residual = ori_hidden_states # just to have a readable code
|
residual = ori_hidden_states[0] # just to have a readable code
|
||||||
torch.sub(x_list[0], ori_hidden_states, out=residual)
|
torch.sub(x_list[0], ori_hidden_states[0], out=residual)
|
||||||
if i==1 or is_uncond:
|
self.previous_residual[x_id] = residual
|
||||||
self.previous_residual_uncond = residual
|
|
||||||
else:
|
|
||||||
self.previous_residual_cond = residual
|
|
||||||
residual, ori_hidden_states = None, None
|
residual, ori_hidden_states = None, None
|
||||||
|
|
||||||
for i, x in enumerate(x_list):
|
for i, x in enumerate(x_list):
|
||||||
@ -1037,10 +1072,10 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
|
|
||||||
c = self.out_dim
|
c = self.out_dim
|
||||||
out = []
|
out = []
|
||||||
for u, v in zip(x, grid_sizes):
|
for u in x:
|
||||||
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
|
u = u[:math.prod(grid_sizes)].view(*grid_sizes, *self.patch_size, c)
|
||||||
u = torch.einsum('fhwpqrc->cfphqwr', u)
|
u = torch.einsum('fhwpqrc->cfphqwr', u)
|
||||||
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
|
u = u.reshape(c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
|
||||||
out.append(u)
|
out.append(u)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|||||||
@ -140,7 +140,7 @@ def sageattn(
|
|||||||
elif arch == "sm90":
|
elif arch == "sm90":
|
||||||
return sageattn_qk_int8_pv_fp8_cuda_sm90(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp32")
|
return sageattn_qk_int8_pv_fp8_cuda_sm90(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp32")
|
||||||
elif arch == "sm120":
|
elif arch == "sm120":
|
||||||
return sageattn_qk_int8_pv_fp8_cuda(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, qk_quant_gran="per_warp", sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32") # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120.
|
return sageattn_qk_int8_pv_fp8_cuda(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, qk_quant_gran="per_warp", sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32", smooth_v= True) # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120.
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported CUDA architecture: {arch}")
|
raise ValueError(f"Unsupported CUDA architecture: {arch}")
|
||||||
|
|
||||||
|
|||||||
@ -78,15 +78,16 @@ class WanT2V:
|
|||||||
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype,
|
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
|
|
||||||
logging.info(f"Creating WanModel from {model_filename}")
|
logging.info(f"Creating WanModel from {model_filename[-1]}")
|
||||||
from mmgp import offload
|
from mmgp import offload
|
||||||
# model_filename
|
# model_filename
|
||||||
|
|
||||||
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False ) #, forcedConfigPath= "e:/vace_config.json")
|
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False ) #, forcedConfigPath= "e:/vace_config.json")
|
||||||
# offload.load_model_data(self.model, "e:/vace.safetensors")
|
# offload.load_model_data(self.model, "e:/vace.safetensors")
|
||||||
# offload.load_model_data(self.model, "c:/temp/Phantom-Wan-1.3B.pth")
|
# offload.load_model_data(self.model, "c:/temp/Phantom-Wan-1.3B.pth")
|
||||||
# self.model.to(torch.bfloat16)
|
# self.model.to(torch.bfloat16)
|
||||||
# self.model.cpu()
|
# self.model.cpu()
|
||||||
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype, True)
|
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
|
||||||
offload.change_dtype(self.model, dtype, True)
|
offload.change_dtype(self.model, dtype, True)
|
||||||
# offload.save_model(self.model, "mvace.safetensors", config_file_path="e:/vace_config.json")
|
# offload.save_model(self.model, "mvace.safetensors", config_file_path="e:/vace_config.json")
|
||||||
# offload.save_model(self.model, "phantom_1.3B.safetensors")
|
# offload.save_model(self.model, "phantom_1.3B.safetensors")
|
||||||
@ -95,7 +96,7 @@ class WanT2V:
|
|||||||
|
|
||||||
self.sample_neg_prompt = config.sample_neg_prompt
|
self.sample_neg_prompt = config.sample_neg_prompt
|
||||||
|
|
||||||
if "Vace" in model_filename:
|
if "Vace" in model_filename[-1]:
|
||||||
self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]),
|
self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]),
|
||||||
min_area=480*832,
|
min_area=480*832,
|
||||||
max_area=480*832,
|
max_area=480*832,
|
||||||
@ -107,7 +108,7 @@ class WanT2V:
|
|||||||
|
|
||||||
self.adapt_vace_model()
|
self.adapt_vace_model()
|
||||||
|
|
||||||
def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0):
|
def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0, overlapped_latents = 0, overlap_noise = 0):
|
||||||
if ref_images is None:
|
if ref_images is None:
|
||||||
ref_images = [None] * len(frames)
|
ref_images = [None] * len(frames)
|
||||||
else:
|
else:
|
||||||
@ -119,6 +120,11 @@ class WanT2V:
|
|||||||
inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
|
inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
|
||||||
reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
|
reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
|
||||||
inactive = self.vae.encode(inactive, tile_size = tile_size)
|
inactive = self.vae.encode(inactive, tile_size = tile_size)
|
||||||
|
# inactive = [ t * (1.0 - noise_factor) + torch.randn_like(t ) * noise_factor for t in inactive]
|
||||||
|
# if overlapped_latents > 0:
|
||||||
|
# for t in inactive:
|
||||||
|
# t[:, :overlapped_latents ] = t[:, :overlapped_latents ] * (1.0 - noise_factor) + torch.randn_like(t[:, :overlapped_latents ] ) * noise_factor
|
||||||
|
|
||||||
reactive = self.vae.encode(reactive, tile_size = tile_size)
|
reactive = self.vae.encode(reactive, tile_size = tile_size)
|
||||||
latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
|
latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
|
||||||
|
|
||||||
@ -288,7 +294,10 @@ class WanT2V:
|
|||||||
slg_end = 1.0,
|
slg_end = 1.0,
|
||||||
cfg_star_switch = True,
|
cfg_star_switch = True,
|
||||||
cfg_zero_step = 5,
|
cfg_zero_step = 5,
|
||||||
):
|
overlapped_latents = 0,
|
||||||
|
overlap_noise = 0,
|
||||||
|
vace = False
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
Generates video frames from text prompt using diffusion process.
|
Generates video frames from text prompt using diffusion process.
|
||||||
|
|
||||||
@ -343,20 +352,20 @@ class WanT2V:
|
|||||||
size = (source_video.shape[2], source_video.shape[1])
|
size = (source_video.shape[2], source_video.shape[1])
|
||||||
source_video = source_video.to(dtype=self.dtype , device=self.device)
|
source_video = source_video.to(dtype=self.dtype , device=self.device)
|
||||||
source_video = source_video.permute(3, 0, 1, 2).div_(127.5).sub_(1.)
|
source_video = source_video.permute(3, 0, 1, 2).div_(127.5).sub_(1.)
|
||||||
source_latents = self.vae.encode([source_video]) #.to(dtype=self.dtype, device=self.device)
|
source_latents = self.vae.encode([source_video])[0] #.to(dtype=self.dtype, device=self.device)
|
||||||
del source_video
|
del source_video
|
||||||
# Process target camera (recammaster)
|
# Process target camera (recammaster)
|
||||||
from wan.utils.cammmaster_tools import get_camera_embedding
|
from wan.utils.cammmaster_tools import get_camera_embedding
|
||||||
cam_emb = get_camera_embedding(target_camera)
|
cam_emb = get_camera_embedding(target_camera)
|
||||||
cam_emb = cam_emb.to(dtype=self.dtype, device=self.device)
|
cam_emb = cam_emb.to(dtype=self.dtype, device=self.device)
|
||||||
|
|
||||||
if input_frames != None:
|
if vace :
|
||||||
# vace context encode
|
# vace context encode
|
||||||
input_frames = [u.to(self.device) for u in input_frames]
|
input_frames = [u.to(self.device) for u in input_frames]
|
||||||
input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images]
|
input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images]
|
||||||
input_masks = [u.to(self.device) for u in input_masks]
|
input_masks = [u.to(self.device) for u in input_masks]
|
||||||
|
|
||||||
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size)
|
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents, overlap_noise = overlap_noise )
|
||||||
m0 = self.vace_encode_masks(input_masks, input_ref_images)
|
m0 = self.vace_encode_masks(input_masks, input_ref_images)
|
||||||
z = self.vace_latent(z0, m0)
|
z = self.vace_latent(z0, m0)
|
||||||
|
|
||||||
@ -365,10 +374,10 @@ class WanT2V:
|
|||||||
else:
|
else:
|
||||||
if input_ref_images != None: # Phantom Ref images
|
if input_ref_images != None: # Phantom Ref images
|
||||||
phantom = True
|
phantom = True
|
||||||
input_ref_images = [self.get_vae_latents(input_ref_images, self.device)]
|
input_ref_images = self.get_vae_latents(input_ref_images, self.device)
|
||||||
input_ref_images_neg = [torch.zeros_like(input_ref_images[0])]
|
input_ref_images_neg = torch.zeros_like(input_ref_images)
|
||||||
F = frame_num
|
F = frame_num
|
||||||
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1 + (input_ref_images[0].shape[1] if input_ref_images != None else 0),
|
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1 + (input_ref_images.shape[1] if input_ref_images != None else 0),
|
||||||
size[1] // self.vae_stride[1],
|
size[1] // self.vae_stride[1],
|
||||||
size[0] // self.vae_stride[2])
|
size[0] // self.vae_stride[2])
|
||||||
|
|
||||||
@ -405,37 +414,48 @@ class WanT2V:
|
|||||||
raise NotImplementedError("Unsupported solver.")
|
raise NotImplementedError("Unsupported solver.")
|
||||||
|
|
||||||
# sample videos
|
# sample videos
|
||||||
latents = noise
|
latents = noise[0]
|
||||||
del noise
|
del noise
|
||||||
batch_size =len(latents)
|
batch_size = 1
|
||||||
if target_camera != None:
|
if target_camera != None:
|
||||||
shape = list(latents[0].shape[1:])
|
shape = list(latents.shape[1:])
|
||||||
shape[0] *= 2
|
shape[0] *= 2
|
||||||
freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
|
freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
|
||||||
else:
|
else:
|
||||||
freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
|
freqs = get_rotary_pos_embed(latents.shape[1:], enable_RIFLEx= enable_RIFLEx)
|
||||||
|
|
||||||
kwargs = {'freqs': freqs, 'pipeline': self, 'callback': callback}
|
kwargs = {'freqs': freqs, 'pipeline': self, 'callback': callback}
|
||||||
|
|
||||||
if target_camera != None:
|
if target_camera != None:
|
||||||
kwargs.update({'cam_emb': cam_emb})
|
kwargs.update({'cam_emb': cam_emb})
|
||||||
|
|
||||||
if input_frames != None:
|
if vace:
|
||||||
|
ref_images_count = len(input_ref_images[0]) if input_ref_images != None else 0
|
||||||
kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale})
|
kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale})
|
||||||
|
if overlapped_latents > 0:
|
||||||
|
z_reactive = [ zz[0:16, ref_images_count:overlapped_latents + ref_images_count].clone() for zz in z]
|
||||||
|
|
||||||
|
|
||||||
if self.model.enable_teacache:
|
if self.model.enable_teacache:
|
||||||
|
x_count = 3 if phantom else 2
|
||||||
|
self.model.previous_residual = [None] * x_count
|
||||||
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
||||||
if callback != None:
|
if callback != None:
|
||||||
callback(-1, None, True)
|
callback(-1, None, True)
|
||||||
for i, t in enumerate(tqdm(timesteps)):
|
for i, t in enumerate(tqdm(timesteps)):
|
||||||
|
if vace and overlapped_latents > 0 :
|
||||||
|
# noise_factor = overlap_noise *(i/(len(timesteps)-1)) / 1000
|
||||||
|
noise_factor = overlap_noise / 1000 # * (999-t) / 999
|
||||||
|
# noise_factor = overlap_noise / 1000 # * t / 999
|
||||||
|
for zz, zz_r in zip(z, z_reactive):
|
||||||
|
zz[0:16, ref_images_count:overlapped_latents + ref_images_count] = zz_r * (1.0 - noise_factor) + torch.randn_like(zz_r ) * noise_factor
|
||||||
|
|
||||||
if target_camera != None:
|
if target_camera != None:
|
||||||
latent_model_input = [torch.cat([u,v], dim=1) for u,v in zip(latents,source_latents )]
|
latent_model_input = torch.cat([latents, source_latents], dim=1)
|
||||||
else:
|
else:
|
||||||
latent_model_input = latents
|
latent_model_input = latents
|
||||||
slg_layers_local = None
|
kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None
|
||||||
if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps):
|
|
||||||
slg_layers_local = slg_layers
|
|
||||||
timestep = [t]
|
timestep = [t]
|
||||||
offload.set_step_no_for_lora(self.model, i)
|
offload.set_step_no_for_lora(self.model, i)
|
||||||
timestep = torch.stack(timestep)
|
timestep = torch.stack(timestep)
|
||||||
@ -444,38 +464,38 @@ class WanT2V:
|
|||||||
if joint_pass:
|
if joint_pass:
|
||||||
if phantom:
|
if phantom:
|
||||||
pos_it, pos_i, neg = self.model(
|
pos_it, pos_i, neg = self.model(
|
||||||
[torch.cat([latent[:,:-ref_latent.shape[1]], ref_latent], dim=1) for latent, ref_latent in zip(latent_model_input, input_ref_images)],
|
[ torch.cat([latent_model_input[:,:-input_ref_images.shape[1]], input_ref_images], dim=1) ] * 2 +
|
||||||
x_neg = [torch.cat([latent[:,:-ref_latent_neg.shape[1]], ref_latent_neg], dim=1) for latent, ref_latent_neg in zip(latent_model_input, input_ref_images_neg)],
|
[ torch.cat([latent_model_input[:,:-input_ref_images_neg.shape[1]], input_ref_images_neg], dim=1)],
|
||||||
context = [context, context_null, context_null], **kwargs)
|
context = [context, context_null, context_null], **kwargs)
|
||||||
else:
|
else:
|
||||||
noise_pred_cond, noise_pred_uncond = self.model(
|
noise_pred_cond, noise_pred_uncond = self.model(
|
||||||
latent_model_input, slg_layers=slg_layers_local, context = [context, context_null], **kwargs)
|
[latent_model_input, latent_model_input], context = [context, context_null], **kwargs)
|
||||||
if self._interrupt:
|
if self._interrupt:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
if phantom:
|
if phantom:
|
||||||
pos_it = self.model(
|
pos_it = self.model(
|
||||||
[torch.cat([latent[:,:-ref_latent.shape[1]], ref_latent], dim=1) for latent, ref_latent in zip(latent_model_input, input_ref_images)], context = [context], **kwargs
|
[ torch.cat([latent_model_input[:,:-input_ref_images.shape[1]], input_ref_images], dim=1) ], x_id = 0, context = [context], **kwargs
|
||||||
)[0]
|
)[0]
|
||||||
if self._interrupt:
|
if self._interrupt:
|
||||||
return None
|
return None
|
||||||
pos_i = self.model(
|
pos_i = self.model(
|
||||||
[torch.cat([latent[:,:-ref_latent.shape[1]], ref_latent], dim=1) for latent, ref_latent in zip(latent_model_input, input_ref_images)], context = [context_null],**kwargs
|
[ torch.cat([latent_model_input[:,:-input_ref_images.shape[1]], input_ref_images], dim=1) ], x_id = 1, context = [context_null],**kwargs
|
||||||
)[0]
|
)[0]
|
||||||
if self._interrupt:
|
if self._interrupt:
|
||||||
return None
|
return None
|
||||||
neg = self.model(
|
neg = self.model(
|
||||||
[torch.cat([latent[:,:-ref_latent_neg.shape[1]], ref_latent_neg], dim=1) for latent, ref_latent_neg in zip(latent_model_input, input_ref_images_neg)], context = [context_null], **kwargs
|
[ torch.cat([latent_model_input[:,:-input_ref_images_neg.shape[1]], input_ref_images_neg], dim=1) ], x_id = 2, context = [context_null], **kwargs
|
||||||
)[0]
|
)[0]
|
||||||
if self._interrupt:
|
if self._interrupt:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
noise_pred_cond = self.model(
|
noise_pred_cond = self.model(
|
||||||
latent_model_input, is_uncond = False, context = [context], **kwargs)[0]
|
[latent_model_input], x_id = 0, context = [context], **kwargs)[0]
|
||||||
if self._interrupt:
|
if self._interrupt:
|
||||||
return None
|
return None
|
||||||
noise_pred_uncond = self.model(
|
noise_pred_uncond = self.model(
|
||||||
latent_model_input, is_uncond = True, slg_layers=slg_layers_local,context = [context_null], **kwargs)[0]
|
[latent_model_input], x_id = 1, context = [context_null], **kwargs)[0]
|
||||||
if self._interrupt:
|
if self._interrupt:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -505,21 +525,21 @@ class WanT2V:
|
|||||||
temp_x0 = sample_scheduler.step(
|
temp_x0 = sample_scheduler.step(
|
||||||
noise_pred[:, :target_shape[1]].unsqueeze(0),
|
noise_pred[:, :target_shape[1]].unsqueeze(0),
|
||||||
t,
|
t,
|
||||||
latents[0].unsqueeze(0),
|
latents.unsqueeze(0),
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
generator=seed_g)[0]
|
generator=seed_g)[0]
|
||||||
latents = [temp_x0.squeeze(0)]
|
latents = temp_x0.squeeze(0)
|
||||||
del temp_x0
|
del temp_x0
|
||||||
|
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback(i, latents[0], False)
|
callback(i, latents, False)
|
||||||
|
|
||||||
x0 = latents
|
x0 = [latents]
|
||||||
|
|
||||||
if input_frames == None:
|
if input_frames == None:
|
||||||
if phantom:
|
if phantom:
|
||||||
# phantom post processing
|
# phantom post processing
|
||||||
x0 = [x0_[:,:-input_ref_images[0].shape[1]] for x0_ in x0]
|
x0 = [x0_[:,:-input_ref_images.shape[1]] for x0_ in x0]
|
||||||
videos = self.vae.decode(x0, VAE_tile_size)
|
videos = self.vae.decode(x0, VAE_tile_size)
|
||||||
else:
|
else:
|
||||||
# vace post processing
|
# vace post processing
|
||||||
|
|||||||
339
wgp.py
339
wgp.py
@ -40,7 +40,7 @@ global_queue_ref = []
|
|||||||
AUTOSAVE_FILENAME = "queue.zip"
|
AUTOSAVE_FILENAME = "queue.zip"
|
||||||
PROMPT_VARS_MAX = 10
|
PROMPT_VARS_MAX = 10
|
||||||
|
|
||||||
target_mmgp_version = "3.4.1"
|
target_mmgp_version = "3.4.2"
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
mmgp_version = version("mmgp")
|
mmgp_version = version("mmgp")
|
||||||
if mmgp_version != target_mmgp_version:
|
if mmgp_version != target_mmgp_version:
|
||||||
@ -49,32 +49,30 @@ if mmgp_version != target_mmgp_version:
|
|||||||
lock = threading.Lock()
|
lock = threading.Lock()
|
||||||
current_task_id = None
|
current_task_id = None
|
||||||
task_id = 0
|
task_id = 0
|
||||||
# progress_tracker = {}
|
|
||||||
# tracker_lock = threading.Lock()
|
|
||||||
|
|
||||||
# def download_ffmpeg():
|
def download_ffmpeg():
|
||||||
# if os.name != 'nt': return
|
if os.name != 'nt': return
|
||||||
# exes = ['ffmpeg.exe', 'ffprobe.exe', 'ffplay.exe']
|
exes = ['ffmpeg.exe', 'ffprobe.exe', 'ffplay.exe']
|
||||||
# if all(os.path.exists(e) for e in exes): return
|
if all(os.path.exists(e) for e in exes): return
|
||||||
# api_url = 'https://api.github.com/repos/GyanD/codexffmpeg/releases/latest'
|
api_url = 'https://api.github.com/repos/GyanD/codexffmpeg/releases/latest'
|
||||||
# r = requests.get(api_url, headers={'Accept': 'application/vnd.github+json'})
|
r = requests.get(api_url, headers={'Accept': 'application/vnd.github+json'})
|
||||||
# assets = r.json().get('assets', [])
|
assets = r.json().get('assets', [])
|
||||||
# zip_asset = next((a for a in assets if 'essentials_build.zip' in a['name']), None)
|
zip_asset = next((a for a in assets if 'essentials_build.zip' in a['name']), None)
|
||||||
# if not zip_asset: return
|
if not zip_asset: return
|
||||||
# zip_url = zip_asset['browser_download_url']
|
zip_url = zip_asset['browser_download_url']
|
||||||
# zip_name = zip_asset['name']
|
zip_name = zip_asset['name']
|
||||||
# with requests.get(zip_url, stream=True) as resp:
|
with requests.get(zip_url, stream=True) as resp:
|
||||||
# total = int(resp.headers.get('Content-Length', 0))
|
total = int(resp.headers.get('Content-Length', 0))
|
||||||
# with open(zip_name, 'wb') as f, tqdm(total=total, unit='B', unit_scale=True) as pbar:
|
with open(zip_name, 'wb') as f, tqdm(total=total, unit='B', unit_scale=True) as pbar:
|
||||||
# for chunk in resp.iter_content(chunk_size=8192):
|
for chunk in resp.iter_content(chunk_size=8192):
|
||||||
# f.write(chunk)
|
f.write(chunk)
|
||||||
# pbar.update(len(chunk))
|
pbar.update(len(chunk))
|
||||||
# with zipfile.ZipFile(zip_name) as z:
|
with zipfile.ZipFile(zip_name) as z:
|
||||||
# for f in z.namelist():
|
for f in z.namelist():
|
||||||
# if f.endswith(tuple(exes)) and '/bin/' in f:
|
if f.endswith(tuple(exes)) and '/bin/' in f:
|
||||||
# z.extract(f)
|
z.extract(f)
|
||||||
# os.rename(f, os.path.basename(f))
|
os.rename(f, os.path.basename(f))
|
||||||
# os.remove(zip_name)
|
os.remove(zip_name)
|
||||||
|
|
||||||
def format_time(seconds):
|
def format_time(seconds):
|
||||||
if seconds < 60:
|
if seconds < 60:
|
||||||
@ -168,14 +166,14 @@ def process_prompt_and_add_tasks(state, model_choice):
|
|||||||
resolution = inputs["resolution"]
|
resolution = inputs["resolution"]
|
||||||
width, height = resolution.split("x")
|
width, height = resolution.split("x")
|
||||||
width, height = int(width), int(height)
|
width, height = int(width), int(height)
|
||||||
if test_class_i2v(model_filename):
|
# if test_class_i2v(model_filename):
|
||||||
if "480p" in model_filename and not "Fun" in model_filename and width * height > 848*480:
|
# if "480p" in model_filename and not "Fun" in model_filename and width * height > 848*480:
|
||||||
gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P")
|
# gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P")
|
||||||
return
|
# return
|
||||||
resolution = str(width) + "*" + str(height)
|
# resolution = str(width) + "*" + str(height)
|
||||||
if resolution not in ['720*1280', '1280*720', '480*832', '832*480']:
|
# if resolution not in ['720*1280', '1280*720', '480*832', '832*480']:
|
||||||
gr.Info(f"Resolution {resolution} not supported by image 2 video")
|
# gr.Info(f"Resolution {resolution} not supported by image 2 video")
|
||||||
return
|
# return
|
||||||
|
|
||||||
if "1.3B" in model_filename and width * height > 848*480 and any( model in model_filename for model in ["image2video", "text2video"] ):
|
if "1.3B" in model_filename and width * height > 848*480 and any( model in model_filename for model in ["image2video", "text2video"] ):
|
||||||
gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P")
|
gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P")
|
||||||
@ -533,7 +531,7 @@ def save_queue_action(state):
|
|||||||
task_id_s = task.get('id', f"task_{task_index}")
|
task_id_s = task.get('id', f"task_{task_index}")
|
||||||
|
|
||||||
image_keys = ["image_start", "image_end", "image_refs"]
|
image_keys = ["image_start", "image_end", "image_refs"]
|
||||||
video_keys = ["video_guide", "video_mask", "video_source"]
|
video_keys = ["video_guide", "video_mask", "video_source", "audio_guide"]
|
||||||
|
|
||||||
for key in image_keys:
|
for key in image_keys:
|
||||||
images_pil = params_copy.get(key)
|
images_pil = params_copy.get(key)
|
||||||
@ -707,7 +705,7 @@ def load_queue_action(filepath, state, evt:gr.EventData):
|
|||||||
params['state'] = state
|
params['state'] = state
|
||||||
|
|
||||||
image_keys = ["image_start", "image_end", "image_refs"]
|
image_keys = ["image_start", "image_end", "image_refs"]
|
||||||
video_keys = ["video_guide", "video_mask", "video_source"]
|
video_keys = ["video_guide", "video_mask", "video_source", "audio_guide"]
|
||||||
|
|
||||||
loaded_pil_images = {}
|
loaded_pil_images = {}
|
||||||
loaded_video_paths = {}
|
loaded_video_paths = {}
|
||||||
@ -925,7 +923,7 @@ def autosave_queue():
|
|||||||
task_id_s = task.get('id', f"task_{task_index}")
|
task_id_s = task.get('id', f"task_{task_index}")
|
||||||
|
|
||||||
image_keys = ["image_start", "image_end", "image_refs"]
|
image_keys = ["image_start", "image_end", "image_refs"]
|
||||||
video_keys = ["video_guide", "video_mask", "video_source"]
|
video_keys = ["video_guide", "video_mask", "video_source", "audio_guide"]
|
||||||
|
|
||||||
for key in image_keys:
|
for key in image_keys:
|
||||||
images_pil = params_copy.get(key)
|
images_pil = params_copy.get(key)
|
||||||
@ -1418,32 +1416,35 @@ else:
|
|||||||
text = reader.read()
|
text = reader.read()
|
||||||
server_config = json.loads(text)
|
server_config = json.loads(text)
|
||||||
|
|
||||||
# for src_path, tgt_path in zip( ["ckpts/sky_reels2_diffusion_forcing_14B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors"], ["ckpts/sky_reels2_diffusion_forcing_540p_14B_quanto_int8.safetensors", "ckpts/sky_reels2_diffusion_forcing_540p_14B_bf16.safetensors"] ):
|
# Deprecated models
|
||||||
# if Path(src_path).is_file():
|
for path in ["wan2.1_Vace_1.3B_preview_bf16.safetensors", "sky_reels2_diffusion_forcing_1.3B_bf16.safetensors","sky_reels2_diffusion_forcing_720p_14B_bf16.safetensors",
|
||||||
# shutil.move(src_path, tgt_path) )
|
"sky_reels2_diffusion_forcing_720p_14B_quanto_int8.safetensors", "sky_reels2_diffusion_forcing_720p_14B_quanto_fp16_int8.safetensors", "wan2.1_image2video_480p_14B_bf16.safetensors", "wan2.1_image2video_480p_14B_quanto_int8.safetensors",
|
||||||
# for path in ["ckpts/wan2.1_Vace_1.3B_preview_mbf16.safetensors", "sky_reels2_diffusion_forcing_1.3B_bf16.safetensors"]:
|
"wan2.1_image2video_720p_14B_quanto_int8.safetensors", "wan2.1_image2video_720p_14B_quanto_fp16_int8.safetensors", "wan2.1_image2video_720p_14B_bf16.safetensors"
|
||||||
# if Path(path).is_file():
|
]:
|
||||||
# os.remove(path)
|
if Path(os.path.join("ckpts" , path)).is_file():
|
||||||
|
os.remove( os.path.join("ckpts" , path))
|
||||||
|
|
||||||
path= "ckpts/sky_reels2_diffusion_forcing_1.3B_bf16.safetensors"
|
|
||||||
if os.path.isfile(path) and os.path.getsize(path) > 4000000000:
|
|
||||||
os.remove(path)
|
|
||||||
|
|
||||||
transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", "ckpts/wan2.1_Vace_1.3B_preview_mbf16.safetensors",
|
transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", "ckpts/wan2.1_Vace_1.3B_preview_mbf16.safetensors",
|
||||||
"ckpts/wan2.1_recammaster_1.3B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_bf16.safetensors",
|
"ckpts/wan2.1_recammaster_1.3B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_bf16.safetensors",
|
||||||
"ckpts/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors", "ckpts/sky_reels2_diffusion_forcing_720p_14B_bf16.safetensors","ckpts/sky_reels2_diffusion_forcing_720p_14B_quanto_int8.safetensors",
|
"ckpts/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors", "ckpts/sky_reels2_diffusion_forcing_720p_14B_mbf16.safetensors","ckpts/sky_reels2_diffusion_forcing_720p_14B_quanto_mbf16_int8.safetensors",
|
||||||
"ckpts/wan2_1_phantom_1.3B_mbf16.safetensors"]
|
"ckpts/wan2_1_phantom_1.3B_mbf16.safetensors"]
|
||||||
transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors",
|
transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_mbf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_mbf16.safetensors",
|
||||||
"ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors",
|
"ckpts/wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors",
|
||||||
"ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_bf16.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_quanto_int8.safetensors"]
|
"ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_bf16.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_quanto_int8.safetensors",
|
||||||
|
"ckpts/wan2.1_fantasy_speaking_14B_bf16.safetensors"]
|
||||||
transformer_choices = transformer_choices_t2v + transformer_choices_i2v
|
transformer_choices = transformer_choices_t2v + transformer_choices_i2v
|
||||||
|
def get_dependent_models(model_filename, quantization ):
|
||||||
model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "fun_inp", "recam_1.3B", "flf2v_720p", "sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "phantom_1.3B"]
|
if "fantasy" in model_filename:
|
||||||
|
return [get_model_filename("i2v_720p", quantization)]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "fun_inp", "recam_1.3B", "flf2v_720p", "sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "phantom_1.3B", "fantasy"]
|
||||||
model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B",
|
model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B",
|
||||||
"i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B", "recam_1.3B": "recammaster_1.3B",
|
"i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B", "recam_1.3B": "recammaster_1.3B",
|
||||||
"flf2v_720p" : "FLF2V_720p", "sky_df_1.3B" : "sky_reels2_diffusion_forcing_1.3B", "sky_df_14B" : "sky_reels2_diffusion_forcing_14B",
|
"flf2v_720p" : "FLF2V_720p", "sky_df_1.3B" : "sky_reels2_diffusion_forcing_1.3B", "sky_df_14B" : "sky_reels2_diffusion_forcing_14B",
|
||||||
"sky_df_720p_14B" : "sky_reels2_diffusion_forcing_720p_14B",
|
"sky_df_720p_14B" : "sky_reels2_diffusion_forcing_720p_14B",
|
||||||
"phantom_1.3B" : "phantom_1.3B", }
|
"phantom_1.3B" : "phantom_1.3B", "fantasy" : "fantasy" }
|
||||||
|
|
||||||
|
|
||||||
def get_model_type(model_filename):
|
def get_model_type(model_filename):
|
||||||
@ -1453,7 +1454,7 @@ def get_model_type(model_filename):
|
|||||||
raise Exception("Unknown model:" + model_filename)
|
raise Exception("Unknown model:" + model_filename)
|
||||||
|
|
||||||
def test_class_i2v(model_filename):
|
def test_class_i2v(model_filename):
|
||||||
return "image2video" in model_filename or "Fun_InP" in model_filename or "FLF2V" in model_filename
|
return "image2video" in model_filename or "Fun_InP" in model_filename or "FLF2V" in model_filename or "fantasy" in model_filename
|
||||||
|
|
||||||
def get_model_name(model_filename, description_container = [""]):
|
def get_model_name(model_filename, description_container = [""]):
|
||||||
if "Fun" in model_filename:
|
if "Fun" in model_filename:
|
||||||
@ -1491,6 +1492,10 @@ def get_model_name(model_filename, description_container = [""]):
|
|||||||
model_name = "Wan2.1 Phantom"
|
model_name = "Wan2.1 Phantom"
|
||||||
model_name += " 14B" if "14B" in model_filename else " 1.3B"
|
model_name += " 14B" if "14B" in model_filename else " 1.3B"
|
||||||
description = "The Phantom model is specialized to transfer people or objects of your choice into a generated Video. It produces very nices results when used at 720p."
|
description = "The Phantom model is specialized to transfer people or objects of your choice into a generated Video. It produces very nices results when used at 720p."
|
||||||
|
elif "fantasy" in model_filename:
|
||||||
|
model_name = "Wan2.1 Fantasy Speaking 720p"
|
||||||
|
model_name += " 14B" if "14B" in model_filename else " 1.3B"
|
||||||
|
description = "The Fantasy Speaking model corresponds to the original Wan image 2 video model combined with the Fantasy Speaking extension to process an audio Input."
|
||||||
else:
|
else:
|
||||||
model_name = "Wan2.1 text2video"
|
model_name = "Wan2.1 text2video"
|
||||||
model_name += " 14B" if "14B" in model_filename else " 1.3B"
|
model_name += " 14B" if "14B" in model_filename else " 1.3B"
|
||||||
@ -1536,6 +1541,7 @@ def get_default_settings(filename):
|
|||||||
"repeat_generation": 1,
|
"repeat_generation": 1,
|
||||||
"multi_images_gen_type": 0,
|
"multi_images_gen_type": 0,
|
||||||
"guidance_scale": 5.0,
|
"guidance_scale": 5.0,
|
||||||
|
"audio_guidance_scale": 5.0,
|
||||||
"flow_shift": get_default_flow(filename, i2v),
|
"flow_shift": get_default_flow(filename, i2v),
|
||||||
"negative_prompt": "",
|
"negative_prompt": "",
|
||||||
"activated_loras": [],
|
"activated_loras": [],
|
||||||
@ -1719,8 +1725,9 @@ def download_models(transformer_filename, text_encoder_filename):
|
|||||||
|
|
||||||
from huggingface_hub import hf_hub_download, snapshot_download
|
from huggingface_hub import hf_hub_download, snapshot_download
|
||||||
repoId = "DeepBeepMeep/Wan2.1"
|
repoId = "DeepBeepMeep/Wan2.1"
|
||||||
sourceFolderList = ["xlm-roberta-large", "pose", "depth", "mask", "", ]
|
sourceFolderList = ["xlm-roberta-large", "pose", "depth", "mask", "wav2vec", "" ]
|
||||||
fileList = [ [], [],[], ["sam_vit_h_4b8939_fp16.safetensors"], ["Wan2.1_VAE.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "flownet.pkl" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ]
|
fileList = [ [], [],[], ["sam_vit_h_4b8939_fp16.safetensors"], ["config.json", "feature_extractor_config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"],
|
||||||
|
["Wan2.1_VAE.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "flownet.pkl" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ]
|
||||||
targetRoot = "ckpts/"
|
targetRoot = "ckpts/"
|
||||||
for sourceFolder, files in zip(sourceFolderList,fileList ):
|
for sourceFolder, files in zip(sourceFolderList,fileList ):
|
||||||
if len(files)==0:
|
if len(files)==0:
|
||||||
@ -1834,12 +1841,13 @@ def setup_loras(model_filename, transformer, lora_dir, lora_preselected_preset,
|
|||||||
return loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset
|
return loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset
|
||||||
|
|
||||||
|
|
||||||
def load_t2v_model(model_filename, value, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False):
|
def load_t2v_model(model_filename, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False):
|
||||||
|
|
||||||
cfg = WAN_CONFIGS['t2v-14B']
|
cfg = WAN_CONFIGS['t2v-14B']
|
||||||
|
filename = model_filename[-1]
|
||||||
# cfg = WAN_CONFIGS['t2v-1.3B']
|
# cfg = WAN_CONFIGS['t2v-1.3B']
|
||||||
print(f"Loading '{model_filename}' model...")
|
print(f"Loading '{filename}' model...")
|
||||||
if get_model_type(model_filename) in ("sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"):
|
if get_model_type(filename) in ("sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"):
|
||||||
model_factory = wan.DTT2V
|
model_factory = wan.DTT2V
|
||||||
else:
|
else:
|
||||||
model_factory = wan.WanT2V
|
model_factory = wan.WanT2V
|
||||||
@ -1859,9 +1867,10 @@ def load_t2v_model(model_filename, value, quantizeTransformer = False, dtype = t
|
|||||||
|
|
||||||
return wan_model, pipe
|
return wan_model, pipe
|
||||||
|
|
||||||
def load_i2v_model(model_filename, value, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False):
|
def load_i2v_model(model_filename, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False):
|
||||||
|
|
||||||
print(f"Loading '{model_filename}' model...")
|
filename = model_filename[-1]
|
||||||
|
print(f"Loading '{filename}' model...")
|
||||||
|
|
||||||
cfg = WAN_CONFIGS['i2v-14B']
|
cfg = WAN_CONFIGS['i2v-14B']
|
||||||
wan_model = wan.WanI2V(
|
wan_model = wan.WanI2V(
|
||||||
@ -1883,7 +1892,6 @@ def load_i2v_model(model_filename, value, quantizeTransformer = False, dtype = t
|
|||||||
def load_models(model_filename):
|
def load_models(model_filename):
|
||||||
global transformer_filename
|
global transformer_filename
|
||||||
|
|
||||||
transformer_filename = model_filename
|
|
||||||
perc_reserved_mem_max = args.perc_reserved_mem_max
|
perc_reserved_mem_max = args.perc_reserved_mem_max
|
||||||
|
|
||||||
major, minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None)
|
major, minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None)
|
||||||
@ -1892,18 +1900,26 @@ def load_models(model_filename):
|
|||||||
default_dtype = torch.float16
|
default_dtype = torch.float16
|
||||||
else:
|
else:
|
||||||
default_dtype = torch.float16 if args.fp16 else torch.bfloat16
|
default_dtype = torch.float16 if args.fp16 else torch.bfloat16
|
||||||
if default_dtype == torch.float16 :
|
model_filelist = get_dependent_models(model_filename, quantization= transformer_quantization) + [model_filename]
|
||||||
if "quanto" in model_filename:
|
updated_model_filename = []
|
||||||
model_filename = model_filename.replace("quanto_int8", "quanto_fp16_int8")
|
for filename in model_filelist:
|
||||||
download_models(model_filename, text_encoder_filename)
|
if default_dtype == torch.float16 :
|
||||||
|
if "quanto_int8" in filename:
|
||||||
|
filename = filename.replace("quanto_int8", "quanto_fp16_int8")
|
||||||
|
elif "quanto_mbf16_int8":
|
||||||
|
filename = filename.replace("quanto_mbf16_int8", "quanto_mfp16_int8")
|
||||||
|
updated_model_filename.append(filename)
|
||||||
|
download_models(filename, text_encoder_filename)
|
||||||
|
model_filelist = updated_model_filename
|
||||||
VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float
|
VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float
|
||||||
mixed_precision_transformer = server_config.get("mixed_precision","0") == "1"
|
mixed_precision_transformer = server_config.get("mixed_precision","0") == "1"
|
||||||
if test_class_i2v(model_filename):
|
transformer_filename = None
|
||||||
res720P = "720p" in model_filename
|
new_transformer_filename = model_filelist[-1]
|
||||||
wan_model, pipe = load_i2v_model(model_filename, "720P" if res720P else "480P", quantizeTransformer = quantizeTransformer, dtype = default_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
|
if test_class_i2v(new_transformer_filename):
|
||||||
|
wan_model, pipe = load_i2v_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = default_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
|
||||||
else:
|
else:
|
||||||
wan_model, pipe = load_t2v_model(model_filename, "", quantizeTransformer = quantizeTransformer, dtype = default_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
|
wan_model, pipe = load_t2v_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = default_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
|
||||||
wan_model._model_file_name = model_filename
|
wan_model._model_file_name = new_transformer_filename
|
||||||
kwargs = { "extraModelsToQuantize": None}
|
kwargs = { "extraModelsToQuantize": None}
|
||||||
if profile == 2 or profile == 4:
|
if profile == 2 or profile == 4:
|
||||||
kwargs["budgets"] = { "transformer" : 100 if preload == 0 else preload, "text_encoder" : 100, "*" : 1000 }
|
kwargs["budgets"] = { "transformer" : 100 if preload == 0 else preload, "text_encoder" : 100, "*" : 1000 }
|
||||||
@ -1914,7 +1930,7 @@ def load_models(model_filename):
|
|||||||
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = default_dtype, **kwargs)
|
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = default_dtype, **kwargs)
|
||||||
if len(args.gpu) > 0:
|
if len(args.gpu) > 0:
|
||||||
torch.set_default_device(args.gpu)
|
torch.set_default_device(args.gpu)
|
||||||
|
transformer_filename = new_transformer_filename
|
||||||
return wan_model, offloadobj, pipe["transformer"]
|
return wan_model, offloadobj, pipe["transformer"]
|
||||||
|
|
||||||
if not "P" in preload_model_policy:
|
if not "P" in preload_model_policy:
|
||||||
@ -2033,13 +2049,7 @@ def apply_changes( state,
|
|||||||
preload_model_policy = server_config["preload_model_policy"]
|
preload_model_policy = server_config["preload_model_policy"]
|
||||||
transformer_quantization = server_config["transformer_quantization"]
|
transformer_quantization = server_config["transformer_quantization"]
|
||||||
transformer_types = server_config["transformer_types"]
|
transformer_types = server_config["transformer_types"]
|
||||||
model_filename = state["model_filename"]
|
|
||||||
model_transformer_type = get_model_type(model_filename)
|
|
||||||
|
|
||||||
if not model_transformer_type in transformer_types:
|
|
||||||
model_transformer_type = transformer_types[0] if len(transformer_types) > 0 else model_types[0]
|
|
||||||
model_filename = get_model_filename(model_transformer_type, transformer_quantization)
|
|
||||||
state["model_filename"] = model_filename
|
|
||||||
if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list"] for change in changes ):
|
if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list"] for change in changes ):
|
||||||
model_choice = gr.Dropdown()
|
model_choice = gr.Dropdown()
|
||||||
else:
|
else:
|
||||||
@ -2249,7 +2259,9 @@ def preprocess_video(process_type, height, width, video_in, max_frames, start_fr
|
|||||||
frame_height, frame_width, _ = frames_list[0].shape
|
frame_height, frame_width, _ = frames_list[0].shape
|
||||||
|
|
||||||
if fit_canvas :
|
if fit_canvas :
|
||||||
scale = min(height / frame_height, width / frame_width)
|
scale1 = min(height / frame_height, width / frame_width)
|
||||||
|
scale2 = min(height / frame_width, width / frame_height)
|
||||||
|
scale = max(scale1, scale2)
|
||||||
else:
|
else:
|
||||||
scale = ((height * width ) / (frame_height * frame_width))**(1/2)
|
scale = ((height * width ) / (frame_height * frame_width))**(1/2)
|
||||||
|
|
||||||
@ -2356,6 +2368,7 @@ def generate_video(
|
|||||||
seed,
|
seed,
|
||||||
num_inference_steps,
|
num_inference_steps,
|
||||||
guidance_scale,
|
guidance_scale,
|
||||||
|
audio_guidance_scale,
|
||||||
flow_shift,
|
flow_shift,
|
||||||
embedded_guidance_scale,
|
embedded_guidance_scale,
|
||||||
repeat_generation,
|
repeat_generation,
|
||||||
@ -2375,8 +2388,10 @@ def generate_video(
|
|||||||
video_guide,
|
video_guide,
|
||||||
keep_frames_video_guide,
|
keep_frames_video_guide,
|
||||||
video_mask,
|
video_mask,
|
||||||
|
audio_guide,
|
||||||
sliding_window_size,
|
sliding_window_size,
|
||||||
sliding_window_overlap,
|
sliding_window_overlap,
|
||||||
|
sliding_window_overlap_noise,
|
||||||
sliding_window_discard_last_frames,
|
sliding_window_discard_last_frames,
|
||||||
remove_background_image_ref,
|
remove_background_image_ref,
|
||||||
temporal_upsampling,
|
temporal_upsampling,
|
||||||
@ -2508,6 +2523,15 @@ def generate_video(
|
|||||||
# VAE Tiling
|
# VAE Tiling
|
||||||
device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576
|
device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576
|
||||||
|
|
||||||
|
diffusion_forcing = "diffusion_forcing" in model_filename
|
||||||
|
vace = "Vace" in model_filename
|
||||||
|
if diffusion_forcing:
|
||||||
|
fps = 24
|
||||||
|
elif audio_guide != None:
|
||||||
|
fps = 23
|
||||||
|
else:
|
||||||
|
fps = 16
|
||||||
|
|
||||||
joint_pass = boost ==1 #and profile != 1 and profile != 3
|
joint_pass = boost ==1 #and profile != 1 and profile != 3
|
||||||
# TeaCache
|
# TeaCache
|
||||||
trans.enable_teacache = tea_cache_setting > 0
|
trans.enable_teacache = tea_cache_setting > 0
|
||||||
@ -2517,12 +2541,10 @@ def generate_video(
|
|||||||
trans.teacache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
|
trans.teacache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
|
||||||
|
|
||||||
if image2video:
|
if image2video:
|
||||||
if '480p' in model_filename:
|
if '720p' in model_filename:
|
||||||
trans.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
|
|
||||||
elif '720p' in model_filename:
|
|
||||||
trans.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
|
trans.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
|
||||||
else:
|
else:
|
||||||
raise gr.Error("Teacache not supported for this model")
|
trans.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
|
||||||
else:
|
else:
|
||||||
if '1.3B' in model_filename:
|
if '1.3B' in model_filename:
|
||||||
trans.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
|
trans.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
|
||||||
@ -2535,6 +2557,18 @@ def generate_video(
|
|||||||
if "recam" in model_filename:
|
if "recam" in model_filename:
|
||||||
source_video = preprocess_video("", width=width, height=height,video_in=video_source, max_frames= video_length, start_frame = 0, fit_canvas= True)
|
source_video = preprocess_video("", width=width, height=height,video_in=video_source, max_frames= video_length, start_frame = 0, fit_canvas= True)
|
||||||
target_camera = model_mode
|
target_camera = model_mode
|
||||||
|
|
||||||
|
audio_proj_split = None
|
||||||
|
audio_scale = None
|
||||||
|
audio_context_lens = None
|
||||||
|
if audio_guide != None:
|
||||||
|
from fantasytalking.infer import parse_audio
|
||||||
|
import librosa
|
||||||
|
duration = librosa.get_duration(path=audio_guide)
|
||||||
|
video_length = min(int(fps * duration // 4) * 4 + 5, video_length)
|
||||||
|
audio_proj_split, audio_context_lens = parse_audio(audio_guide, num_frames= video_length, fps= fps, device= processing_device )
|
||||||
|
audio_scale = 1.0
|
||||||
|
|
||||||
import random
|
import random
|
||||||
if seed == None or seed <0:
|
if seed == None or seed <0:
|
||||||
seed = random.randint(0, 999999999)
|
seed = random.randint(0, 999999999)
|
||||||
@ -2551,11 +2585,11 @@ def generate_video(
|
|||||||
extra_generation = 0
|
extra_generation = 0
|
||||||
initial_total_windows = 0
|
initial_total_windows = 0
|
||||||
max_frames_to_generate = video_length
|
max_frames_to_generate = video_length
|
||||||
diffusion_forcing = "diffusion_forcing" in model_filename
|
|
||||||
vace = "Vace" in model_filename
|
|
||||||
phantom = "phantom" in model_filename
|
phantom = "phantom" in model_filename
|
||||||
if diffusion_forcing or vace:
|
if diffusion_forcing or vace:
|
||||||
reuse_frames = min(sliding_window_size - 4, sliding_window_overlap)
|
reuse_frames = min(sliding_window_size - 4, sliding_window_overlap)
|
||||||
|
else:
|
||||||
|
reuse_frames = 0
|
||||||
if diffusion_forcing and source_video != None:
|
if diffusion_forcing and source_video != None:
|
||||||
video_length += sliding_window_overlap
|
video_length += sliding_window_overlap
|
||||||
sliding_window = ("Vace" in model_filename or diffusion_forcing) and video_length > sliding_window_size
|
sliding_window = ("Vace" in model_filename or diffusion_forcing) and video_length > sliding_window_size
|
||||||
@ -2571,10 +2605,8 @@ def generate_video(
|
|||||||
initial_total_windows = 1
|
initial_total_windows = 1
|
||||||
|
|
||||||
first_window_video_length = video_length
|
first_window_video_length = video_length
|
||||||
fps = 24 if diffusion_forcing else 16
|
|
||||||
|
|
||||||
gen["sliding_window"] = sliding_window
|
gen["sliding_window"] = sliding_window
|
||||||
|
|
||||||
while not abort:
|
while not abort:
|
||||||
extra_generation += gen.get("extra_orders",0)
|
extra_generation += gen.get("extra_orders",0)
|
||||||
gen["extra_orders"] = 0
|
gen["extra_orders"] = 0
|
||||||
@ -2594,6 +2626,7 @@ def generate_video(
|
|||||||
guide_start_frame = 0
|
guide_start_frame = 0
|
||||||
video_length = first_window_video_length
|
video_length = first_window_video_length
|
||||||
gen["extra_windows"] = 0
|
gen["extra_windows"] = 0
|
||||||
|
start_time = time.time()
|
||||||
while not abort:
|
while not abort:
|
||||||
if sliding_window:
|
if sliding_window:
|
||||||
prompt = prompts[window_no] if window_no < len(prompts) else prompts[-1]
|
prompt = prompts[window_no] if window_no < len(prompts) else prompts[-1]
|
||||||
@ -2642,7 +2675,7 @@ def generate_video(
|
|||||||
|
|
||||||
if preprocess_type != None :
|
if preprocess_type != None :
|
||||||
send_cmd("progress", progress_args)
|
send_cmd("progress", progress_args)
|
||||||
video_guide_copy = preprocess_video(preprocess_type, width=width, height=height,video_in=video_guide, max_frames= video_length if window_no == 1 else video_length - reuse_frames, start_frame = guide_start_frame, target_fps = fps)
|
video_guide_copy = preprocess_video(preprocess_type, width=width, height=height,video_in=video_guide, max_frames= video_length if window_no == 1 else video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = True, target_fps = fps)
|
||||||
keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, max_frames_to_generate)
|
keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, max_frames_to_generate)
|
||||||
if len(error) > 0:
|
if len(error) > 0:
|
||||||
raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
|
raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
|
||||||
@ -2678,8 +2711,7 @@ def generate_video(
|
|||||||
trans.teacache_counter = 0
|
trans.teacache_counter = 0
|
||||||
trans.num_steps = num_inference_steps
|
trans.num_steps = num_inference_steps
|
||||||
trans.teacache_skipped_steps = 0
|
trans.teacache_skipped_steps = 0
|
||||||
trans.previous_residual_uncond = None
|
trans.previous_residual = None
|
||||||
trans.previous_residual_cond = None
|
|
||||||
|
|
||||||
if image2video:
|
if image2video:
|
||||||
samples = wan_model.generate(
|
samples = wan_model.generate(
|
||||||
@ -2687,7 +2719,9 @@ def generate_video(
|
|||||||
image_start,
|
image_start,
|
||||||
image_end if image_end != None else None,
|
image_end if image_end != None else None,
|
||||||
frame_num=(video_length // 4)* 4 + 1,
|
frame_num=(video_length // 4)* 4 + 1,
|
||||||
max_area=MAX_AREA_CONFIGS[resolution_reformated],
|
# max_area=MAX_AREA_CONFIGS[resolution_reformated],
|
||||||
|
height = height,
|
||||||
|
width = width,
|
||||||
shift=flow_shift,
|
shift=flow_shift,
|
||||||
sampling_steps=num_inference_steps,
|
sampling_steps=num_inference_steps,
|
||||||
guide_scale=guidance_scale,
|
guide_scale=guidance_scale,
|
||||||
@ -2702,7 +2736,11 @@ def generate_video(
|
|||||||
slg_end = slg_end_perc/100,
|
slg_end = slg_end_perc/100,
|
||||||
cfg_star_switch = cfg_star_switch,
|
cfg_star_switch = cfg_star_switch,
|
||||||
cfg_zero_step = cfg_zero_step,
|
cfg_zero_step = cfg_zero_step,
|
||||||
add_frames_for_end_image = "image2video" in model_filename
|
add_frames_for_end_image = "image2video" in model_filename,
|
||||||
|
audio_cfg_scale= audio_guidance_scale,
|
||||||
|
audio_proj= audio_proj_split,
|
||||||
|
audio_scale= audio_scale,
|
||||||
|
audio_context_lens= audio_context_lens
|
||||||
)
|
)
|
||||||
elif diffusion_forcing:
|
elif diffusion_forcing:
|
||||||
samples = wan_model.generate(
|
samples = wan_model.generate(
|
||||||
@ -2720,14 +2758,17 @@ def generate_video(
|
|||||||
callback= callback,
|
callback= callback,
|
||||||
VAE_tile_size = VAE_tile_size,
|
VAE_tile_size = VAE_tile_size,
|
||||||
joint_pass = joint_pass,
|
joint_pass = joint_pass,
|
||||||
addnoise_condition = 20,
|
slg_layers = slg_layers,
|
||||||
|
slg_start = slg_start_perc/100,
|
||||||
|
slg_end = slg_end_perc/100,
|
||||||
|
addnoise_condition = sliding_window_overlap_noise,
|
||||||
ar_step = model_mode, #5
|
ar_step = model_mode, #5
|
||||||
causal_block_size = 5,
|
causal_block_size = 5,
|
||||||
causal_attention = True,
|
causal_attention = True,
|
||||||
fps = fps,
|
fps = fps,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
samples = wan_model.generate(
|
samples = wan_model.generate(
|
||||||
prompt,
|
prompt,
|
||||||
input_frames = src_video,
|
input_frames = src_video,
|
||||||
input_ref_images= src_ref_images,
|
input_ref_images= src_ref_images,
|
||||||
@ -2751,6 +2792,9 @@ def generate_video(
|
|||||||
slg_end = slg_end_perc/100,
|
slg_end = slg_end_perc/100,
|
||||||
cfg_star_switch = cfg_star_switch,
|
cfg_star_switch = cfg_star_switch,
|
||||||
cfg_zero_step = cfg_zero_step,
|
cfg_zero_step = cfg_zero_step,
|
||||||
|
overlapped_latents = 0 if reuse_frames == 0 or window_no == 1 else ((reuse_frames - 1) // 4 + 1),
|
||||||
|
overlap_noise = sliding_window_overlap_noise,
|
||||||
|
vace = vace
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if temp_filename!= None and os.path.isfile(temp_filename):
|
if temp_filename!= None and os.path.isfile(temp_filename):
|
||||||
@ -2782,11 +2826,11 @@ def generate_video(
|
|||||||
print('\n'.join(tb))
|
print('\n'.join(tb))
|
||||||
send_cmd("error", new_error)
|
send_cmd("error", new_error)
|
||||||
return
|
return
|
||||||
|
finally:
|
||||||
|
trans.previous_residual = None
|
||||||
|
|
||||||
if trans.enable_teacache:
|
if trans.enable_teacache:
|
||||||
print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{num_inference_steps}" )
|
print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{trans.num_steps}" )
|
||||||
trans.previous_residual_uncond = None
|
|
||||||
trans.previous_residual_cond = None
|
|
||||||
|
|
||||||
if samples != None:
|
if samples != None:
|
||||||
samples = samples.to("cpu")
|
samples = samples.to("cpu")
|
||||||
@ -2810,14 +2854,27 @@ def generate_video(
|
|||||||
if discard_last_frames > 0:
|
if discard_last_frames > 0:
|
||||||
sample = sample[: , :-discard_last_frames]
|
sample = sample[: , :-discard_last_frames]
|
||||||
guide_start_frame -= discard_last_frames
|
guide_start_frame -= discard_last_frames
|
||||||
pre_video_guide = sample[:, -reuse_frames:]
|
if reuse_frames == 0:
|
||||||
|
pre_video_guide = sample[:,9999 :]
|
||||||
|
else:
|
||||||
|
# noise_factor = 200/ 1000
|
||||||
|
# pre_video_guide = sample[:, -reuse_frames:] * (1.0 - noise_factor) + torch.randn_like(sample[:, -reuse_frames:] ) * noise_factor
|
||||||
|
pre_video_guide = sample[:, -reuse_frames:]
|
||||||
|
|
||||||
|
|
||||||
if prefix_video != None:
|
if prefix_video != None:
|
||||||
sample = torch.cat([ prefix_video[:, :-reuse_frames], sample], dim = 1)
|
if reuse_frames == 0:
|
||||||
|
sample = torch.cat([ prefix_video[:, :], sample], dim = 1)
|
||||||
|
else:
|
||||||
|
sample = torch.cat([ prefix_video[:, :-reuse_frames], sample], dim = 1)
|
||||||
prefix_video = None
|
prefix_video = None
|
||||||
if sliding_window and window_no > 1:
|
if sliding_window and window_no > 1:
|
||||||
sample = sample[: , reuse_frames:]
|
if reuse_frames == 0:
|
||||||
guide_start_frame -= reuse_frames
|
sample = sample[: , :]
|
||||||
|
else:
|
||||||
|
sample = sample[: , reuse_frames:]
|
||||||
|
|
||||||
|
guide_start_frame -= reuse_frames
|
||||||
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
|
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
|
||||||
if os.name == 'nt':
|
if os.name == 'nt':
|
||||||
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50]).strip()}.mp4"
|
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50]).strip()}.mp4"
|
||||||
@ -2875,18 +2932,23 @@ def generate_video(
|
|||||||
sample = torch.cat([frames_already_processed, sample], dim=1)
|
sample = torch.cat([frames_already_processed, sample], dim=1)
|
||||||
frames_already_processed = sample
|
frames_already_processed = sample
|
||||||
|
|
||||||
cache_video(
|
if audio_guide == None:
|
||||||
tensor=sample[None],
|
cache_video( tensor=sample[None], save_file=video_path, fps=fps, nrow=1, normalize=True, value_range=(-1, 1))
|
||||||
save_file=video_path,
|
else:
|
||||||
fps=fps,
|
save_path_tmp = video_path[:-4] + "_tmp.mp4"
|
||||||
nrow=1,
|
cache_video( tensor=sample[None], save_file=save_path_tmp, fps=fps, nrow=1, normalize=True, value_range=(-1, 1))
|
||||||
normalize=True,
|
final_command = [ "ffmpeg", "-y", "-i", save_path_tmp, "-i", audio_guide, "-c:v", "libx264", "-c:a", "aac", "-shortest", "-loglevel", "warning", "-nostats", video_path, ]
|
||||||
value_range=(-1, 1))
|
import subprocess
|
||||||
|
subprocess.run(final_command, check=True)
|
||||||
|
os.remove(save_path_tmp)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
inputs = get_function_arguments(generate_video, locals())
|
inputs = get_function_arguments(generate_video, locals())
|
||||||
inputs.pop("send_cmd")
|
inputs.pop("send_cmd")
|
||||||
|
inputs.pop("task_id")
|
||||||
configs = prepare_inputs_dict("metadata", inputs)
|
configs = prepare_inputs_dict("metadata", inputs)
|
||||||
|
configs["generation_time"] = round(end_time-start_time)
|
||||||
metadata_choice = server_config.get("metadata_type","metadata")
|
metadata_choice = server_config.get("metadata_type","metadata")
|
||||||
if metadata_choice == "json":
|
if metadata_choice == "json":
|
||||||
with open(video_path.replace('.mp4', '.json'), 'w') as f:
|
with open(video_path.replace('.mp4', '.json'), 'w') as f:
|
||||||
@ -3113,7 +3175,7 @@ def get_latest_status(state):
|
|||||||
prompt_no = gen["prompt_no"]
|
prompt_no = gen["prompt_no"]
|
||||||
prompts_max = gen.get("prompts_max",0)
|
prompts_max = gen.get("prompts_max",0)
|
||||||
total_generation = gen.get("total_generation", 1)
|
total_generation = gen.get("total_generation", 1)
|
||||||
repeat_no = gen["repeat_no"]
|
repeat_no = gen.get("repeat_no",0)
|
||||||
total_generation += gen.get("extra_orders", 0)
|
total_generation += gen.get("extra_orders", 0)
|
||||||
total_windows = gen.get("total_windows", 0)
|
total_windows = gen.get("total_windows", 0)
|
||||||
total_windows += gen.get("extra_windows", 0)
|
total_windows += gen.get("extra_windows", 0)
|
||||||
@ -3456,7 +3518,7 @@ def prepare_inputs_dict(target, inputs ):
|
|||||||
|
|
||||||
if target == "state":
|
if target == "state":
|
||||||
return inputs
|
return inputs
|
||||||
unsaved_params = ["image_start", "image_end", "image_refs", "video_guide", "video_source", "video_mask"]
|
unsaved_params = ["image_start", "image_end", "image_refs", "video_guide", "video_source", "video_mask", "audio_guide", "embedded_guidance_scale"]
|
||||||
for k in unsaved_params:
|
for k in unsaved_params:
|
||||||
inputs.pop(k)
|
inputs.pop(k)
|
||||||
|
|
||||||
@ -3484,10 +3546,14 @@ def prepare_inputs_dict(target, inputs ):
|
|||||||
inputs.pop(k)
|
inputs.pop(k)
|
||||||
|
|
||||||
if not "Vace" in model_filename or "diffusion_forcing" in model_filename:
|
if not "Vace" in model_filename or "diffusion_forcing" in model_filename:
|
||||||
unsaved_params = [ "sliding_window_size", "sliding_window_overlap", "sliding_window_discard_last_frames"]
|
unsaved_params = [ "sliding_window_size", "sliding_window_overlap", "sliding_window_overlap_noise", "sliding_window_discard_last_frames"]
|
||||||
for k in unsaved_params:
|
for k in unsaved_params:
|
||||||
inputs.pop(k)
|
inputs.pop(k)
|
||||||
|
|
||||||
|
if not "fantasy" in model_filename:
|
||||||
|
inputs.pop("audio_guidance_scale")
|
||||||
|
|
||||||
|
|
||||||
if target == "metadata":
|
if target == "metadata":
|
||||||
inputs = {k: v for k,v in inputs.items() if v != None }
|
inputs = {k: v for k,v in inputs.items() if v != None }
|
||||||
|
|
||||||
@ -3511,6 +3577,7 @@ def save_inputs(
|
|||||||
seed,
|
seed,
|
||||||
num_inference_steps,
|
num_inference_steps,
|
||||||
guidance_scale,
|
guidance_scale,
|
||||||
|
audio_guidance_scale,
|
||||||
flow_shift,
|
flow_shift,
|
||||||
embedded_guidance_scale,
|
embedded_guidance_scale,
|
||||||
repeat_generation,
|
repeat_generation,
|
||||||
@ -3530,8 +3597,10 @@ def save_inputs(
|
|||||||
video_guide,
|
video_guide,
|
||||||
keep_frames_video_guide,
|
keep_frames_video_guide,
|
||||||
video_mask,
|
video_mask,
|
||||||
|
audio_guide,
|
||||||
sliding_window_size,
|
sliding_window_size,
|
||||||
sliding_window_overlap,
|
sliding_window_overlap,
|
||||||
|
sliding_window_overlap_noise,
|
||||||
sliding_window_discard_last_frames,
|
sliding_window_discard_last_frames,
|
||||||
remove_background_image_ref,
|
remove_background_image_ref,
|
||||||
temporal_upsampling,
|
temporal_upsampling,
|
||||||
@ -3834,6 +3903,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|||||||
recammaster = "recam" in model_filename
|
recammaster = "recam" in model_filename
|
||||||
vace = "Vace" in model_filename
|
vace = "Vace" in model_filename
|
||||||
phantom = "phantom" in model_filename
|
phantom = "phantom" in model_filename
|
||||||
|
fantasy = "fantasy" in model_filename
|
||||||
with gr.Column(visible= test_class_i2v(model_filename) or diffusion_forcing or recammaster) as image_prompt_column:
|
with gr.Column(visible= test_class_i2v(model_filename) or diffusion_forcing or recammaster) as image_prompt_column:
|
||||||
if diffusion_forcing:
|
if diffusion_forcing:
|
||||||
image_prompt_type_value= ui_defaults.get("image_prompt_type","S")
|
image_prompt_type_value= ui_defaults.get("image_prompt_type","S")
|
||||||
@ -3939,7 +4009,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|||||||
|
|
||||||
|
|
||||||
video_mask = gr.Video(label= "Video Mask (for Inpainting or Outpaing, white pixels = Mask)", visible= "M" in video_prompt_type_value, value= ui_defaults.get("video_mask", None))
|
video_mask = gr.Video(label= "Video Mask (for Inpainting or Outpaing, white pixels = Mask)", visible= "M" in video_prompt_type_value, value= ui_defaults.get("video_mask", None))
|
||||||
|
audio_guide = gr.Audio(value= ui_defaults.get("audio_guide", None), type="filepath", label="Voice to follow", show_download_button= True, visible= fantasy )
|
||||||
|
|
||||||
advanced_prompt = advanced_ui
|
advanced_prompt = advanced_ui
|
||||||
prompt_vars=[]
|
prompt_vars=[]
|
||||||
@ -3972,12 +4042,12 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|||||||
wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False)
|
wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False)
|
||||||
wizard_variables_var = gr.Text(wizard_variables, visible = False)
|
wizard_variables_var = gr.Text(wizard_variables, visible = False)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
if test_class_i2v(model_filename):
|
if test_class_i2v(model_filename) and False:
|
||||||
resolution = gr.Dropdown(
|
resolution = gr.Dropdown(
|
||||||
choices=[
|
choices=[
|
||||||
# 720p
|
# 720p
|
||||||
("720p", "1280x720"),
|
("720p (same amount of pixels)", "1280x720"),
|
||||||
("480p", "832x480"),
|
("480p (same amount of pixels)", "832x480"),
|
||||||
],
|
],
|
||||||
value=ui_defaults.get("resolution","480p"),
|
value=ui_defaults.get("resolution","480p"),
|
||||||
label="Resolution (video will have the same height / width ratio than the original image)"
|
label="Resolution (video will have the same height / width ratio than the original image)"
|
||||||
@ -3989,19 +4059,21 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|||||||
("1280x720 (16:9, 720p)", "1280x720"),
|
("1280x720 (16:9, 720p)", "1280x720"),
|
||||||
("720x1280 (9:16, 720p)", "720x1280"),
|
("720x1280 (9:16, 720p)", "720x1280"),
|
||||||
("1024x1024 (4:3, 720p)", "1024x024"),
|
("1024x1024 (4:3, 720p)", "1024x024"),
|
||||||
# ("832x1104 (3:4, 720p)", "832x1104"),
|
("832x1104 (3:4, 720p)", "832x1104"),
|
||||||
# ("960x960 (1:1, 720p)", "960x960"),
|
("1104x832 (3:4, 720p)", "1104x832"),
|
||||||
|
("960x960 (1:1, 720p)", "960x960"),
|
||||||
# 480p
|
# 480p
|
||||||
("960x544 (16:9, 540p)", "960x544"),
|
("960x544 (16:9, 540p)", "960x544"),
|
||||||
("544x960 (16:9, 540p)", "544x960"),
|
("544x960 (16:9, 540p)", "544x960"),
|
||||||
("832x480 (16:9, 480p)", "832x480"),
|
("832x480 (16:9, 480p)", "832x480"),
|
||||||
("480x832 (9:16, 480p)", "480x832"),
|
("480x832 (9:16, 480p)", "480x832"),
|
||||||
# ("832x624 (4:3, 540p)", "832x624"),
|
("832x624 (4:3, 480p)", "832x624"),
|
||||||
# ("624x832 (3:4, 540p)", "624x832"),
|
("624x832 (3:4, 480p)", "624x832"),
|
||||||
# ("720x720 (1:1, 540p)", "720x720"),
|
("720x720 (1:1, 480p)", "720x720"),
|
||||||
|
("512x512 (1:1, 480p)", "512x512"),
|
||||||
],
|
],
|
||||||
value=ui_defaults.get("resolution","832x480"),
|
value=ui_defaults.get("resolution","832x480"),
|
||||||
label="Resolution"
|
label="Max Resolution (as it maybe less depending on video width / height ratio)" if test_class_i2v(model_filename) else "Resolution"
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
if recammaster:
|
if recammaster:
|
||||||
@ -4010,6 +4082,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|||||||
video_length = gr.Slider(17, 737, value=ui_defaults.get("video_length", 97), step=20, label="Number of frames (24 = 1s)", interactive= True)
|
video_length = gr.Slider(17, 737, value=ui_defaults.get("video_length", 97), step=20, label="Number of frames (24 = 1s)", interactive= True)
|
||||||
elif vace:
|
elif vace:
|
||||||
video_length = gr.Slider(17, 737, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True)
|
video_length = gr.Slider(17, 737, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True)
|
||||||
|
elif fantasy:
|
||||||
|
video_length = gr.Slider(5, 233, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (23 = 1s)", interactive= True)
|
||||||
else:
|
else:
|
||||||
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True)
|
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -4029,10 +4103,11 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|||||||
choices=[
|
choices=[
|
||||||
("Generate every combination of images and texts", 0),
|
("Generate every combination of images and texts", 0),
|
||||||
("Match images and text prompts", 1),
|
("Match images and text prompts", 1),
|
||||||
], visible= True, label= "Multiple Images as Texts Prompts"
|
], visible= test_class_i2v(model_filename), label= "Multiple Images as Texts Prompts"
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance_scale",5), step=0.5, label="Guidance Scale", visible=True)
|
guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance_scale",5), step=0.5, label="Guidance Scale", visible=True)
|
||||||
|
audio_guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("audio_guidance_scale",5), step=0.5, label="Audio Guidance", visible=fantasy)
|
||||||
embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=False)
|
embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=False)
|
||||||
flow_shift = gr.Slider(0.0, 25.0, value=ui_defaults.get("flow_shift",3), step=0.1, label="Shift Scale")
|
flow_shift = gr.Slider(0.0, 25.0, value=ui_defaults.get("flow_shift",3), step=0.1, label="Shift Scale")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -4099,7 +4174,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|||||||
|
|
||||||
with gr.Tab("Quality"):
|
with gr.Tab("Quality"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
gr.Markdown("<B>Experimental: Skip Layer Guidance, should improve video quality</B>")
|
gr.Markdown("<B>Skip Layer Guidance (improves video quality)</B>")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
slg_switch = gr.Dropdown(
|
slg_switch = gr.Dropdown(
|
||||||
choices=[
|
choices=[
|
||||||
@ -4148,11 +4223,13 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|||||||
if diffusion_forcing:
|
if diffusion_forcing:
|
||||||
sliding_window_size = gr.Slider(37, 137, value=ui_defaults.get("sliding_window_size", 97), step=20, label="Sliding Window Size (recommended to keep it at 97)")
|
sliding_window_size = gr.Slider(37, 137, value=ui_defaults.get("sliding_window_size", 97), step=20, label="Sliding Window Size (recommended to keep it at 97)")
|
||||||
sliding_window_overlap = gr.Slider(17, 97, value=ui_defaults.get("sliding_window_overlap",17), step=20, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
|
sliding_window_overlap = gr.Slider(17, 97, value=ui_defaults.get("sliding_window_overlap",17), step=20, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
|
||||||
sliding_window_discard_last_frames = gr.Slider(0, 12, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, visible = False)
|
sliding_window_overlap_noise = gr.Slider(0, 100, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect")
|
||||||
|
sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, visible = False)
|
||||||
else:
|
else:
|
||||||
sliding_window_size = gr.Slider(5, 137, value=ui_defaults.get("sliding_window_size", 81), step=4, label="Sliding Window Size")
|
sliding_window_size = gr.Slider(5, 137, value=ui_defaults.get("sliding_window_size", 81), step=4, label="Sliding Window Size")
|
||||||
sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",17), step=4, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
|
sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",5), step=4, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
|
||||||
sliding_window_discard_last_frames = gr.Slider(0, 12, value=ui_defaults.get("sliding_window_discard_last_frames", 4), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True)
|
sliding_window_overlap_noise = gr.Slider(0, 100, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect")
|
||||||
|
sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 8), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True)
|
||||||
|
|
||||||
|
|
||||||
with gr.Tab("Miscellaneous", visible= not "recam" in model_filename):
|
with gr.Tab("Miscellaneous", visible= not "recam" in model_filename):
|
||||||
@ -4167,8 +4244,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|||||||
label="RIFLEx positional embedding to generate long video"
|
label="RIFLEx positional embedding to generate long video"
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
save_settings_btn = gr.Button("Set Settings as Default", visible = not args.lock_config)
|
save_settings_btn = gr.Button("Set Settings as Default", visible = not args.lock_config)
|
||||||
|
|
||||||
if not update_form:
|
if not update_form:
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
@ -5035,7 +5112,7 @@ def create_demo():
|
|||||||
theme = gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md")
|
theme = gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md")
|
||||||
|
|
||||||
with gr.Blocks(css=css, theme=theme, title= "Wan2GP") as main:
|
with gr.Blocks(css=css, theme=theme, title= "Wan2GP") as main:
|
||||||
gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.4 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
|
gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.5 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
|
||||||
global model_list
|
global model_list
|
||||||
|
|
||||||
tab_state = gr.State({ "tab_no":0 })
|
tab_state = gr.State({ "tab_no":0 })
|
||||||
@ -5076,7 +5153,7 @@ def create_demo():
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
atexit.register(autosave_queue)
|
atexit.register(autosave_queue)
|
||||||
# download_ffmpeg()
|
download_ffmpeg()
|
||||||
# threading.Thread(target=runner, daemon=True).start()
|
# threading.Thread(target=runner, daemon=True).start()
|
||||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||||
server_port = int(args.server_port)
|
server_port = int(args.server_port)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user