pain reliever

This commit is contained in:
DeepBeepMeep 2025-09-02 20:39:31 +02:00
parent 6490af145a
commit 898b542cc6
14 changed files with 800 additions and 183 deletions

View File

@ -20,6 +20,17 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
## 🔥 Latest Updates : ## 🔥 Latest Updates :
### September 2 2025: WanGP v8.3 - At last the pain stops
- This single new feature should give you the strength to face all the potential bugs of this new release:
**Images Management (multiple additions or deletions, reordering) for Start Images / End Images / Images References.**
- Unofficial **Video to Video (Non Sparse this time) for InfinitTalk**. Use the Strength Noise slider to decide how much motion of the original window you want to keep. I have also *greatly reduced the VRAM requirements for Multitalk / Infinitalk* (especially the multispeakers version & when generating at 1080p).
- **Experimental Sage 3 Attention support**: you will need to deserve this one, first you need a Blackwell GPU (RTX50xx), then you will have to compile Sage 3, install it and cross your fingers that there isn't any crash.
### August 29 2025: WanGP v8.21 - Here Goes Your Weekend ### August 29 2025: WanGP v8.21 - Here Goes Your Weekend
- **InfiniteTalk Video to Video**: this feature can be used for Video Dubbing. Keep in mind that it is a *Sparse Video to Video*, that is internally only image is used by Sliding Window. However thanks to the new *Smooth Transition* mode, each new clip is connected to the previous and all the camera work is done by InfiniteTalk. If you dont get any transition, increase the number of frames of a Sliding Window (81 frames recommended) - **InfiniteTalk Video to Video**: this feature can be used for Video Dubbing. Keep in mind that it is a *Sparse Video to Video*, that is internally only image is used by Sliding Window. However thanks to the new *Smooth Transition* mode, each new clip is connected to the previous and all the camera work is done by InfiniteTalk. If you dont get any transition, increase the number of frames of a Sliding Window (81 frames recommended)

View File

@ -53,7 +53,7 @@ class family_handler():
if base_model_type in ["hunyuan_avatar"]: extra_model_def["no_background_removal"] = True if base_model_type in ["hunyuan_avatar"]: extra_model_def["no_background_removal"] = True
if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_audio", "hunyuan_avatar"]: if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar"]:
extra_model_def["one_image_ref_needed"] = True extra_model_def["one_image_ref_needed"] = True
return extra_model_def return extra_model_def

View File

@ -14,7 +14,7 @@ from torch.nn.attention.flex_attention import (
) )
@lru_cache # @lru_cache
def create_block_mask_cached(score_mod, B, H, M, N, device="cuda", _compile=False): def create_block_mask_cached(score_mod, B, H, M, N, device="cuda", _compile=False):
block_mask = create_block_mask(score_mod, B, H, M, N, device=device, _compile=_compile) block_mask = create_block_mask(score_mod, B, H, M, N, device=device, _compile=_compile)
return block_mask return block_mask

View File

@ -204,7 +204,7 @@ class QwenEmbedRope(nn.Module):
frame, height, width = fhw frame, height, width = fhw
rope_key = f"{idx}_{height}_{width}" rope_key = f"{idx}_{height}_{width}"
if not torch.compiler.is_compiling(): if not torch.compiler.is_compiling() and False:
if rope_key not in self.rope_cache: if rope_key not in self.rope_cache:
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx) self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx)
video_freq = self.rope_cache[rope_key] video_freq = self.rope_cache[rope_key]
@ -224,7 +224,6 @@ class QwenEmbedRope(nn.Module):
return vid_freqs, txt_freqs return vid_freqs, txt_freqs
@functools.lru_cache(maxsize=None)
def _compute_video_freqs(self, frame, height, width, idx=0): def _compute_video_freqs(self, frame, height, width, idx=0):
seq_lens = frame * height * width seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)

View File

@ -19,7 +19,8 @@ from PIL import Image
import torchvision.transforms.functional as TF import torchvision.transforms.functional as TF
import torch.nn.functional as F import torch.nn.functional as F
from .distributed.fsdp import shard_model from .distributed.fsdp import shard_model
from .modules.model import WanModel, clear_caches from .modules.model import WanModel
from mmgp.offload import get_cache, clear_caches
from .modules.t5 import T5EncoderModel from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE from .modules.vae import WanVAE
from .modules.vae2_2 import Wan2_2_VAE from .modules.vae2_2 import Wan2_2_VAE
@ -496,6 +497,8 @@ class WanAny2V:
text_len = self.model.text_len text_len = self.model.text_len
context = torch.cat([context, context.new_zeros(text_len -context.size(0), context.size(1)) ]).unsqueeze(0) context = torch.cat([context, context.new_zeros(text_len -context.size(0), context.size(1)) ]).unsqueeze(0)
context_null = torch.cat([context_null, context_null.new_zeros(text_len -context_null.size(0), context_null.size(1)) ]).unsqueeze(0) context_null = torch.cat([context_null, context_null.new_zeros(text_len -context_null.size(0), context_null.size(1)) ]).unsqueeze(0)
if input_video is not None: height, width = input_video.shape[-2:]
# NAG_prompt = "static, low resolution, blurry" # NAG_prompt = "static, low resolution, blurry"
# context_NAG = self.text_encoder([NAG_prompt], self.device)[0] # context_NAG = self.text_encoder([NAG_prompt], self.device)[0]
# context_NAG = context_NAG.to(self.dtype) # context_NAG = context_NAG.to(self.dtype)
@ -530,9 +533,10 @@ class WanAny2V:
if image_start is None: if image_start is None:
if infinitetalk: if infinitetalk:
if input_frames is not None: if input_frames is not None:
image_ref = input_frames[:, -1] image_ref = input_frames[:, 0]
if input_video is None: input_video = input_frames[:, -1:] if input_video is None: input_video = input_frames[:, 0:1]
new_shot = "Q" in video_prompt_type new_shot = "Q" in video_prompt_type
denoising_strength = 0.5
else: else:
if pre_video_frame is None: if pre_video_frame is None:
new_shot = True new_shot = True
@ -888,6 +892,7 @@ class WanAny2V:
latents[:, :, :extended_overlapped_latents.shape[2]] = extended_overlapped_latents latents[:, :, :extended_overlapped_latents.shape[2]] = extended_overlapped_latents
else: else:
latent_noise_factor = t / 1000 latent_noise_factor = t / 1000
latents[:, :, :extended_overlapped_latents.shape[2]] = extended_overlapped_latents * (1.0 - latent_noise_factor) + torch.randn_like(extended_overlapped_latents ) * latent_noise_factor
if vace: if vace:
overlap_noise_factor = overlap_noise / 1000 overlap_noise_factor = overlap_noise / 1000
for zz in z: for zz in z:

View File

@ -12,6 +12,7 @@ from diffusers.models.modeling_utils import ModelMixin
import numpy as np import numpy as np
from typing import Union,Optional from typing import Union,Optional
from mmgp import offload from mmgp import offload
from mmgp.offload import get_cache, clear_caches
from shared.attention import pay_attention from shared.attention import pay_attention
from torch.backends.cuda import sdp_kernel from torch.backends.cuda import sdp_kernel
from ..multitalk.multitalk_utils import get_attn_map_with_target from ..multitalk.multitalk_utils import get_attn_map_with_target
@ -19,22 +20,6 @@ from ..multitalk.multitalk_utils import get_attn_map_with_target
__all__ = ['WanModel'] __all__ = ['WanModel']
def get_cache(cache_name):
all_cache = offload.shared_state.get("_cache", None)
if all_cache is None:
all_cache = {}
offload.shared_state["_cache"]= all_cache
cache = offload.shared_state.get(cache_name, None)
if cache is None:
cache = {}
offload.shared_state[cache_name] = cache
return cache
def clear_caches():
all_cache = offload.shared_state.get("_cache", None)
if all_cache is not None:
all_cache.clear()
def sinusoidal_embedding_1d(dim, position): def sinusoidal_embedding_1d(dim, position):
# preprocess # preprocess
assert dim % 2 == 0 assert dim % 2 == 0
@ -579,19 +564,23 @@ class WanAttentionBlock(nn.Module):
y = self.norm_x(x) y = self.norm_x(x)
y = y.to(attention_dtype) y = y.to(attention_dtype)
if ref_images_count == 0: if ref_images_count == 0:
x += self.audio_cross_attn(y, encoder_hidden_states=multitalk_audio, shape=grid_sizes, x_ref_attn_map=x_ref_attn_map) ylist= [y]
del y
x += self.audio_cross_attn(ylist, encoder_hidden_states=multitalk_audio, shape=grid_sizes, x_ref_attn_map=x_ref_attn_map)
else: else:
y_shape = y.shape y_shape = y.shape
y = y.reshape(y_shape[0], grid_sizes[0], -1) y = y.reshape(y_shape[0], grid_sizes[0], -1)
y = y[:, ref_images_count:] y = y[:, ref_images_count:]
y = y.reshape(y_shape[0], -1, y_shape[-1]) y = y.reshape(y_shape[0], -1, y_shape[-1])
grid_sizes_alt = [grid_sizes[0]-ref_images_count, *grid_sizes[1:]] grid_sizes_alt = [grid_sizes[0]-ref_images_count, *grid_sizes[1:]]
y = self.audio_cross_attn(y, encoder_hidden_states=multitalk_audio, shape=grid_sizes_alt, x_ref_attn_map=x_ref_attn_map) ylist= [y]
y = None
y = self.audio_cross_attn(ylist, encoder_hidden_states=multitalk_audio, shape=grid_sizes_alt, x_ref_attn_map=x_ref_attn_map)
y = y.reshape(y_shape[0], grid_sizes[0]-ref_images_count, -1) y = y.reshape(y_shape[0], grid_sizes[0]-ref_images_count, -1)
x = x.reshape(y_shape[0], grid_sizes[0], -1) x = x.reshape(y_shape[0], grid_sizes[0], -1)
x[:, ref_images_count:] += y x[:, ref_images_count:] += y
x = x.reshape(y_shape[0], -1, y_shape[-1]) x = x.reshape(y_shape[0], -1, y_shape[-1])
del y del y
y = self.norm2(x) y = self.norm2(x)

View File

@ -221,13 +221,16 @@ class SingleStreamAttention(nn.Module):
self.add_q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.add_q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.add_k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.add_k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None, enable_sp=False, kv_seq=None) -> torch.Tensor: def forward(self, xlist: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None, enable_sp=False, kv_seq=None) -> torch.Tensor:
N_t, N_h, N_w = shape N_t, N_h, N_w = shape
x = xlist[0]
xlist.clear()
x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t) x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t)
# get q for hidden_state # get q for hidden_state
B, N, C = x.shape B, N, C = x.shape
q = self.q_linear(x) q = self.q_linear(x)
del x
q_shape = (B, N, self.num_heads, self.head_dim) q_shape = (B, N, self.num_heads, self.head_dim)
q = q.view(q_shape).permute((0, 2, 1, 3)) q = q.view(q_shape).permute((0, 2, 1, 3))
@ -247,9 +250,6 @@ class SingleStreamAttention(nn.Module):
q = rearrange(q, "B H M K -> B M H K") q = rearrange(q, "B H M K -> B M H K")
encoder_k = rearrange(encoder_k, "B H M K -> B M H K") encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
encoder_v = rearrange(encoder_v, "B H M K -> B M H K") encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
attn_bias = None
# x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=attn_bias, op=None,)
qkv_list = [q, encoder_k, encoder_v] qkv_list = [q, encoder_k, encoder_v]
q = encoder_k = encoder_v = None q = encoder_k = encoder_v = None
x = pay_attention(qkv_list) x = pay_attention(qkv_list)
@ -302,7 +302,7 @@ class SingleStreamMutiAttention(SingleStreamAttention):
self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim) self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim)
def forward(self, def forward(self,
x: torch.Tensor, xlist: torch.Tensor,
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
shape=None, shape=None,
x_ref_attn_map=None, x_ref_attn_map=None,
@ -310,14 +310,17 @@ class SingleStreamMutiAttention(SingleStreamAttention):
encoder_hidden_states = encoder_hidden_states.squeeze(0) encoder_hidden_states = encoder_hidden_states.squeeze(0)
if x_ref_attn_map == None: if x_ref_attn_map == None:
return super().forward(x, encoder_hidden_states, shape) return super().forward(xlist, encoder_hidden_states, shape)
N_t, _, _ = shape N_t, _, _ = shape
x = xlist[0]
xlist.clear()
x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t) x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t)
# get q for hidden_state # get q for hidden_state
B, N, C = x.shape B, N, C = x.shape
q = self.q_linear(x) q = self.q_linear(x)
del x
q_shape = (B, N, self.num_heads, self.head_dim) q_shape = (B, N, self.num_heads, self.head_dim)
q = q.view(q_shape).permute((0, 2, 1, 3)) q = q.view(q_shape).permute((0, 2, 1, 3))
@ -339,7 +342,9 @@ class SingleStreamMutiAttention(SingleStreamAttention):
normalized_pos = normalized_map[range(x_ref_attn_map.size(1)), max_indices] # N normalized_pos = normalized_map[range(x_ref_attn_map.size(1)), max_indices] # N
q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t) q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
q = self.rope_1d(q, normalized_pos) qlist = [q]
del q
q = self.rope_1d(qlist, normalized_pos, "q")
q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t) q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
_, N_a, _ = encoder_hidden_states.shape _, N_a, _ = encoder_hidden_states.shape
@ -347,7 +352,7 @@ class SingleStreamMutiAttention(SingleStreamAttention):
encoder_kv_shape = (B, N_a, 2, self.num_heads, self.head_dim) encoder_kv_shape = (B, N_a, 2, self.num_heads, self.head_dim)
encoder_kv = encoder_kv.view(encoder_kv_shape).permute((2, 0, 3, 1, 4)) encoder_kv = encoder_kv.view(encoder_kv_shape).permute((2, 0, 3, 1, 4))
encoder_k, encoder_v = encoder_kv.unbind(0) encoder_k, encoder_v = encoder_kv.unbind(0)
del encoder_kv
if self.qk_norm: if self.qk_norm:
encoder_k = self.add_k_norm(encoder_k) encoder_k = self.add_k_norm(encoder_k)
@ -356,13 +361,14 @@ class SingleStreamMutiAttention(SingleStreamAttention):
per_frame[per_frame.size(0)//2:] = (self.rope_h2[0] + self.rope_h2[1]) / 2 per_frame[per_frame.size(0)//2:] = (self.rope_h2[0] + self.rope_h2[1]) / 2
encoder_pos = torch.concat([per_frame]*N_t, dim=0) encoder_pos = torch.concat([per_frame]*N_t, dim=0)
encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t) encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
encoder_k = self.rope_1d(encoder_k, encoder_pos) enclist = [encoder_k]
del encoder_k
encoder_k = self.rope_1d(enclist, encoder_pos, "encoder_k")
encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t) encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
q = rearrange(q, "B H M K -> B M H K") q = rearrange(q, "B H M K -> B M H K")
encoder_k = rearrange(encoder_k, "B H M K -> B M H K") encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
encoder_v = rearrange(encoder_v, "B H M K -> B M H K") encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
# x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=None, op=None,)
qkv_list = [q, encoder_k, encoder_v] qkv_list = [q, encoder_k, encoder_v]
q = encoder_k = encoder_v = None q = encoder_k = encoder_v = None
x = pay_attention(qkv_list) x = pay_attention(qkv_list)

View File

@ -16,7 +16,7 @@ import torchvision
import binascii import binascii
import os.path as osp import os.path as osp
from skimage import color from skimage import color
from mmgp.offload import get_cache, clear_caches
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv") VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
ASPECT_RATIO_627 = { ASPECT_RATIO_627 = {
@ -73,42 +73,70 @@ def normalize_and_scale(column, source_range, target_range, epsilon=1e-8):
# @torch.compile # @torch.compile
def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, ref_images_count, mode='mean', attn_bias=None): def calculate_x_ref_attn_map_per_head(visual_q, ref_k, ref_target_masks, ref_images_count, attn_bias=None):
dtype = visual_q.dtype
ref_k = ref_k.to(dtype).to(visual_q.device)
scale = 1.0 / visual_q.shape[-1] ** 0.5
visual_q = visual_q * scale
visual_q = visual_q.transpose(1, 2)
ref_k = ref_k.transpose(1, 2)
visual_q_shape = visual_q.shape
visual_q = visual_q.view(-1, visual_q_shape[-1] )
number_chunks = visual_q_shape[-2]*ref_k.shape[-2] / 53090100 * 2
chunk_size = int(visual_q_shape[-2] / number_chunks)
chunks =torch.split(visual_q, chunk_size)
maps_lists = [ [] for _ in ref_target_masks]
for q_chunk in chunks:
attn = q_chunk @ ref_k.transpose(-2, -1)
x_ref_attn_map_source = attn.softmax(-1) # B, H, x_seqlens, ref_seqlens
del attn
ref_target_masks = ref_target_masks.to(dtype)
x_ref_attn_map_source = x_ref_attn_map_source.to(dtype)
ref_k = ref_k.to(visual_q.dtype).to(visual_q.device) for class_idx, ref_target_mask in enumerate(ref_target_masks):
ref_target_mask = ref_target_mask[None, None, None, ...]
x_ref_attnmap = x_ref_attn_map_source * ref_target_mask
x_ref_attnmap = x_ref_attnmap.sum(-1) / ref_target_mask.sum() # B, H, x_seqlens, ref_seqlens --> B, H, x_seqlens
maps_lists[class_idx].append(x_ref_attnmap)
del x_ref_attn_map_source
x_ref_attn_maps = []
for class_idx, maps_list in enumerate(maps_lists):
attn_map_fuse = torch.concat(maps_list, dim= -1)
attn_map_fuse = attn_map_fuse.view(1, visual_q_shape[1], -1).squeeze(1)
x_ref_attn_maps.append( attn_map_fuse )
return torch.concat(x_ref_attn_maps, dim=0)
def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, ref_images_count):
dtype = visual_q.dtype
ref_k = ref_k.to(dtype).to(visual_q.device)
scale = 1.0 / visual_q.shape[-1] ** 0.5 scale = 1.0 / visual_q.shape[-1] ** 0.5
visual_q = visual_q * scale visual_q = visual_q * scale
visual_q = visual_q.transpose(1, 2) visual_q = visual_q.transpose(1, 2)
ref_k = ref_k.transpose(1, 2) ref_k = ref_k.transpose(1, 2)
attn = visual_q @ ref_k.transpose(-2, -1) attn = visual_q @ ref_k.transpose(-2, -1)
if attn_bias is not None: attn += attn_bias
x_ref_attn_map_source = attn.softmax(-1) # B, H, x_seqlens, ref_seqlens x_ref_attn_map_source = attn.softmax(-1) # B, H, x_seqlens, ref_seqlens
del attn
x_ref_attn_maps = [] x_ref_attn_maps = []
ref_target_masks = ref_target_masks.to(visual_q.dtype) ref_target_masks = ref_target_masks.to(dtype)
x_ref_attn_map_source = x_ref_attn_map_source.to(visual_q.dtype) x_ref_attn_map_source = x_ref_attn_map_source.to(dtype)
for class_idx, ref_target_mask in enumerate(ref_target_masks): for class_idx, ref_target_mask in enumerate(ref_target_masks):
ref_target_mask = ref_target_mask[None, None, None, ...] ref_target_mask = ref_target_mask[None, None, None, ...]
x_ref_attnmap = x_ref_attn_map_source * ref_target_mask x_ref_attnmap = x_ref_attn_map_source * ref_target_mask
x_ref_attnmap = x_ref_attnmap.sum(-1) / ref_target_mask.sum() # B, H, x_seqlens, ref_seqlens --> B, H, x_seqlens x_ref_attnmap = x_ref_attnmap.sum(-1) / ref_target_mask.sum() # B, H, x_seqlens, ref_seqlens --> B, H, x_seqlens
x_ref_attnmap = x_ref_attnmap.permute(0, 2, 1) # B, x_seqlens, H x_ref_attnmap = x_ref_attnmap.permute(0, 2, 1) # B, x_seqlens, H
x_ref_attnmap = x_ref_attnmap.mean(-1) # B, x_seqlens (mean of heads)
if mode == 'mean':
x_ref_attnmap = x_ref_attnmap.mean(-1) # B, x_seqlens
elif mode == 'max':
x_ref_attnmap = x_ref_attnmap.max(-1) # B, x_seqlens
x_ref_attn_maps.append(x_ref_attnmap) x_ref_attn_maps.append(x_ref_attnmap)
del attn
del x_ref_attn_map_source del x_ref_attn_map_source
return torch.concat(x_ref_attn_maps, dim=0) return torch.concat(x_ref_attn_maps, dim=0)
def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=10, ref_images_count = 0): def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=10, ref_images_count = 0):
"""Args: """Args:
query (torch.tensor): B M H K query (torch.tensor): B M H K
@ -120,6 +148,11 @@ def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, spli
N_t, N_h, N_w = shape N_t, N_h, N_w = shape
x_seqlens = N_h * N_w x_seqlens = N_h * N_w
if x_seqlens <= 1508:
split_num = 10 # 540p
else:
split_num = 20 if x_seqlens <= 3600 else 40 # 720p / 1080p
ref_k = ref_k[:, :x_seqlens] ref_k = ref_k[:, :x_seqlens]
if ref_images_count > 0 : if ref_images_count > 0 :
visual_q_shape = visual_q.shape visual_q_shape = visual_q.shape
@ -133,9 +166,14 @@ def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, spli
split_chunk = heads // split_num split_chunk = heads // split_num
for i in range(split_num): if split_chunk == 1:
x_ref_attn_maps_perhead = calculate_x_ref_attn_map(visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_target_masks, ref_images_count) for i in range(split_num):
x_ref_attn_maps += x_ref_attn_maps_perhead x_ref_attn_maps_perhead = calculate_x_ref_attn_map_per_head(visual_q[:, :, i:(i+1), :], ref_k[:, :, i:(i+1), :], ref_target_masks, ref_images_count)
x_ref_attn_maps += x_ref_attn_maps_perhead
else:
for i in range(split_num):
x_ref_attn_maps_perhead = calculate_x_ref_attn_map(visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_target_masks, ref_images_count)
x_ref_attn_maps += x_ref_attn_maps_perhead
x_ref_attn_maps /= split_num x_ref_attn_maps /= split_num
return x_ref_attn_maps return x_ref_attn_maps
@ -158,7 +196,6 @@ class RotaryPositionalEmbedding1D(nn.Module):
self.base = 10000 self.base = 10000
@lru_cache(maxsize=32)
def precompute_freqs_cis_1d(self, pos_indices): def precompute_freqs_cis_1d(self, pos_indices):
freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim)) freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim))
@ -167,7 +204,7 @@ class RotaryPositionalEmbedding1D(nn.Module):
freqs = repeat(freqs, "... n -> ... (n r)", r=2) freqs = repeat(freqs, "... n -> ... (n r)", r=2)
return freqs return freqs
def forward(self, x, pos_indices): def forward(self, qlist, pos_indices, cache_entry = None):
"""1D RoPE. """1D RoPE.
Args: Args:
@ -176,16 +213,26 @@ class RotaryPositionalEmbedding1D(nn.Module):
Returns: Returns:
query with the same shape as input. query with the same shape as input.
""" """
freqs_cis = self.precompute_freqs_cis_1d(pos_indices) xq= qlist[0]
qlist.clear()
x_ = x.float() cache = get_cache("multitalk_rope")
freqs_cis= cache.get(cache_entry, None)
freqs_cis = freqs_cis.float().to(x.device) if freqs_cis is None:
cos, sin = freqs_cis.cos(), freqs_cis.sin() freqs_cis = cache[cache_entry] = self.precompute_freqs_cis_1d(pos_indices)
cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d') cos, sin = freqs_cis.cos().unsqueeze(0).unsqueeze(0), freqs_cis.sin().unsqueeze(0).unsqueeze(0)
x_ = (x_ * cos) + (rotate_half(x_) * sin) # cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
# real * cos - imag * sin
return x_.type_as(x) # imag * cos + real * sin
xq_dtype = xq.dtype
xq_out = xq.to(torch.float)
xq = None
xq_rot = rotate_half(xq_out)
xq_out *= cos
xq_rot *= sin
xq_out += xq_rot
del xq_rot
xq_out = xq_out.to(xq_dtype)
return xq_out

View File

@ -1,5 +1,6 @@
import torch import torch
import numpy as np import numpy as np
import gradio as gr
def test_class_i2v(base_model_type): def test_class_i2v(base_model_type):
return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk" ] return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk" ]
@ -116,6 +117,11 @@ class family_handler():
extra_model_def["no_background_removal"] = True extra_model_def["no_background_removal"] = True
# extra_model_def["at_least_one_image_ref_needed"] = True # extra_model_def["at_least_one_image_ref_needed"] = True
if base_model_type in ["phantom_1.3B", "phantom_14B"]:
extra_model_def["one_image_ref_needed"] = True
return extra_model_def return extra_model_def
@staticmethod @staticmethod
@ -235,6 +241,14 @@ class family_handler():
if "I" in video_prompt_type: if "I" in video_prompt_type:
video_prompt_type = video_prompt_type.replace("KI", "QKI") video_prompt_type = video_prompt_type.replace("KI", "QKI")
ui_defaults["video_prompt_type"] = video_prompt_type ui_defaults["video_prompt_type"] = video_prompt_type
if settings_version < 2.28:
if base_model_type in "infinitetalk":
video_prompt_type = ui_defaults.get("video_prompt_type", "")
if "U" in video_prompt_type:
video_prompt_type = video_prompt_type.replace("U", "RU")
ui_defaults["video_prompt_type"] = video_prompt_type
@staticmethod @staticmethod
def update_default_settings(base_model_type, model_def, ui_defaults): def update_default_settings(base_model_type, model_def, ui_defaults):
ui_defaults.update({ ui_defaults.update({
@ -309,3 +323,11 @@ class family_handler():
if ("V" in image_prompt_type or "L" in image_prompt_type) and image_refs is None: if ("V" in image_prompt_type or "L" in image_prompt_type) and image_refs is None:
video_prompt_type = video_prompt_type.replace("I", "").replace("K","") video_prompt_type = video_prompt_type.replace("I", "").replace("K","")
inputs["video_prompt_type"] = video_prompt_type inputs["video_prompt_type"] = video_prompt_type
if base_model_type in ["vace_standin_14B"]:
image_refs = inputs["image_refs"]
video_prompt_type = inputs["video_prompt_type"]
if image_refs is not None and len(image_refs) == 1 and "K" in video_prompt_type:
gr.Info("Warning, Ref Image for Standin Missing: if 'Landscape and then People or Objects' is selected beside the Landscape Image Ref there should be another Image Ref that contains a Face.")

View File

@ -20,6 +20,7 @@ soundfile
mutagen mutagen
pyloudnorm pyloudnorm
librosa==0.11.0 librosa==0.11.0
speechbrain==1.0.3
# UI & interaction # UI & interaction
gradio==5.23.0 gradio==5.23.0
@ -43,7 +44,7 @@ pydantic==2.10.6
# Math & modeling # Math & modeling
torchdiffeq>=0.2.5 torchdiffeq>=0.2.5
tensordict>=0.6.1 tensordict>=0.6.1
mmgp==3.5.10 mmgp==3.5.11
peft==0.15.0 peft==0.15.0
matplotlib matplotlib

View File

@ -3,6 +3,7 @@ import torch
from importlib.metadata import version from importlib.metadata import version
from mmgp import offload from mmgp import offload
import torch.nn.functional as F import torch.nn.functional as F
import warnings
major, minor = torch.cuda.get_device_capability(None) major, minor = torch.cuda.get_device_capability(None)
bfloat16_supported = major >= 8 bfloat16_supported = major >= 8
@ -42,34 +43,51 @@ except ImportError:
sageattn_varlen_wrapper = None sageattn_varlen_wrapper = None
import warnings
try: try:
from sageattention import sageattn from .sage2_core import sageattn as sageattn2, is_sage2_supported
from .sage2_core import sageattn as alt_sageattn, is_sage2_supported
sage2_supported = is_sage2_supported() sage2_supported = is_sage2_supported()
except ImportError: except ImportError:
sageattn = None sageattn2 = None
alt_sageattn = None
sage2_supported = False sage2_supported = False
# @torch.compiler.disable() @torch.compiler.disable()
def sageattn_wrapper( def sageattn2_wrapper(
qkv_list, qkv_list,
attention_length attention_length
): ):
q,k, v = qkv_list q,k, v = qkv_list
if True: qkv_list = [q,k,v]
qkv_list = [q,k,v] del q, k ,v
del q, k ,v o = sageattn2(qkv_list, tensor_layout="NHD")
o = alt_sageattn(qkv_list, tensor_layout="NHD")
else:
o = sageattn(q, k, v, tensor_layout="NHD")
del q, k ,v
qkv_list.clear() qkv_list.clear()
return o return o
try:
from sageattn import sageattn_blackwell as sageattn3
except ImportError:
sageattn3 = None
@torch.compiler.disable()
def sageattn3_wrapper(
qkv_list,
attention_length
):
q,k, v = qkv_list
# qkv_list = [q,k,v]
# del q, k ,v
# o = sageattn3(qkv_list, tensor_layout="NHD")
q = q.transpose(1,2)
k = k.transpose(1,2)
v = v.transpose(1,2)
o = sageattn3(q, k, v)
o = o.transpose(1,2)
qkv_list.clear()
return o
# try: # try:
# if True: # if True:
# from .sage2_core import sageattn_qk_int8_pv_fp8_window_cuda # from .sage2_core import sageattn_qk_int8_pv_fp8_window_cuda
@ -94,7 +112,7 @@ def sageattn_wrapper(
# return o # return o
# except ImportError: # except ImportError:
# sageattn = sageattn_qk_int8_pv_fp8_window_cuda # sageattn2 = sageattn_qk_int8_pv_fp8_window_cuda
@torch.compiler.disable() @torch.compiler.disable()
def sdpa_wrapper( def sdpa_wrapper(
@ -124,21 +142,28 @@ def get_attention_modes():
ret.append("xformers") ret.append("xformers")
if sageattn_varlen_wrapper != None: if sageattn_varlen_wrapper != None:
ret.append("sage") ret.append("sage")
if sageattn != None and version("sageattention").startswith("2") : if sageattn2 != None and version("sageattention").startswith("2") :
ret.append("sage2") ret.append("sage2")
if sageattn3 != None: # and version("sageattention").startswith("3") :
ret.append("sage3")
return ret return ret
def get_supported_attention_modes(): def get_supported_attention_modes():
ret = get_attention_modes() ret = get_attention_modes()
major, minor = torch.cuda.get_device_capability()
if major < 10:
if "sage3" in ret:
ret.remove("sage3")
if not sage2_supported: if not sage2_supported:
if "sage2" in ret: if "sage2" in ret:
ret.remove("sage2") ret.remove("sage2")
major, minor = torch.cuda.get_device_capability()
if major < 7: if major < 7:
if "sage" in ret: if "sage" in ret:
ret.remove("sage") ret.remove("sage")
return ret return ret
__all__ = [ __all__ = [
@ -201,7 +226,7 @@ def pay_attention(
from src.chipmunk.modules import SparseDiffMlp, SparseDiffAttn from src.chipmunk.modules import SparseDiffMlp, SparseDiffAttn
from src.chipmunk.util import LayerCounter, GLOBAL_CONFIG from src.chipmunk.util import LayerCounter, GLOBAL_CONFIG
if b > 1 and k_lens != None and attn in ("sage2", "sdpa"): if b > 1 and k_lens != None and attn in ("sage2", "sage3", "sdpa"):
assert attention_mask == None assert attention_mask == None
# Poor's man var k len attention # Poor's man var k len attention
assert q_lens == None assert q_lens == None
@ -234,7 +259,7 @@ def pay_attention(
q_chunks, k_chunks, v_chunks = None, None, None q_chunks, k_chunks, v_chunks = None, None, None
o = torch.cat(o, dim = 0) o = torch.cat(o, dim = 0)
return o return o
elif (q_lens != None or k_lens != None) and attn in ("sage2", "sdpa"): elif (q_lens != None or k_lens != None) and attn in ("sage2", "sage3", "sdpa"):
assert b == 1 assert b == 1
szq = q_lens[0].item() if q_lens != None else lq szq = q_lens[0].item() if q_lens != None else lq
szk = k_lens[0].item() if k_lens != None else lk szk = k_lens[0].item() if k_lens != None else lk
@ -284,13 +309,19 @@ def pay_attention(
max_seqlen_q=lq, max_seqlen_q=lq,
max_seqlen_kv=lk, max_seqlen_kv=lk,
).unflatten(0, (b, lq)) ).unflatten(0, (b, lq))
elif attn=="sage3":
import math
if cross_attn or True:
qkv_list = [q,k,v]
del q,k,v
x = sageattn3_wrapper(qkv_list, lq)
elif attn=="sage2": elif attn=="sage2":
import math import math
if cross_attn or True: if cross_attn or True:
qkv_list = [q,k,v] qkv_list = [q,k,v]
del q,k,v del q,k,v
x = sageattn_wrapper(qkv_list, lq) #.unsqueeze(0) x = sageattn2_wrapper(qkv_list, lq) #.unsqueeze(0)
# else: # else:
# layer = offload.shared_state["layer"] # layer = offload.shared_state["layer"]
# embed_sizes = offload.shared_state["embed_sizes"] # embed_sizes = offload.shared_state["embed_sizes"]

496
shared/gradio/gallery.py Normal file
View File

@ -0,0 +1,496 @@
from __future__ import annotations
import os, io, tempfile, mimetypes
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, Literal
import gradio as gr
import PIL
from PIL import Image as PILImage
FilePath = str
ImageLike = Union["PIL.Image.Image", Any]
IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"}
VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm", ".m4v", ".mpeg", ".mpg", ".ogv"}
def get_state(state):
return state if isinstance(state, dict) else state.value
def get_list( objs):
if objs is None:
return []
return [ obj[0] if isinstance(obj, tuple) else obj for obj in objs]
class AdvancedMediaGallery:
def __init__(
self,
label: str = "Media",
*,
media_mode: Literal["image", "video"] = "image",
height = None,
columns: Union[int, Tuple[int, ...]] = 6,
show_label: bool = True,
initial: Optional[Sequence[Union[FilePath, ImageLike]]] = None,
elem_id: Optional[str] = None,
elem_classes: Optional[Sequence[str]] = ("adv-media-gallery",),
accept_filter: bool = True, # restrict Add-button dialog to allowed extensions
single_image_mode: bool = False, # start in single-image mode (Add replaces)
):
assert media_mode in ("image", "video")
self.label = label
self.media_mode = media_mode
self.height = height
self.columns = columns
self.show_label = show_label
self.elem_id = elem_id
self.elem_classes = list(elem_classes) if elem_classes else None
self.accept_filter = accept_filter
items = self._normalize_initial(initial or [], media_mode)
# Components (filled on mount)
self.container: Optional[gr.Column] = None
self.gallery: Optional[gr.Gallery] = None
self.upload_btn: Optional[gr.UploadButton] = None
self.btn_remove: Optional[gr.Button] = None
self.btn_left: Optional[gr.Button] = None
self.btn_right: Optional[gr.Button] = None
self.btn_clear: Optional[gr.Button] = None
# Single dict state
self.state: Optional[gr.State] = None
self._initial_state: Dict[str, Any] = {
"items": items,
"selected": (len(items) - 1) if items else None,
"single": bool(single_image_mode),
"mode": self.media_mode,
}
# ---------------- helpers ----------------
def _normalize_initial(self, items: Sequence[Union[FilePath, ImageLike]], mode: str) -> List[Any]:
out: List[Any] = []
if mode == "image":
for it in items:
p = self._ensure_image_item(it)
if p is not None:
out.append(p)
else:
for it in items:
if isinstance(item, tuple): item = item[0]
if isinstance(it, str) and self._is_video_path(it):
out.append(os.path.abspath(it))
return out
def _ensure_image_item(self, item: Union[FilePath, ImageLike]) -> Optional[Any]:
# Accept a path to an image, or a PIL.Image/np.ndarray -> save temp PNG and return its path
if isinstance(item, tuple): item = item[0]
if isinstance(item, str):
return os.path.abspath(item) if self._is_image_path(item) else None
if PILImage is None:
return None
try:
if isinstance(item, PILImage.Image):
img = item
else:
import numpy as np # type: ignore
if isinstance(item, np.ndarray):
img = PILImage.fromarray(item)
elif hasattr(item, "read"):
data = item.read()
img = PILImage.open(io.BytesIO(data)).convert("RGBA")
else:
return None
tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
img.save(tmp.name)
return tmp.name
except Exception:
return None
@staticmethod
def _extract_path(obj: Any) -> Optional[str]:
# Try to get a filesystem path (for mode filtering); otherwise None.
if isinstance(obj, str):
return obj
try:
import pathlib
if isinstance(obj, pathlib.Path): # type: ignore
return str(obj)
except Exception:
pass
if isinstance(obj, dict):
return obj.get("path") or obj.get("name")
for attr in ("path", "name"):
if hasattr(obj, attr):
try:
val = getattr(obj, attr)
if isinstance(val, str):
return val
except Exception:
pass
return None
@staticmethod
def _is_image_path(p: str) -> bool:
ext = os.path.splitext(p)[1].lower()
if ext in IMAGE_EXTS:
return True
mt, _ = mimetypes.guess_type(p)
return bool(mt and mt.startswith("image/"))
@staticmethod
def _is_video_path(p: str) -> bool:
ext = os.path.splitext(p)[1].lower()
if ext in VIDEO_EXTS:
return True
mt, _ = mimetypes.guess_type(p)
return bool(mt and mt.startswith("video/"))
def _filter_items_by_mode(self, items: List[Any]) -> List[Any]:
# Enforce image-only or video-only collection regardless of how files were added.
out: List[Any] = []
if self.media_mode == "image":
for it in items:
p = self._extract_path(it)
if p is None:
# No path: likely an image object added programmatically => keep
out.append(it)
elif self._is_image_path(p):
out.append(os.path.abspath(p))
else:
for it in items:
p = self._extract_path(it)
if p is not None and self._is_video_path(p):
out.append(os.path.abspath(p))
return out
@staticmethod
def _concat_and_optionally_dedupe(cur: List[Any], add: List[Any]) -> List[Any]:
# Keep it simple: dedupe by path when available, else allow duplicates.
seen_paths = set()
def key(x: Any) -> Optional[str]:
if isinstance(x, str): return os.path.abspath(x)
try:
import pathlib
if isinstance(x, pathlib.Path): # type: ignore
return os.path.abspath(str(x))
except Exception:
pass
if isinstance(x, dict):
p = x.get("path") or x.get("name")
return os.path.abspath(p) if isinstance(p, str) else None
for attr in ("path", "name"):
if hasattr(x, attr):
try:
v = getattr(x, attr)
return os.path.abspath(v) if isinstance(v, str) else None
except Exception:
pass
return None
out: List[Any] = []
for lst in (cur, add):
for it in lst:
k = key(it)
if k is None or k not in seen_paths:
out.append(it)
if k is not None:
seen_paths.add(k)
return out
@staticmethod
def _paths_from_payload(payload: Any) -> List[Any]:
# Return as raw objects (paths/dicts/UploadedFile) to feed Gallery directly.
if payload is None:
return []
if isinstance(payload, (list, tuple, set)):
return list(payload)
return [payload]
# ---------------- event handlers ----------------
def _on_select(self, state: Dict[str, Any], gallery, evt: gr.SelectData) :
# Mirror the selected index into state and the gallery (server-side selected_index)
idx = None
if evt is not None and hasattr(evt, "index"):
ix = evt.index
if isinstance(ix, int):
idx = ix
elif isinstance(ix, (tuple, list)) and ix and isinstance(ix[0], int):
if isinstance(self.columns, int) and len(ix) >= 2:
idx = ix[0] * max(1, int(self.columns)) + ix[1]
else:
idx = ix[0]
st = get_state(state)
n = len(get_list(gallery))
sel = idx if (idx is not None and 0 <= idx < n) else None
st["selected"] = sel
return gr.update(selected_index=sel), st
def _on_gallery_change(self, value: List[Any], state: Dict[str, Any]) :
# Fires when users add/drag/drop/delete via the Gallery itself.
items_filtered = self._filter_items_by_mode(list(value or []))
st = get_state(state)
st["items"] = items_filtered
# Keep selection if still valid, else default to last
old_sel = st.get("selected", None)
if old_sel is None or not (0 <= old_sel < len(items_filtered)):
new_sel = (len(items_filtered) - 1) if items_filtered else None
else:
new_sel = old_sel
st["selected"] = new_sel
return gr.update(value=items_filtered, selected_index=new_sel), st
def _on_add(self, files_payload: Any, state: Dict[str, Any], gallery):
"""
Insert added items right AFTER the currently selected index.
Keeps the same ordering as chosen in the file picker, dedupes by path,
and re-selects the last inserted item.
"""
# New items (respect image/video mode)
new_items = self._filter_items_by_mode(self._paths_from_payload(files_payload))
st = get_state(state)
cur: List[Any] = get_list(gallery)
sel = st.get("selected", None)
if sel is None:
sel = (len(cur) -1) if len(cur)>0 else 0
single = bool(st.get("single", False))
# Nothing to add: keep as-is
if not new_items:
return gr.update(value=cur, selected_index=st.get("selected")), st
# Single-image mode: replace
if single:
st["items"] = [new_items[-1]]
st["selected"] = 0
return gr.update(value=st["items"], selected_index=0), st
# ---------- helpers ----------
def key_of(it: Any) -> Optional[str]:
# Prefer class helper if present
if hasattr(self, "_extract_path"):
p = self._extract_path(it) # type: ignore
else:
p = it if isinstance(it, str) else None
if p is None and isinstance(it, dict):
p = it.get("path") or it.get("name")
if p is None and hasattr(it, "path"):
try: p = getattr(it, "path")
except Exception: p = None
if p is None and hasattr(it, "name"):
try: p = getattr(it, "name")
except Exception: p = None
return os.path.abspath(p) if isinstance(p, str) else None
# Dedupe the incoming batch by path, preserve order
seen_new = set()
incoming: List[Any] = []
for it in new_items:
k = key_of(it)
if k is None or k not in seen_new:
incoming.append(it)
if k is not None:
seen_new.add(k)
# Remove any existing occurrences of the incoming items from current list,
# BUT keep the currently selected item even if it's also in incoming.
cur_clean: List[Any] = []
# sel_item = cur[sel] if (sel is not None and 0 <= sel < len(cur)) else None
# for idx, it in enumerate(cur):
# k = key_of(it)
# if it is sel_item:
# cur_clean.append(it)
# continue
# if k is not None and k in seen_new:
# continue # drop duplicate; we'll reinsert at the target spot
# cur_clean.append(it)
# # Compute insertion position: right AFTER the (possibly shifted) selected item
# if sel_item is not None:
# # find sel_item's new index in cur_clean
# try:
# pos_sel = cur_clean.index(sel_item)
# except ValueError:
# # Shouldn't happen, but fall back to end
# pos_sel = len(cur_clean) - 1
# insert_pos = pos_sel + 1
# else:
# insert_pos = len(cur_clean) # no selection -> append at end
insert_pos = min(sel, len(cur) -1)
cur_clean = cur
# Build final list and selection
merged = cur_clean[:insert_pos+1] + incoming + cur_clean[insert_pos+1:]
new_sel = insert_pos + len(incoming) # select the last inserted item
st["items"] = merged
st["selected"] = new_sel
return gr.update(value=merged, selected_index=new_sel), st
def _on_remove(self, state: Dict[str, Any], gallery) :
st = get_state(state); items: List[Any] = get_list(gallery); sel = st.get("selected", None)
if sel is None or not (0 <= sel < len(items)):
return gr.update(value=items, selected_index=st.get("selected")), st
items.pop(sel)
if not items:
st["items"] = []; st["selected"] = None
return gr.update(value=[], selected_index=None), st
new_sel = min(sel, len(items) - 1)
st["items"] = items; st["selected"] = new_sel
return gr.update(value=items, selected_index=new_sel), st
def _on_move(self, delta: int, state: Dict[str, Any], gallery) :
st = get_state(state); items: List[Any] = get_list(gallery); sel = st.get("selected", None)
if sel is None or not (0 <= sel < len(items)):
return gr.update(value=items, selected_index=sel), st
j = sel + delta
if j < 0 or j >= len(items):
return gr.update(value=items, selected_index=sel), st
items[sel], items[j] = items[j], items[sel]
st["items"] = items; st["selected"] = j
return gr.update(value=items, selected_index=j), st
def _on_clear(self, state: Dict[str, Any]) :
st = {"items": [], "selected": None, "single": state.get("single", False), "mode": self.media_mode}
return gr.update(value=[], selected_index=None), st
def _on_toggle_single(self, to_single: bool, state: Dict[str, Any]) :
st = get_state(state); st["single"] = bool(to_single)
items: List[Any] = list(st["items"]); sel = st.get("selected", None)
if st["single"]:
keep = items[sel] if (sel is not None and 0 <= sel < len(items)) else (items[-1] if items else None)
items = [keep] if keep is not None else []
sel = 0 if items else None
st["items"] = items; st["selected"] = sel
upload_update = gr.update(file_count=("single" if st["single"] else "multiple"))
left_update = gr.update(visible=not st["single"])
right_update = gr.update(visible=not st["single"])
clear_update = gr.update(visible=not st["single"])
gallery_update= gr.update(value=items, selected_index=sel)
return upload_update, left_update, right_update, clear_update, gallery_update, st
# ---------------- build & wire ----------------
def mount(self, parent: Optional[gr.Blocks | gr.Group | gr.Row | gr.Column] = None, update_form = False):
if parent is not None:
with parent:
col = self._build_ui()
else:
col = self._build_ui()
if not update_form:
self._wire_events()
return col
def _build_ui(self) -> gr.Column:
with gr.Column(elem_id=self.elem_id, elem_classes=self.elem_classes) as col:
self.container = col
self.state = gr.State(dict(self._initial_state))
self.gallery = gr.Gallery(
label=self.label,
value=self._initial_state["items"],
height=self.height,
columns=self.columns,
show_label=self.show_label,
preview= True,
selected_index=self._initial_state["selected"], # server-side selection
)
# One-line controls
exts = sorted(IMAGE_EXTS if self.media_mode == "image" else VIDEO_EXTS) if self.accept_filter else None
with gr.Row(equal_height=True, elem_classes=["amg-controls"]):
self.upload_btn = gr.UploadButton(
"Set" if self._initial_state["single"] else "Add",
file_types=exts,
file_count=("single" if self._initial_state["single"] else "multiple"),
variant="primary",
size="sm",
min_width=1,
)
self.btn_remove = gr.Button("Remove", size="sm", min_width=1)
self.btn_left = gr.Button("◀ Left", size="sm", visible=not self._initial_state["single"], min_width=1)
self.btn_right = gr.Button("Right ▶", size="sm", visible=not self._initial_state["single"], min_width=1)
self.btn_clear = gr.Button("Clear", variant="secondary", size="sm", visible=not self._initial_state["single"], min_width=1)
return col
def _wire_events(self):
# Selection: mirror into state and keep gallery.selected_index in sync
self.gallery.select(
self._on_select,
inputs=[self.state, self.gallery],
outputs=[self.gallery, self.state],
)
# Gallery value changed by user actions (click-to-add, drag-drop, internal remove, etc.)
self.gallery.change(
self._on_gallery_change,
inputs=[self.gallery, self.state],
outputs=[self.gallery, self.state],
)
# Add via UploadButton
self.upload_btn.upload(
self._on_add,
inputs=[self.upload_btn, self.state, self.gallery],
outputs=[self.gallery, self.state],
)
# Remove selected
self.btn_remove.click(
self._on_remove,
inputs=[self.state, self.gallery],
outputs=[self.gallery, self.state],
)
# Reorder using selected index, keep same item selected
self.btn_left.click(
lambda st, gallery: self._on_move(-1, st, gallery),
inputs=[self.state, self.gallery],
outputs=[self.gallery, self.state],
)
self.btn_right.click(
lambda st, gallery: self._on_move(+1, st, gallery),
inputs=[self.state, self.gallery],
outputs=[self.gallery, self.state],
)
# Clear all
self.btn_clear.click(
self._on_clear,
inputs=[self.state],
outputs=[self.gallery, self.state],
)
# ---------------- public API ----------------
def set_one_image_mode(self, enabled: bool = True):
"""Toggle single-image mode at runtime."""
return (
self._on_toggle_single,
[gr.State(enabled), self.state],
[self.upload_btn, self.btn_left, self.btn_right, self.btn_clear, self.gallery, self.state],
)
def get_toggable_elements(self):
return [self.upload_btn, self.btn_left, self.btn_right, self.btn_clear, self.gallery, self.state]
# import gradio as gr
# with gr.Blocks() as demo:
# amg = AdvancedMediaGallery(media_mode="image", height=190, columns=8)
# amg.mount()
# g = amg.gallery
# # buttons to switch modes live (optional)
# def process(g):
# pass
# with gr.Row():
# gr.Button("toto").click(process, g)
# gr.Button("ONE image").click(*amg.set_one_image_mode(True))
# gr.Button("MULTI image").click(*amg.set_one_image_mode(False))
# demo.launch()

View File

@ -9,7 +9,6 @@ import os
import sys import sys
import time import time
import warnings import warnings
from functools import lru_cache
from io import BytesIO from io import BytesIO
import requests import requests
@ -257,7 +256,6 @@ VIDEO_READER_BACKENDS = {
FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None) FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None)
@lru_cache(maxsize=1)
def get_video_reader_backend() -> str: def get_video_reader_backend() -> str:
if FORCE_QWENVL_VIDEO_READER is not None: if FORCE_QWENVL_VIDEO_READER is not None:
video_reader_backend = FORCE_QWENVL_VIDEO_READER video_reader_backend = FORCE_QWENVL_VIDEO_READER

186
wgp.py
View File

@ -49,6 +49,7 @@ logging.set_verbosity_error
from preprocessing.matanyone import app as matanyone_app from preprocessing.matanyone import app as matanyone_app
from tqdm import tqdm from tqdm import tqdm
import requests import requests
from shared.gradio.gallery import AdvancedMediaGallery
# import torch._dynamo as dynamo # import torch._dynamo as dynamo
# dynamo.config.recompile_limit = 2000 # default is 256 # dynamo.config.recompile_limit = 2000 # default is 256
@ -58,9 +59,9 @@ global_queue_ref = []
AUTOSAVE_FILENAME = "queue.zip" AUTOSAVE_FILENAME = "queue.zip"
PROMPT_VARS_MAX = 10 PROMPT_VARS_MAX = 10
target_mmgp_version = "3.5.10" target_mmgp_version = "3.5.11"
WanGP_version = "8.21" WanGP_version = "8.3"
settings_version = 2.27 settings_version = 2.28
max_source_video_frames = 3000 max_source_video_frames = 3000
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
@ -466,7 +467,7 @@ def process_prompt_and_add_tasks(state, model_choice):
image_mask = None image_mask = None
if "G" in video_prompt_type: if "G" in video_prompt_type:
gr.Info(f"With Denoising Strength {denoising_strength:.1f}, denoising will start a Step no {int(num_inference_steps * (1. - denoising_strength))} ") gr.Info(f"With Denoising Strength {denoising_strength:.1f}, denoising will start at Step no {int(num_inference_steps * (1. - denoising_strength))} ")
else: else:
denoising_strength = 1.0 denoising_strength = 1.0
if len(keep_frames_video_guide) > 0 and model_type in ["ltxv_13B"]: if len(keep_frames_video_guide) > 0 and model_type in ["ltxv_13B"]:
@ -4733,46 +4734,61 @@ def generate_video(
# special treatment for start frame pos when alignement to first frame requested as otherwise the start frame number will be negative due to overlapped frames (has been previously compensated later with padding) # special treatment for start frame pos when alignement to first frame requested as otherwise the start frame number will be negative due to overlapped frames (has been previously compensated later with padding)
audio_proj_split = get_window_audio_embeddings(audio_proj_full, audio_start_idx= aligned_window_start_frame + (source_video_overlap_frames_count if reset_control_aligment else 0 ), clip_length = current_video_length) audio_proj_split = get_window_audio_embeddings(audio_proj_full, audio_start_idx= aligned_window_start_frame + (source_video_overlap_frames_count if reset_control_aligment else 0 ), clip_length = current_video_length)
if video_guide is not None: if video_guide is not None:
keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate) keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate)
if len(error) > 0: if len(error) > 0:
raise gr.Error(f"invalid keep frames {keep_frames_video_guide}") raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
keep_frames_parsed = keep_frames_parsed[aligned_guide_start_frame: aligned_guide_end_frame ] keep_frames_parsed = keep_frames_parsed[aligned_guide_start_frame: aligned_guide_end_frame ]
if infinitetalk and video_guide is not None:
src_image = get_video_frame(video_guide, aligned_guide_start_frame-1, return_last_if_missing = True, return_PIL = True) if ltxv:
new_height, new_width = calculate_new_dimensions(image_size[0], image_size[1], src_image.height, src_image.width, sample_fit_canvas, block_size = block_size) preprocess_type = process_map_video_guide.get(filter_letters(video_prompt_type, "PED"), "raw")
src_image = src_image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) status_info = "Extracting " + processes_names[preprocess_type]
refresh_preview["video_guide"] = src_image send_cmd("progress", [0, get_latest_status(state, status_info)])
src_video = convert_image_to_tensor(src_image).unsqueeze(1) # start one frame ealier to facilitate latents merging later
if sample_fit_canvas != None: src_video, _ = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) + (0 if aligned_guide_start_frame == 0 else 1), start_frame = aligned_guide_start_frame - (0 if aligned_guide_start_frame == 0 else 1), fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type, inpaint_color = 0, proc_no =1, negate_mask = "N" in video_prompt_type, process_outside_mask = "inpaint" if "X" in video_prompt_type else "identity", block_size =block_size )
image_size = src_video.shape[-2:] if src_video != None:
sample_fit_canvas = None src_video = src_video[ :(len(src_video)-1)// latent_size * latent_size +1 ]
if ltxv and video_guide is not None: refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy())
preprocess_type = process_map_video_guide.get(filter_letters(video_prompt_type, "PED"), "raw") src_video = src_video.permute(3, 0, 1, 2)
status_info = "Extracting " + processes_names[preprocess_type] src_video = src_video.float().div_(127.5).sub_(1.) # c, f, h, w
send_cmd("progress", [0, get_latest_status(state, status_info)]) if sample_fit_canvas != None:
# start one frame ealier to facilitate latents merging later image_size = src_video.shape[-2:]
src_video, _ = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) + (0 if aligned_guide_start_frame == 0 else 1), start_frame = aligned_guide_start_frame - (0 if aligned_guide_start_frame == 0 else 1), fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type, inpaint_color = 0, proc_no =1, negate_mask = "N" in video_prompt_type, process_outside_mask = "inpaint" if "X" in video_prompt_type else "identity", block_size =block_size ) sample_fit_canvas = None
if src_video != None:
src_video = src_video[ :(len(src_video)-1)// latent_size * latent_size +1 ] elif hunyuan_custom_edit:
if "P" in video_prompt_type:
progress_args = [0, get_latest_status(state,"Extracting Open Pose Information and Expanding Mask")]
else:
progress_args = [0, get_latest_status(state,"Extracting Video and Mask")]
send_cmd("progress", progress_args)
src_video, src_mask = preprocess_video_with_mask(video_guide, video_mask, height=height, width = width, max_frames= current_video_length if window_no == 1 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type= "pose" if "P" in video_prompt_type else "inpaint", negate_mask = "N" in video_prompt_type, inpaint_color =0)
refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy()) refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy())
src_video = src_video.permute(3, 0, 1, 2) if src_mask != None:
src_video = src_video.float().div_(127.5).sub_(1.) # c, f, h, w refresh_preview["video_mask"] = Image.fromarray(src_mask[0].cpu().numpy())
elif "R" in video_prompt_type: # sparse video to video
src_image = get_video_frame(video_guide, aligned_guide_start_frame, return_last_if_missing = True, return_PIL = True)
new_height, new_width = calculate_new_dimensions(image_size[0], image_size[1], src_image.height, src_image.width, sample_fit_canvas, block_size = block_size)
src_image = src_image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
refresh_preview["video_guide"] = src_image
src_video = convert_image_to_tensor(src_image).unsqueeze(1)
if sample_fit_canvas != None: if sample_fit_canvas != None:
image_size = src_video.shape[-2:] image_size = src_video.shape[-2:]
sample_fit_canvas = None sample_fit_canvas = None
if t2v and "G" in video_prompt_type: elif "G" in video_prompt_type: # video to video
video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, target_fps = fps) video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, target_fps = fps)
if video_guide_processed == None: if video_guide_processed is None:
src_video = pre_video_guide src_video = pre_video_guide
else: else:
if sample_fit_canvas != None: if sample_fit_canvas != None:
image_size = video_guide_processed.shape[-3: -1] image_size = video_guide_processed.shape[-3: -1]
sample_fit_canvas = None sample_fit_canvas = None
src_video = video_guide_processed.float().div_(127.5).sub_(1.).permute(-1,0,1,2) src_video = video_guide_processed.float().div_(127.5).sub_(1.).permute(-1,0,1,2)
if pre_video_guide != None: if pre_video_guide != None:
src_video = torch.cat( [pre_video_guide, src_video], dim=1) src_video = torch.cat( [pre_video_guide, src_video], dim=1)
if vace : if vace :
image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications
@ -4834,17 +4850,8 @@ def generate_video(
if sample_fit_canvas != None: if sample_fit_canvas != None:
image_size = src_video[0].shape[-2:] image_size = src_video[0].shape[-2:]
sample_fit_canvas = None sample_fit_canvas = None
elif hunyuan_custom_edit:
if "P" in video_prompt_type:
progress_args = [0, get_latest_status(state,"Extracting Open Pose Information and Expanding Mask")]
else:
progress_args = [0, get_latest_status(state,"Extracting Video and Mask")]
send_cmd("progress", progress_args)
src_video, src_mask = preprocess_video_with_mask(video_guide, video_mask, height=height, width = width, max_frames= current_video_length if window_no == 1 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type= "pose" if "P" in video_prompt_type else "inpaint", negate_mask = "N" in video_prompt_type, inpaint_color =0)
refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy())
if src_mask != None:
refresh_preview["video_mask"] = Image.fromarray(src_mask[0].cpu().numpy())
if len(refresh_preview) > 0: if len(refresh_preview) > 0:
new_inputs= locals() new_inputs= locals()
new_inputs.update(refresh_preview) new_inputs.update(refresh_preview)
@ -6013,10 +6020,16 @@ def video_to_source_video(state, input_file_list, choice):
def image_to_ref_image_add(state, input_file_list, choice, target, target_name): def image_to_ref_image_add(state, input_file_list, choice, target, target_name):
file_list, file_settings_list = get_file_list(state, input_file_list) file_list, file_settings_list = get_file_list(state, input_file_list)
if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list): return gr.update() if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list): return gr.update()
gr.Info(f"Selected Image was added to {target_name}") model_type = state["model_type"]
if target == None: model_def = get_model_def(model_type)
target =[] if model_def.get("one_image_ref_needed", False):
target.append( file_list[choice]) gr.Info(f"Selected Image was set to {target_name}")
target =[file_list[choice]]
else:
gr.Info(f"Selected Image was added to {target_name}")
if target == None:
target =[]
target.append( file_list[choice])
return target return target
def image_to_ref_image_set(state, input_file_list, choice, target, target_name): def image_to_ref_image_set(state, input_file_list, choice, target, target_name):
@ -6229,6 +6242,7 @@ def get_settings_from_file(state, file_path, allow_json, merge_with_defaults, sw
if not "WanGP" in configs.get("type", ""): configs = None if not "WanGP" in configs.get("type", ""): configs = None
except: except:
configs = None configs = None
if configs is None: return None, False
current_model_filename = state["model_filename"] current_model_filename = state["model_filename"]
@ -6615,11 +6629,12 @@ def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt
return video_prompt_type, gr.update(visible = visible and not image_outputs), gr.update(visible = visible and image_outputs), gr.update(visible = keep_frames_video_guide_visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and vace), gr.update(visible= visible and not "U" in video_prompt_type ), gr.update(visible= mask_visible and not image_outputs), gr.update(visible= mask_visible and image_outputs), gr.update(visible= mask_visible) return video_prompt_type, gr.update(visible = visible and not image_outputs), gr.update(visible = visible and image_outputs), gr.update(visible = keep_frames_video_guide_visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and vace), gr.update(visible= visible and not "U" in video_prompt_type ), gr.update(visible= mask_visible and not image_outputs), gr.update(visible= mask_visible and image_outputs), gr.update(visible= mask_visible)
def refresh_video_prompt_type_video_guide_alt(state, video_prompt_type, video_prompt_type_video_guide_alt): def refresh_video_prompt_type_video_guide_alt(state, video_prompt_type, video_prompt_type_video_guide_alt):
video_prompt_type = del_in_sequence(video_prompt_type, "UVQKI") video_prompt_type = del_in_sequence(video_prompt_type, "RGUVQKI")
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide_alt) video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide_alt)
control_video_visible = "V" in video_prompt_type control_video_visible = "V" in video_prompt_type
ref_images_visible = "I" in video_prompt_type ref_images_visible = "I" in video_prompt_type
return video_prompt_type, gr.update(visible = control_video_visible), gr.update(visible = ref_images_visible ) denoising_strength_visible = "G" in video_prompt_type
return video_prompt_type, gr.update(visible = control_video_visible), gr.update(visible = ref_images_visible ), gr.update(visible = denoising_strength_visible )
# def refresh_video_prompt_video_guide_trigger(state, video_prompt_type, video_prompt_type_video_guide): # def refresh_video_prompt_video_guide_trigger(state, video_prompt_type, video_prompt_type_video_guide):
# video_prompt_type_video_guide = video_prompt_type_video_guide.split("#")[0] # video_prompt_type_video_guide = video_prompt_type_video_guide.split("#")[0]
@ -6996,6 +7011,12 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
v2i_switch_supported = (vace or t2v or standin) and not image_outputs v2i_switch_supported = (vace or t2v or standin) and not image_outputs
ti2v_2_2 = base_model_type in ["ti2v_2_2"] ti2v_2_2 = base_model_type in ["ti2v_2_2"]
def get_image_gallery(label ="", value = None, single_image_mode = False, visible = False ):
with gr.Row(visible = visible) as gallery_row:
gallery_amg = AdvancedMediaGallery(media_mode="image", height=None, columns=4, label=label, initial = value , single_image_mode = single_image_mode )
gallery_amg.mount(update_form=update_form)
return gallery_row, gallery_amg.gallery, [gallery_row] + gallery_amg.get_toggable_elements()
image_mode_value = ui_defaults.get("image_mode", 1 if image_outputs else 0 ) image_mode_value = ui_defaults.get("image_mode", 1 if image_outputs else 0 )
if not v2i_switch_supported and not image_outputs: if not v2i_switch_supported and not image_outputs:
image_mode_value = 0 image_mode_value = 0
@ -7009,15 +7030,14 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
with gr.Tab("Text to Image", id = "t2i", elem_classes="compact_tab"): with gr.Tab("Text to Image", id = "t2i", elem_classes="compact_tab"):
pass pass
with gr.Column(visible= test_class_i2v(model_type) or hunyuan_i2v or diffusion_forcing or ltxv or recammaster or vace or ti2v_2_2) as image_prompt_column: with gr.Column(visible= test_class_i2v(model_type) or hunyuan_i2v or diffusion_forcing or ltxv or recammaster or vace or ti2v_2_2) as image_prompt_column:
if vace or infinitetalk: if vace or infinitetalk:
image_prompt_type_value= ui_defaults.get("image_prompt_type","") image_prompt_type_value= ui_defaults.get("image_prompt_type","")
image_prompt_type_value = "" if image_prompt_type_value == "S" else image_prompt_type_value image_prompt_type_value = "" if image_prompt_type_value == "S" else image_prompt_type_value
image_prompt_type = gr.Radio( [("New Video", ""),("Continue Video File", "V"),("Continue Last Video", "L")], value =image_prompt_type_value, label="Source Video", show_label= False, visible= not image_outputs , scale= 3) image_prompt_type = gr.Radio( [("New Video", ""),("Continue Video File", "V"),("Continue Last Video", "L")], value =image_prompt_type_value, label="Source Video", show_label= False, visible= not image_outputs , scale= 3)
image_start = gr.Gallery(visible = False) image_start_row, image_start, image_start_extra = get_image_gallery(visible = False )
image_end = gr.Gallery(visible = False) image_end_row, image_end, image_end_extra = get_image_gallery(visible = False )
video_source = gr.Video(label= "Video Source", visible = "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None)) video_source = gr.Video(label= "Video Source", visible = "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None))
model_mode = gr.Dropdown(visible = False) model_mode = gr.Dropdown(visible = False)
keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= len(filter_letters(image_prompt_type_value, "VLG"))>0 , scale = 2, label= "Truncate Video beyond this number of resampled Frames (empty=Keep All, negative truncates from End)" ) keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= len(filter_letters(image_prompt_type_value, "VLG"))>0 , scale = 2, label= "Truncate Video beyond this number of resampled Frames (empty=Keep All, negative truncates from End)" )
@ -7034,13 +7054,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
image_prompt_type_choices += [("Continue Video", "V")] image_prompt_type_choices += [("Continue Video", "V")]
image_prompt_type = gr.Radio( image_prompt_type_choices, value =image_prompt_type_value, label="Location", show_label= False, visible= True , scale= 3) image_prompt_type = gr.Radio( image_prompt_type_choices, value =image_prompt_type_value, label="Location", show_label= False, visible= True , scale= 3)
# image_start = gr.Image(label= "Image as a starting point for a new video", type ="pil",value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value ) image_start_row, image_start, image_start_extra = get_image_gallery(label= "Images as starting points for new videos", value = ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value )
image_start = gr.Gallery(preview= True, image_end_row, image_end, image_end_extra = get_image_gallery(label= "Images as ending points for new videos", value = ui_defaults.get("image_end", None), visible= "E" in image_prompt_type_value )
label="Images as starting points for new videos", type ="pil", #file_types= "image",
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value)
image_end = gr.Gallery(preview= True,
label="Images as ending points for new videos", type ="pil", #file_types= "image",
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None))
video_source = gr.Video(label= "Video to Continue", visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),) video_source = gr.Video(label= "Video to Continue", visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),)
if not diffusion_forcing: if not diffusion_forcing:
model_mode = gr.Dropdown( model_mode = gr.Dropdown(
@ -7061,8 +7076,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= "V" in image_prompt_type_value, scale = 2, label= "Truncate Video beyond this number of Frames of Video (empty=Keep All)" ) keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= "V" in image_prompt_type_value, scale = 2, label= "Truncate Video beyond this number of Frames of Video (empty=Keep All)" )
elif recammaster: elif recammaster:
image_prompt_type = gr.Radio(choices=[("Source Video", "V")], value="V") image_prompt_type = gr.Radio(choices=[("Source Video", "V")], value="V")
image_start = gr.Gallery(value = None, visible = False) image_start_row, image_start, image_start_extra = get_image_gallery(visible = False )
image_end = gr.Gallery(value = None, visible= False) image_end_row, image_end, image_end_extra = get_image_gallery(visible = False )
video_source = gr.Video(label= "Video Source", visible = True, value= ui_defaults.get("video_source", None),) video_source = gr.Video(label= "Video Source", visible = True, value= ui_defaults.get("video_source", None),)
model_mode = gr.Dropdown( model_mode = gr.Dropdown(
choices=[ choices=[
@ -7095,21 +7110,16 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
image_prompt_type = gr.Radio( image_prompt_type_choices, value =image_prompt_type_value, label="Location", show_label= False, visible= not hunyuan_i2v, scale= 3) image_prompt_type = gr.Radio( image_prompt_type_choices, value =image_prompt_type_value, label="Location", show_label= False, visible= not hunyuan_i2v, scale= 3)
any_start_image = True any_start_image = True
any_end_image = True any_end_image = True
image_start = gr.Gallery(preview= True, image_start_row, image_start, image_start_extra = get_image_gallery(label= "Images as starting points for new videos", value = ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value )
label="Images as starting points for new videos", type ="pil", #file_types= "image", image_end_row, image_end, image_end_extra = get_image_gallery(label= "Images as ending points for new videos", value = ui_defaults.get("image_end", None), visible= "E" in image_prompt_type_value )
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value)
image_end = gr.Gallery(preview= True,
label="Images as ending points for new videos", type ="pil", #file_types= "image",
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None))
if hunyuan_i2v: if hunyuan_i2v:
video_source = gr.Video(value=None, visible=False) video_source = gr.Video(value=None, visible=False)
else: else:
video_source = gr.Video(label= "Video to Continue", visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),) video_source = gr.Video(label= "Video to Continue", visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),)
else: else:
image_prompt_type = gr.Radio(choices=[("", "")], value="") image_prompt_type = gr.Radio(choices=[("", "")], value="")
image_start = gr.Gallery(value=None) image_start_row, image_start, image_start_extra = get_image_gallery(visible = False )
image_end = gr.Gallery(value=None) image_end_row, image_end, image_end_extra = get_image_gallery(visible = False )
video_source = gr.Video(value=None, visible=False) video_source = gr.Video(value=None, visible=False)
model_mode = gr.Dropdown(value=None, visible=False) model_mode = gr.Dropdown(value=None, visible=False)
keep_frames_video_source = gr.Text(visible=False) keep_frames_video_source = gr.Text(visible=False)
@ -7184,12 +7194,14 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
if infinitetalk: if infinitetalk:
video_prompt_type_video_guide_alt = gr.Dropdown( video_prompt_type_video_guide_alt = gr.Dropdown(
choices=[ choices=[
("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Smooth Transitions", "UV"),
("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Sharp Transitions", "QUV"),
("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Smooth Transitions", "KI"),
("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Sharp Transitions", "QKI"), ("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Sharp Transitions", "QKI"),
("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Smooth Transitions", "KI"),
("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Sharp Transitions", "QRUV"),
("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Smooth Transitions", "RUV"),
("Video to Video, amount of motion transferred depends on Denoising Strength - Sharp Transitions", "GQUV"),
("Video to Video, amount of motion transferred depends on Denoising Strength - Smooth Transitions", "GUV"),
], ],
value=filter_letters(video_prompt_type_value, "UVQKI"), value=filter_letters(video_prompt_type_value, "RGUVQKI"),
label="Video to Video", scale = 3, visible= True, show_label= False, label="Video to Video", scale = 3, visible= True, show_label= False,
) )
else: else:
@ -7318,11 +7330,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value ) mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value )
any_reference_image = vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or infinitetalk or (flux or qwen) and model_reference_image any_reference_image = vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or infinitetalk or (flux or qwen) and model_reference_image
image_refs = gr.Gallery(preview= True, label ="Start Image" if hunyuan_video_avatar else "Reference Images" + (" (each Image will start a new Clip)" if infinitetalk else ""),
type ="pil", show_label= True, image_refs_single_image_mode = model_def.get("one_image_ref_needed", False)
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in video_prompt_type_value, image_refs_label = "Start Image" if hunyuan_video_avatar else ("Reference Image" if image_refs_single_image_mode else "Reference Images") + (" (each Image will start a new Clip)" if infinitetalk else "")
value= ui_defaults.get("image_refs", None), image_refs_row, image_refs, image_refs_extra = get_image_gallery(label= image_refs_label, value = ui_defaults.get("image_refs", None), visible= "I" in video_prompt_type_value, single_image_mode=image_refs_single_image_mode)
)
frames_positions = gr.Text(value=ui_defaults.get("frames_positions","") , visible= "F" in video_prompt_type_value, scale = 2, label= "Positions of Injected Frames separated by Spaces (1=first, no position for Objects / People)" ) frames_positions = gr.Text(value=ui_defaults.get("frames_positions","") , visible= "F" in video_prompt_type_value, scale = 2, label= "Positions of Injected Frames separated by Spaces (1=first, no position for Objects / People)" )
image_refs_relative_size = gr.Slider(20, 100, value=ui_defaults.get("image_refs_relative_size", 50), step=1, label="Rescale Internaly Image Ref (% in relation to Output Video) to change Output Composition", visible = model_def.get("any_image_refs_relative_size", False) and image_outputs) image_refs_relative_size = gr.Slider(20, 100, value=ui_defaults.get("image_refs_relative_size", 50), step=1, label="Rescale Internaly Image Ref (% in relation to Output Video) to change Output Composition", visible = model_def.get("any_image_refs_relative_size", False) and image_outputs)
@ -7935,7 +7946,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
video_buttons_row, image_buttons_row, video_postprocessing_tab, audio_remuxing_tab, PP_MMAudio_row, PP_custom_audio_row, video_buttons_row, image_buttons_row, video_postprocessing_tab, audio_remuxing_tab, PP_MMAudio_row, PP_custom_audio_row,
video_info_to_start_image_btn, video_info_to_end_image_btn, video_info_to_reference_image_btn, video_info_to_image_guide_btn, video_info_to_image_mask_btn, video_info_to_start_image_btn, video_info_to_end_image_btn, video_info_to_reference_image_btn, video_info_to_image_guide_btn, video_info_to_image_mask_btn,
NAG_col, speakers_locations_row, embedded_guidance_row, guidance_phases_row, guidance_row, resolution_group, cfg_free_guidance_col, control_net_weights_row, image_mode_tabs, NAG_col, speakers_locations_row, embedded_guidance_row, guidance_phases_row, guidance_row, resolution_group, cfg_free_guidance_col, control_net_weights_row, image_mode_tabs,
min_frames_if_references_col, video_prompt_type_alignment, prompt_enhancer_btn] # presets_column, min_frames_if_references_col, video_prompt_type_alignment, prompt_enhancer_btn] + image_start_extra + image_end_extra + image_refs_extra # presets_column,
if update_form: if update_form:
locals_dict = locals() locals_dict = locals()
gen_inputs = [state_dict if k=="state" else locals_dict[k] for k in inputs_names] + [state_dict] + extra_inputs gen_inputs = [state_dict if k=="state" else locals_dict[k] for k in inputs_names] + [state_dict] + extra_inputs
@ -7953,11 +7964,11 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
guidance_phases.change(fn=change_guidance_phases, inputs= [state, guidance_phases], outputs =[model_switch_phase, guidance_phases_row, switch_threshold, switch_threshold2, guidance2_scale, guidance3_scale ]) guidance_phases.change(fn=change_guidance_phases, inputs= [state, guidance_phases], outputs =[model_switch_phase, guidance_phases_row, switch_threshold, switch_threshold2, guidance2_scale, guidance3_scale ])
audio_prompt_type_remux.change(fn=refresh_audio_prompt_type_remux, inputs=[state, audio_prompt_type, audio_prompt_type_remux], outputs=[audio_prompt_type]) audio_prompt_type_remux.change(fn=refresh_audio_prompt_type_remux, inputs=[state, audio_prompt_type, audio_prompt_type_remux], outputs=[audio_prompt_type])
audio_prompt_type_sources.change(fn=refresh_audio_prompt_type_sources, inputs=[state, audio_prompt_type, audio_prompt_type_sources], outputs=[audio_prompt_type, audio_guide, audio_guide2, speakers_locations_row]) audio_prompt_type_sources.change(fn=refresh_audio_prompt_type_sources, inputs=[state, audio_prompt_type, audio_prompt_type_sources], outputs=[audio_prompt_type, audio_guide, audio_guide2, speakers_locations_row])
image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start, image_end, video_source, keep_frames_video_source] ) image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start_row, image_end_row, video_source, keep_frames_video_source] )
# video_prompt_video_guide_trigger.change(fn=refresh_video_prompt_video_guide_trigger, inputs=[state, video_prompt_type, video_prompt_video_guide_trigger], outputs=[video_prompt_type, video_prompt_type_video_guide, video_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, mask_expand]) # video_prompt_video_guide_trigger.change(fn=refresh_video_prompt_video_guide_trigger, inputs=[state, video_prompt_type, video_prompt_video_guide_trigger], outputs=[video_prompt_type, video_prompt_type_video_guide, video_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, mask_expand])
video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_images_ref, image_refs_relative_size, frames_positions,video_guide_outpainting_col]) video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs_row, remove_background_images_ref, image_refs_relative_size, frames_positions,video_guide_outpainting_col])
video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, video_prompt_type, video_prompt_type_video_guide, image_mode], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, mask_expand]) video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, video_prompt_type, video_prompt_type_video_guide, image_mode], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, mask_expand])
video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide_alt, inputs = [state, video_prompt_type, video_prompt_type_video_guide_alt], outputs = [video_prompt_type, video_guide, image_refs ]) video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide_alt, inputs = [state, video_prompt_type, video_prompt_type_video_guide_alt], outputs = [video_prompt_type, video_guide, image_refs_row, denoising_strength ])
video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [state, video_prompt_type, video_prompt_type_video_mask, image_mode], outputs = [video_prompt_type, video_mask, image_mask, mask_expand]) video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [state, video_prompt_type, video_prompt_type_video_mask, image_mode], outputs = [video_prompt_type, video_mask, image_mask, mask_expand])
video_prompt_type_alignment.input(fn=refresh_video_prompt_type_alignment, inputs = [state, video_prompt_type, video_prompt_type_alignment], outputs = [video_prompt_type]) video_prompt_type_alignment.input(fn=refresh_video_prompt_type_alignment, inputs = [state, video_prompt_type, video_prompt_type_alignment], outputs = [video_prompt_type])
multi_prompts_gen_type.select(fn=refresh_prompt_labels, inputs=[multi_prompts_gen_type, image_mode], outputs=[prompt, wizard_prompt]) multi_prompts_gen_type.select(fn=refresh_prompt_labels, inputs=[multi_prompts_gen_type, image_mode], outputs=[prompt, wizard_prompt])
@ -8348,6 +8359,7 @@ def generate_configuration_tab(state, blocks, header, model_family, model_choice
("Xformers" + check("xformers")+ ": good quality - requires additional install (usually complex, may consume less VRAM to set up on Windows without WSL)", "xformers"), ("Xformers" + check("xformers")+ ": good quality - requires additional install (usually complex, may consume less VRAM to set up on Windows without WSL)", "xformers"),
("Sage" + check("sage")+ ": 30% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage"), ("Sage" + check("sage")+ ": 30% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage"),
("Sage2/2++" + check("sage2")+ ": 40% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage2"), ("Sage2/2++" + check("sage2")+ ": 40% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage2"),
("Sage3" + check("sage3")+ ": x2 faster but worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage3"),
], ],
value= attention_mode, value= attention_mode,
label="Attention Type", label="Attention Type",
@ -8663,7 +8675,7 @@ def generate_about_tab():
gr.Markdown("- <B>Blackforest Labs</B> for the innovative Flux image generators (https://github.com/black-forest-labs/flux)") gr.Markdown("- <B>Blackforest Labs</B> for the innovative Flux image generators (https://github.com/black-forest-labs/flux)")
gr.Markdown("- <B>Alibaba Qwen Team</B> for their state of the art Qwen Image generators (https://github.com/QwenLM/Qwen-Image)") gr.Markdown("- <B>Alibaba Qwen Team</B> for their state of the art Qwen Image generators (https://github.com/QwenLM/Qwen-Image)")
gr.Markdown("- <B>Lightricks</B> for their super fast LTX Video models (https://github.com/Lightricks/LTX-Video)") gr.Markdown("- <B>Lightricks</B> for their super fast LTX Video models (https://github.com/Lightricks/LTX-Video)")
gr.Markdown("- <B>Hugging Face</B> for the providing hosting for the models and developing must have open source libraries such as Tranformers, Diffusers, Accelerate and Gradio (https://huggingface.co/)") gr.Markdown("- <B>Hugging Face</B> for providing hosting for the models and developing must have open source libraries such as Tranformers, Diffusers, Accelerate and Gradio (https://huggingface.co/)")
gr.Markdown("<BR>Huge acknowledgments to these great open source projects used in WanGP:") gr.Markdown("<BR>Huge acknowledgments to these great open source projects used in WanGP:")
gr.Markdown("- <B>Rife</B>: temporal upsampler (https://github.com/hzwer/ECCV2022-RIFE)") gr.Markdown("- <B>Rife</B>: temporal upsampler (https://github.com/hzwer/ECCV2022-RIFE)")
gr.Markdown("- <B>DwPose</B>: Open Pose extractor (https://github.com/IDEA-Research/DWPose)") gr.Markdown("- <B>DwPose</B>: Open Pose extractor (https://github.com/IDEA-Research/DWPose)")
@ -8672,7 +8684,7 @@ def generate_about_tab():
gr.Markdown("- <B>Pyannote</B>: speaker diarization (https://github.com/pyannote/pyannote-audio)") gr.Markdown("- <B>Pyannote</B>: speaker diarization (https://github.com/pyannote/pyannote-audio)")
gr.Markdown("<BR>Special thanks to the following people for their support:") gr.Markdown("<BR>Special thanks to the following people for their support:")
gr.Markdown("- <B>Cocktail Peanuts</B> : QA and simple installation via Pinokio.computer") gr.Markdown("- <B>Cocktail Peanuts</B> : QA dpand simple installation via Pinokio.computer")
gr.Markdown("- <B>Tophness</B> : created (former) multi tabs and queuing frameworks") gr.Markdown("- <B>Tophness</B> : created (former) multi tabs and queuing frameworks")
gr.Markdown("- <B>AmericanPresidentJimmyCarter</B> : added original support for Skip Layer Guidance") gr.Markdown("- <B>AmericanPresidentJimmyCarter</B> : added original support for Skip Layer Guidance")
gr.Markdown("- <B>Remade_AI</B> : for their awesome Loras collection") gr.Markdown("- <B>Remade_AI</B> : for their awesome Loras collection")