mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-03 22:04:21 +00:00 
			
		
		
		
	pain reliever
This commit is contained in:
		
							parent
							
								
									6490af145a
								
							
						
					
					
						commit
						898b542cc6
					
				
							
								
								
									
										11
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										11
									
								
								README.md
									
									
									
									
									
								
							@ -20,6 +20,17 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
 | 
			
		||||
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
 | 
			
		||||
 | 
			
		||||
## 🔥 Latest Updates : 
 | 
			
		||||
### September 2 2025: WanGP v8.3 - At last the pain stops
 | 
			
		||||
 | 
			
		||||
- This single new feature should give you the strength to face all the potential bugs of this new release:
 | 
			
		||||
**Images Management (multiple additions or deletions, reordering) for Start Images / End Images / Images References.**  
 | 
			
		||||
 | 
			
		||||
- Unofficial **Video to Video (Non Sparse this time) for InfinitTalk**. Use the Strength Noise slider to decide how much motion of the original window you want to keep. I have also *greatly reduced the VRAM requirements for Multitalk / Infinitalk* (especially the multispeakers version & when generating at 1080p). 
 | 
			
		||||
 | 
			
		||||
- **Experimental Sage 3 Attention support**: you will need to deserve this one, first you need a Blackwell GPU (RTX50xx), then you will have to compile Sage 3, install it and cross your fingers that there isn't any crash.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
### August 29 2025: WanGP v8.21 -  Here Goes Your Weekend
 | 
			
		||||
 | 
			
		||||
- **InfiniteTalk Video to Video**: this feature can be used for Video Dubbing. Keep in mind that it is a *Sparse Video to Video*, that is internally only image is used by Sliding Window. However thanks to the new *Smooth Transition* mode, each new clip is connected to the previous and all the camera work is done by InfiniteTalk. If you dont get any transition, increase the number of frames of a Sliding Window (81 frames recommended)
 | 
			
		||||
 | 
			
		||||
@ -53,7 +53,7 @@ class family_handler():
 | 
			
		||||
 | 
			
		||||
        if base_model_type in ["hunyuan_avatar"]: extra_model_def["no_background_removal"] = True
 | 
			
		||||
 | 
			
		||||
        if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_audio", "hunyuan_avatar"]:
 | 
			
		||||
        if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar"]:
 | 
			
		||||
            extra_model_def["one_image_ref_needed"] = True
 | 
			
		||||
 | 
			
		||||
        return extra_model_def
 | 
			
		||||
 | 
			
		||||
@ -14,7 +14,7 @@ from torch.nn.attention.flex_attention import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@lru_cache
 | 
			
		||||
# @lru_cache
 | 
			
		||||
def create_block_mask_cached(score_mod, B, H, M, N, device="cuda", _compile=False):
 | 
			
		||||
    block_mask = create_block_mask(score_mod, B, H, M, N, device=device, _compile=_compile)
 | 
			
		||||
    return block_mask
 | 
			
		||||
 | 
			
		||||
@ -204,7 +204,7 @@ class QwenEmbedRope(nn.Module):
 | 
			
		||||
            frame, height, width = fhw
 | 
			
		||||
            rope_key = f"{idx}_{height}_{width}"
 | 
			
		||||
 | 
			
		||||
            if not torch.compiler.is_compiling():
 | 
			
		||||
            if not torch.compiler.is_compiling() and False:
 | 
			
		||||
                if rope_key not in self.rope_cache:
 | 
			
		||||
                    self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx)
 | 
			
		||||
                video_freq = self.rope_cache[rope_key]
 | 
			
		||||
@ -224,7 +224,6 @@ class QwenEmbedRope(nn.Module):
 | 
			
		||||
 | 
			
		||||
        return vid_freqs, txt_freqs
 | 
			
		||||
 | 
			
		||||
    @functools.lru_cache(maxsize=None)
 | 
			
		||||
    def _compute_video_freqs(self, frame, height, width, idx=0):
 | 
			
		||||
        seq_lens = frame * height * width
 | 
			
		||||
        freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
 | 
			
		||||
 | 
			
		||||
@ -19,7 +19,8 @@ from PIL import Image
 | 
			
		||||
import torchvision.transforms.functional as TF
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from .distributed.fsdp import shard_model
 | 
			
		||||
from .modules.model import WanModel, clear_caches
 | 
			
		||||
from .modules.model import WanModel
 | 
			
		||||
from mmgp.offload import get_cache, clear_caches
 | 
			
		||||
from .modules.t5 import T5EncoderModel
 | 
			
		||||
from .modules.vae import WanVAE
 | 
			
		||||
from .modules.vae2_2 import Wan2_2_VAE
 | 
			
		||||
@ -496,6 +497,8 @@ class WanAny2V:
 | 
			
		||||
        text_len = self.model.text_len
 | 
			
		||||
        context = torch.cat([context, context.new_zeros(text_len -context.size(0), context.size(1)) ]).unsqueeze(0) 
 | 
			
		||||
        context_null = torch.cat([context_null, context_null.new_zeros(text_len -context_null.size(0), context_null.size(1)) ]).unsqueeze(0) 
 | 
			
		||||
        if input_video is not None: height, width = input_video.shape[-2:]
 | 
			
		||||
 | 
			
		||||
        # NAG_prompt =  "static, low resolution, blurry"
 | 
			
		||||
        # context_NAG = self.text_encoder([NAG_prompt], self.device)[0]
 | 
			
		||||
        # context_NAG = context_NAG.to(self.dtype)
 | 
			
		||||
@ -530,9 +533,10 @@ class WanAny2V:
 | 
			
		||||
            if image_start is None:
 | 
			
		||||
                if infinitetalk:
 | 
			
		||||
                    if input_frames is not None:
 | 
			
		||||
                        image_ref = input_frames[:, -1]
 | 
			
		||||
                        if input_video is None: input_video = input_frames[:, -1:] 
 | 
			
		||||
                        image_ref = input_frames[:, 0]
 | 
			
		||||
                        if input_video is None: input_video = input_frames[:, 0:1]
 | 
			
		||||
                        new_shot = "Q" in video_prompt_type
 | 
			
		||||
                        denoising_strength = 0.5
 | 
			
		||||
                    else:
 | 
			
		||||
                        if pre_video_frame is None:
 | 
			
		||||
                            new_shot = True
 | 
			
		||||
@ -888,6 +892,7 @@ class WanAny2V:
 | 
			
		||||
                    latents[:, :, :extended_overlapped_latents.shape[2]]   = extended_overlapped_latents 
 | 
			
		||||
                else:
 | 
			
		||||
                    latent_noise_factor = t / 1000
 | 
			
		||||
                    latents[:, :, :extended_overlapped_latents.shape[2]]   = extended_overlapped_latents  * (1.0 - latent_noise_factor) + torch.randn_like(extended_overlapped_latents ) * latent_noise_factor 
 | 
			
		||||
                if vace:
 | 
			
		||||
                    overlap_noise_factor = overlap_noise / 1000 
 | 
			
		||||
                    for zz in z:
 | 
			
		||||
 | 
			
		||||
@ -12,6 +12,7 @@ from diffusers.models.modeling_utils import ModelMixin
 | 
			
		||||
import numpy as np
 | 
			
		||||
from typing import Union,Optional
 | 
			
		||||
from mmgp import offload
 | 
			
		||||
from mmgp.offload import get_cache, clear_caches
 | 
			
		||||
from shared.attention import pay_attention
 | 
			
		||||
from torch.backends.cuda import sdp_kernel
 | 
			
		||||
from ..multitalk.multitalk_utils import get_attn_map_with_target
 | 
			
		||||
@ -19,22 +20,6 @@ from ..multitalk.multitalk_utils import get_attn_map_with_target
 | 
			
		||||
__all__ = ['WanModel']
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_cache(cache_name):
 | 
			
		||||
    all_cache = offload.shared_state.get("_cache",  None)
 | 
			
		||||
    if all_cache is None:
 | 
			
		||||
        all_cache = {}
 | 
			
		||||
        offload.shared_state["_cache"]=  all_cache
 | 
			
		||||
    cache = offload.shared_state.get(cache_name, None)
 | 
			
		||||
    if cache is None:
 | 
			
		||||
        cache = {}
 | 
			
		||||
        offload.shared_state[cache_name] = cache
 | 
			
		||||
    return cache
 | 
			
		||||
 | 
			
		||||
def clear_caches():
 | 
			
		||||
    all_cache = offload.shared_state.get("_cache",  None)
 | 
			
		||||
    if all_cache is not None:
 | 
			
		||||
        all_cache.clear()
 | 
			
		||||
 | 
			
		||||
def sinusoidal_embedding_1d(dim, position):
 | 
			
		||||
    # preprocess
 | 
			
		||||
    assert dim % 2 == 0
 | 
			
		||||
@ -579,19 +564,23 @@ class WanAttentionBlock(nn.Module):
 | 
			
		||||
            y = self.norm_x(x)
 | 
			
		||||
            y = y.to(attention_dtype)
 | 
			
		||||
            if ref_images_count == 0:
 | 
			
		||||
                x += self.audio_cross_attn(y, encoder_hidden_states=multitalk_audio, shape=grid_sizes, x_ref_attn_map=x_ref_attn_map)
 | 
			
		||||
                ylist= [y]
 | 
			
		||||
                del y
 | 
			
		||||
                x += self.audio_cross_attn(ylist, encoder_hidden_states=multitalk_audio, shape=grid_sizes, x_ref_attn_map=x_ref_attn_map)
 | 
			
		||||
            else:
 | 
			
		||||
                y_shape = y.shape
 | 
			
		||||
                y = y.reshape(y_shape[0], grid_sizes[0], -1)
 | 
			
		||||
                y = y[:, ref_images_count:]
 | 
			
		||||
                y = y.reshape(y_shape[0], -1, y_shape[-1])
 | 
			
		||||
                grid_sizes_alt = [grid_sizes[0]-ref_images_count, *grid_sizes[1:]]
 | 
			
		||||
                y = self.audio_cross_attn(y, encoder_hidden_states=multitalk_audio, shape=grid_sizes_alt, x_ref_attn_map=x_ref_attn_map)
 | 
			
		||||
                ylist= [y]
 | 
			
		||||
                y = None
 | 
			
		||||
                y = self.audio_cross_attn(ylist, encoder_hidden_states=multitalk_audio, shape=grid_sizes_alt, x_ref_attn_map=x_ref_attn_map)
 | 
			
		||||
                y = y.reshape(y_shape[0], grid_sizes[0]-ref_images_count, -1)
 | 
			
		||||
                x = x.reshape(y_shape[0], grid_sizes[0], -1)
 | 
			
		||||
                x[:, ref_images_count:] += y
 | 
			
		||||
                x = x.reshape(y_shape[0], -1, y_shape[-1])
 | 
			
		||||
            del y
 | 
			
		||||
                del y
 | 
			
		||||
 | 
			
		||||
        y = self.norm2(x)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -221,13 +221,16 @@ class SingleStreamAttention(nn.Module):
 | 
			
		||||
        self.add_q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
 | 
			
		||||
        self.add_k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None, enable_sp=False, kv_seq=None) -> torch.Tensor:
 | 
			
		||||
    def forward(self, xlist: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None, enable_sp=False, kv_seq=None) -> torch.Tensor:
 | 
			
		||||
        N_t, N_h, N_w = shape
 | 
			
		||||
 | 
			
		||||
        x = xlist[0]
 | 
			
		||||
        xlist.clear()
 | 
			
		||||
        x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t)
 | 
			
		||||
        # get q for hidden_state
 | 
			
		||||
        B, N, C = x.shape
 | 
			
		||||
        q = self.q_linear(x)
 | 
			
		||||
        del x
 | 
			
		||||
        q_shape = (B, N, self.num_heads, self.head_dim)
 | 
			
		||||
        q = q.view(q_shape).permute((0, 2, 1, 3))
 | 
			
		||||
 | 
			
		||||
@ -247,9 +250,6 @@ class SingleStreamAttention(nn.Module):
 | 
			
		||||
        q = rearrange(q, "B H M K -> B M H K")
 | 
			
		||||
        encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
 | 
			
		||||
        encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
 | 
			
		||||
 | 
			
		||||
        attn_bias = None
 | 
			
		||||
        # x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=attn_bias, op=None,)
 | 
			
		||||
        qkv_list = [q, encoder_k, encoder_v]
 | 
			
		||||
        q = encoder_k = encoder_v = None
 | 
			
		||||
        x = pay_attention(qkv_list)
 | 
			
		||||
@ -302,7 +302,7 @@ class SingleStreamMutiAttention(SingleStreamAttention):
 | 
			
		||||
        self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim)
 | 
			
		||||
 | 
			
		||||
    def forward(self, 
 | 
			
		||||
                x: torch.Tensor, 
 | 
			
		||||
                xlist: torch.Tensor, 
 | 
			
		||||
                encoder_hidden_states: torch.Tensor, 
 | 
			
		||||
                shape=None, 
 | 
			
		||||
                x_ref_attn_map=None,
 | 
			
		||||
@ -310,14 +310,17 @@ class SingleStreamMutiAttention(SingleStreamAttention):
 | 
			
		||||
        
 | 
			
		||||
        encoder_hidden_states = encoder_hidden_states.squeeze(0)
 | 
			
		||||
        if x_ref_attn_map == None:
 | 
			
		||||
            return super().forward(x, encoder_hidden_states, shape)
 | 
			
		||||
            return super().forward(xlist, encoder_hidden_states, shape)
 | 
			
		||||
 | 
			
		||||
        N_t, _, _ = shape 
 | 
			
		||||
        x = xlist[0]
 | 
			
		||||
        xlist.clear()
 | 
			
		||||
        x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t) 
 | 
			
		||||
 | 
			
		||||
        # get q for hidden_state
 | 
			
		||||
        B, N, C = x.shape
 | 
			
		||||
        q = self.q_linear(x) 
 | 
			
		||||
        del x
 | 
			
		||||
        q_shape = (B, N, self.num_heads, self.head_dim) 
 | 
			
		||||
        q = q.view(q_shape).permute((0, 2, 1, 3))
 | 
			
		||||
 | 
			
		||||
@ -339,7 +342,9 @@ class SingleStreamMutiAttention(SingleStreamAttention):
 | 
			
		||||
        normalized_pos = normalized_map[range(x_ref_attn_map.size(1)), max_indices] # N 
 | 
			
		||||
 | 
			
		||||
        q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
 | 
			
		||||
        q = self.rope_1d(q, normalized_pos)
 | 
			
		||||
        qlist = [q]
 | 
			
		||||
        del q
 | 
			
		||||
        q = self.rope_1d(qlist, normalized_pos, "q")
 | 
			
		||||
        q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
 | 
			
		||||
 | 
			
		||||
        _, N_a, _ = encoder_hidden_states.shape 
 | 
			
		||||
@ -347,7 +352,7 @@ class SingleStreamMutiAttention(SingleStreamAttention):
 | 
			
		||||
        encoder_kv_shape = (B, N_a, 2, self.num_heads, self.head_dim)
 | 
			
		||||
        encoder_kv = encoder_kv.view(encoder_kv_shape).permute((2, 0, 3, 1, 4)) 
 | 
			
		||||
        encoder_k, encoder_v = encoder_kv.unbind(0) 
 | 
			
		||||
 | 
			
		||||
        del encoder_kv
 | 
			
		||||
        if self.qk_norm:
 | 
			
		||||
            encoder_k = self.add_k_norm(encoder_k)
 | 
			
		||||
 | 
			
		||||
@ -356,13 +361,14 @@ class SingleStreamMutiAttention(SingleStreamAttention):
 | 
			
		||||
        per_frame[per_frame.size(0)//2:] = (self.rope_h2[0] + self.rope_h2[1]) / 2
 | 
			
		||||
        encoder_pos = torch.concat([per_frame]*N_t, dim=0)
 | 
			
		||||
        encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
 | 
			
		||||
        encoder_k = self.rope_1d(encoder_k, encoder_pos)
 | 
			
		||||
        enclist = [encoder_k]
 | 
			
		||||
        del encoder_k
 | 
			
		||||
        encoder_k = self.rope_1d(enclist, encoder_pos, "encoder_k")
 | 
			
		||||
        encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
 | 
			
		||||
 
 | 
			
		||||
        q = rearrange(q, "B H M K -> B M H K")
 | 
			
		||||
        encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
 | 
			
		||||
        encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
 | 
			
		||||
        # x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=None, op=None,)
 | 
			
		||||
        qkv_list = [q, encoder_k, encoder_v]
 | 
			
		||||
        q = encoder_k = encoder_v = None
 | 
			
		||||
        x = pay_attention(qkv_list)
 | 
			
		||||
 | 
			
		||||
@ -16,7 +16,7 @@ import torchvision
 | 
			
		||||
import binascii
 | 
			
		||||
import os.path as osp
 | 
			
		||||
from skimage import color
 | 
			
		||||
 | 
			
		||||
from mmgp.offload import get_cache, clear_caches
 | 
			
		||||
 | 
			
		||||
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
 | 
			
		||||
ASPECT_RATIO_627 = {
 | 
			
		||||
@ -73,42 +73,70 @@ def normalize_and_scale(column, source_range, target_range, epsilon=1e-8):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# @torch.compile
 | 
			
		||||
def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, ref_images_count, mode='mean', attn_bias=None):
 | 
			
		||||
    
 | 
			
		||||
    ref_k = ref_k.to(visual_q.dtype).to(visual_q.device)
 | 
			
		||||
def calculate_x_ref_attn_map_per_head(visual_q, ref_k, ref_target_masks, ref_images_count, attn_bias=None):
 | 
			
		||||
    dtype = visual_q.dtype
 | 
			
		||||
    ref_k = ref_k.to(dtype).to(visual_q.device)
 | 
			
		||||
    scale = 1.0 / visual_q.shape[-1] ** 0.5
 | 
			
		||||
    visual_q = visual_q * scale
 | 
			
		||||
    visual_q = visual_q.transpose(1, 2)
 | 
			
		||||
    ref_k = ref_k.transpose(1, 2)
 | 
			
		||||
    visual_q_shape = visual_q.shape
 | 
			
		||||
    visual_q = visual_q.view(-1, visual_q_shape[-1] )
 | 
			
		||||
    number_chunks = visual_q_shape[-2]*ref_k.shape[-2] /  53090100 * 2
 | 
			
		||||
    chunk_size =  int(visual_q_shape[-2] / number_chunks)
 | 
			
		||||
    chunks =torch.split(visual_q, chunk_size)
 | 
			
		||||
    maps_lists = [ [] for _ in ref_target_masks]  
 | 
			
		||||
    for q_chunk  in chunks:
 | 
			
		||||
        attn = q_chunk @ ref_k.transpose(-2, -1)
 | 
			
		||||
        x_ref_attn_map_source = attn.softmax(-1) # B, H, x_seqlens, ref_seqlens
 | 
			
		||||
        del attn
 | 
			
		||||
        ref_target_masks = ref_target_masks.to(dtype)
 | 
			
		||||
        x_ref_attn_map_source = x_ref_attn_map_source.to(dtype)
 | 
			
		||||
 | 
			
		||||
        for class_idx, ref_target_mask in enumerate(ref_target_masks):
 | 
			
		||||
            ref_target_mask = ref_target_mask[None, None, None, ...]
 | 
			
		||||
            x_ref_attnmap = x_ref_attn_map_source * ref_target_mask
 | 
			
		||||
            x_ref_attnmap = x_ref_attnmap.sum(-1) / ref_target_mask.sum() # B, H, x_seqlens, ref_seqlens --> B, H, x_seqlens
 | 
			
		||||
            maps_lists[class_idx].append(x_ref_attnmap)
 | 
			
		||||
 | 
			
		||||
        del x_ref_attn_map_source
 | 
			
		||||
 | 
			
		||||
    x_ref_attn_maps = []
 | 
			
		||||
    for class_idx, maps_list in enumerate(maps_lists):
 | 
			
		||||
        attn_map_fuse = torch.concat(maps_list, dim= -1)
 | 
			
		||||
        attn_map_fuse = attn_map_fuse.view(1, visual_q_shape[1], -1).squeeze(1)
 | 
			
		||||
        x_ref_attn_maps.append( attn_map_fuse )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    return torch.concat(x_ref_attn_maps, dim=0)
 | 
			
		||||
 | 
			
		||||
def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, ref_images_count):
 | 
			
		||||
    dtype = visual_q.dtype
 | 
			
		||||
    ref_k = ref_k.to(dtype).to(visual_q.device)
 | 
			
		||||
    scale = 1.0 / visual_q.shape[-1] ** 0.5
 | 
			
		||||
    visual_q = visual_q * scale
 | 
			
		||||
    visual_q = visual_q.transpose(1, 2)
 | 
			
		||||
    ref_k = ref_k.transpose(1, 2)
 | 
			
		||||
    attn = visual_q @ ref_k.transpose(-2, -1)
 | 
			
		||||
 | 
			
		||||
    if attn_bias is not None: attn += attn_bias
 | 
			
		||||
 | 
			
		||||
    x_ref_attn_map_source = attn.softmax(-1) # B, H, x_seqlens, ref_seqlens
 | 
			
		||||
 | 
			
		||||
    del attn
 | 
			
		||||
    x_ref_attn_maps = []
 | 
			
		||||
    ref_target_masks = ref_target_masks.to(visual_q.dtype)
 | 
			
		||||
    x_ref_attn_map_source = x_ref_attn_map_source.to(visual_q.dtype)
 | 
			
		||||
    ref_target_masks = ref_target_masks.to(dtype)
 | 
			
		||||
    x_ref_attn_map_source = x_ref_attn_map_source.to(dtype)
 | 
			
		||||
 | 
			
		||||
    for class_idx, ref_target_mask in enumerate(ref_target_masks):
 | 
			
		||||
        ref_target_mask = ref_target_mask[None, None, None, ...]
 | 
			
		||||
        x_ref_attnmap = x_ref_attn_map_source * ref_target_mask
 | 
			
		||||
        x_ref_attnmap = x_ref_attnmap.sum(-1) / ref_target_mask.sum() # B, H, x_seqlens, ref_seqlens --> B, H, x_seqlens
 | 
			
		||||
        x_ref_attnmap = x_ref_attnmap.permute(0, 2, 1) # B, x_seqlens, H
 | 
			
		||||
       
 | 
			
		||||
        if mode == 'mean':
 | 
			
		||||
            x_ref_attnmap = x_ref_attnmap.mean(-1) # B, x_seqlens
 | 
			
		||||
        elif mode == 'max':
 | 
			
		||||
            x_ref_attnmap = x_ref_attnmap.max(-1) # B, x_seqlens
 | 
			
		||||
        
 | 
			
		||||
        x_ref_attnmap = x_ref_attnmap.mean(-1) # B, x_seqlens       (mean of heads)
 | 
			
		||||
        x_ref_attn_maps.append(x_ref_attnmap)
 | 
			
		||||
    
 | 
			
		||||
    del attn
 | 
			
		||||
    del x_ref_attn_map_source
 | 
			
		||||
 | 
			
		||||
    return torch.concat(x_ref_attn_maps, dim=0)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=10, ref_images_count = 0):
 | 
			
		||||
    """Args:
 | 
			
		||||
        query (torch.tensor): B M H K
 | 
			
		||||
@ -120,6 +148,11 @@ def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, spli
 | 
			
		||||
    N_t, N_h, N_w = shape
 | 
			
		||||
    
 | 
			
		||||
    x_seqlens = N_h * N_w
 | 
			
		||||
    if x_seqlens <= 1508:
 | 
			
		||||
        split_num = 10 # 540p
 | 
			
		||||
    else:
 | 
			
		||||
        split_num = 20 if x_seqlens <= 3600 else 40 # 720p / 1080p
 | 
			
		||||
 | 
			
		||||
    ref_k     = ref_k[:, :x_seqlens]
 | 
			
		||||
    if ref_images_count > 0 :
 | 
			
		||||
        visual_q_shape = visual_q.shape 
 | 
			
		||||
@ -133,9 +166,14 @@ def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, spli
 | 
			
		||||
 | 
			
		||||
    split_chunk = heads // split_num
 | 
			
		||||
    
 | 
			
		||||
    for i in range(split_num):
 | 
			
		||||
        x_ref_attn_maps_perhead = calculate_x_ref_attn_map(visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_target_masks, ref_images_count)
 | 
			
		||||
        x_ref_attn_maps += x_ref_attn_maps_perhead
 | 
			
		||||
    if split_chunk == 1:
 | 
			
		||||
        for i in range(split_num):
 | 
			
		||||
            x_ref_attn_maps_perhead = calculate_x_ref_attn_map_per_head(visual_q[:, :, i:(i+1), :], ref_k[:, :, i:(i+1), :], ref_target_masks, ref_images_count)
 | 
			
		||||
            x_ref_attn_maps += x_ref_attn_maps_perhead
 | 
			
		||||
    else:
 | 
			
		||||
        for i in range(split_num):
 | 
			
		||||
            x_ref_attn_maps_perhead = calculate_x_ref_attn_map(visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_target_masks, ref_images_count)
 | 
			
		||||
            x_ref_attn_maps += x_ref_attn_maps_perhead
 | 
			
		||||
    
 | 
			
		||||
    x_ref_attn_maps /= split_num
 | 
			
		||||
    return x_ref_attn_maps
 | 
			
		||||
@ -158,7 +196,6 @@ class RotaryPositionalEmbedding1D(nn.Module):
 | 
			
		||||
        self.base = 10000
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @lru_cache(maxsize=32)
 | 
			
		||||
    def precompute_freqs_cis_1d(self, pos_indices):
 | 
			
		||||
 | 
			
		||||
        freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim))
 | 
			
		||||
@ -167,7 +204,7 @@ class RotaryPositionalEmbedding1D(nn.Module):
 | 
			
		||||
        freqs = repeat(freqs, "... n -> ... (n r)", r=2)
 | 
			
		||||
        return freqs
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, pos_indices):
 | 
			
		||||
    def forward(self, qlist, pos_indices, cache_entry = None):
 | 
			
		||||
        """1D RoPE.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
@ -176,16 +213,26 @@ class RotaryPositionalEmbedding1D(nn.Module):
 | 
			
		||||
        Returns:
 | 
			
		||||
            query with the same shape as input.
 | 
			
		||||
        """
 | 
			
		||||
        freqs_cis = self.precompute_freqs_cis_1d(pos_indices)
 | 
			
		||||
 | 
			
		||||
        x_ = x.float()
 | 
			
		||||
 | 
			
		||||
        freqs_cis = freqs_cis.float().to(x.device)
 | 
			
		||||
        cos, sin = freqs_cis.cos(), freqs_cis.sin()
 | 
			
		||||
        cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
 | 
			
		||||
        x_ = (x_ * cos) + (rotate_half(x_) * sin)
 | 
			
		||||
 | 
			
		||||
        return x_.type_as(x)
 | 
			
		||||
        xq= qlist[0]
 | 
			
		||||
        qlist.clear()
 | 
			
		||||
        cache = get_cache("multitalk_rope")
 | 
			
		||||
        freqs_cis= cache.get(cache_entry, None)
 | 
			
		||||
        if freqs_cis is None:
 | 
			
		||||
            freqs_cis = cache[cache_entry] = self.precompute_freqs_cis_1d(pos_indices)
 | 
			
		||||
        cos, sin = freqs_cis.cos().unsqueeze(0).unsqueeze(0), freqs_cis.sin().unsqueeze(0).unsqueeze(0)
 | 
			
		||||
        # cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
 | 
			
		||||
        # real * cos - imag * sin
 | 
			
		||||
        # imag * cos + real * sin
 | 
			
		||||
        xq_dtype = xq.dtype
 | 
			
		||||
        xq_out = xq.to(torch.float)
 | 
			
		||||
        xq = None        
 | 
			
		||||
        xq_rot = rotate_half(xq_out)
 | 
			
		||||
        xq_out *= cos
 | 
			
		||||
        xq_rot *= sin
 | 
			
		||||
        xq_out += xq_rot
 | 
			
		||||
        del xq_rot
 | 
			
		||||
        xq_out = xq_out.to(xq_dtype)
 | 
			
		||||
        return xq_out 
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,6 @@
 | 
			
		||||
import torch
 | 
			
		||||
import numpy as np
 | 
			
		||||
import gradio as gr
 | 
			
		||||
 | 
			
		||||
def test_class_i2v(base_model_type):    
 | 
			
		||||
    return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p",  "fantasy",  "multitalk", "infinitetalk", "i2v_2_2_multitalk" ]
 | 
			
		||||
@ -116,6 +117,11 @@ class family_handler():
 | 
			
		||||
            extra_model_def["no_background_removal"] = True
 | 
			
		||||
            # extra_model_def["at_least_one_image_ref_needed"] = True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        if base_model_type in ["phantom_1.3B", "phantom_14B"]: 
 | 
			
		||||
            extra_model_def["one_image_ref_needed"] = True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        return extra_model_def
 | 
			
		||||
        
 | 
			
		||||
    @staticmethod
 | 
			
		||||
@ -235,6 +241,14 @@ class family_handler():
 | 
			
		||||
                if "I" in video_prompt_type:
 | 
			
		||||
                    video_prompt_type = video_prompt_type.replace("KI", "QKI")
 | 
			
		||||
                    ui_defaults["video_prompt_type"] = video_prompt_type 
 | 
			
		||||
 | 
			
		||||
        if settings_version < 2.28:
 | 
			
		||||
            if base_model_type in "infinitetalk":
 | 
			
		||||
                video_prompt_type = ui_defaults.get("video_prompt_type", "")
 | 
			
		||||
                if "U" in video_prompt_type:
 | 
			
		||||
                    video_prompt_type = video_prompt_type.replace("U", "RU")
 | 
			
		||||
                    ui_defaults["video_prompt_type"] = video_prompt_type 
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def update_default_settings(base_model_type, model_def, ui_defaults):
 | 
			
		||||
        ui_defaults.update({
 | 
			
		||||
@ -309,3 +323,11 @@ class family_handler():
 | 
			
		||||
            if ("V" in image_prompt_type or "L" in image_prompt_type) and image_refs is None:
 | 
			
		||||
                video_prompt_type = video_prompt_type.replace("I", "").replace("K","")
 | 
			
		||||
                inputs["video_prompt_type"] = video_prompt_type 
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        if base_model_type in ["vace_standin_14B"]:
 | 
			
		||||
            image_refs = inputs["image_refs"]
 | 
			
		||||
            video_prompt_type = inputs["video_prompt_type"]
 | 
			
		||||
            if image_refs is not None and  len(image_refs) == 1 and "K" in video_prompt_type:
 | 
			
		||||
                gr.Info("Warning, Ref Image for Standin Missing: if 'Landscape and then People or Objects' is selected beside the Landscape Image Ref there should be another Image Ref that contains a Face.")
 | 
			
		||||
                    
 | 
			
		||||
 | 
			
		||||
@ -20,7 +20,8 @@ soundfile
 | 
			
		||||
mutagen
 | 
			
		||||
pyloudnorm
 | 
			
		||||
librosa==0.11.0
 | 
			
		||||
 | 
			
		||||
speechbrain==1.0.3
 | 
			
		||||
 
 | 
			
		||||
# UI & interaction
 | 
			
		||||
gradio==5.23.0
 | 
			
		||||
dashscope
 | 
			
		||||
@ -43,7 +44,7 @@ pydantic==2.10.6
 | 
			
		||||
# Math & modeling
 | 
			
		||||
torchdiffeq>=0.2.5
 | 
			
		||||
tensordict>=0.6.1
 | 
			
		||||
mmgp==3.5.10
 | 
			
		||||
mmgp==3.5.11
 | 
			
		||||
peft==0.15.0
 | 
			
		||||
matplotlib
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -3,6 +3,7 @@ import torch
 | 
			
		||||
from importlib.metadata import version
 | 
			
		||||
from mmgp import offload
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
import warnings
 | 
			
		||||
 | 
			
		||||
major, minor = torch.cuda.get_device_capability(None)
 | 
			
		||||
bfloat16_supported =  major >= 8 
 | 
			
		||||
@ -42,34 +43,51 @@ except ImportError:
 | 
			
		||||
    sageattn_varlen_wrapper = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import warnings
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    from sageattention import sageattn
 | 
			
		||||
    from .sage2_core import sageattn as alt_sageattn, is_sage2_supported
 | 
			
		||||
    from .sage2_core import sageattn as sageattn2, is_sage2_supported
 | 
			
		||||
    sage2_supported =  is_sage2_supported()
 | 
			
		||||
except ImportError:
 | 
			
		||||
    sageattn = None
 | 
			
		||||
    alt_sageattn = None
 | 
			
		||||
    sageattn2 = None
 | 
			
		||||
    sage2_supported = False
 | 
			
		||||
# @torch.compiler.disable()
 | 
			
		||||
def sageattn_wrapper(
 | 
			
		||||
@torch.compiler.disable()
 | 
			
		||||
def sageattn2_wrapper(
 | 
			
		||||
        qkv_list,
 | 
			
		||||
        attention_length
 | 
			
		||||
    ):
 | 
			
		||||
    q,k, v = qkv_list
 | 
			
		||||
    if True:
 | 
			
		||||
        qkv_list = [q,k,v]
 | 
			
		||||
        del q, k ,v
 | 
			
		||||
        o = alt_sageattn(qkv_list, tensor_layout="NHD")
 | 
			
		||||
    else:
 | 
			
		||||
        o = sageattn(q, k, v, tensor_layout="NHD")
 | 
			
		||||
        del q, k ,v
 | 
			
		||||
 | 
			
		||||
    qkv_list = [q,k,v]
 | 
			
		||||
    del q, k ,v
 | 
			
		||||
    o = sageattn2(qkv_list, tensor_layout="NHD")
 | 
			
		||||
    qkv_list.clear()
 | 
			
		||||
 | 
			
		||||
    return o
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    from sageattn import sageattn_blackwell as sageattn3
 | 
			
		||||
except ImportError:
 | 
			
		||||
    sageattn3 = None
 | 
			
		||||
 | 
			
		||||
@torch.compiler.disable()
 | 
			
		||||
def sageattn3_wrapper(
 | 
			
		||||
        qkv_list,
 | 
			
		||||
        attention_length
 | 
			
		||||
    ):
 | 
			
		||||
    q,k, v = qkv_list
 | 
			
		||||
    # qkv_list = [q,k,v]
 | 
			
		||||
    # del q, k ,v
 | 
			
		||||
    # o = sageattn3(qkv_list, tensor_layout="NHD")
 | 
			
		||||
    q = q.transpose(1,2)
 | 
			
		||||
    k = k.transpose(1,2)
 | 
			
		||||
    v = v.transpose(1,2)
 | 
			
		||||
    o = sageattn3(q, k, v)
 | 
			
		||||
    o = o.transpose(1,2)
 | 
			
		||||
    qkv_list.clear()
 | 
			
		||||
 | 
			
		||||
    return o
 | 
			
		||||
 | 
			
		||||
     
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# try:
 | 
			
		||||
# if True:
 | 
			
		||||
    # from .sage2_core import sageattn_qk_int8_pv_fp8_window_cuda
 | 
			
		||||
@ -94,7 +112,7 @@ def sageattn_wrapper(
 | 
			
		||||
 | 
			
		||||
    #     return o
 | 
			
		||||
# except ImportError:
 | 
			
		||||
#     sageattn = sageattn_qk_int8_pv_fp8_window_cuda
 | 
			
		||||
#     sageattn2 = sageattn_qk_int8_pv_fp8_window_cuda
 | 
			
		||||
 | 
			
		||||
@torch.compiler.disable()
 | 
			
		||||
def sdpa_wrapper(
 | 
			
		||||
@ -124,21 +142,28 @@ def get_attention_modes():
 | 
			
		||||
        ret.append("xformers")
 | 
			
		||||
    if sageattn_varlen_wrapper != None:
 | 
			
		||||
        ret.append("sage")
 | 
			
		||||
    if sageattn != None and version("sageattention").startswith("2") :
 | 
			
		||||
    if sageattn2 != None and version("sageattention").startswith("2") :
 | 
			
		||||
        ret.append("sage2")
 | 
			
		||||
    if sageattn3 != None: # and version("sageattention").startswith("3") :
 | 
			
		||||
        ret.append("sage3")
 | 
			
		||||
        
 | 
			
		||||
    return ret
 | 
			
		||||
 | 
			
		||||
def get_supported_attention_modes():
 | 
			
		||||
    ret = get_attention_modes()
 | 
			
		||||
    major, minor = torch.cuda.get_device_capability()
 | 
			
		||||
    if  major < 10:
 | 
			
		||||
        if "sage3" in ret:
 | 
			
		||||
            ret.remove("sage3")
 | 
			
		||||
 | 
			
		||||
    if not sage2_supported:
 | 
			
		||||
        if "sage2" in ret:
 | 
			
		||||
            ret.remove("sage2")
 | 
			
		||||
 | 
			
		||||
    major, minor = torch.cuda.get_device_capability()
 | 
			
		||||
    if  major < 7:
 | 
			
		||||
        if "sage" in ret:
 | 
			
		||||
            ret.remove("sage")
 | 
			
		||||
 | 
			
		||||
    return ret
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
@ -201,7 +226,7 @@ def pay_attention(
 | 
			
		||||
        from src.chipmunk.modules import SparseDiffMlp, SparseDiffAttn
 | 
			
		||||
        from src.chipmunk.util import LayerCounter, GLOBAL_CONFIG
 | 
			
		||||
 | 
			
		||||
    if b > 1 and k_lens != None and attn in ("sage2", "sdpa"):
 | 
			
		||||
    if b > 1 and k_lens != None and attn in ("sage2", "sage3", "sdpa"):
 | 
			
		||||
        assert attention_mask == None
 | 
			
		||||
        # Poor's man var k len attention
 | 
			
		||||
        assert q_lens == None
 | 
			
		||||
@ -234,7 +259,7 @@ def pay_attention(
 | 
			
		||||
            q_chunks, k_chunks, v_chunks = None, None, None
 | 
			
		||||
            o = torch.cat(o, dim = 0)
 | 
			
		||||
            return o
 | 
			
		||||
    elif (q_lens != None or k_lens != None) and attn in ("sage2", "sdpa"):
 | 
			
		||||
    elif (q_lens != None or k_lens != None) and attn in ("sage2", "sage3", "sdpa"):
 | 
			
		||||
        assert b == 1
 | 
			
		||||
        szq = q_lens[0].item() if q_lens != None else lq
 | 
			
		||||
        szk = k_lens[0].item() if k_lens != None else lk
 | 
			
		||||
@ -284,13 +309,19 @@ def pay_attention(
 | 
			
		||||
            max_seqlen_q=lq,
 | 
			
		||||
            max_seqlen_kv=lk,
 | 
			
		||||
        ).unflatten(0, (b, lq))
 | 
			
		||||
    elif attn=="sage3":
 | 
			
		||||
        import math
 | 
			
		||||
        if cross_attn or True:
 | 
			
		||||
            qkv_list = [q,k,v]
 | 
			
		||||
            del q,k,v
 | 
			
		||||
            x = sageattn3_wrapper(qkv_list, lq)
 | 
			
		||||
    elif attn=="sage2":
 | 
			
		||||
        import math
 | 
			
		||||
        if cross_attn or True:
 | 
			
		||||
            qkv_list = [q,k,v]
 | 
			
		||||
            del q,k,v
 | 
			
		||||
 | 
			
		||||
            x = sageattn_wrapper(qkv_list, lq) #.unsqueeze(0)
 | 
			
		||||
            x = sageattn2_wrapper(qkv_list, lq) #.unsqueeze(0)
 | 
			
		||||
        # else:
 | 
			
		||||
        #     layer =  offload.shared_state["layer"]
 | 
			
		||||
        #     embed_sizes = offload.shared_state["embed_sizes"] 
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										496
									
								
								shared/gradio/gallery.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										496
									
								
								shared/gradio/gallery.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,496 @@
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
import os, io, tempfile, mimetypes
 | 
			
		||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, Literal
 | 
			
		||||
 | 
			
		||||
import gradio as gr
 | 
			
		||||
import PIL
 | 
			
		||||
from PIL import Image as PILImage
 | 
			
		||||
 | 
			
		||||
FilePath = str
 | 
			
		||||
ImageLike = Union["PIL.Image.Image", Any]
 | 
			
		||||
 | 
			
		||||
IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"}
 | 
			
		||||
VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm", ".m4v", ".mpeg", ".mpg", ".ogv"}
 | 
			
		||||
 | 
			
		||||
def get_state(state):
 | 
			
		||||
    return state if isinstance(state, dict) else state.value
 | 
			
		||||
 | 
			
		||||
def get_list( objs):
 | 
			
		||||
    if objs is None:
 | 
			
		||||
        return []
 | 
			
		||||
    return [ obj[0] if isinstance(obj, tuple) else obj for obj in objs]
 | 
			
		||||
 | 
			
		||||
class AdvancedMediaGallery:
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        label: str = "Media",
 | 
			
		||||
        *,
 | 
			
		||||
        media_mode: Literal["image", "video"] = "image",
 | 
			
		||||
        height = None,
 | 
			
		||||
        columns: Union[int, Tuple[int, ...]] = 6,
 | 
			
		||||
        show_label: bool = True,
 | 
			
		||||
        initial: Optional[Sequence[Union[FilePath, ImageLike]]] = None,
 | 
			
		||||
        elem_id: Optional[str] = None,
 | 
			
		||||
        elem_classes: Optional[Sequence[str]] = ("adv-media-gallery",),
 | 
			
		||||
        accept_filter: bool = True,        # restrict Add-button dialog to allowed extensions
 | 
			
		||||
        single_image_mode: bool = False,   # start in single-image mode (Add replaces)
 | 
			
		||||
    ):
 | 
			
		||||
        assert media_mode in ("image", "video")
 | 
			
		||||
        self.label = label
 | 
			
		||||
        self.media_mode = media_mode
 | 
			
		||||
        self.height = height
 | 
			
		||||
        self.columns = columns
 | 
			
		||||
        self.show_label = show_label
 | 
			
		||||
        self.elem_id = elem_id
 | 
			
		||||
        self.elem_classes = list(elem_classes) if elem_classes else None
 | 
			
		||||
        self.accept_filter = accept_filter
 | 
			
		||||
 | 
			
		||||
        items = self._normalize_initial(initial or [], media_mode)
 | 
			
		||||
 | 
			
		||||
        # Components (filled on mount)
 | 
			
		||||
        self.container: Optional[gr.Column] = None
 | 
			
		||||
        self.gallery: Optional[gr.Gallery] = None
 | 
			
		||||
        self.upload_btn: Optional[gr.UploadButton] = None
 | 
			
		||||
        self.btn_remove: Optional[gr.Button] = None
 | 
			
		||||
        self.btn_left: Optional[gr.Button] = None
 | 
			
		||||
        self.btn_right: Optional[gr.Button] = None
 | 
			
		||||
        self.btn_clear: Optional[gr.Button] = None
 | 
			
		||||
 | 
			
		||||
        # Single dict state
 | 
			
		||||
        self.state: Optional[gr.State] = None
 | 
			
		||||
        self._initial_state: Dict[str, Any] = {
 | 
			
		||||
            "items": items,
 | 
			
		||||
            "selected": (len(items) - 1) if items else None,
 | 
			
		||||
            "single": bool(single_image_mode),
 | 
			
		||||
            "mode": self.media_mode,
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    # ---------------- helpers ----------------
 | 
			
		||||
 | 
			
		||||
    def _normalize_initial(self, items: Sequence[Union[FilePath, ImageLike]], mode: str) -> List[Any]:
 | 
			
		||||
        out: List[Any] = []
 | 
			
		||||
        if mode == "image":
 | 
			
		||||
            for it in items:
 | 
			
		||||
                p = self._ensure_image_item(it)
 | 
			
		||||
                if p is not None:
 | 
			
		||||
                    out.append(p)
 | 
			
		||||
        else:
 | 
			
		||||
            for it in items:
 | 
			
		||||
                if isinstance(item, tuple): item = item[0]
 | 
			
		||||
                if isinstance(it, str) and self._is_video_path(it):
 | 
			
		||||
                    out.append(os.path.abspath(it))
 | 
			
		||||
        return out
 | 
			
		||||
 | 
			
		||||
    def _ensure_image_item(self, item: Union[FilePath, ImageLike]) -> Optional[Any]:
 | 
			
		||||
        # Accept a path to an image, or a PIL.Image/np.ndarray -> save temp PNG and return its path
 | 
			
		||||
        if isinstance(item, tuple): item = item[0]
 | 
			
		||||
        if isinstance(item, str):
 | 
			
		||||
            return os.path.abspath(item) if self._is_image_path(item) else None
 | 
			
		||||
        if PILImage is None:
 | 
			
		||||
            return None
 | 
			
		||||
        try:
 | 
			
		||||
            if isinstance(item, PILImage.Image):
 | 
			
		||||
                img = item
 | 
			
		||||
            else:
 | 
			
		||||
                import numpy as np  # type: ignore
 | 
			
		||||
                if isinstance(item, np.ndarray):
 | 
			
		||||
                    img = PILImage.fromarray(item)
 | 
			
		||||
                elif hasattr(item, "read"):
 | 
			
		||||
                    data = item.read()
 | 
			
		||||
                    img = PILImage.open(io.BytesIO(data)).convert("RGBA")
 | 
			
		||||
                else:
 | 
			
		||||
                    return None
 | 
			
		||||
            tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
 | 
			
		||||
            img.save(tmp.name)
 | 
			
		||||
            return tmp.name
 | 
			
		||||
        except Exception:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _extract_path(obj: Any) -> Optional[str]:
 | 
			
		||||
        # Try to get a filesystem path (for mode filtering); otherwise None.
 | 
			
		||||
        if isinstance(obj, str):
 | 
			
		||||
            return obj
 | 
			
		||||
        try:
 | 
			
		||||
            import pathlib
 | 
			
		||||
            if isinstance(obj, pathlib.Path):  # type: ignore
 | 
			
		||||
                return str(obj)
 | 
			
		||||
        except Exception:
 | 
			
		||||
            pass
 | 
			
		||||
        if isinstance(obj, dict):
 | 
			
		||||
            return obj.get("path") or obj.get("name")
 | 
			
		||||
        for attr in ("path", "name"):
 | 
			
		||||
            if hasattr(obj, attr):
 | 
			
		||||
                try:
 | 
			
		||||
                    val = getattr(obj, attr)
 | 
			
		||||
                    if isinstance(val, str):
 | 
			
		||||
                        return val
 | 
			
		||||
                except Exception:
 | 
			
		||||
                    pass
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _is_image_path(p: str) -> bool:
 | 
			
		||||
        ext = os.path.splitext(p)[1].lower()
 | 
			
		||||
        if ext in IMAGE_EXTS:
 | 
			
		||||
            return True
 | 
			
		||||
        mt, _ = mimetypes.guess_type(p)
 | 
			
		||||
        return bool(mt and mt.startswith("image/"))
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _is_video_path(p: str) -> bool:
 | 
			
		||||
        ext = os.path.splitext(p)[1].lower()
 | 
			
		||||
        if ext in VIDEO_EXTS:
 | 
			
		||||
            return True
 | 
			
		||||
        mt, _ = mimetypes.guess_type(p)
 | 
			
		||||
        return bool(mt and mt.startswith("video/"))
 | 
			
		||||
 | 
			
		||||
    def _filter_items_by_mode(self, items: List[Any]) -> List[Any]:
 | 
			
		||||
        # Enforce image-only or video-only collection regardless of how files were added.
 | 
			
		||||
        out: List[Any] = []
 | 
			
		||||
        if self.media_mode == "image":
 | 
			
		||||
            for it in items:
 | 
			
		||||
                p = self._extract_path(it)
 | 
			
		||||
                if p is None:
 | 
			
		||||
                    # No path: likely an image object added programmatically => keep
 | 
			
		||||
                    out.append(it)
 | 
			
		||||
                elif self._is_image_path(p):
 | 
			
		||||
                    out.append(os.path.abspath(p))
 | 
			
		||||
        else:
 | 
			
		||||
            for it in items:
 | 
			
		||||
                p = self._extract_path(it)
 | 
			
		||||
                if p is not None and self._is_video_path(p):
 | 
			
		||||
                    out.append(os.path.abspath(p))
 | 
			
		||||
        return out
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _concat_and_optionally_dedupe(cur: List[Any], add: List[Any]) -> List[Any]:
 | 
			
		||||
        # Keep it simple: dedupe by path when available, else allow duplicates.
 | 
			
		||||
        seen_paths = set()
 | 
			
		||||
        def key(x: Any) -> Optional[str]:
 | 
			
		||||
            if isinstance(x, str): return os.path.abspath(x)
 | 
			
		||||
            try:
 | 
			
		||||
                import pathlib
 | 
			
		||||
                if isinstance(x, pathlib.Path):  # type: ignore
 | 
			
		||||
                    return os.path.abspath(str(x))
 | 
			
		||||
            except Exception:
 | 
			
		||||
                pass
 | 
			
		||||
            if isinstance(x, dict):
 | 
			
		||||
                p = x.get("path") or x.get("name")
 | 
			
		||||
                return os.path.abspath(p) if isinstance(p, str) else None
 | 
			
		||||
            for attr in ("path", "name"):
 | 
			
		||||
                if hasattr(x, attr):
 | 
			
		||||
                    try:
 | 
			
		||||
                        v = getattr(x, attr)
 | 
			
		||||
                        return os.path.abspath(v) if isinstance(v, str) else None
 | 
			
		||||
                    except Exception:
 | 
			
		||||
                        pass
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
        out: List[Any] = []
 | 
			
		||||
        for lst in (cur, add):
 | 
			
		||||
            for it in lst:
 | 
			
		||||
                k = key(it)
 | 
			
		||||
                if k is None or k not in seen_paths:
 | 
			
		||||
                    out.append(it)
 | 
			
		||||
                    if k is not None:
 | 
			
		||||
                        seen_paths.add(k)
 | 
			
		||||
        return out
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _paths_from_payload(payload: Any) -> List[Any]:
 | 
			
		||||
        # Return as raw objects (paths/dicts/UploadedFile) to feed Gallery directly.
 | 
			
		||||
        if payload is None:
 | 
			
		||||
            return []
 | 
			
		||||
        if isinstance(payload, (list, tuple, set)):
 | 
			
		||||
            return list(payload)
 | 
			
		||||
        return [payload]
 | 
			
		||||
 | 
			
		||||
    # ---------------- event handlers ----------------
 | 
			
		||||
 | 
			
		||||
    def _on_select(self, state: Dict[str, Any], gallery, evt: gr.SelectData) :
 | 
			
		||||
        # Mirror the selected index into state and the gallery (server-side selected_index)
 | 
			
		||||
        idx = None
 | 
			
		||||
        if evt is not None and hasattr(evt, "index"):
 | 
			
		||||
            ix = evt.index
 | 
			
		||||
            if isinstance(ix, int):
 | 
			
		||||
                idx = ix
 | 
			
		||||
            elif isinstance(ix, (tuple, list)) and ix and isinstance(ix[0], int):
 | 
			
		||||
                if isinstance(self.columns, int) and len(ix) >= 2:
 | 
			
		||||
                    idx = ix[0] * max(1, int(self.columns)) + ix[1]
 | 
			
		||||
                else:
 | 
			
		||||
                    idx = ix[0]
 | 
			
		||||
        st = get_state(state)
 | 
			
		||||
        n = len(get_list(gallery))
 | 
			
		||||
        sel = idx if (idx is not None and 0 <= idx < n) else None
 | 
			
		||||
        st["selected"] = sel
 | 
			
		||||
        return gr.update(selected_index=sel), st
 | 
			
		||||
 | 
			
		||||
    def _on_gallery_change(self, value: List[Any], state: Dict[str, Any]) :
 | 
			
		||||
        # Fires when users add/drag/drop/delete via the Gallery itself.
 | 
			
		||||
        items_filtered = self._filter_items_by_mode(list(value or []))
 | 
			
		||||
        st = get_state(state)
 | 
			
		||||
        st["items"] = items_filtered
 | 
			
		||||
        # Keep selection if still valid, else default to last
 | 
			
		||||
        old_sel = st.get("selected", None)
 | 
			
		||||
        if old_sel is None or not (0 <= old_sel < len(items_filtered)):
 | 
			
		||||
            new_sel = (len(items_filtered) - 1) if items_filtered else None
 | 
			
		||||
        else:
 | 
			
		||||
            new_sel = old_sel
 | 
			
		||||
        st["selected"] = new_sel
 | 
			
		||||
        return gr.update(value=items_filtered, selected_index=new_sel), st
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def _on_add(self, files_payload: Any, state: Dict[str, Any], gallery):
 | 
			
		||||
        """
 | 
			
		||||
        Insert added items right AFTER the currently selected index.
 | 
			
		||||
        Keeps the same ordering as chosen in the file picker, dedupes by path,
 | 
			
		||||
        and re-selects the last inserted item.
 | 
			
		||||
        """
 | 
			
		||||
        # New items (respect image/video mode)
 | 
			
		||||
        new_items = self._filter_items_by_mode(self._paths_from_payload(files_payload))
 | 
			
		||||
 | 
			
		||||
        st = get_state(state)
 | 
			
		||||
        cur: List[Any] = get_list(gallery)
 | 
			
		||||
        sel = st.get("selected", None)
 | 
			
		||||
        if sel is None:
 | 
			
		||||
            sel = (len(cur) -1) if len(cur)>0 else 0
 | 
			
		||||
        single = bool(st.get("single", False))
 | 
			
		||||
 | 
			
		||||
        # Nothing to add: keep as-is
 | 
			
		||||
        if not new_items:
 | 
			
		||||
            return gr.update(value=cur, selected_index=st.get("selected")), st
 | 
			
		||||
 | 
			
		||||
        # Single-image mode: replace
 | 
			
		||||
        if single:
 | 
			
		||||
            st["items"] = [new_items[-1]]
 | 
			
		||||
            st["selected"] = 0
 | 
			
		||||
            return gr.update(value=st["items"], selected_index=0), st
 | 
			
		||||
 | 
			
		||||
        # ---------- helpers ----------
 | 
			
		||||
        def key_of(it: Any) -> Optional[str]:
 | 
			
		||||
            # Prefer class helper if present
 | 
			
		||||
            if hasattr(self, "_extract_path"):
 | 
			
		||||
                p = self._extract_path(it)  # type: ignore
 | 
			
		||||
            else:
 | 
			
		||||
                p = it if isinstance(it, str) else None
 | 
			
		||||
                if p is None and isinstance(it, dict):
 | 
			
		||||
                    p = it.get("path") or it.get("name")
 | 
			
		||||
                if p is None and hasattr(it, "path"):
 | 
			
		||||
                    try: p = getattr(it, "path")
 | 
			
		||||
                    except Exception: p = None
 | 
			
		||||
                if p is None and hasattr(it, "name"):
 | 
			
		||||
                    try: p = getattr(it, "name")
 | 
			
		||||
                    except Exception: p = None
 | 
			
		||||
            return os.path.abspath(p) if isinstance(p, str) else None
 | 
			
		||||
 | 
			
		||||
        # Dedupe the incoming batch by path, preserve order
 | 
			
		||||
        seen_new = set()
 | 
			
		||||
        incoming: List[Any] = []
 | 
			
		||||
        for it in new_items:
 | 
			
		||||
            k = key_of(it)
 | 
			
		||||
            if k is None or k not in seen_new:
 | 
			
		||||
                incoming.append(it)
 | 
			
		||||
                if k is not None:
 | 
			
		||||
                    seen_new.add(k)
 | 
			
		||||
 | 
			
		||||
        # Remove any existing occurrences of the incoming items from current list,
 | 
			
		||||
        # BUT keep the currently selected item even if it's also in incoming.
 | 
			
		||||
        cur_clean: List[Any] = []
 | 
			
		||||
        # sel_item = cur[sel] if (sel is not None and 0 <= sel < len(cur)) else None
 | 
			
		||||
        # for idx, it in enumerate(cur):
 | 
			
		||||
        #     k = key_of(it)
 | 
			
		||||
        #     if it is sel_item:
 | 
			
		||||
        #         cur_clean.append(it)
 | 
			
		||||
        #         continue
 | 
			
		||||
        #     if k is not None and k in seen_new:
 | 
			
		||||
        #         continue  # drop duplicate; we'll reinsert at the target spot
 | 
			
		||||
        #     cur_clean.append(it)
 | 
			
		||||
 | 
			
		||||
        # # Compute insertion position: right AFTER the (possibly shifted) selected item
 | 
			
		||||
        # if sel_item is not None:
 | 
			
		||||
        #     # find sel_item's new index in cur_clean
 | 
			
		||||
        #     try:
 | 
			
		||||
        #         pos_sel = cur_clean.index(sel_item)
 | 
			
		||||
        #     except ValueError:
 | 
			
		||||
        #         # Shouldn't happen, but fall back to end
 | 
			
		||||
        #         pos_sel = len(cur_clean) - 1
 | 
			
		||||
        #     insert_pos = pos_sel + 1
 | 
			
		||||
        # else:
 | 
			
		||||
        #     insert_pos = len(cur_clean)  # no selection -> append at end
 | 
			
		||||
        insert_pos = min(sel, len(cur) -1)
 | 
			
		||||
        cur_clean = cur
 | 
			
		||||
        # Build final list and selection
 | 
			
		||||
        merged = cur_clean[:insert_pos+1] + incoming + cur_clean[insert_pos+1:]
 | 
			
		||||
        new_sel = insert_pos + len(incoming)   # select the last inserted item
 | 
			
		||||
 | 
			
		||||
        st["items"] = merged
 | 
			
		||||
        st["selected"] = new_sel
 | 
			
		||||
        return gr.update(value=merged, selected_index=new_sel), st
 | 
			
		||||
 | 
			
		||||
    def _on_remove(self, state: Dict[str, Any], gallery) :
 | 
			
		||||
        st = get_state(state); items: List[Any] = get_list(gallery); sel = st.get("selected", None)
 | 
			
		||||
        if sel is None or not (0 <= sel < len(items)):
 | 
			
		||||
            return gr.update(value=items, selected_index=st.get("selected")), st
 | 
			
		||||
        items.pop(sel)
 | 
			
		||||
        if not items:
 | 
			
		||||
            st["items"] = []; st["selected"] = None
 | 
			
		||||
            return gr.update(value=[], selected_index=None), st
 | 
			
		||||
        new_sel = min(sel, len(items) - 1)
 | 
			
		||||
        st["items"] = items; st["selected"] = new_sel
 | 
			
		||||
        return gr.update(value=items, selected_index=new_sel), st
 | 
			
		||||
 | 
			
		||||
    def _on_move(self, delta: int, state: Dict[str, Any], gallery) :
 | 
			
		||||
        st = get_state(state); items: List[Any] = get_list(gallery); sel = st.get("selected", None)
 | 
			
		||||
        if sel is None or not (0 <= sel < len(items)):
 | 
			
		||||
            return gr.update(value=items, selected_index=sel), st
 | 
			
		||||
        j = sel + delta
 | 
			
		||||
        if j < 0 or j >= len(items):
 | 
			
		||||
            return gr.update(value=items, selected_index=sel), st
 | 
			
		||||
        items[sel], items[j] = items[j], items[sel]
 | 
			
		||||
        st["items"] = items; st["selected"] = j
 | 
			
		||||
        return gr.update(value=items, selected_index=j), st
 | 
			
		||||
 | 
			
		||||
    def _on_clear(self, state: Dict[str, Any]) :
 | 
			
		||||
        st = {"items": [], "selected": None, "single": state.get("single", False), "mode": self.media_mode}
 | 
			
		||||
        return gr.update(value=[], selected_index=None), st
 | 
			
		||||
 | 
			
		||||
    def _on_toggle_single(self, to_single: bool, state: Dict[str, Any]) :
 | 
			
		||||
        st = get_state(state); st["single"] = bool(to_single)
 | 
			
		||||
        items: List[Any] = list(st["items"]); sel = st.get("selected", None)
 | 
			
		||||
        if st["single"]:
 | 
			
		||||
            keep = items[sel] if (sel is not None and 0 <= sel < len(items)) else (items[-1] if items else None)
 | 
			
		||||
            items = [keep] if keep is not None else []
 | 
			
		||||
            sel = 0 if items else None
 | 
			
		||||
        st["items"] = items; st["selected"] = sel
 | 
			
		||||
 | 
			
		||||
        upload_update = gr.update(file_count=("single" if st["single"] else "multiple"))
 | 
			
		||||
        left_update   = gr.update(visible=not st["single"])
 | 
			
		||||
        right_update  = gr.update(visible=not st["single"])
 | 
			
		||||
        clear_update  = gr.update(visible=not st["single"])
 | 
			
		||||
        gallery_update= gr.update(value=items, selected_index=sel)
 | 
			
		||||
 | 
			
		||||
        return upload_update, left_update, right_update, clear_update, gallery_update, st
 | 
			
		||||
 | 
			
		||||
    # ---------------- build & wire ----------------
 | 
			
		||||
 | 
			
		||||
    def mount(self, parent: Optional[gr.Blocks | gr.Group | gr.Row | gr.Column] = None, update_form = False):
 | 
			
		||||
        if parent is not None:
 | 
			
		||||
            with parent:
 | 
			
		||||
                col = self._build_ui()
 | 
			
		||||
        else:
 | 
			
		||||
            col = self._build_ui()
 | 
			
		||||
        if not update_form:
 | 
			
		||||
            self._wire_events()
 | 
			
		||||
        return col
 | 
			
		||||
 | 
			
		||||
    def _build_ui(self) -> gr.Column:
 | 
			
		||||
        with gr.Column(elem_id=self.elem_id, elem_classes=self.elem_classes) as col:
 | 
			
		||||
            self.container = col
 | 
			
		||||
 | 
			
		||||
            self.state = gr.State(dict(self._initial_state))
 | 
			
		||||
 | 
			
		||||
            self.gallery = gr.Gallery(
 | 
			
		||||
                label=self.label,
 | 
			
		||||
                value=self._initial_state["items"],
 | 
			
		||||
                height=self.height,
 | 
			
		||||
                columns=self.columns,
 | 
			
		||||
                show_label=self.show_label,
 | 
			
		||||
                preview= True,
 | 
			
		||||
                selected_index=self._initial_state["selected"],  # server-side selection
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # One-line controls
 | 
			
		||||
            exts = sorted(IMAGE_EXTS if self.media_mode == "image" else VIDEO_EXTS) if self.accept_filter else None
 | 
			
		||||
            with gr.Row(equal_height=True, elem_classes=["amg-controls"]):
 | 
			
		||||
                self.upload_btn = gr.UploadButton(
 | 
			
		||||
                    "Set" if self._initial_state["single"] else "Add",
 | 
			
		||||
                    file_types=exts,
 | 
			
		||||
                    file_count=("single" if self._initial_state["single"] else "multiple"),
 | 
			
		||||
                    variant="primary",
 | 
			
		||||
                    size="sm",
 | 
			
		||||
                    min_width=1,
 | 
			
		||||
                )
 | 
			
		||||
                self.btn_remove = gr.Button("Remove", size="sm", min_width=1)
 | 
			
		||||
                self.btn_left   = gr.Button("◀ Left",  size="sm", visible=not self._initial_state["single"], min_width=1)
 | 
			
		||||
                self.btn_right  = gr.Button("Right ▶", size="sm", visible=not self._initial_state["single"], min_width=1)
 | 
			
		||||
                self.btn_clear  = gr.Button("Clear",   variant="secondary", size="sm", visible=not self._initial_state["single"], min_width=1)
 | 
			
		||||
 | 
			
		||||
        return col
 | 
			
		||||
 | 
			
		||||
    def _wire_events(self):
 | 
			
		||||
        # Selection: mirror into state and keep gallery.selected_index in sync
 | 
			
		||||
        self.gallery.select(
 | 
			
		||||
            self._on_select,
 | 
			
		||||
            inputs=[self.state, self.gallery],
 | 
			
		||||
            outputs=[self.gallery, self.state],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Gallery value changed by user actions (click-to-add, drag-drop, internal remove, etc.)
 | 
			
		||||
        self.gallery.change(
 | 
			
		||||
            self._on_gallery_change,
 | 
			
		||||
            inputs=[self.gallery, self.state],
 | 
			
		||||
            outputs=[self.gallery, self.state],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Add via UploadButton
 | 
			
		||||
        self.upload_btn.upload(
 | 
			
		||||
            self._on_add,
 | 
			
		||||
            inputs=[self.upload_btn, self.state, self.gallery],
 | 
			
		||||
            outputs=[self.gallery, self.state],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Remove selected
 | 
			
		||||
        self.btn_remove.click(
 | 
			
		||||
            self._on_remove,
 | 
			
		||||
            inputs=[self.state, self.gallery],
 | 
			
		||||
            outputs=[self.gallery, self.state],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Reorder using selected index, keep same item selected
 | 
			
		||||
        self.btn_left.click(
 | 
			
		||||
            lambda st, gallery: self._on_move(-1, st, gallery),
 | 
			
		||||
            inputs=[self.state, self.gallery],
 | 
			
		||||
            outputs=[self.gallery, self.state],
 | 
			
		||||
        )
 | 
			
		||||
        self.btn_right.click(
 | 
			
		||||
            lambda st, gallery: self._on_move(+1, st, gallery),
 | 
			
		||||
            inputs=[self.state, self.gallery],
 | 
			
		||||
            outputs=[self.gallery, self.state],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Clear all
 | 
			
		||||
        self.btn_clear.click(
 | 
			
		||||
            self._on_clear,
 | 
			
		||||
            inputs=[self.state],
 | 
			
		||||
            outputs=[self.gallery, self.state],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # ---------------- public API ----------------
 | 
			
		||||
 | 
			
		||||
    def set_one_image_mode(self, enabled: bool = True):
 | 
			
		||||
        """Toggle single-image mode at runtime."""
 | 
			
		||||
        return (
 | 
			
		||||
            self._on_toggle_single,
 | 
			
		||||
            [gr.State(enabled), self.state],
 | 
			
		||||
            [self.upload_btn, self.btn_left, self.btn_right, self.btn_clear, self.gallery, self.state],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def get_toggable_elements(self):
 | 
			
		||||
        return [self.upload_btn, self.btn_left, self.btn_right, self.btn_clear, self.gallery, self.state]
 | 
			
		||||
 | 
			
		||||
# import gradio as gr
 | 
			
		||||
 | 
			
		||||
# with gr.Blocks() as demo:
 | 
			
		||||
#     amg = AdvancedMediaGallery(media_mode="image", height=190, columns=8)
 | 
			
		||||
#     amg.mount()
 | 
			
		||||
#     g = amg.gallery
 | 
			
		||||
#     # buttons to switch modes live (optional)
 | 
			
		||||
#     def process(g):
 | 
			
		||||
#         pass
 | 
			
		||||
#     with gr.Row():
 | 
			
		||||
#         gr.Button("toto").click(process, g)
 | 
			
		||||
#         gr.Button("ONE image").click(*amg.set_one_image_mode(True))
 | 
			
		||||
#         gr.Button("MULTI image").click(*amg.set_one_image_mode(False))
 | 
			
		||||
 | 
			
		||||
# demo.launch()
 | 
			
		||||
@ -9,7 +9,6 @@ import os
 | 
			
		||||
import sys
 | 
			
		||||
import time
 | 
			
		||||
import warnings
 | 
			
		||||
from functools import lru_cache
 | 
			
		||||
from io import BytesIO
 | 
			
		||||
 | 
			
		||||
import requests
 | 
			
		||||
@ -257,7 +256,6 @@ VIDEO_READER_BACKENDS = {
 | 
			
		||||
FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@lru_cache(maxsize=1)
 | 
			
		||||
def get_video_reader_backend() -> str:
 | 
			
		||||
    if FORCE_QWENVL_VIDEO_READER is not None:
 | 
			
		||||
        video_reader_backend = FORCE_QWENVL_VIDEO_READER
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										190
									
								
								wgp.py
									
									
									
									
									
								
							
							
						
						
									
										190
									
								
								wgp.py
									
									
									
									
									
								
							@ -49,6 +49,7 @@ logging.set_verbosity_error
 | 
			
		||||
from preprocessing.matanyone  import app as matanyone_app
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
import requests
 | 
			
		||||
from shared.gradio.gallery import AdvancedMediaGallery
 | 
			
		||||
 | 
			
		||||
# import torch._dynamo as dynamo
 | 
			
		||||
# dynamo.config.recompile_limit = 2000   # default is 256
 | 
			
		||||
@ -58,9 +59,9 @@ global_queue_ref = []
 | 
			
		||||
AUTOSAVE_FILENAME = "queue.zip"
 | 
			
		||||
PROMPT_VARS_MAX = 10
 | 
			
		||||
 | 
			
		||||
target_mmgp_version = "3.5.10"
 | 
			
		||||
WanGP_version = "8.21"
 | 
			
		||||
settings_version = 2.27
 | 
			
		||||
target_mmgp_version = "3.5.11"
 | 
			
		||||
WanGP_version = "8.3"
 | 
			
		||||
settings_version = 2.28
 | 
			
		||||
max_source_video_frames = 3000
 | 
			
		||||
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
 | 
			
		||||
 | 
			
		||||
@ -466,7 +467,7 @@ def process_prompt_and_add_tasks(state, model_choice):
 | 
			
		||||
            image_mask = None
 | 
			
		||||
 | 
			
		||||
        if "G" in video_prompt_type:
 | 
			
		||||
            gr.Info(f"With Denoising Strength {denoising_strength:.1f}, denoising will start a Step no {int(num_inference_steps * (1. - denoising_strength))} ")
 | 
			
		||||
            gr.Info(f"With Denoising Strength {denoising_strength:.1f}, denoising will start at Step no {int(num_inference_steps * (1. - denoising_strength))} ")
 | 
			
		||||
        else: 
 | 
			
		||||
            denoising_strength = 1.0
 | 
			
		||||
        if len(keep_frames_video_guide) > 0 and model_type in ["ltxv_13B"]:
 | 
			
		||||
@ -4733,46 +4734,61 @@ def generate_video(
 | 
			
		||||
                # special treatment for start frame pos when alignement to first frame requested as otherwise the start frame number will be negative due to overlapped frames (has been previously compensated later with padding)
 | 
			
		||||
                audio_proj_split = get_window_audio_embeddings(audio_proj_full, audio_start_idx= aligned_window_start_frame + (source_video_overlap_frames_count if reset_control_aligment else 0 ), clip_length = current_video_length)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
            if video_guide is not None:
 | 
			
		||||
                keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate)
 | 
			
		||||
                if len(error) > 0:
 | 
			
		||||
                    raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
 | 
			
		||||
                keep_frames_parsed = keep_frames_parsed[aligned_guide_start_frame: aligned_guide_end_frame ]
 | 
			
		||||
            if infinitetalk and video_guide is not None:
 | 
			
		||||
                src_image = get_video_frame(video_guide, aligned_guide_start_frame-1, return_last_if_missing = True, return_PIL = True)
 | 
			
		||||
                new_height, new_width = calculate_new_dimensions(image_size[0], image_size[1], src_image.height, src_image.width, sample_fit_canvas, block_size = block_size)
 | 
			
		||||
                src_image = src_image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
 | 
			
		||||
                refresh_preview["video_guide"] = src_image  
 | 
			
		||||
                src_video = convert_image_to_tensor(src_image).unsqueeze(1)
 | 
			
		||||
                if sample_fit_canvas != None:  
 | 
			
		||||
                    image_size  = src_video.shape[-2:]
 | 
			
		||||
                    sample_fit_canvas = None
 | 
			
		||||
            if ltxv and video_guide is not None:
 | 
			
		||||
                preprocess_type = process_map_video_guide.get(filter_letters(video_prompt_type, "PED"), "raw")
 | 
			
		||||
                status_info = "Extracting " + processes_names[preprocess_type]
 | 
			
		||||
                send_cmd("progress", [0, get_latest_status(state, status_info)])
 | 
			
		||||
                # start one frame ealier to facilitate latents merging later
 | 
			
		||||
                src_video, _ = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) + (0 if aligned_guide_start_frame == 0 else 1), start_frame = aligned_guide_start_frame - (0 if aligned_guide_start_frame == 0 else 1), fit_canvas = sample_fit_canvas, target_fps = fps,  process_type = preprocess_type, inpaint_color = 0, proc_no =1, negate_mask = "N" in video_prompt_type, process_outside_mask = "inpaint" if "X" in video_prompt_type else "identity", block_size =block_size )
 | 
			
		||||
                if src_video !=  None:
 | 
			
		||||
                    src_video = src_video[ :(len(src_video)-1)// latent_size * latent_size +1 ]
 | 
			
		||||
                    refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy())
 | 
			
		||||
                    src_video  = src_video.permute(3, 0, 1, 2)
 | 
			
		||||
                    src_video  = src_video.float().div_(127.5).sub_(1.) # c, f, h, w
 | 
			
		||||
                    if sample_fit_canvas != None:
 | 
			
		||||
                        image_size = src_video.shape[-2:]
 | 
			
		||||
 | 
			
		||||
                if ltxv:
 | 
			
		||||
                    preprocess_type = process_map_video_guide.get(filter_letters(video_prompt_type, "PED"), "raw")
 | 
			
		||||
                    status_info = "Extracting " + processes_names[preprocess_type]
 | 
			
		||||
                    send_cmd("progress", [0, get_latest_status(state, status_info)])
 | 
			
		||||
                    # start one frame ealier to facilitate latents merging later
 | 
			
		||||
                    src_video, _ = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) + (0 if aligned_guide_start_frame == 0 else 1), start_frame = aligned_guide_start_frame - (0 if aligned_guide_start_frame == 0 else 1), fit_canvas = sample_fit_canvas, target_fps = fps,  process_type = preprocess_type, inpaint_color = 0, proc_no =1, negate_mask = "N" in video_prompt_type, process_outside_mask = "inpaint" if "X" in video_prompt_type else "identity", block_size =block_size )
 | 
			
		||||
                    if src_video !=  None:
 | 
			
		||||
                        src_video = src_video[ :(len(src_video)-1)// latent_size * latent_size +1 ]
 | 
			
		||||
                        refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy())
 | 
			
		||||
                        src_video  = src_video.permute(3, 0, 1, 2)
 | 
			
		||||
                        src_video  = src_video.float().div_(127.5).sub_(1.) # c, f, h, w
 | 
			
		||||
                        if sample_fit_canvas != None:
 | 
			
		||||
                            image_size = src_video.shape[-2:]
 | 
			
		||||
                            sample_fit_canvas = None
 | 
			
		||||
 | 
			
		||||
                elif hunyuan_custom_edit:
 | 
			
		||||
                    if "P" in  video_prompt_type:
 | 
			
		||||
                        progress_args = [0, get_latest_status(state,"Extracting Open Pose Information and Expanding Mask")]
 | 
			
		||||
                    else:
 | 
			
		||||
                        progress_args = [0, get_latest_status(state,"Extracting Video and Mask")]
 | 
			
		||||
 | 
			
		||||
                    send_cmd("progress", progress_args)
 | 
			
		||||
                    src_video, src_mask = preprocess_video_with_mask(video_guide,  video_mask, height=height, width = width, max_frames= current_video_length if window_no == 1 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type= "pose" if "P" in video_prompt_type else "inpaint", negate_mask = "N" in video_prompt_type, inpaint_color =0)
 | 
			
		||||
                    refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy()) 
 | 
			
		||||
                    if src_mask != None:                        
 | 
			
		||||
                        refresh_preview["video_mask"] = Image.fromarray(src_mask[0].cpu().numpy())
 | 
			
		||||
 | 
			
		||||
                elif "R" in video_prompt_type: # sparse video to video
 | 
			
		||||
                    src_image = get_video_frame(video_guide, aligned_guide_start_frame, return_last_if_missing = True, return_PIL = True)
 | 
			
		||||
                    new_height, new_width = calculate_new_dimensions(image_size[0], image_size[1], src_image.height, src_image.width, sample_fit_canvas, block_size = block_size)
 | 
			
		||||
                    src_image = src_image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
 | 
			
		||||
                    refresh_preview["video_guide"] = src_image  
 | 
			
		||||
                    src_video = convert_image_to_tensor(src_image).unsqueeze(1)
 | 
			
		||||
                    if sample_fit_canvas != None:  
 | 
			
		||||
                        image_size  = src_video.shape[-2:]
 | 
			
		||||
                        sample_fit_canvas = None
 | 
			
		||||
 | 
			
		||||
            if t2v and "G" in video_prompt_type:
 | 
			
		||||
                video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, target_fps = fps)
 | 
			
		||||
                if video_guide_processed == None:
 | 
			
		||||
                    src_video = pre_video_guide
 | 
			
		||||
                else:
 | 
			
		||||
                    if sample_fit_canvas != None:
 | 
			
		||||
                        image_size = video_guide_processed.shape[-3: -1]
 | 
			
		||||
                        sample_fit_canvas = None
 | 
			
		||||
                    src_video = video_guide_processed.float().div_(127.5).sub_(1.).permute(-1,0,1,2)
 | 
			
		||||
                    if pre_video_guide != None:
 | 
			
		||||
                        src_video = torch.cat( [pre_video_guide, src_video], dim=1) 
 | 
			
		||||
                elif "G" in video_prompt_type: # video to video
 | 
			
		||||
                    video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, target_fps = fps)
 | 
			
		||||
                    if video_guide_processed is None:
 | 
			
		||||
                        src_video = pre_video_guide
 | 
			
		||||
                    else:
 | 
			
		||||
                        if sample_fit_canvas != None:
 | 
			
		||||
                            image_size = video_guide_processed.shape[-3: -1]
 | 
			
		||||
                            sample_fit_canvas = None
 | 
			
		||||
                        src_video = video_guide_processed.float().div_(127.5).sub_(1.).permute(-1,0,1,2)
 | 
			
		||||
                        if pre_video_guide != None:
 | 
			
		||||
                            src_video = torch.cat( [pre_video_guide, src_video], dim=1) 
 | 
			
		||||
 | 
			
		||||
            if vace :
 | 
			
		||||
                image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications
 | 
			
		||||
@ -4834,17 +4850,8 @@ def generate_video(
 | 
			
		||||
                if sample_fit_canvas != None:
 | 
			
		||||
                    image_size = src_video[0].shape[-2:]
 | 
			
		||||
                    sample_fit_canvas = None
 | 
			
		||||
            elif hunyuan_custom_edit:
 | 
			
		||||
                if "P" in  video_prompt_type:
 | 
			
		||||
                    progress_args = [0, get_latest_status(state,"Extracting Open Pose Information and Expanding Mask")]
 | 
			
		||||
                else:
 | 
			
		||||
                    progress_args = [0, get_latest_status(state,"Extracting Video and Mask")]
 | 
			
		||||
 | 
			
		||||
                send_cmd("progress", progress_args)
 | 
			
		||||
                src_video, src_mask = preprocess_video_with_mask(video_guide,  video_mask, height=height, width = width, max_frames= current_video_length if window_no == 1 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type= "pose" if "P" in video_prompt_type else "inpaint", negate_mask = "N" in video_prompt_type, inpaint_color =0)
 | 
			
		||||
                refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy()) 
 | 
			
		||||
                if src_mask != None:                        
 | 
			
		||||
                    refresh_preview["video_mask"] = Image.fromarray(src_mask[0].cpu().numpy())
 | 
			
		||||
 | 
			
		||||
            if len(refresh_preview) > 0:
 | 
			
		||||
                new_inputs= locals()
 | 
			
		||||
                new_inputs.update(refresh_preview)
 | 
			
		||||
@ -6013,10 +6020,16 @@ def video_to_source_video(state, input_file_list, choice):
 | 
			
		||||
def image_to_ref_image_add(state, input_file_list, choice, target, target_name):
 | 
			
		||||
    file_list, file_settings_list = get_file_list(state, input_file_list)
 | 
			
		||||
    if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list): return gr.update()
 | 
			
		||||
    gr.Info(f"Selected Image was added to {target_name}")
 | 
			
		||||
    if target == None:
 | 
			
		||||
        target =[]
 | 
			
		||||
    target.append( file_list[choice])
 | 
			
		||||
    model_type = state["model_type"]
 | 
			
		||||
    model_def = get_model_def(model_type)    
 | 
			
		||||
    if model_def.get("one_image_ref_needed", False):
 | 
			
		||||
        gr.Info(f"Selected Image was set to {target_name}")
 | 
			
		||||
        target =[file_list[choice]]
 | 
			
		||||
    else:
 | 
			
		||||
        gr.Info(f"Selected Image was added to {target_name}")
 | 
			
		||||
        if target == None:
 | 
			
		||||
            target =[]
 | 
			
		||||
        target.append( file_list[choice])
 | 
			
		||||
    return target
 | 
			
		||||
 | 
			
		||||
def image_to_ref_image_set(state, input_file_list, choice, target, target_name):
 | 
			
		||||
@ -6229,6 +6242,7 @@ def get_settings_from_file(state, file_path, allow_json, merge_with_defaults, sw
 | 
			
		||||
        if not "WanGP" in configs.get("type", ""): configs = None 
 | 
			
		||||
    except:
 | 
			
		||||
        configs = None
 | 
			
		||||
    if configs is None: return None, False
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
    current_model_filename = state["model_filename"]
 | 
			
		||||
@ -6615,11 +6629,12 @@ def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt
 | 
			
		||||
    return video_prompt_type,  gr.update(visible = visible and not image_outputs), gr.update(visible = visible and image_outputs), gr.update(visible = keep_frames_video_guide_visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and vace), gr.update(visible= visible and not "U" in video_prompt_type ), gr.update(visible= mask_visible and not image_outputs), gr.update(visible= mask_visible and image_outputs), gr.update(visible= mask_visible)
 | 
			
		||||
 | 
			
		||||
def refresh_video_prompt_type_video_guide_alt(state, video_prompt_type, video_prompt_type_video_guide_alt):
 | 
			
		||||
    video_prompt_type = del_in_sequence(video_prompt_type, "UVQKI")
 | 
			
		||||
    video_prompt_type = del_in_sequence(video_prompt_type, "RGUVQKI")
 | 
			
		||||
    video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide_alt)
 | 
			
		||||
    control_video_visible = "V" in video_prompt_type
 | 
			
		||||
    ref_images_visible = "I" in video_prompt_type
 | 
			
		||||
    return video_prompt_type,  gr.update(visible = control_video_visible), gr.update(visible = ref_images_visible )
 | 
			
		||||
    denoising_strength_visible = "G" in video_prompt_type
 | 
			
		||||
    return video_prompt_type,  gr.update(visible = control_video_visible), gr.update(visible = ref_images_visible ), gr.update(visible = denoising_strength_visible )
 | 
			
		||||
 
 | 
			
		||||
# def refresh_video_prompt_video_guide_trigger(state, video_prompt_type, video_prompt_type_video_guide):
 | 
			
		||||
#     video_prompt_type_video_guide = video_prompt_type_video_guide.split("#")[0]
 | 
			
		||||
@ -6996,6 +7011,12 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
 | 
			
		||||
            v2i_switch_supported = (vace or t2v or standin) and not image_outputs
 | 
			
		||||
            ti2v_2_2 = base_model_type in ["ti2v_2_2"]
 | 
			
		||||
 | 
			
		||||
            def get_image_gallery(label ="", value = None, single_image_mode = False, visible = False ):
 | 
			
		||||
                with gr.Row(visible = visible) as gallery_row:
 | 
			
		||||
                    gallery_amg = AdvancedMediaGallery(media_mode="image", height=None, columns=4, label=label, initial = value , single_image_mode = single_image_mode )
 | 
			
		||||
                    gallery_amg.mount(update_form=update_form)
 | 
			
		||||
                return gallery_row, gallery_amg.gallery, [gallery_row] + gallery_amg.get_toggable_elements()
 | 
			
		||||
 | 
			
		||||
            image_mode_value = ui_defaults.get("image_mode", 1 if image_outputs else 0 )
 | 
			
		||||
            if not v2i_switch_supported and not image_outputs:
 | 
			
		||||
                image_mode_value = 0
 | 
			
		||||
@ -7009,15 +7030,14 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
 | 
			
		||||
                with gr.Tab("Text to Image", id = "t2i", elem_classes="compact_tab"):
 | 
			
		||||
                    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
            with gr.Column(visible= test_class_i2v(model_type) or hunyuan_i2v or diffusion_forcing or ltxv or recammaster or vace or ti2v_2_2) as image_prompt_column: 
 | 
			
		||||
                if vace or infinitetalk:
 | 
			
		||||
                    image_prompt_type_value= ui_defaults.get("image_prompt_type","")
 | 
			
		||||
                    image_prompt_type_value = "" if image_prompt_type_value == "S" else image_prompt_type_value
 | 
			
		||||
                    image_prompt_type = gr.Radio( [("New Video", ""),("Continue Video File", "V"),("Continue Last Video", "L")], value =image_prompt_type_value, label="Source Video", show_label= False, visible= not image_outputs , scale= 3)
 | 
			
		||||
 | 
			
		||||
                    image_start = gr.Gallery(visible = False)
 | 
			
		||||
                    image_end  = gr.Gallery(visible = False)
 | 
			
		||||
                    image_start_row, image_start, image_start_extra = get_image_gallery(visible = False )
 | 
			
		||||
                    image_end_row, image_end, image_end_extra = get_image_gallery(visible = False )
 | 
			
		||||
                    video_source = gr.Video(label= "Video Source", visible = "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None))
 | 
			
		||||
                    model_mode = gr.Dropdown(visible = False)
 | 
			
		||||
                    keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= len(filter_letters(image_prompt_type_value, "VLG"))>0 , scale = 2, label= "Truncate Video beyond this number of resampled Frames (empty=Keep All, negative truncates from End)" ) 
 | 
			
		||||
@ -7034,13 +7054,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
 | 
			
		||||
                        image_prompt_type_choices += [("Continue Video", "V")]
 | 
			
		||||
                    image_prompt_type = gr.Radio( image_prompt_type_choices, value =image_prompt_type_value, label="Location", show_label= False, visible= True , scale= 3)
 | 
			
		||||
 | 
			
		||||
                    # image_start = gr.Image(label= "Image as a starting point for a new video", type ="pil",value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value )
 | 
			
		||||
                    image_start = gr.Gallery(preview= True,
 | 
			
		||||
                            label="Images as starting points for new videos", type ="pil", #file_types= "image", 
 | 
			
		||||
                            columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value) 
 | 
			
		||||
                    image_end  = gr.Gallery(preview= True,
 | 
			
		||||
                            label="Images as ending points for new videos", type ="pil", #file_types= "image", 
 | 
			
		||||
                            columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None))
 | 
			
		||||
                    image_start_row, image_start, image_start_extra = get_image_gallery(label= "Images as starting points for new videos", value = ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value )
 | 
			
		||||
                    image_end_row, image_end, image_end_extra = get_image_gallery(label= "Images as ending points for new videos", value = ui_defaults.get("image_end", None), visible= "E" in image_prompt_type_value ) 
 | 
			
		||||
                    video_source = gr.Video(label= "Video to Continue", visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),)
 | 
			
		||||
                    if not diffusion_forcing:
 | 
			
		||||
                        model_mode = gr.Dropdown(
 | 
			
		||||
@ -7061,8 +7076,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
 | 
			
		||||
                    keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= "V" in image_prompt_type_value, scale = 2, label= "Truncate Video beyond this number of Frames of Video (empty=Keep All)" ) 
 | 
			
		||||
                elif recammaster:
 | 
			
		||||
                    image_prompt_type = gr.Radio(choices=[("Source Video", "V")], value="V")
 | 
			
		||||
                    image_start = gr.Gallery(value = None, visible = False)
 | 
			
		||||
                    image_end  = gr.Gallery(value = None, visible= False)
 | 
			
		||||
                    image_start_row, image_start, image_start_extra = get_image_gallery(visible = False )
 | 
			
		||||
                    image_end_row, image_end, image_end_extra = get_image_gallery(visible = False )
 | 
			
		||||
                    video_source = gr.Video(label= "Video Source", visible = True, value= ui_defaults.get("video_source", None),)
 | 
			
		||||
                    model_mode = gr.Dropdown(
 | 
			
		||||
                        choices=[
 | 
			
		||||
@ -7095,21 +7110,16 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
 | 
			
		||||
                        image_prompt_type = gr.Radio( image_prompt_type_choices, value =image_prompt_type_value, label="Location", show_label= False, visible= not hunyuan_i2v, scale= 3)
 | 
			
		||||
                        any_start_image = True
 | 
			
		||||
                        any_end_image = True
 | 
			
		||||
                        image_start = gr.Gallery(preview= True,
 | 
			
		||||
                                label="Images as starting points for new videos", type ="pil", #file_types= "image", 
 | 
			
		||||
                                columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value) 
 | 
			
		||||
 | 
			
		||||
                        image_end  = gr.Gallery(preview= True,
 | 
			
		||||
                                label="Images as ending points for new videos", type ="pil", #file_types= "image", 
 | 
			
		||||
                                columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None))
 | 
			
		||||
                        image_start_row, image_start, image_start_extra = get_image_gallery(label= "Images as starting points for new videos", value = ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value )
 | 
			
		||||
                        image_end_row, image_end, image_end_extra = get_image_gallery(label= "Images as ending points for new videos", value = ui_defaults.get("image_end", None), visible= "E" in image_prompt_type_value ) 
 | 
			
		||||
                        if hunyuan_i2v:
 | 
			
		||||
                            video_source = gr.Video(value=None, visible=False)
 | 
			
		||||
                        else:
 | 
			
		||||
                            video_source = gr.Video(label= "Video to Continue", visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),)
 | 
			
		||||
                    else:
 | 
			
		||||
                        image_prompt_type = gr.Radio(choices=[("", "")], value="")
 | 
			
		||||
                        image_start = gr.Gallery(value=None)
 | 
			
		||||
                        image_end  = gr.Gallery(value=None)
 | 
			
		||||
                        image_start_row, image_start, image_start_extra = get_image_gallery(visible = False )
 | 
			
		||||
                        image_end_row, image_end, image_end_extra = get_image_gallery(visible = False )
 | 
			
		||||
                        video_source = gr.Video(value=None, visible=False)
 | 
			
		||||
                    model_mode = gr.Dropdown(value=None, visible=False)
 | 
			
		||||
                    keep_frames_video_source = gr.Text(visible=False)
 | 
			
		||||
@ -7184,12 +7194,14 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
 | 
			
		||||
                    if infinitetalk:
 | 
			
		||||
                        video_prompt_type_video_guide_alt = gr.Dropdown(
 | 
			
		||||
                            choices=[
 | 
			
		||||
                                ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Smooth Transitions", "UV"),
 | 
			
		||||
                                ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Sharp Transitions", "QUV"),
 | 
			
		||||
                                ("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Smooth Transitions", "KI"),
 | 
			
		||||
                                ("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Sharp Transitions", "QKI"),
 | 
			
		||||
                                ("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Smooth Transitions", "KI"),
 | 
			
		||||
                                ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Sharp Transitions", "QRUV"),
 | 
			
		||||
                                ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Smooth Transitions", "RUV"),
 | 
			
		||||
                                ("Video to Video, amount of motion transferred depends on Denoising Strength - Sharp Transitions", "GQUV"),
 | 
			
		||||
                                ("Video to Video, amount of motion transferred depends on Denoising Strength - Smooth Transitions", "GUV"),
 | 
			
		||||
                            ],
 | 
			
		||||
                            value=filter_letters(video_prompt_type_value, "UVQKI"),
 | 
			
		||||
                            value=filter_letters(video_prompt_type_value, "RGUVQKI"),
 | 
			
		||||
                            label="Video to Video", scale = 3, visible= True, show_label= False,
 | 
			
		||||
                        ) 
 | 
			
		||||
                    else:
 | 
			
		||||
@ -7318,11 +7330,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
 | 
			
		||||
 | 
			
		||||
                mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value )
 | 
			
		||||
                any_reference_image = vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or infinitetalk or (flux or qwen) and model_reference_image
 | 
			
		||||
                image_refs = gr.Gallery(preview= True, label ="Start Image" if hunyuan_video_avatar else "Reference Images" + (" (each Image will start a new Clip)" if infinitetalk else ""),
 | 
			
		||||
                        type ="pil",   show_label= True,
 | 
			
		||||
                        columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in video_prompt_type_value, 
 | 
			
		||||
                        value= ui_defaults.get("image_refs", None),
 | 
			
		||||
                 )
 | 
			
		||||
 | 
			
		||||
                image_refs_single_image_mode = model_def.get("one_image_ref_needed", False)
 | 
			
		||||
                image_refs_label = "Start Image" if hunyuan_video_avatar else ("Reference Image" if image_refs_single_image_mode else "Reference Images")  + (" (each Image will start a new Clip)" if infinitetalk else "")
 | 
			
		||||
                image_refs_row, image_refs, image_refs_extra = get_image_gallery(label= image_refs_label, value = ui_defaults.get("image_refs", None), visible= "I" in video_prompt_type_value, single_image_mode=image_refs_single_image_mode)
 | 
			
		||||
 | 
			
		||||
                frames_positions = gr.Text(value=ui_defaults.get("frames_positions","") , visible= "F" in video_prompt_type_value, scale = 2, label= "Positions of Injected Frames separated by Spaces (1=first, no position for Objects / People)" ) 
 | 
			
		||||
                image_refs_relative_size = gr.Slider(20, 100, value=ui_defaults.get("image_refs_relative_size", 50), step=1, label="Rescale Internaly Image Ref (% in relation to Output Video) to change Output Composition", visible = model_def.get("any_image_refs_relative_size", False) and image_outputs)
 | 
			
		||||
@ -7935,7 +7946,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
 | 
			
		||||
                                      video_buttons_row, image_buttons_row, video_postprocessing_tab, audio_remuxing_tab, PP_MMAudio_row, PP_custom_audio_row, 
 | 
			
		||||
                                      video_info_to_start_image_btn, video_info_to_end_image_btn, video_info_to_reference_image_btn, video_info_to_image_guide_btn, video_info_to_image_mask_btn,
 | 
			
		||||
                                      NAG_col, speakers_locations_row, embedded_guidance_row, guidance_phases_row, guidance_row, resolution_group, cfg_free_guidance_col, control_net_weights_row, image_mode_tabs, 
 | 
			
		||||
                                      min_frames_if_references_col, video_prompt_type_alignment, prompt_enhancer_btn] #  presets_column,
 | 
			
		||||
                                      min_frames_if_references_col, video_prompt_type_alignment, prompt_enhancer_btn] + image_start_extra + image_end_extra + image_refs_extra #  presets_column,
 | 
			
		||||
        if update_form:
 | 
			
		||||
            locals_dict = locals()
 | 
			
		||||
            gen_inputs = [state_dict if k=="state" else locals_dict[k]  for k in inputs_names] + [state_dict] + extra_inputs
 | 
			
		||||
@ -7953,11 +7964,11 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
 | 
			
		||||
            guidance_phases.change(fn=change_guidance_phases, inputs= [state, guidance_phases], outputs =[model_switch_phase, guidance_phases_row, switch_threshold, switch_threshold2, guidance2_scale, guidance3_scale ])
 | 
			
		||||
            audio_prompt_type_remux.change(fn=refresh_audio_prompt_type_remux, inputs=[state, audio_prompt_type, audio_prompt_type_remux], outputs=[audio_prompt_type])
 | 
			
		||||
            audio_prompt_type_sources.change(fn=refresh_audio_prompt_type_sources, inputs=[state, audio_prompt_type, audio_prompt_type_sources], outputs=[audio_prompt_type, audio_guide, audio_guide2, speakers_locations_row])
 | 
			
		||||
            image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start, image_end, video_source, keep_frames_video_source] ) 
 | 
			
		||||
            image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start_row, image_end_row, video_source, keep_frames_video_source] ) 
 | 
			
		||||
            # video_prompt_video_guide_trigger.change(fn=refresh_video_prompt_video_guide_trigger, inputs=[state, video_prompt_type, video_prompt_video_guide_trigger], outputs=[video_prompt_type, video_prompt_type_video_guide, video_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, mask_expand])
 | 
			
		||||
            video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_images_ref,  image_refs_relative_size, frames_positions,video_guide_outpainting_col])
 | 
			
		||||
            video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs_row, remove_background_images_ref,  image_refs_relative_size, frames_positions,video_guide_outpainting_col])
 | 
			
		||||
            video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, video_prompt_type, video_prompt_type_video_guide, image_mode], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, mask_expand])
 | 
			
		||||
            video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide_alt, inputs = [state, video_prompt_type, video_prompt_type_video_guide_alt], outputs = [video_prompt_type, video_guide, image_refs ])
 | 
			
		||||
            video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide_alt, inputs = [state, video_prompt_type, video_prompt_type_video_guide_alt], outputs = [video_prompt_type, video_guide, image_refs_row, denoising_strength ])
 | 
			
		||||
            video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [state, video_prompt_type, video_prompt_type_video_mask, image_mode], outputs = [video_prompt_type, video_mask, image_mask, mask_expand])
 | 
			
		||||
            video_prompt_type_alignment.input(fn=refresh_video_prompt_type_alignment, inputs = [state, video_prompt_type, video_prompt_type_alignment], outputs = [video_prompt_type])
 | 
			
		||||
            multi_prompts_gen_type.select(fn=refresh_prompt_labels, inputs=[multi_prompts_gen_type, image_mode], outputs=[prompt, wizard_prompt])
 | 
			
		||||
@ -8348,6 +8359,7 @@ def generate_configuration_tab(state, blocks, header, model_family, model_choice
 | 
			
		||||
                        ("Xformers" + check("xformers")+ ": good quality - requires additional install (usually complex, may consume less VRAM to set up on Windows without WSL)", "xformers"),
 | 
			
		||||
                        ("Sage" + check("sage")+ ": 30% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage"),
 | 
			
		||||
                        ("Sage2/2++" + check("sage2")+ ": 40% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage2"),
 | 
			
		||||
                        ("Sage3" + check("sage3")+ ": x2 faster but worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage3"),
 | 
			
		||||
                    ],
 | 
			
		||||
                    value= attention_mode,
 | 
			
		||||
                    label="Attention Type",
 | 
			
		||||
@ -8663,7 +8675,7 @@ def generate_about_tab():
 | 
			
		||||
    gr.Markdown("- <B>Blackforest Labs</B> for the innovative Flux image generators (https://github.com/black-forest-labs/flux)")
 | 
			
		||||
    gr.Markdown("- <B>Alibaba Qwen Team</B> for their state of the art Qwen Image generators (https://github.com/QwenLM/Qwen-Image)")
 | 
			
		||||
    gr.Markdown("- <B>Lightricks</B> for their super fast LTX Video models (https://github.com/Lightricks/LTX-Video)")
 | 
			
		||||
    gr.Markdown("- <B>Hugging Face</B> for the providing hosting for the models and developing must have open source libraries such as Tranformers, Diffusers, Accelerate and Gradio (https://huggingface.co/)")
 | 
			
		||||
    gr.Markdown("- <B>Hugging Face</B> for providing hosting for the models and developing must have open source libraries such as Tranformers, Diffusers, Accelerate and Gradio (https://huggingface.co/)")
 | 
			
		||||
    gr.Markdown("<BR>Huge acknowledgments to these great open source projects used in WanGP:")
 | 
			
		||||
    gr.Markdown("- <B>Rife</B>: temporal upsampler (https://github.com/hzwer/ECCV2022-RIFE)")
 | 
			
		||||
    gr.Markdown("- <B>DwPose</B>: Open Pose extractor (https://github.com/IDEA-Research/DWPose)")
 | 
			
		||||
@ -8672,7 +8684,7 @@ def generate_about_tab():
 | 
			
		||||
    gr.Markdown("- <B>Pyannote</B>: speaker diarization (https://github.com/pyannote/pyannote-audio)")
 | 
			
		||||
 | 
			
		||||
    gr.Markdown("<BR>Special thanks to the following people for their support:")
 | 
			
		||||
    gr.Markdown("- <B>Cocktail Peanuts</B> : QA and simple installation via Pinokio.computer")
 | 
			
		||||
    gr.Markdown("- <B>Cocktail Peanuts</B> : QA dpand simple installation via Pinokio.computer")
 | 
			
		||||
    gr.Markdown("- <B>Tophness</B> : created (former) multi tabs and queuing frameworks")
 | 
			
		||||
    gr.Markdown("- <B>AmericanPresidentJimmyCarter</B> : added original support for Skip Layer Guidance")
 | 
			
		||||
    gr.Markdown("- <B>Remade_AI</B> : for their awesome Loras collection")
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user