mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			1019 lines
		
	
	
		
			38 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			1019 lines
		
	
	
		
			38 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from typing import Any, List, Tuple, Optional, Union, Dict
 | 
						|
from einops import rearrange
 | 
						|
 | 
						|
import torch
 | 
						|
import torch.nn as nn
 | 
						|
import torch.nn.functional as F
 | 
						|
 | 
						|
from diffusers.models import ModelMixin
 | 
						|
from diffusers.configuration_utils import ConfigMixin, register_to_config
 | 
						|
 | 
						|
from .activation_layers import get_activation_layer
 | 
						|
from .norm_layers import get_norm_layer
 | 
						|
from .embed_layers import TimestepEmbedder, PatchEmbed, TextProjection
 | 
						|
from .attenion import attention, parallel_attention, get_cu_seqlens
 | 
						|
from .posemb_layers import apply_rotary_emb
 | 
						|
from .mlp_layers import MLP, MLPEmbedder, FinalLayer
 | 
						|
from .modulate_layers import ModulateDiT, modulate, modulate_ , apply_gate, apply_gate_and_accumulate_
 | 
						|
from .token_refiner import SingleTokenRefiner
 | 
						|
import numpy as np
 | 
						|
from mmgp import offload
 | 
						|
from wan.modules.attention import pay_attention
 | 
						|
 | 
						|
def get_linear_split_map():
 | 
						|
    hidden_size = 3072
 | 
						|
    split_linear_modules_map =  {
 | 
						|
                                "img_attn_qkv" : {"mapped_modules" : ["img_attn_q", "img_attn_k", "img_attn_v"] , "split_sizes": [hidden_size, hidden_size, hidden_size]},
 | 
						|
                                "linear1" : {"mapped_modules" : ["linear1_attn_q", "linear1_attn_k", "linear1_attn_v", "linear1_mlp"] , "split_sizes":  [hidden_size, hidden_size, hidden_size, 7*hidden_size- 3*hidden_size]}
 | 
						|
                                }
 | 
						|
    return split_linear_modules_map
 | 
						|
try:
 | 
						|
    from xformers.ops.fmha.attn_bias import BlockDiagonalPaddedKeysMask
 | 
						|
except ImportError:
 | 
						|
    BlockDiagonalPaddedKeysMask = None
 | 
						|
 | 
						|
 | 
						|
class MMDoubleStreamBlock(nn.Module):
 | 
						|
    """
 | 
						|
    A multimodal dit block with seperate modulation for
 | 
						|
    text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
 | 
						|
                                     (Flux.1): https://github.com/black-forest-labs/flux
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        hidden_size: int,
 | 
						|
        heads_num: int,
 | 
						|
        mlp_width_ratio: float,
 | 
						|
        mlp_act_type: str = "gelu_tanh",
 | 
						|
        qk_norm: bool = True,
 | 
						|
        qk_norm_type: str = "rms",
 | 
						|
        qkv_bias: bool = False,
 | 
						|
        dtype: Optional[torch.dtype] = None,
 | 
						|
        device: Optional[torch.device] = None,
 | 
						|
        attention_mode: str = "sdpa",        
 | 
						|
    ):  
 | 
						|
        factory_kwargs = {"device": device, "dtype": dtype}
 | 
						|
        super().__init__()
 | 
						|
 | 
						|
        self.attention_mode = attention_mode
 | 
						|
        self.deterministic = False
 | 
						|
        self.heads_num = heads_num
 | 
						|
        head_dim = hidden_size // heads_num
 | 
						|
        mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
 | 
						|
 | 
						|
        self.img_mod = ModulateDiT(
 | 
						|
            hidden_size,
 | 
						|
            factor=6,
 | 
						|
            act_layer=get_activation_layer("silu"),
 | 
						|
            **factory_kwargs,
 | 
						|
        )
 | 
						|
        self.img_norm1 = nn.LayerNorm(
 | 
						|
            hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
 | 
						|
        )
 | 
						|
 | 
						|
        self.img_attn_qkv = nn.Linear(
 | 
						|
            hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
 | 
						|
        )
 | 
						|
        qk_norm_layer = get_norm_layer(qk_norm_type)
 | 
						|
        self.img_attn_q_norm = (
 | 
						|
            qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
 | 
						|
            if qk_norm
 | 
						|
            else nn.Identity()
 | 
						|
        )
 | 
						|
        self.img_attn_k_norm = (
 | 
						|
            qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
 | 
						|
            if qk_norm
 | 
						|
            else nn.Identity()
 | 
						|
        )
 | 
						|
        self.img_attn_proj = nn.Linear(
 | 
						|
            hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
 | 
						|
        )
 | 
						|
 | 
						|
        self.img_norm2 = nn.LayerNorm(
 | 
						|
            hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
 | 
						|
        )
 | 
						|
        self.img_mlp = MLP(
 | 
						|
            hidden_size,
 | 
						|
            mlp_hidden_dim,
 | 
						|
            act_layer=get_activation_layer(mlp_act_type),
 | 
						|
            bias=True,
 | 
						|
            **factory_kwargs,
 | 
						|
        )
 | 
						|
 | 
						|
        self.txt_mod = ModulateDiT(
 | 
						|
            hidden_size,
 | 
						|
            factor=6,
 | 
						|
            act_layer=get_activation_layer("silu"),
 | 
						|
            **factory_kwargs,
 | 
						|
        )
 | 
						|
        self.txt_norm1 = nn.LayerNorm(
 | 
						|
            hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
 | 
						|
        )
 | 
						|
 | 
						|
        self.txt_attn_qkv = nn.Linear(
 | 
						|
            hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
 | 
						|
        )
 | 
						|
        self.txt_attn_q_norm = (
 | 
						|
            qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
 | 
						|
            if qk_norm
 | 
						|
            else nn.Identity()
 | 
						|
        )
 | 
						|
        self.txt_attn_k_norm = (
 | 
						|
            qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
 | 
						|
            if qk_norm
 | 
						|
            else nn.Identity()
 | 
						|
        )
 | 
						|
        self.txt_attn_proj = nn.Linear(
 | 
						|
            hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
 | 
						|
        )
 | 
						|
 | 
						|
        self.txt_norm2 = nn.LayerNorm(
 | 
						|
            hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
 | 
						|
        )
 | 
						|
        self.txt_mlp = MLP(
 | 
						|
            hidden_size,
 | 
						|
            mlp_hidden_dim,
 | 
						|
            act_layer=get_activation_layer(mlp_act_type),
 | 
						|
            bias=True,
 | 
						|
            **factory_kwargs,
 | 
						|
        )
 | 
						|
        self.hybrid_seq_parallel_attn = None
 | 
						|
 | 
						|
    def enable_deterministic(self):
 | 
						|
        self.deterministic = True
 | 
						|
 | 
						|
    def disable_deterministic(self):
 | 
						|
        self.deterministic = False
 | 
						|
 | 
						|
    def forward(
 | 
						|
        self,
 | 
						|
        img: torch.Tensor,
 | 
						|
        txt: torch.Tensor,
 | 
						|
        vec: torch.Tensor,
 | 
						|
        attn_mask = None,  
 | 
						|
        seqlens_q: Optional[torch.Tensor] = None,
 | 
						|
        seqlens_kv: Optional[torch.Tensor] = None,
 | 
						|
        freqs_cis: tuple = None,
 | 
						|
        condition_type: str = None,
 | 
						|
        token_replace_vec: torch.Tensor = None,
 | 
						|
        frist_frame_token_num: int = None,        
 | 
						|
    ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
						|
        
 | 
						|
        if condition_type == "token_replace":
 | 
						|
            img_mod1, token_replace_img_mod1 = self.img_mod(vec, condition_type=condition_type, \
 | 
						|
                                                            token_replace_vec=token_replace_vec)
 | 
						|
            (img_mod1_shift,
 | 
						|
             img_mod1_scale,
 | 
						|
             img_mod1_gate,
 | 
						|
             img_mod2_shift,
 | 
						|
             img_mod2_scale,
 | 
						|
             img_mod2_gate) = img_mod1.chunk(6, dim=-1)
 | 
						|
            (tr_img_mod1_shift,
 | 
						|
             tr_img_mod1_scale,
 | 
						|
             tr_img_mod1_gate,
 | 
						|
             tr_img_mod2_shift,
 | 
						|
             tr_img_mod2_scale,
 | 
						|
             tr_img_mod2_gate) = token_replace_img_mod1.chunk(6, dim=-1)
 | 
						|
        else:
 | 
						|
            (
 | 
						|
                img_mod1_shift,
 | 
						|
                img_mod1_scale,
 | 
						|
                img_mod1_gate,
 | 
						|
                img_mod2_shift,
 | 
						|
                img_mod2_scale,
 | 
						|
                img_mod2_gate,
 | 
						|
            ) = self.img_mod(vec).chunk(6, dim=-1)
 | 
						|
        (
 | 
						|
            txt_mod1_shift,
 | 
						|
            txt_mod1_scale,
 | 
						|
            txt_mod1_gate,
 | 
						|
            txt_mod2_shift,
 | 
						|
            txt_mod2_scale,
 | 
						|
            txt_mod2_gate,
 | 
						|
        ) = self.txt_mod(vec).chunk(6, dim=-1)
 | 
						|
 | 
						|
        ##### Enjoy this spagheti VRAM optimizations done by DeepBeepMeep !
 | 
						|
        # I am sure you are a nice person and as you copy this code, you will give me officially proper credits:
 | 
						|
        # Please link to https://github.com/deepbeepmeep/HunyuanVideoGP and @deepbeepmeep on twitter  
 | 
						|
 | 
						|
        # Prepare image for attention.
 | 
						|
        img_modulated = self.img_norm1(img)
 | 
						|
        img_modulated = img_modulated.to(torch.bfloat16)
 | 
						|
 | 
						|
        if condition_type == "token_replace":
 | 
						|
            modulate_(img_modulated[:, :frist_frame_token_num], shift=tr_img_mod1_shift, scale=tr_img_mod1_scale)
 | 
						|
            modulate_(img_modulated[:, frist_frame_token_num:], shift=img_mod1_shift, scale=img_mod1_scale)
 | 
						|
        else:
 | 
						|
            modulate_( img_modulated, shift=img_mod1_shift, scale=img_mod1_scale )
 | 
						|
 | 
						|
        shape = (*img_modulated.shape[:2], self.heads_num, int(img_modulated.shape[-1] / self.heads_num) )
 | 
						|
        img_q = self.img_attn_q(img_modulated).view(*shape)
 | 
						|
        img_k = self.img_attn_k(img_modulated).view(*shape)        
 | 
						|
        img_v = self.img_attn_v(img_modulated).view(*shape)
 | 
						|
        del img_modulated
 | 
						|
 | 
						|
        # Apply QK-Norm if needed
 | 
						|
        self.img_attn_q_norm.apply_(img_q).to(img_v)
 | 
						|
        img_q_len = img_q.shape[1]
 | 
						|
        self.img_attn_k_norm.apply_(img_k).to(img_v)
 | 
						|
        img_kv_len= img_k.shape[1]        
 | 
						|
        batch_size = img_k.shape[0]
 | 
						|
        # Apply RoPE if needed.
 | 
						|
        qklist = [img_q, img_k]
 | 
						|
        del img_q, img_k
 | 
						|
        img_q, img_k = apply_rotary_emb(qklist, freqs_cis, head_first=False)
 | 
						|
        # Prepare txt for attention.
 | 
						|
        txt_modulated = self.txt_norm1(txt)
 | 
						|
        modulate_(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale )
 | 
						|
 | 
						|
        txt_qkv = self.txt_attn_qkv(txt_modulated)
 | 
						|
        del txt_modulated
 | 
						|
        txt_q, txt_k, txt_v = rearrange(
 | 
						|
            txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
 | 
						|
        )
 | 
						|
        del txt_qkv
 | 
						|
        # Apply QK-Norm if needed.
 | 
						|
        self.txt_attn_q_norm.apply_(txt_q).to(txt_v)
 | 
						|
        self.txt_attn_k_norm.apply_(txt_k).to(txt_v)
 | 
						|
 | 
						|
        # Run actual attention.
 | 
						|
        q = torch.cat((img_q, txt_q), dim=1)
 | 
						|
        del img_q, txt_q
 | 
						|
        k = torch.cat((img_k, txt_k), dim=1)        
 | 
						|
        del img_k, txt_k
 | 
						|
        v = torch.cat((img_v, txt_v), dim=1)
 | 
						|
        del img_v, txt_v
 | 
						|
        
 | 
						|
        # attention computation start
 | 
						|
        qkv_list = [q,k,v]
 | 
						|
        del q, k, v
 | 
						|
 | 
						|
        attn = pay_attention(
 | 
						|
            qkv_list,
 | 
						|
            attention_mask=attn_mask,                
 | 
						|
            q_lens=seqlens_q,
 | 
						|
            k_lens=seqlens_kv,
 | 
						|
        )
 | 
						|
        b, s, a, d = attn.shape
 | 
						|
        attn = attn.reshape(b, s, -1)        
 | 
						|
        del qkv_list
 | 
						|
 | 
						|
        # attention computation end
 | 
						|
 | 
						|
        img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
 | 
						|
        del attn
 | 
						|
        # Calculate the img bloks.
 | 
						|
 | 
						|
        if condition_type == "token_replace":
 | 
						|
            img_attn = self.img_attn_proj(img_attn)
 | 
						|
            apply_gate_and_accumulate_(img[:, :frist_frame_token_num], img_attn[:, :frist_frame_token_num], gate=tr_img_mod1_gate)
 | 
						|
            apply_gate_and_accumulate_(img[:, frist_frame_token_num:], img_attn[:, frist_frame_token_num:], gate=img_mod1_gate)
 | 
						|
            del img_attn
 | 
						|
            img_modulated = self.img_norm2(img)
 | 
						|
            img_modulated = img_modulated.to(torch.bfloat16)
 | 
						|
            modulate_( img_modulated[:, :frist_frame_token_num], shift=tr_img_mod2_shift, scale=tr_img_mod2_scale)
 | 
						|
            modulate_( img_modulated[:, frist_frame_token_num:], shift=img_mod2_shift, scale=img_mod2_scale)
 | 
						|
            self.img_mlp.apply_(img_modulated)        
 | 
						|
            apply_gate_and_accumulate_(img[:, :frist_frame_token_num], img_modulated[:, :frist_frame_token_num], gate=tr_img_mod2_gate)
 | 
						|
            apply_gate_and_accumulate_(img[:, frist_frame_token_num:], img_modulated[:, frist_frame_token_num:], gate=img_mod2_gate)
 | 
						|
            del img_modulated
 | 
						|
        else:
 | 
						|
            img_attn = self.img_attn_proj(img_attn)
 | 
						|
            apply_gate_and_accumulate_(img, img_attn, gate=img_mod1_gate)
 | 
						|
            del img_attn
 | 
						|
            img_modulated = self.img_norm2(img)
 | 
						|
            img_modulated = img_modulated.to(torch.bfloat16)
 | 
						|
            modulate_( img_modulated , shift=img_mod2_shift, scale=img_mod2_scale)
 | 
						|
            self.img_mlp.apply_(img_modulated)        
 | 
						|
            apply_gate_and_accumulate_(img, img_modulated, gate=img_mod2_gate)
 | 
						|
            del img_modulated
 | 
						|
 | 
						|
        # Calculate the txt bloks.
 | 
						|
        txt_attn  = self.txt_attn_proj(txt_attn)
 | 
						|
        apply_gate_and_accumulate_(txt, txt_attn, gate=txt_mod1_gate)
 | 
						|
        del txt_attn
 | 
						|
        txt_modulated = self.txt_norm2(txt)
 | 
						|
        txt_modulated = txt_modulated.to(torch.bfloat16)
 | 
						|
        modulate_(txt_modulated, shift=txt_mod2_shift, scale=txt_mod2_scale)
 | 
						|
        txt_mlp = self.txt_mlp(txt_modulated)
 | 
						|
        del txt_modulated 
 | 
						|
        apply_gate_and_accumulate_(txt, txt_mlp, gate=txt_mod2_gate)
 | 
						|
        return img, txt
 | 
						|
 | 
						|
 | 
						|
class MMSingleStreamBlock(nn.Module):
 | 
						|
    """
 | 
						|
    A DiT block with parallel linear layers as described in
 | 
						|
    https://arxiv.org/abs/2302.05442 and adapted modulation interface.
 | 
						|
    Also refer to (SD3): https://arxiv.org/abs/2403.03206
 | 
						|
                  (Flux.1): https://github.com/black-forest-labs/flux
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        hidden_size: int,
 | 
						|
        heads_num: int,
 | 
						|
        mlp_width_ratio: float = 4.0,
 | 
						|
        mlp_act_type: str = "gelu_tanh",
 | 
						|
        qk_norm: bool = True,
 | 
						|
        qk_norm_type: str = "rms",
 | 
						|
        qk_scale: float = None,
 | 
						|
        dtype: Optional[torch.dtype] = None,
 | 
						|
        device: Optional[torch.device] = None,
 | 
						|
        attention_mode: str = "sdpa",
 | 
						|
    ):
 | 
						|
        factory_kwargs = {"device": device, "dtype": dtype}
 | 
						|
        super().__init__()
 | 
						|
        self.attention_mode = attention_mode
 | 
						|
        self.deterministic = False
 | 
						|
        self.hidden_size = hidden_size
 | 
						|
        self.heads_num = heads_num
 | 
						|
        head_dim = hidden_size // heads_num
 | 
						|
        mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
 | 
						|
        self.mlp_hidden_dim = mlp_hidden_dim
 | 
						|
        self.scale = qk_scale or head_dim ** -0.5
 | 
						|
 | 
						|
        # qkv and mlp_in
 | 
						|
        self.linear1 = nn.Linear(
 | 
						|
            hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs
 | 
						|
        )
 | 
						|
        # proj and mlp_out
 | 
						|
        self.linear2 = nn.Linear(
 | 
						|
            hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs
 | 
						|
        )
 | 
						|
 | 
						|
        qk_norm_layer = get_norm_layer(qk_norm_type)
 | 
						|
        self.q_norm = (
 | 
						|
            qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
 | 
						|
            if qk_norm
 | 
						|
            else nn.Identity()
 | 
						|
        )
 | 
						|
        self.k_norm = (
 | 
						|
            qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
 | 
						|
            if qk_norm
 | 
						|
            else nn.Identity()
 | 
						|
        )
 | 
						|
 | 
						|
        self.pre_norm = nn.LayerNorm(
 | 
						|
            hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
 | 
						|
        )
 | 
						|
 | 
						|
        self.mlp_act = get_activation_layer(mlp_act_type)()
 | 
						|
        self.modulation = ModulateDiT(
 | 
						|
            hidden_size,
 | 
						|
            factor=3,
 | 
						|
            act_layer=get_activation_layer("silu"),
 | 
						|
            **factory_kwargs,
 | 
						|
        )
 | 
						|
        self.hybrid_seq_parallel_attn = None
 | 
						|
 | 
						|
    def enable_deterministic(self):
 | 
						|
        self.deterministic = True
 | 
						|
 | 
						|
    def disable_deterministic(self):
 | 
						|
        self.deterministic = False
 | 
						|
 | 
						|
    def forward(
 | 
						|
        self,
 | 
						|
        # x: torch.Tensor,
 | 
						|
        img: torch.Tensor,
 | 
						|
        txt: torch.Tensor,
 | 
						|
        vec: torch.Tensor,
 | 
						|
        txt_len: int,
 | 
						|
        attn_mask= None,
 | 
						|
        seqlens_q: Optional[torch.Tensor] = None,
 | 
						|
        seqlens_kv: Optional[torch.Tensor] = None,
 | 
						|
        freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
 | 
						|
        condition_type: str = None,
 | 
						|
        token_replace_vec: torch.Tensor = None,
 | 
						|
        frist_frame_token_num: int = None,        
 | 
						|
    ) -> torch.Tensor:
 | 
						|
 | 
						|
        ##### More spagheti VRAM optimizations done by DeepBeepMeep !
 | 
						|
        # I am sure you are a nice person and as you copy this code, you will give me proper credits:
 | 
						|
        # Please link to https://github.com/deepbeepmeep/HunyuanVideoGP and @deepbeepmeep on twitter  
 | 
						|
 | 
						|
        if condition_type == "token_replace":
 | 
						|
            mod, tr_mod = self.modulation(vec,
 | 
						|
                                          condition_type=condition_type,
 | 
						|
                                          token_replace_vec=token_replace_vec)
 | 
						|
            (mod_shift,
 | 
						|
             mod_scale,
 | 
						|
             mod_gate) = mod.chunk(3, dim=-1)
 | 
						|
            (tr_mod_shift,
 | 
						|
             tr_mod_scale,
 | 
						|
             tr_mod_gate) = tr_mod.chunk(3, dim=-1)
 | 
						|
        else:
 | 
						|
            mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
 | 
						|
 | 
						|
        img_mod = self.pre_norm(img)
 | 
						|
        img_mod = img_mod.to(torch.bfloat16)
 | 
						|
        if condition_type == "token_replace":
 | 
						|
            modulate_(img_mod[:, :frist_frame_token_num], shift=tr_mod_shift, scale=tr_mod_scale)
 | 
						|
            modulate_(img_mod[:, frist_frame_token_num:], shift=mod_shift, scale=mod_scale)
 | 
						|
        else:
 | 
						|
            modulate_(img_mod, shift=mod_shift, scale=mod_scale)
 | 
						|
        txt_mod = self.pre_norm(txt)
 | 
						|
        txt_mod = txt_mod.to(torch.bfloat16)
 | 
						|
        modulate_(txt_mod, shift=mod_shift, scale=mod_scale)
 | 
						|
 | 
						|
        shape = (*img_mod.shape[:2], self.heads_num, int(img_mod.shape[-1] / self.heads_num) )
 | 
						|
        img_q = self.linear1_attn_q(img_mod).view(*shape)
 | 
						|
        img_k = self.linear1_attn_k(img_mod).view(*shape)
 | 
						|
        img_v = self.linear1_attn_v(img_mod).view(*shape)
 | 
						|
 | 
						|
        shape = (*txt_mod.shape[:2], self.heads_num, int(txt_mod.shape[-1] / self.heads_num) )
 | 
						|
        txt_q = self.linear1_attn_q(txt_mod).view(*shape)
 | 
						|
        txt_k = self.linear1_attn_k(txt_mod).view(*shape)
 | 
						|
        txt_v = self.linear1_attn_v(txt_mod).view(*shape)
 | 
						|
 | 
						|
        batch_size = img_mod.shape[0]        
 | 
						|
 | 
						|
        # Apply QK-Norm if needed.
 | 
						|
        # q = self.q_norm(q).to(v)
 | 
						|
        self.q_norm.apply_(img_q)
 | 
						|
        self.k_norm.apply_(img_k)
 | 
						|
        self.q_norm.apply_(txt_q)
 | 
						|
        self.k_norm.apply_(txt_k)
 | 
						|
 | 
						|
        qklist = [img_q, img_k]
 | 
						|
        del img_q, img_k
 | 
						|
        img_q, img_k = apply_rotary_emb(qklist, freqs_cis, head_first=False)
 | 
						|
        img_q_len=img_q.shape[1]
 | 
						|
        q = torch.cat((img_q, txt_q), dim=1)
 | 
						|
        del img_q, txt_q
 | 
						|
        k = torch.cat((img_k, txt_k), dim=1)
 | 
						|
        img_kv_len=img_k.shape[1]
 | 
						|
        del img_k, txt_k
 | 
						|
        
 | 
						|
        v = torch.cat((img_v, txt_v), dim=1)
 | 
						|
        del img_v, txt_v
 | 
						|
 | 
						|
        # attention computation start
 | 
						|
        qkv_list = [q,k,v]
 | 
						|
        del q, k, v
 | 
						|
        attn = pay_attention(
 | 
						|
            qkv_list,
 | 
						|
            attention_mask=attn_mask,                
 | 
						|
            q_lens = seqlens_q,
 | 
						|
            k_lens = seqlens_kv,
 | 
						|
        )
 | 
						|
        b, s, a, d = attn.shape
 | 
						|
        attn = attn.reshape(b, s, -1)        
 | 
						|
        del qkv_list
 | 
						|
        # attention computation end
 | 
						|
      
 | 
						|
        x_mod =  torch.cat((img_mod, txt_mod), 1)
 | 
						|
        del img_mod, txt_mod
 | 
						|
        x_mod_shape = x_mod.shape
 | 
						|
        x_mod = x_mod.view(-1, x_mod.shape[-1])
 | 
						|
        chunk_size = int(x_mod_shape[1]/6)
 | 
						|
        x_chunks = torch.split(x_mod, chunk_size)
 | 
						|
        attn = attn.view(-1, attn.shape[-1])
 | 
						|
        attn_chunks =torch.split(attn, chunk_size)
 | 
						|
        for x_chunk, attn_chunk in zip(x_chunks, attn_chunks):
 | 
						|
            mlp_chunk = self.linear1_mlp(x_chunk)
 | 
						|
            mlp_chunk = self.mlp_act(mlp_chunk)
 | 
						|
            attn_mlp_chunk = torch.cat((attn_chunk, mlp_chunk), -1)
 | 
						|
            del attn_chunk, mlp_chunk 
 | 
						|
            x_chunk[...] = self.linear2(attn_mlp_chunk)
 | 
						|
            del attn_mlp_chunk
 | 
						|
        x_mod = x_mod.view(x_mod_shape)
 | 
						|
 | 
						|
        if condition_type == "token_replace":
 | 
						|
            apply_gate_and_accumulate_(img[:, :frist_frame_token_num, :], x_mod[:, :frist_frame_token_num, :], gate=tr_mod_gate)
 | 
						|
            apply_gate_and_accumulate_(img[:, frist_frame_token_num:, :], x_mod[:, frist_frame_token_num:-txt_len, :], gate=mod_gate)
 | 
						|
        else:
 | 
						|
            apply_gate_and_accumulate_(img, x_mod[:, :-txt_len, :], gate=mod_gate)
 | 
						|
 | 
						|
        apply_gate_and_accumulate_(txt, x_mod[:, -txt_len:, :], gate=mod_gate)
 | 
						|
 | 
						|
        return img, txt
 | 
						|
 | 
						|
class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
 | 
						|
    def preprocess_loras(self, model_filename, sd):
 | 
						|
        if not "i2v" in model_filename:
 | 
						|
            return sd
 | 
						|
        new_sd = {}
 | 
						|
        for k,v in sd.items():
 | 
						|
            repl_list = ["double_blocks", "single_blocks", "final_layer", "img_mlp", "img_attn_qkv", "img_attn_proj","img_mod", "txt_mlp", "txt_attn_qkv","txt_attn_proj", "txt_mod", "linear1", 
 | 
						|
                        "linear2", "modulation",  "mlp_fc1"]
 | 
						|
            src_list = [k +"_" for k in repl_list] +  ["_" + k for k in repl_list]
 | 
						|
            tgt_list = [k +"." for k in repl_list] +  ["." + k for k in repl_list]
 | 
						|
            if k.startswith("Hunyuan_video_I2V_lora_"):
 | 
						|
                # crappy conversion script for non reversible lora naming  
 | 
						|
                k = k.replace("Hunyuan_video_I2V_lora_","diffusion_model.")
 | 
						|
                k = k.replace("lora_up","lora_B")
 | 
						|
                k = k.replace("lora_down","lora_A")
 | 
						|
                if "txt_in_individual" in k:
 | 
						|
                    pass
 | 
						|
                for s,t in zip(src_list, tgt_list):
 | 
						|
                    k = k.replace(s,t)
 | 
						|
                if  "individual_token_refiner" in k:
 | 
						|
                    k = k.replace("txt_in_individual_token_refiner_blocks_", "txt_in.individual_token_refiner.blocks.")
 | 
						|
                    k = k.replace("_mlp_fc", ".mlp.fc",)
 | 
						|
                    k = k.replace(".mlp_fc", ".mlp.fc",)
 | 
						|
            new_sd[k] = v
 | 
						|
        return new_sd    
 | 
						|
    """
 | 
						|
    HunyuanVideo Transformer backbone
 | 
						|
 | 
						|
    Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline.
 | 
						|
 | 
						|
    Reference:
 | 
						|
    [1] Flux.1: https://github.com/black-forest-labs/flux
 | 
						|
    [2] MMDiT: http://arxiv.org/abs/2403.03206
 | 
						|
 | 
						|
    Parameters
 | 
						|
    ----------
 | 
						|
    args: argparse.Namespace
 | 
						|
        The arguments parsed by argparse.
 | 
						|
    patch_size: list
 | 
						|
        The size of the patch.
 | 
						|
    in_channels: int
 | 
						|
        The number of input channels.
 | 
						|
    out_channels: int
 | 
						|
        The number of output channels.
 | 
						|
    hidden_size: int
 | 
						|
        The hidden size of the transformer backbone.
 | 
						|
    heads_num: int
 | 
						|
        The number of attention heads.
 | 
						|
    mlp_width_ratio: float
 | 
						|
        The ratio of the hidden size of the MLP in the transformer block.
 | 
						|
    mlp_act_type: str
 | 
						|
        The activation function of the MLP in the transformer block.
 | 
						|
    depth_double_blocks: int
 | 
						|
        The number of transformer blocks in the double blocks.
 | 
						|
    depth_single_blocks: int
 | 
						|
        The number of transformer blocks in the single blocks.
 | 
						|
    rope_dim_list: list
 | 
						|
        The dimension of the rotary embedding for t, h, w.
 | 
						|
    qkv_bias: bool
 | 
						|
        Whether to use bias in the qkv linear layer.
 | 
						|
    qk_norm: bool
 | 
						|
        Whether to use qk norm.
 | 
						|
    qk_norm_type: str
 | 
						|
        The type of qk norm.
 | 
						|
    guidance_embed: bool
 | 
						|
        Whether to use guidance embedding for distillation.
 | 
						|
    text_projection: str
 | 
						|
        The type of the text projection, default is single_refiner.
 | 
						|
    use_attention_mask: bool
 | 
						|
        Whether to use attention mask for text encoder.
 | 
						|
    dtype: torch.dtype
 | 
						|
        The dtype of the model.
 | 
						|
    device: torch.device
 | 
						|
        The device of the model.
 | 
						|
    """
 | 
						|
 | 
						|
    @register_to_config
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        i2v_condition_type,
 | 
						|
        patch_size: list = [1, 2, 2],
 | 
						|
        in_channels: int = 4,  # Should be VAE.config.latent_channels.
 | 
						|
        out_channels: int = None,
 | 
						|
        hidden_size: int = 3072,
 | 
						|
        heads_num: int = 24,
 | 
						|
        mlp_width_ratio: float = 4.0,
 | 
						|
        mlp_act_type: str = "gelu_tanh",
 | 
						|
        mm_double_blocks_depth: int = 20,
 | 
						|
        mm_single_blocks_depth: int = 40,
 | 
						|
        rope_dim_list: List[int] = [16, 56, 56],
 | 
						|
        qkv_bias: bool = True,
 | 
						|
        qk_norm: bool = True,
 | 
						|
        qk_norm_type: str = "rms",
 | 
						|
        guidance_embed: bool = False,  # For modulation.
 | 
						|
        text_projection: str = "single_refiner",
 | 
						|
        use_attention_mask: bool = True,
 | 
						|
        dtype: Optional[torch.dtype] = None,
 | 
						|
        device: Optional[torch.device] = None,
 | 
						|
        attention_mode: Optional[str] = "sdpa"
 | 
						|
    ):
 | 
						|
        factory_kwargs = {"device": device, "dtype": dtype}
 | 
						|
        super().__init__()
 | 
						|
 | 
						|
        # mm_double_blocks_depth , mm_single_blocks_depth = 5, 5 
 | 
						|
 | 
						|
        self.patch_size = patch_size
 | 
						|
        self.in_channels = in_channels
 | 
						|
        self.out_channels = in_channels if out_channels is None else out_channels
 | 
						|
        self.unpatchify_channels = self.out_channels
 | 
						|
        self.guidance_embed = guidance_embed
 | 
						|
        self.rope_dim_list = rope_dim_list
 | 
						|
        self.i2v_condition_type = i2v_condition_type
 | 
						|
        self.attention_mode = attention_mode
 | 
						|
        
 | 
						|
        # Text projection. Default to linear projection.
 | 
						|
        # Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831
 | 
						|
        self.use_attention_mask = use_attention_mask
 | 
						|
        self.text_projection = text_projection
 | 
						|
 | 
						|
        self.text_states_dim = 4096
 | 
						|
        self.text_states_dim_2 = 768
 | 
						|
 | 
						|
        if hidden_size % heads_num != 0:
 | 
						|
            raise ValueError(
 | 
						|
                f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}"
 | 
						|
            )
 | 
						|
        pe_dim = hidden_size // heads_num
 | 
						|
        if sum(rope_dim_list) != pe_dim:
 | 
						|
            raise ValueError(
 | 
						|
                f"Got {rope_dim_list} but expected positional dim {pe_dim}"
 | 
						|
            )
 | 
						|
        self.hidden_size = hidden_size
 | 
						|
        self.heads_num = heads_num
 | 
						|
 | 
						|
        # image projection
 | 
						|
        self.img_in = PatchEmbed(
 | 
						|
            self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs
 | 
						|
        )
 | 
						|
 | 
						|
        # text projection
 | 
						|
        if self.text_projection == "linear":
 | 
						|
            self.txt_in = TextProjection(
 | 
						|
                self.text_states_dim,
 | 
						|
                self.hidden_size,
 | 
						|
                get_activation_layer("silu"),
 | 
						|
                **factory_kwargs,
 | 
						|
            )
 | 
						|
        elif self.text_projection == "single_refiner":
 | 
						|
            self.txt_in = SingleTokenRefiner(
 | 
						|
                self.text_states_dim, hidden_size, heads_num, depth=2, **factory_kwargs
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            raise NotImplementedError(
 | 
						|
                f"Unsupported text_projection: {self.text_projection}"
 | 
						|
            )
 | 
						|
 | 
						|
        # time modulation
 | 
						|
        self.time_in = TimestepEmbedder(
 | 
						|
            self.hidden_size, get_activation_layer("silu"), **factory_kwargs
 | 
						|
        )
 | 
						|
 | 
						|
        # text modulation
 | 
						|
        self.vector_in = MLPEmbedder(
 | 
						|
            self.text_states_dim_2, self.hidden_size, **factory_kwargs
 | 
						|
        )
 | 
						|
 | 
						|
        # guidance modulation
 | 
						|
        self.guidance_in = (
 | 
						|
            TimestepEmbedder(
 | 
						|
                self.hidden_size, get_activation_layer("silu"), **factory_kwargs
 | 
						|
            )
 | 
						|
            if guidance_embed
 | 
						|
            else None
 | 
						|
        )
 | 
						|
 | 
						|
        # double blocks
 | 
						|
        self.double_blocks = nn.ModuleList(
 | 
						|
            [
 | 
						|
                MMDoubleStreamBlock(
 | 
						|
                    self.hidden_size,
 | 
						|
                    self.heads_num,
 | 
						|
                    mlp_width_ratio=mlp_width_ratio,
 | 
						|
                    mlp_act_type=mlp_act_type,
 | 
						|
                    qk_norm=qk_norm,
 | 
						|
                    qk_norm_type=qk_norm_type,
 | 
						|
                    qkv_bias=qkv_bias,
 | 
						|
                    attention_mode = attention_mode,
 | 
						|
                    **factory_kwargs,
 | 
						|
                )
 | 
						|
                for _ in range(mm_double_blocks_depth)
 | 
						|
            ]
 | 
						|
        )
 | 
						|
 | 
						|
        # single blocks
 | 
						|
        self.single_blocks = nn.ModuleList(
 | 
						|
            [
 | 
						|
                MMSingleStreamBlock(
 | 
						|
                    self.hidden_size,
 | 
						|
                    self.heads_num,
 | 
						|
                    mlp_width_ratio=mlp_width_ratio,
 | 
						|
                    mlp_act_type=mlp_act_type,
 | 
						|
                    qk_norm=qk_norm,
 | 
						|
                    qk_norm_type=qk_norm_type,
 | 
						|
                    attention_mode = attention_mode,
 | 
						|
                    **factory_kwargs,
 | 
						|
                )
 | 
						|
                for _ in range(mm_single_blocks_depth)
 | 
						|
            ]
 | 
						|
        )
 | 
						|
 | 
						|
        self.final_layer = FinalLayer(
 | 
						|
            self.hidden_size,
 | 
						|
            self.patch_size,
 | 
						|
            self.out_channels,
 | 
						|
            get_activation_layer("silu"),
 | 
						|
            **factory_kwargs,
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
    def lock_layers_dtypes(self, dtype = torch.float32):
 | 
						|
        layer_list = [self.final_layer, self.final_layer.linear, self.final_layer.adaLN_modulation[1]]
 | 
						|
        target_dype= dtype
 | 
						|
        
 | 
						|
        for current_layer_list, current_dtype in zip([layer_list], [target_dype]):
 | 
						|
            for layer in current_layer_list:
 | 
						|
                layer._lock_dtype = dtype
 | 
						|
 | 
						|
                if hasattr(layer, "weight") and layer.weight.dtype != current_dtype :
 | 
						|
                    layer.weight.data = layer.weight.data.to(current_dtype)
 | 
						|
                    if hasattr(layer, "bias"):
 | 
						|
                        layer.bias.data = layer.bias.data.to(current_dtype)
 | 
						|
 | 
						|
        self._lock_dtype = dtype
 | 
						|
 | 
						|
    def enable_deterministic(self):
 | 
						|
        for block in self.double_blocks:
 | 
						|
            block.enable_deterministic()
 | 
						|
        for block in self.single_blocks:
 | 
						|
            block.enable_deterministic()
 | 
						|
 | 
						|
    def disable_deterministic(self):
 | 
						|
        for block in self.double_blocks:
 | 
						|
            block.disable_deterministic()
 | 
						|
        for block in self.single_blocks:
 | 
						|
            block.disable_deterministic()
 | 
						|
 | 
						|
    def forward(
 | 
						|
        self,
 | 
						|
        x: torch.Tensor,
 | 
						|
        t: torch.Tensor,  # Should be in range(0, 1000).
 | 
						|
        ref_latents: torch.Tensor=None,        
 | 
						|
        text_states: torch.Tensor = None,
 | 
						|
        text_mask: torch.Tensor = None,  # Now we don't use it.
 | 
						|
        text_states_2: Optional[torch.Tensor] = None,  # Text embedding for modulation.
 | 
						|
        freqs_cos: Optional[torch.Tensor] = None,
 | 
						|
        freqs_sin: Optional[torch.Tensor] = None,
 | 
						|
        guidance: torch.Tensor = None,  # Guidance for modulation, should be cfg_scale x 1000.
 | 
						|
        pipeline=None,
 | 
						|
        x_id = 0,
 | 
						|
        callback = None,
 | 
						|
    ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
 | 
						|
    
 | 
						|
        img = x
 | 
						|
        batch_no, _, ot, oh, ow = x.shape
 | 
						|
        del x
 | 
						|
        txt = text_states   
 | 
						|
        tt, th, tw = (
 | 
						|
            ot // self.patch_size[0],
 | 
						|
            oh // self.patch_size[1],
 | 
						|
            ow // self.patch_size[2],
 | 
						|
        )
 | 
						|
 | 
						|
        # Prepare modulation vectors.
 | 
						|
        vec = self.time_in(t)
 | 
						|
 | 
						|
        if self.i2v_condition_type == "token_replace":
 | 
						|
            token_replace_t = torch.zeros_like(t)
 | 
						|
            token_replace_vec = self.time_in(token_replace_t)
 | 
						|
            frist_frame_token_num = th * tw
 | 
						|
        else:
 | 
						|
            token_replace_vec = None
 | 
						|
            frist_frame_token_num = None
 | 
						|
            # token_replace_mask_img = None
 | 
						|
            # token_replace_mask_txt = None
 | 
						|
 | 
						|
        # text modulation
 | 
						|
        # vec = vec + self.vector_in(text_states_2)
 | 
						|
        vec_2 = self.vector_in(text_states_2)
 | 
						|
        del text_states_2
 | 
						|
        vec += vec_2
 | 
						|
        if self.i2v_condition_type == "token_replace":
 | 
						|
            token_replace_vec += vec_2
 | 
						|
        del vec_2
 | 
						|
        
 | 
						|
        # guidance modulation
 | 
						|
        if self.guidance_embed:
 | 
						|
            if guidance is None:
 | 
						|
                raise ValueError(
 | 
						|
                    "Didn't get guidance strength for guidance distilled model."
 | 
						|
                )
 | 
						|
 | 
						|
            # our timestep_embedding is merged into guidance_in(TimestepEmbedder)
 | 
						|
            vec = vec + self.guidance_in(guidance)
 | 
						|
 | 
						|
        # Embed image and text.
 | 
						|
        img = self.img_in(img)
 | 
						|
        if ref_latents != None:
 | 
						|
            ref_latents = self.img_in(ref_latents)
 | 
						|
        if self.text_projection == "linear":
 | 
						|
            txt = self.txt_in(txt)
 | 
						|
        elif self.text_projection == "single_refiner":
 | 
						|
            txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
 | 
						|
        else:
 | 
						|
            raise NotImplementedError(
 | 
						|
                f"Unsupported text_projection: {self.text_projection}"
 | 
						|
            )
 | 
						|
 | 
						|
        if ref_latents == None:
 | 
						|
            ref_length  = None
 | 
						|
        else:
 | 
						|
            ref_length = ref_latents.shape[-2]
 | 
						|
            img = torch.cat([ref_latents, img], dim=-2) # t c
 | 
						|
        txt_seq_len = txt.shape[1]
 | 
						|
        img_seq_len = img.shape[1]
 | 
						|
 | 
						|
        text_len = text_mask.sum(1)
 | 
						|
        total_len = text_len + img_seq_len
 | 
						|
        seqlens_q = seqlens_kv = total_len 
 | 
						|
        attn_mask = None
 | 
						|
 | 
						|
        freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
 | 
						|
        
 | 
						|
 | 
						|
        if self.enable_teacache:
 | 
						|
            if x_id == 0:
 | 
						|
                self.should_calc = True
 | 
						|
                inp = img[0:1] 
 | 
						|
                vec_ = vec 
 | 
						|
                ( img_mod1_shift, img_mod1_scale, _ , _ , _ , _ , ) = self.double_blocks[0].img_mod(vec_).chunk(6, dim=-1)
 | 
						|
                normed_inp = self.double_blocks[0].img_norm1(inp)
 | 
						|
                normed_inp = normed_inp.to(torch.bfloat16)
 | 
						|
                modulated_inp = modulate( normed_inp, shift=img_mod1_shift, scale=img_mod1_scale )
 | 
						|
                del normed_inp, img_mod1_shift, img_mod1_scale
 | 
						|
                if self.teacache_counter <= self.teacache_start_step or self.teacache_counter == self.num_steps-1:
 | 
						|
                    self.accumulated_rel_l1_distance = 0
 | 
						|
                else: 
 | 
						|
                    coefficients = [7.33226126e+02, -4.01131952e+02,  6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
 | 
						|
                    rescale_func = np.poly1d(coefficients)
 | 
						|
                    self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
 | 
						|
                    if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
 | 
						|
                        self.should_calc = False
 | 
						|
                        self.teacache_skipped_steps += 1
 | 
						|
                    else:
 | 
						|
                        self.accumulated_rel_l1_distance = 0
 | 
						|
                self.previous_modulated_input = modulated_inp  
 | 
						|
                self.teacache_counter += 1
 | 
						|
                if self.teacache_counter == self.num_steps:
 | 
						|
                    self.teacache_counter = 0
 | 
						|
        else:
 | 
						|
            self.should_calc = True
 | 
						|
 | 
						|
        if not self.should_calc:
 | 
						|
            img += self.previous_residual[x_id]
 | 
						|
        else:
 | 
						|
            if self.enable_teacache:            
 | 
						|
                self.previous_residual[x_id] = None
 | 
						|
                ori_img = img[0:1].clone()
 | 
						|
            # --------------------- Pass through DiT blocks ------------------------
 | 
						|
            for _, block in enumerate(self.double_blocks):
 | 
						|
                for i in range(len(img)):
 | 
						|
                    if callback != None:
 | 
						|
                        callback(-1, None, False, True)
 | 
						|
                    if pipeline._interrupt:
 | 
						|
                        return None
 | 
						|
                    double_block_args = [
 | 
						|
                        img[i:i+1],
 | 
						|
                        txt[i:i+1],
 | 
						|
                        vec[i:i+1],
 | 
						|
                        attn_mask,                
 | 
						|
                        seqlens_q[i:i+1],
 | 
						|
                        seqlens_kv[i:i+1],
 | 
						|
                        freqs_cis,
 | 
						|
                        self.i2v_condition_type,
 | 
						|
                        token_replace_vec,
 | 
						|
                        frist_frame_token_num,                    
 | 
						|
                    ]
 | 
						|
 | 
						|
                    img[i], txt[i] = block(*double_block_args)
 | 
						|
                    double_block_args = None
 | 
						|
 | 
						|
            for _, block in enumerate(self.single_blocks):
 | 
						|
                for i in range(len(img)):
 | 
						|
                    if callback != None:
 | 
						|
                        callback(-1, None, False, True)
 | 
						|
                    if pipeline._interrupt:
 | 
						|
                        return None
 | 
						|
                    single_block_args = [
 | 
						|
                        # x,
 | 
						|
                        img[i:i+1],
 | 
						|
                        txt[i:i+1],
 | 
						|
                        vec[i:i+1],
 | 
						|
                        txt_seq_len,
 | 
						|
                        attn_mask,                
 | 
						|
                        seqlens_q[i:i+1],
 | 
						|
                        seqlens_kv[i:i+1],
 | 
						|
                        (freqs_cos, freqs_sin),
 | 
						|
                        self.i2v_condition_type,
 | 
						|
                        token_replace_vec,
 | 
						|
                        frist_frame_token_num,                    
 | 
						|
                    ]
 | 
						|
 | 
						|
                    img[i], txt[i] = block(*single_block_args)
 | 
						|
                    single_block_args = None
 | 
						|
 | 
						|
            # img = x[:, :img_seq_len, ...]
 | 
						|
            if self.enable_teacache:
 | 
						|
                if len(img) > 1:
 | 
						|
                    self.previous_residual[0] = torch.empty_like(img)
 | 
						|
                    for i, (x, residual) in enumerate(zip(img, self.previous_residual[0])):
 | 
						|
                        if i < len(img) - 1:
 | 
						|
                            residual[...] = torch.sub(x, ori_img) 
 | 
						|
                        else:
 | 
						|
                            residual[...] = ori_img
 | 
						|
                            torch.sub(x, ori_img, out=residual)                     
 | 
						|
                    x = None
 | 
						|
                else:
 | 
						|
                    self.previous_residual[x_id] = ori_img
 | 
						|
                    torch.sub(img, ori_img, out=self.previous_residual[x_id]) 
 | 
						|
 | 
						|
 | 
						|
        if ref_length != None:
 | 
						|
            img = img[:, ref_length:]
 | 
						|
        # ---------------------------- Final layer ------------------------------
 | 
						|
        out_dtype = self.final_layer.linear.weight.dtype
 | 
						|
        vec = vec.to(out_dtype)        
 | 
						|
        img_list  = []
 | 
						|
        for img_chunk, vec_chunk in zip(img,vec):
 | 
						|
             img_list.append( self.final_layer(img_chunk.to(out_dtype).unsqueeze(0), vec_chunk.unsqueeze(0))) # (N, T, patch_size ** 2 * out_channels) 
 | 
						|
        img = torch.cat(img_list)
 | 
						|
        img_list = None
 | 
						|
 | 
						|
        img = self.unpatchify(img, tt, th, tw)
 | 
						|
 | 
						|
        return img
 | 
						|
 | 
						|
    def unpatchify(self, x, t, h, w):
 | 
						|
        """
 | 
						|
        x: (N, T, patch_size**2 * C)
 | 
						|
        imgs: (N, H, W, C)
 | 
						|
        """
 | 
						|
        c = self.unpatchify_channels
 | 
						|
        pt, ph, pw = self.patch_size
 | 
						|
        assert t * h * w == x.shape[1]
 | 
						|
 | 
						|
        x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
 | 
						|
        x = torch.einsum("nthwcopq->nctohpwq", x)
 | 
						|
        imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
 | 
						|
 | 
						|
        return imgs
 | 
						|
 | 
						|
    def params_count(self):
 | 
						|
        counts = {
 | 
						|
            "double": sum(
 | 
						|
                [
 | 
						|
                    sum(p.numel() for p in block.img_attn_qkv.parameters())
 | 
						|
                    + sum(p.numel() for p in block.img_attn_proj.parameters())
 | 
						|
                    + sum(p.numel() for p in block.img_mlp.parameters())
 | 
						|
                    + sum(p.numel() for p in block.txt_attn_qkv.parameters())
 | 
						|
                    + sum(p.numel() for p in block.txt_attn_proj.parameters())
 | 
						|
                    + sum(p.numel() for p in block.txt_mlp.parameters())
 | 
						|
                    for block in self.double_blocks
 | 
						|
                ]
 | 
						|
            ),
 | 
						|
            "single": sum(
 | 
						|
                [
 | 
						|
                    sum(p.numel() for p in block.linear1.parameters())
 | 
						|
                    + sum(p.numel() for p in block.linear2.parameters())
 | 
						|
                    for block in self.single_blocks
 | 
						|
                ]
 | 
						|
            ),
 | 
						|
            "total": sum(p.numel() for p in self.parameters()),
 | 
						|
        }
 | 
						|
        counts["attn+mlp"] = counts["double"] + counts["single"]
 | 
						|
        return counts       
 | 
						|
 | 
						|
 | 
						|
#################################################################################
 | 
						|
#                             HunyuanVideo Configs                              #
 | 
						|
#################################################################################
 | 
						|
 | 
						|
HUNYUAN_VIDEO_CONFIG = {
 | 
						|
    "HYVideo-T/2": {
 | 
						|
        "mm_double_blocks_depth": 20,
 | 
						|
        "mm_single_blocks_depth": 40,
 | 
						|
        "rope_dim_list": [16, 56, 56],
 | 
						|
        "hidden_size": 3072,
 | 
						|
        "heads_num": 24,
 | 
						|
        "mlp_width_ratio": 4,
 | 
						|
    },
 | 
						|
    "HYVideo-T/2-cfgdistill": {
 | 
						|
        "mm_double_blocks_depth": 20,
 | 
						|
        "mm_single_blocks_depth": 40,
 | 
						|
        "rope_dim_list": [16, 56, 56],
 | 
						|
        "hidden_size": 3072,
 | 
						|
        "heads_num": 24,
 | 
						|
        "mlp_width_ratio": 4,
 | 
						|
        "guidance_embed": True,
 | 
						|
    },
 | 
						|
    "HYVideo-S/2": {
 | 
						|
        "mm_double_blocks_depth": 6,
 | 
						|
        "mm_single_blocks_depth": 12,
 | 
						|
        "rope_dim_list": [12, 42, 42],
 | 
						|
        "hidden_size": 480,
 | 
						|
        "heads_num": 5,
 | 
						|
        "mlp_width_ratio": 4,
 | 
						|
    },
 | 
						|
    'HYVideo-T/2-custom': {                                                                       #   9.0B   / 12.5B
 | 
						|
        "mm_double_blocks_depth": 20,
 | 
						|
        "mm_single_blocks_depth": 40,
 | 
						|
        "rope_dim_list": [16, 56, 56],
 | 
						|
        "hidden_size": 3072,
 | 
						|
        "heads_num": 24,
 | 
						|
        "mlp_width_ratio": 4,
 | 
						|
    },
 | 
						|
    
 | 
						|
} |