mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			1070 lines
		
	
	
		
			37 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			1070 lines
		
	
	
		
			37 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 | 
						||
import math
 | 
						||
from einops import rearrange
 | 
						||
import torch
 | 
						||
import torch.cuda.amp as amp
 | 
						||
import torch.nn as nn
 | 
						||
from diffusers.configuration_utils import ConfigMixin, register_to_config
 | 
						||
from diffusers.models.modeling_utils import ModelMixin
 | 
						||
import numpy as np
 | 
						||
from typing import Union,Optional
 | 
						||
from mmgp import offload
 | 
						||
from .attention import pay_attention
 | 
						||
from torch.backends.cuda import sdp_kernel
 | 
						||
 | 
						||
__all__ = ['WanModel']
 | 
						||
 | 
						||
 | 
						||
def sinusoidal_embedding_1d(dim, position):
 | 
						||
    # preprocess
 | 
						||
    assert dim % 2 == 0
 | 
						||
    half = dim // 2
 | 
						||
    position = position.type(torch.float32)
 | 
						||
 | 
						||
    # calculation
 | 
						||
    sinusoid = torch.outer(
 | 
						||
        position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
 | 
						||
    x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
 | 
						||
    return x
 | 
						||
 | 
						||
 | 
						||
def reshape_latent(latent, latent_frames):
 | 
						||
    if latent_frames == latent.shape[0]:
 | 
						||
        return latent
 | 
						||
    return latent.reshape(latent_frames, -1, latent.shape[-1] )
 | 
						||
 | 
						||
 | 
						||
def identify_k( b: float, d: int, N: int):
 | 
						||
    """
 | 
						||
    This function identifies the index of the intrinsic frequency component in a RoPE-based pre-trained diffusion transformer.
 | 
						||
 | 
						||
    Args:
 | 
						||
        b (`float`): The base frequency for RoPE.
 | 
						||
        d (`int`): Dimension of the frequency tensor
 | 
						||
        N (`int`): the first observed repetition frame in latent space
 | 
						||
    Returns:
 | 
						||
        k (`int`): the index of intrinsic frequency component
 | 
						||
        N_k (`int`): the period of intrinsic frequency component in latent space
 | 
						||
    Example:
 | 
						||
        In HunyuanVideo, b=256 and d=16, the repetition occurs approximately 8s (N=48 in latent space).
 | 
						||
        k, N_k = identify_k(b=256, d=16, N=48)
 | 
						||
        In this case, the intrinsic frequency index k is 4, and the period N_k is 50.
 | 
						||
    """
 | 
						||
 | 
						||
    # Compute the period of each frequency in RoPE according to Eq.(4)
 | 
						||
    periods = []
 | 
						||
    for j in range(1, d // 2 + 1):
 | 
						||
        theta_j = 1.0 / (b ** (2 * (j - 1) / d))
 | 
						||
        N_j = round(2 * torch.pi / theta_j)
 | 
						||
        periods.append(N_j)
 | 
						||
 | 
						||
    # Identify the intrinsic frequency whose period is closed to N(see Eq.(7))
 | 
						||
    diffs = [abs(N_j - N) for N_j in periods]
 | 
						||
    k = diffs.index(min(diffs)) + 1
 | 
						||
    N_k = periods[k-1]
 | 
						||
    return k, N_k
 | 
						||
 | 
						||
def rope_params_riflex(max_seq_len, dim, theta=10000, L_test=30, k=6):
 | 
						||
    assert dim % 2 == 0
 | 
						||
    exponents = torch.arange(0, dim, 2, dtype=torch.float64).div(dim)
 | 
						||
    inv_theta_pow = 1.0 / torch.pow(theta, exponents)
 | 
						||
    
 | 
						||
    inv_theta_pow[k-1] = 0.9 * 2 * torch.pi / L_test
 | 
						||
        
 | 
						||
    freqs = torch.outer(torch.arange(max_seq_len), inv_theta_pow)
 | 
						||
    if True:
 | 
						||
        freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float()  # [S, D]
 | 
						||
        freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float()  # [S, D]
 | 
						||
        return (freqs_cos, freqs_sin)
 | 
						||
    else:
 | 
						||
        freqs = torch.polar(torch.ones_like(freqs), freqs)  # complex64     # [S, D/2]
 | 
						||
    return freqs
 | 
						||
 | 
						||
 | 
						||
 | 
						||
def relative_l1_distance(last_tensor, current_tensor):
 | 
						||
    l1_distance = torch.abs(last_tensor - current_tensor).mean()
 | 
						||
    norm = torch.abs(last_tensor).mean()
 | 
						||
    relative_l1_distance = l1_distance / norm
 | 
						||
    return relative_l1_distance.to(torch.float32)
 | 
						||
 | 
						||
class WanRMSNorm(nn.Module):
 | 
						||
 | 
						||
    def __init__(self, dim, eps=1e-5):
 | 
						||
        super().__init__()
 | 
						||
        self.dim = dim
 | 
						||
        self.eps = eps
 | 
						||
        self.weight = nn.Parameter(torch.ones(dim))
 | 
						||
 | 
						||
    def forward(self, x):
 | 
						||
        r"""
 | 
						||
        Args:
 | 
						||
            x(Tensor): Shape [B, L, C]
 | 
						||
        """
 | 
						||
        y = x.float()
 | 
						||
        y.pow_(2)
 | 
						||
        y = y.mean(dim=-1, keepdim=True)
 | 
						||
        y += self.eps
 | 
						||
        y.rsqrt_()
 | 
						||
        x *=  y
 | 
						||
        x *= self.weight
 | 
						||
        return x
 | 
						||
        # return self._norm(x).type_as(x) * self.weight
 | 
						||
 | 
						||
    def _norm(self, x):
 | 
						||
        return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
 | 
						||
 | 
						||
def my_LayerNorm(norm, x):
 | 
						||
        y = x.float()
 | 
						||
        y_m = y.mean(dim=-1, keepdim=True)
 | 
						||
        y -= y_m 
 | 
						||
        del y_m
 | 
						||
        y.pow_(2)
 | 
						||
        y = y.mean(dim=-1, keepdim=True)
 | 
						||
        y += norm.eps
 | 
						||
        y.rsqrt_()
 | 
						||
        x = x *  y
 | 
						||
        return x
 | 
						||
 | 
						||
 | 
						||
class WanLayerNorm(nn.LayerNorm):
 | 
						||
 | 
						||
    def __init__(self, dim, eps=1e-6, elementwise_affine=False):
 | 
						||
        super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
 | 
						||
 | 
						||
    def forward(self, x):
 | 
						||
        r"""
 | 
						||
        Args:
 | 
						||
            x(Tensor): Shape [B, L, C]
 | 
						||
        """
 | 
						||
        # return F.layer_norm(
 | 
						||
        #     input, self.normalized_shape, self.weight, self.bias, self.eps
 | 
						||
        # )
 | 
						||
        y = super().forward(x)
 | 
						||
        x = y.type_as(x)
 | 
						||
        return x
 | 
						||
        # return super().forward(x).type_as(x)
 | 
						||
 | 
						||
from wan.modules.posemb_layers import apply_rotary_emb
 | 
						||
 | 
						||
class WanSelfAttention(nn.Module):
 | 
						||
 | 
						||
    def __init__(self,
 | 
						||
                 dim,
 | 
						||
                 num_heads,
 | 
						||
                 window_size=(-1, -1),
 | 
						||
                 qk_norm=True,
 | 
						||
                 eps=1e-6):
 | 
						||
        assert dim % num_heads == 0
 | 
						||
        super().__init__()
 | 
						||
        self.dim = dim
 | 
						||
        self.num_heads = num_heads
 | 
						||
        self.head_dim = dim // num_heads
 | 
						||
        self.window_size = window_size
 | 
						||
        self.qk_norm = qk_norm
 | 
						||
        self.eps = eps
 | 
						||
 | 
						||
        # layers
 | 
						||
        self.q = nn.Linear(dim, dim)
 | 
						||
        self.k = nn.Linear(dim, dim)
 | 
						||
        self.v = nn.Linear(dim, dim)
 | 
						||
        self.o = nn.Linear(dim, dim)
 | 
						||
        self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
 | 
						||
        self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
 | 
						||
 | 
						||
    def forward(self, xlist, grid_sizes, freqs, block_mask = None):
 | 
						||
        r"""
 | 
						||
        Args:
 | 
						||
            x(Tensor): Shape [B, L, num_heads, C / num_heads]
 | 
						||
            grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
 | 
						||
            freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
 | 
						||
        """
 | 
						||
        x = xlist[0]
 | 
						||
        xlist.clear()
 | 
						||
 | 
						||
        b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
 | 
						||
 | 
						||
        # query, key, value function
 | 
						||
        q = self.q(x)
 | 
						||
        self.norm_q(q)
 | 
						||
        q = q.view(b, s, n, d) # !!!
 | 
						||
        k = self.k(x)
 | 
						||
        self.norm_k(k)
 | 
						||
        k = k.view(b, s, n, d) 
 | 
						||
        v = self.v(x).view(b, s, n, d)
 | 
						||
        del x
 | 
						||
        qklist = [q,k]
 | 
						||
        del q,k
 | 
						||
 | 
						||
        q,k = apply_rotary_emb(qklist, freqs, head_first=False)
 | 
						||
        qkv_list = [q,k,v]
 | 
						||
        del q,k,v
 | 
						||
        if block_mask == None:
 | 
						||
            x = pay_attention(
 | 
						||
                qkv_list,
 | 
						||
                window_size=self.window_size)
 | 
						||
        else:
 | 
						||
            with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
 | 
						||
                x = (
 | 
						||
                    torch.nn.functional.scaled_dot_product_attention(
 | 
						||
                        q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=block_mask
 | 
						||
                    )
 | 
						||
                    .transpose(1, 2)
 | 
						||
                    .contiguous()
 | 
						||
                )
 | 
						||
 | 
						||
        # if not self._flag_ar_attention:
 | 
						||
        #     q = rope_apply(q, grid_sizes, freqs)
 | 
						||
        #     k = rope_apply(k, grid_sizes, freqs)
 | 
						||
        #     x = flash_attention(q=q, k=k, v=v, window_size=self.window_size)
 | 
						||
        # else:
 | 
						||
        #     q = rope_apply(q, grid_sizes, freqs)
 | 
						||
        #     k = rope_apply(k, grid_sizes, freqs)
 | 
						||
        #     q = q.to(torch.bfloat16)
 | 
						||
        #     k = k.to(torch.bfloat16)
 | 
						||
        #     v = v.to(torch.bfloat16)
 | 
						||
 | 
						||
        #     with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
 | 
						||
        #         x = (
 | 
						||
        #             torch.nn.functional.scaled_dot_product_attention(
 | 
						||
        #                 q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=block_mask
 | 
						||
        #             )
 | 
						||
        #             .transpose(1, 2)
 | 
						||
        #             .contiguous()
 | 
						||
        #         )
 | 
						||
 | 
						||
        # output
 | 
						||
        x = x.flatten(2)
 | 
						||
        x = self.o(x)
 | 
						||
        return x
 | 
						||
 | 
						||
 | 
						||
class WanT2VCrossAttention(WanSelfAttention):
 | 
						||
 | 
						||
    def forward(self, xlist, context):
 | 
						||
        r"""
 | 
						||
        Args:
 | 
						||
            x(Tensor): Shape [B, L1, C]
 | 
						||
            context(Tensor): Shape [B, L2, C]
 | 
						||
        """
 | 
						||
        x = xlist[0]
 | 
						||
        xlist.clear()
 | 
						||
        b, n, d = x.size(0), self.num_heads, self.head_dim
 | 
						||
 | 
						||
        # compute query, key, value
 | 
						||
        q = self.q(x)
 | 
						||
        del x
 | 
						||
        self.norm_q(q)
 | 
						||
        q= q.view(b, -1, n, d)
 | 
						||
        k = self.k(context)
 | 
						||
        self.norm_k(k)
 | 
						||
        k = k.view(b, -1, n, d)
 | 
						||
        v = self.v(context).view(b, -1, n, d)
 | 
						||
 | 
						||
        # compute attention
 | 
						||
        qvl_list=[q, k, v]
 | 
						||
        del q, k, v
 | 
						||
        x = pay_attention(qvl_list,  cross_attn= True)
 | 
						||
 | 
						||
        # output
 | 
						||
        x = x.flatten(2)
 | 
						||
        x = self.o(x)
 | 
						||
        return x
 | 
						||
 | 
						||
 | 
						||
class WanI2VCrossAttention(WanSelfAttention):
 | 
						||
 | 
						||
    def __init__(self,
 | 
						||
                 dim,
 | 
						||
                 num_heads,
 | 
						||
                 window_size=(-1, -1),
 | 
						||
                 qk_norm=True,
 | 
						||
                 eps=1e-6):
 | 
						||
        super().__init__(dim, num_heads, window_size, qk_norm, eps)
 | 
						||
 | 
						||
        self.k_img = nn.Linear(dim, dim)
 | 
						||
        self.v_img = nn.Linear(dim, dim)
 | 
						||
        # self.alpha = nn.Parameter(torch.zeros((1, )))
 | 
						||
        self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
 | 
						||
 | 
						||
    def forward(self, xlist, context):
 | 
						||
        r"""
 | 
						||
        Args:
 | 
						||
            x(Tensor): Shape [B, L1, C]
 | 
						||
            context(Tensor): Shape [B, L2, C]
 | 
						||
        """
 | 
						||
 | 
						||
        ##### Enjoy this spagheti VRAM optimizations done by DeepBeepMeep !
 | 
						||
        # I am sure you are a nice person and as you copy this code, you will give me officially proper credits:
 | 
						||
        # Please link to https://github.com/deepbeepmeep/Wan2GP and @deepbeepmeep on twitter  
 | 
						||
 | 
						||
        x = xlist[0]
 | 
						||
        xlist.clear()
 | 
						||
 | 
						||
        context_img = context[:, :257]
 | 
						||
        context = context[:, 257:]
 | 
						||
        b, n, d = x.size(0), self.num_heads, self.head_dim
 | 
						||
 | 
						||
        # compute query, key, value
 | 
						||
        q = self.q(x)
 | 
						||
        del x
 | 
						||
        self.norm_q(q)
 | 
						||
        q= q.view(b, -1, n, d)
 | 
						||
        k = self.k(context)
 | 
						||
        self.norm_k(k)
 | 
						||
        k = k.view(b, -1, n, d)
 | 
						||
        v = self.v(context).view(b, -1, n, d)
 | 
						||
 | 
						||
        qkv_list = [q, k, v]
 | 
						||
        del k,v
 | 
						||
        x = pay_attention(qkv_list)
 | 
						||
 | 
						||
        k_img = self.k_img(context_img)
 | 
						||
        self.norm_k_img(k_img)
 | 
						||
        k_img = k_img.view(b, -1, n, d)
 | 
						||
        v_img = self.v_img(context_img).view(b, -1, n, d)
 | 
						||
        qkv_list = [q, k_img, v_img]
 | 
						||
        del q, k_img, v_img
 | 
						||
        img_x = pay_attention(qkv_list)
 | 
						||
        # compute attention
 | 
						||
 | 
						||
 | 
						||
        # output
 | 
						||
        x = x.flatten(2)
 | 
						||
        img_x = img_x.flatten(2)
 | 
						||
        x += img_x
 | 
						||
        del img_x
 | 
						||
        x = self.o(x)
 | 
						||
        return x
 | 
						||
 | 
						||
 | 
						||
 | 
						||
WAN_CROSSATTENTION_CLASSES = {
 | 
						||
    't2v_cross_attn': WanT2VCrossAttention,
 | 
						||
    'i2v_cross_attn': WanI2VCrossAttention,
 | 
						||
}
 | 
						||
 | 
						||
 | 
						||
class WanAttentionBlock(nn.Module):
 | 
						||
 | 
						||
    def __init__(self,
 | 
						||
                 cross_attn_type,
 | 
						||
                 dim,
 | 
						||
                 ffn_dim,
 | 
						||
                 num_heads,
 | 
						||
                 window_size=(-1, -1),
 | 
						||
                 qk_norm=True,
 | 
						||
                 cross_attn_norm=False,
 | 
						||
                 eps=1e-6,
 | 
						||
                 block_id=None
 | 
						||
                 ):
 | 
						||
        super().__init__()
 | 
						||
        self.dim = dim
 | 
						||
        self.ffn_dim = ffn_dim
 | 
						||
        self.num_heads = num_heads
 | 
						||
        self.window_size = window_size
 | 
						||
        self.qk_norm = qk_norm
 | 
						||
        self.cross_attn_norm = cross_attn_norm
 | 
						||
        self.eps = eps
 | 
						||
 | 
						||
        # layers
 | 
						||
        self.norm1 = WanLayerNorm(dim, eps)
 | 
						||
        self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
 | 
						||
                                          eps)
 | 
						||
        self.norm3 = WanLayerNorm(
 | 
						||
            dim, eps,
 | 
						||
            elementwise_affine=True) if cross_attn_norm else nn.Identity()
 | 
						||
        self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
 | 
						||
                                                                      num_heads,
 | 
						||
                                                                      (-1, -1),
 | 
						||
                                                                      qk_norm,
 | 
						||
                                                                      eps)
 | 
						||
        self.norm2 = WanLayerNorm(dim, eps)
 | 
						||
        self.ffn = nn.Sequential(
 | 
						||
            nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
 | 
						||
            nn.Linear(ffn_dim, dim))
 | 
						||
 | 
						||
        # modulation
 | 
						||
        self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
 | 
						||
        self.block_id = block_id
 | 
						||
 | 
						||
    def forward(
 | 
						||
        self,
 | 
						||
        x,
 | 
						||
        e,
 | 
						||
        grid_sizes,
 | 
						||
        freqs,
 | 
						||
        context,
 | 
						||
        hints= None, 
 | 
						||
        context_scale=1.0,
 | 
						||
        cam_emb= None,
 | 
						||
        block_mask = None
 | 
						||
    ):
 | 
						||
        r"""
 | 
						||
        Args:
 | 
						||
            x(Tensor): Shape [B, L, C]
 | 
						||
            e(Tensor): Shape [B, 6, C]
 | 
						||
            grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
 | 
						||
            freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
 | 
						||
        """
 | 
						||
        hint = None
 | 
						||
        attention_dtype =  self.self_attn.q.weight.dtype 
 | 
						||
        dtype = x.dtype
 | 
						||
 | 
						||
        if self.block_id is not None and hints is not None:
 | 
						||
            kwargs = { 
 | 
						||
                "grid_sizes" : grid_sizes,
 | 
						||
                "freqs" :freqs, 
 | 
						||
                "context" : context,
 | 
						||
                "e" : e,
 | 
						||
            }
 | 
						||
            if self.block_id == 0:
 | 
						||
                hint = self.vace(hints, x, **kwargs)
 | 
						||
            else:
 | 
						||
                hint = self.vace(hints, None, **kwargs)
 | 
						||
        latent_frames = e.shape[0]
 | 
						||
        e = (self.modulation + e).chunk(6, dim=1)
 | 
						||
        # self-attention
 | 
						||
        x_mod = self.norm1(x)
 | 
						||
        x_mod = reshape_latent(x_mod , latent_frames)
 | 
						||
        x_mod *= 1 + e[1]
 | 
						||
        x_mod += e[0]
 | 
						||
        x_mod = reshape_latent(x_mod , 1)
 | 
						||
        if cam_emb != None:
 | 
						||
            cam_emb = self.cam_encoder(cam_emb)
 | 
						||
            cam_emb = cam_emb.repeat(1, 2, 1)
 | 
						||
            cam_emb = cam_emb.unsqueeze(2).unsqueeze(3).repeat(1, 1, grid_sizes[0][1], grid_sizes[0][2], 1)
 | 
						||
            cam_emb = rearrange(cam_emb, 'b f h w d -> b (f h w) d')
 | 
						||
            x_mod += cam_emb
 | 
						||
 | 
						||
        xlist = [x_mod.to(attention_dtype)]
 | 
						||
        del x_mod
 | 
						||
        y = self.self_attn( xlist, grid_sizes, freqs, block_mask)
 | 
						||
        y = y.to(dtype)
 | 
						||
 | 
						||
        if cam_emb != None: 
 | 
						||
            y = self.projector(y)
 | 
						||
 | 
						||
        x, y = reshape_latent(x , latent_frames), reshape_latent(y , latent_frames)
 | 
						||
        x.addcmul_(y, e[2])
 | 
						||
        x, y = reshape_latent(x , 1), reshape_latent(y , 1)
 | 
						||
        del y
 | 
						||
        y = self.norm3(x)
 | 
						||
        y = y.to(attention_dtype)
 | 
						||
        ylist= [y]
 | 
						||
        del y
 | 
						||
        x += self.cross_attn(ylist, context).to(dtype)
 | 
						||
 | 
						||
        y = self.norm2(x)
 | 
						||
 | 
						||
        y = reshape_latent(y , latent_frames)
 | 
						||
        y *= 1 + e[4]
 | 
						||
        y += e[3]
 | 
						||
        y = reshape_latent(y , 1)
 | 
						||
        y = y.to(attention_dtype)
 | 
						||
 | 
						||
        ffn = self.ffn[0]
 | 
						||
        gelu = self.ffn[1]
 | 
						||
        ffn2= self.ffn[2]
 | 
						||
 | 
						||
        y_shape = y.shape
 | 
						||
        y = y.view(-1, y_shape[-1])
 | 
						||
        chunk_size = int(y_shape[1]/2.7)
 | 
						||
        chunks =torch.split(y, chunk_size)
 | 
						||
        for y_chunk  in chunks:
 | 
						||
            mlp_chunk = ffn(y_chunk)
 | 
						||
            mlp_chunk = gelu(mlp_chunk)
 | 
						||
            y_chunk[...] = ffn2(mlp_chunk)
 | 
						||
            del mlp_chunk 
 | 
						||
        y = y.view(y_shape)
 | 
						||
        y = y.to(dtype)
 | 
						||
        x, y = reshape_latent(x , latent_frames), reshape_latent(y , latent_frames)
 | 
						||
        x.addcmul_(y, e[5])
 | 
						||
        x, y = reshape_latent(x , 1), reshape_latent(y , 1)
 | 
						||
 | 
						||
        if hint is not None:
 | 
						||
            if context_scale == 1:
 | 
						||
                x.add_(hint)
 | 
						||
            else:
 | 
						||
                x.add_(hint, alpha= context_scale)
 | 
						||
        return x 
 | 
						||
 | 
						||
 | 
						||
 | 
						||
class VaceWanAttentionBlock(WanAttentionBlock):
 | 
						||
    def __init__(
 | 
						||
            self,
 | 
						||
            cross_attn_type,
 | 
						||
            dim,
 | 
						||
            ffn_dim,
 | 
						||
            num_heads,
 | 
						||
            window_size=(-1, -1),
 | 
						||
            qk_norm=True,
 | 
						||
            cross_attn_norm=False,
 | 
						||
            eps=1e-6,
 | 
						||
            block_id=0
 | 
						||
    ):
 | 
						||
        super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
 | 
						||
        self.block_id = block_id
 | 
						||
        if block_id == 0:
 | 
						||
            self.before_proj = nn.Linear(self.dim, self.dim)
 | 
						||
            nn.init.zeros_(self.before_proj.weight)
 | 
						||
            nn.init.zeros_(self.before_proj.bias)
 | 
						||
        self.after_proj = nn.Linear(self.dim, self.dim)
 | 
						||
        nn.init.zeros_(self.after_proj.weight)
 | 
						||
        nn.init.zeros_(self.after_proj.bias)
 | 
						||
 | 
						||
    def forward(self, hints, x, **kwargs):
 | 
						||
        # behold dbm magic !
 | 
						||
        c = hints[0]
 | 
						||
        hints[0] = None
 | 
						||
        if self.block_id == 0:
 | 
						||
            c = self.before_proj(c)
 | 
						||
            c += x
 | 
						||
        c = super().forward(c, **kwargs)
 | 
						||
        c_skip = self.after_proj(c)
 | 
						||
        hints[0] = c
 | 
						||
        return c_skip
 | 
						||
 | 
						||
    
 | 
						||
class Head(nn.Module):
 | 
						||
 | 
						||
    def __init__(self, dim, out_dim, patch_size, eps=1e-6):
 | 
						||
        super().__init__()
 | 
						||
        self.dim = dim
 | 
						||
        self.out_dim = out_dim
 | 
						||
        self.patch_size = patch_size
 | 
						||
        self.eps = eps
 | 
						||
 | 
						||
        # layers
 | 
						||
        out_dim = math.prod(patch_size) * out_dim
 | 
						||
        self.norm = WanLayerNorm(dim, eps)
 | 
						||
        self.head = nn.Linear(dim, out_dim)
 | 
						||
        # modulation
 | 
						||
        self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
 | 
						||
 | 
						||
    def forward(self, x, e):
 | 
						||
        r"""
 | 
						||
        Args:
 | 
						||
            x(Tensor): Shape [B, L1, C]
 | 
						||
            e(Tensor): Shape [B, C]
 | 
						||
        """
 | 
						||
        # assert e.dtype == torch.float32
 | 
						||
        dtype = x.dtype
 | 
						||
 | 
						||
        latent_frames = e.shape[0]
 | 
						||
        e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
 | 
						||
        x = self.norm(x).to(dtype)
 | 
						||
        x = reshape_latent(x , latent_frames)
 | 
						||
        x *= (1 + e[1])
 | 
						||
        x += e[0]
 | 
						||
        x = reshape_latent(x , 1)
 | 
						||
        x= x.to(self.head.weight.dtype)
 | 
						||
        x = self.head(x)
 | 
						||
        return x
 | 
						||
 | 
						||
 | 
						||
class MLPProj(torch.nn.Module):
 | 
						||
 | 
						||
    def __init__(self, in_dim, out_dim):
 | 
						||
        super().__init__()
 | 
						||
 | 
						||
        self.proj = torch.nn.Sequential(
 | 
						||
            torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
 | 
						||
            torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
 | 
						||
            torch.nn.LayerNorm(out_dim))
 | 
						||
 | 
						||
    def forward(self, image_embeds):
 | 
						||
        clip_extra_context_tokens = self.proj(image_embeds)
 | 
						||
        return clip_extra_context_tokens
 | 
						||
 | 
						||
 | 
						||
class WanModel(ModelMixin, ConfigMixin):
 | 
						||
    r"""
 | 
						||
    Wan diffusion backbone supporting both text-to-video and image-to-video.
 | 
						||
    """
 | 
						||
 | 
						||
    ignore_for_config = [
 | 
						||
        'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
 | 
						||
    ]
 | 
						||
    _no_split_modules = ['WanAttentionBlock']
 | 
						||
 | 
						||
    @register_to_config
 | 
						||
    def __init__(self,
 | 
						||
                 vace_layers=None,
 | 
						||
                 vace_in_dim=None,                 
 | 
						||
                 model_type='t2v',
 | 
						||
                 patch_size=(1, 2, 2),
 | 
						||
                 text_len=512,
 | 
						||
                 in_dim=16,
 | 
						||
                 dim=2048,
 | 
						||
                 ffn_dim=8192,
 | 
						||
                 freq_dim=256,
 | 
						||
                 text_dim=4096,
 | 
						||
                 out_dim=16,
 | 
						||
                 num_heads=16,
 | 
						||
                 num_layers=32,
 | 
						||
                 window_size=(-1, -1),
 | 
						||
                 qk_norm=True,
 | 
						||
                 cross_attn_norm=True,
 | 
						||
                 eps=1e-6,
 | 
						||
                 recammaster = False,
 | 
						||
                 inject_sample_info = False,
 | 
						||
                 ):
 | 
						||
        r"""
 | 
						||
        Initialize the diffusion model backbone.
 | 
						||
 | 
						||
        Args:
 | 
						||
            model_type (`str`, *optional*, defaults to 't2v'):
 | 
						||
                Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
 | 
						||
            patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
 | 
						||
                3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
 | 
						||
            text_len (`int`, *optional*, defaults to 512):
 | 
						||
                Fixed length for text embeddings
 | 
						||
            in_dim (`int`, *optional*, defaults to 16):
 | 
						||
                Input video channels (C_in)
 | 
						||
            dim (`int`, *optional*, defaults to 2048):
 | 
						||
                Hidden dimension of the transformer
 | 
						||
            ffn_dim (`int`, *optional*, defaults to 8192):
 | 
						||
                Intermediate dimension in feed-forward network
 | 
						||
            freq_dim (`int`, *optional*, defaults to 256):
 | 
						||
                Dimension for sinusoidal time embeddings
 | 
						||
            text_dim (`int`, *optional*, defaults to 4096):
 | 
						||
                Input dimension for text embeddings
 | 
						||
            out_dim (`int`, *optional*, defaults to 16):
 | 
						||
                Output video channels (C_out)
 | 
						||
            num_heads (`int`, *optional*, defaults to 16):
 | 
						||
                Number of attention heads
 | 
						||
            num_layers (`int`, *optional*, defaults to 32):
 | 
						||
                Number of transformer blocks
 | 
						||
            window_size (`tuple`, *optional*, defaults to (-1, -1)):
 | 
						||
                Window size for local attention (-1 indicates global attention)
 | 
						||
            qk_norm (`bool`, *optional*, defaults to True):
 | 
						||
                Enable query/key normalization
 | 
						||
            cross_attn_norm (`bool`, *optional*, defaults to False):
 | 
						||
                Enable cross-attention normalization
 | 
						||
            eps (`float`, *optional*, defaults to 1e-6):
 | 
						||
                Epsilon value for normalization layers
 | 
						||
        """
 | 
						||
 | 
						||
        super().__init__()
 | 
						||
 | 
						||
        assert model_type in ['t2v', 'i2v']
 | 
						||
        self.model_type = model_type
 | 
						||
 | 
						||
        self.patch_size = patch_size
 | 
						||
        self.text_len = text_len
 | 
						||
        self.in_dim = in_dim
 | 
						||
        self.dim = dim
 | 
						||
        self.ffn_dim = ffn_dim
 | 
						||
        self.freq_dim = freq_dim
 | 
						||
        self.text_dim = text_dim
 | 
						||
        self.out_dim = out_dim
 | 
						||
        self.num_heads = num_heads
 | 
						||
        self.num_layers = num_layers
 | 
						||
        self.window_size = window_size
 | 
						||
        self.qk_norm = qk_norm
 | 
						||
        self.cross_attn_norm = cross_attn_norm
 | 
						||
        self.eps = eps
 | 
						||
        self.num_frame_per_block = 1
 | 
						||
        self.flag_causal_attention = False
 | 
						||
        self.block_mask = None
 | 
						||
        self.inject_sample_info = inject_sample_info
 | 
						||
 | 
						||
        # embeddings
 | 
						||
        self.patch_embedding = nn.Conv3d(
 | 
						||
            in_dim, dim, kernel_size=patch_size, stride=patch_size)
 | 
						||
        self.text_embedding = nn.Sequential(
 | 
						||
            nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
 | 
						||
            nn.Linear(dim, dim))
 | 
						||
 | 
						||
        if inject_sample_info:
 | 
						||
            self.fps_embedding = nn.Embedding(2, dim)
 | 
						||
            self.fps_projection = nn.Sequential(nn.Linear(dim, dim), nn.SiLU(), nn.Linear(dim, dim * 6))
 | 
						||
 | 
						||
        self.time_embedding = nn.Sequential(
 | 
						||
            nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
 | 
						||
        self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
 | 
						||
 | 
						||
        # blocks
 | 
						||
        if vace_layers == None:
 | 
						||
            cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
 | 
						||
            self.blocks = nn.ModuleList([
 | 
						||
                WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
 | 
						||
                                window_size, qk_norm, cross_attn_norm, eps)
 | 
						||
                for _ in range(num_layers)
 | 
						||
            ])
 | 
						||
 | 
						||
        # head
 | 
						||
        self.head = Head(dim, out_dim, patch_size, eps)
 | 
						||
 | 
						||
        # buffers (don't use register_buffer otherwise dtype will be changed in to())
 | 
						||
 | 
						||
        if model_type == 'i2v':
 | 
						||
            self.img_emb = MLPProj(1280, dim)
 | 
						||
 | 
						||
        # initialize weights
 | 
						||
        self.init_weights()
 | 
						||
 | 
						||
        if vace_layers != None:            
 | 
						||
            self.vace_layers = [i for i in range(0, self.num_layers, 2)] if vace_layers is None else vace_layers
 | 
						||
            self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim
 | 
						||
 | 
						||
            assert 0 in self.vace_layers
 | 
						||
            self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
 | 
						||
 | 
						||
            # blocks
 | 
						||
            self.blocks = nn.ModuleList([
 | 
						||
                WanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
 | 
						||
                                    self.cross_attn_norm, self.eps,
 | 
						||
                                    block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None)
 | 
						||
                for i in range(self.num_layers)
 | 
						||
            ])
 | 
						||
 | 
						||
            # vace blocks
 | 
						||
            self.vace_blocks = nn.ModuleList([
 | 
						||
                VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
 | 
						||
                                        self.cross_attn_norm, self.eps, block_id=i)
 | 
						||
                for i in self.vace_layers
 | 
						||
            ])
 | 
						||
 | 
						||
            # vace patch embeddings
 | 
						||
            self.vace_patch_embedding = nn.Conv3d(
 | 
						||
                self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size
 | 
						||
            )
 | 
						||
        if recammaster :
 | 
						||
            dim=self.blocks[0].self_attn.q.weight.shape[0]
 | 
						||
            for block in self.blocks:
 | 
						||
                block.cam_encoder = nn.Linear(12, dim)
 | 
						||
                block.projector = nn.Linear(dim, dim)
 | 
						||
                block.cam_encoder.weight.data.zero_()
 | 
						||
                block.cam_encoder.bias.data.zero_()
 | 
						||
                block.projector.weight = nn.Parameter(torch.eye(dim))
 | 
						||
                block.projector.bias = nn.Parameter(torch.zeros(dim))            
 | 
						||
 | 
						||
 | 
						||
    def lock_layers_dtypes(self,  dtype = torch.float32, force = False):
 | 
						||
        count = 0
 | 
						||
        layer_list = [self.head, self.head.head, self.patch_embedding, self.time_embedding, self.time_embedding[0], self.time_embedding[2], 
 | 
						||
                      self.time_projection, self.time_projection[1]] #, self.text_embedding, self.text_embedding[0], self.text_embedding[2] ]
 | 
						||
        if hasattr(self, "fps_embedding"):
 | 
						||
            layer_list += [self.fps_embedding, self.fps_projection, self.fps_projection[0], self.fps_projection[2]]
 | 
						||
 | 
						||
        if hasattr(self, "vace_patch_embedding"):
 | 
						||
            layer_list += [self.vace_patch_embedding]
 | 
						||
            layer_list += [self.vace_blocks[0].before_proj]
 | 
						||
            for block in self.vace_blocks:
 | 
						||
                layer_list += [block.after_proj, block.norm3]
 | 
						||
 | 
						||
        # cam master
 | 
						||
        if hasattr(self.blocks[0], "projector"):
 | 
						||
            for block in self.blocks:
 | 
						||
                layer_list += [block.projector]
 | 
						||
 | 
						||
        for block in self.blocks:
 | 
						||
            layer_list += [block.norm3]
 | 
						||
        for layer in layer_list:
 | 
						||
            if hasattr(layer, "weight"):
 | 
						||
                if layer.weight.dtype == dtype  :
 | 
						||
                    count += 1
 | 
						||
                elif force:
 | 
						||
                    if hasattr(layer, "weight"):
 | 
						||
                        layer.weight.data = layer.weight.data.to(dtype)
 | 
						||
                    if hasattr(layer, "bias"):
 | 
						||
                        layer.bias.data = layer.bias.data.to(dtype)
 | 
						||
                    count += 1
 | 
						||
 | 
						||
            layer._lock_dtype = dtype 
 | 
						||
 | 
						||
 | 
						||
        if count > 0:
 | 
						||
            self._lock_dtype = dtype
 | 
						||
 | 
						||
 | 
						||
    def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0): 
 | 
						||
        rescale_func = np.poly1d(self.coefficients)         
 | 
						||
        e_list = []
 | 
						||
        for t in timesteps:
 | 
						||
            t = torch.stack([t])
 | 
						||
            time_emb =  self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype) )  # b, dim   
 | 
						||
            e_list.append(time_emb)
 | 
						||
	
 | 
						||
        best_threshold = 0.01
 | 
						||
        best_diff = 1000
 | 
						||
        best_signed_diff = 1000
 | 
						||
        target_nb_steps= int(len(timesteps) / speed_factor)
 | 
						||
        threshold = 0.01
 | 
						||
        while threshold <= 0.6:
 | 
						||
            accumulated_rel_l1_distance =0
 | 
						||
            nb_steps = 0
 | 
						||
            diff = 1000
 | 
						||
            for i, t in enumerate(timesteps):
 | 
						||
                skip = False    
 | 
						||
                if not (i<=start_step or i== len(timesteps)):
 | 
						||
                    accumulated_rel_l1_distance += abs(rescale_func(((e_list[i]-e_list[i-1]).abs().mean() / e_list[i-1].abs().mean()).cpu().item()))
 | 
						||
                    if accumulated_rel_l1_distance < threshold:
 | 
						||
                        skip = True
 | 
						||
                    else:
 | 
						||
                        accumulated_rel_l1_distance = 0
 | 
						||
                if not skip:
 | 
						||
                    nb_steps += 1
 | 
						||
                    signed_diff = target_nb_steps - nb_steps               
 | 
						||
                    diff = abs(signed_diff)  
 | 
						||
            if diff < best_diff:
 | 
						||
                best_threshold = threshold
 | 
						||
                best_diff = diff
 | 
						||
                best_signed_diff = signed_diff
 | 
						||
            elif diff > best_diff:
 | 
						||
                break
 | 
						||
            threshold += 0.01
 | 
						||
        self.rel_l1_thresh = best_threshold
 | 
						||
        print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}")
 | 
						||
        return best_threshold
 | 
						||
 | 
						||
    
 | 
						||
    def forward(
 | 
						||
        self,
 | 
						||
        x,
 | 
						||
        t,
 | 
						||
        context,
 | 
						||
        vace_context = None,
 | 
						||
        vace_context_scale=1.0,        
 | 
						||
        clip_fea=None,
 | 
						||
        y=None,
 | 
						||
        freqs = None,
 | 
						||
        pipeline = None,
 | 
						||
        current_step = 0,
 | 
						||
        is_uncond=False,
 | 
						||
        max_steps = 0, 
 | 
						||
        slg_layers=None,
 | 
						||
        callback = None,
 | 
						||
        cam_emb: torch.Tensor = None,
 | 
						||
        fps = None,
 | 
						||
        causal_block_size = 1,
 | 
						||
        causal_attention = False,
 | 
						||
        x_neg = None
 | 
						||
    ):
 | 
						||
        # dtype =  self.blocks[0].self_attn.q.weight.dtype 
 | 
						||
        dtype =  self.patch_embedding.weight.dtype
 | 
						||
 | 
						||
        if self.model_type == 'i2v':
 | 
						||
            assert clip_fea is not None and y is not None
 | 
						||
        # params
 | 
						||
        device = self.patch_embedding.weight.device
 | 
						||
        if torch.is_tensor(freqs) and freqs.device != device:
 | 
						||
            freqs = freqs.to(device)
 | 
						||
 | 
						||
        if y is not None:
 | 
						||
            x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
 | 
						||
 | 
						||
        # embeddings
 | 
						||
        x = [self.patch_embedding(u.unsqueeze(0)).to(dtype) for u in x] 
 | 
						||
        if x_neg !=None:
 | 
						||
            x_neg = [self.patch_embedding(u.unsqueeze(0)).to(dtype) for u in x_neg] 
 | 
						||
 | 
						||
        grid_sizes = [ list(u.shape[2:]) for u in x]
 | 
						||
        embed_sizes = grid_sizes[0]
 | 
						||
        if causal_attention : #causal_block_size > 0:
 | 
						||
            frame_num = embed_sizes[0]
 | 
						||
            height = embed_sizes[1]
 | 
						||
            width = embed_sizes[2]
 | 
						||
            block_num = frame_num // causal_block_size
 | 
						||
            range_tensor = torch.arange(block_num).view(-1, 1)
 | 
						||
            range_tensor = range_tensor.repeat(1, causal_block_size).flatten()
 | 
						||
            causal_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1)  # f, f
 | 
						||
            causal_mask = causal_mask.view(frame_num, 1, 1, frame_num, 1, 1).to(x[0].device)
 | 
						||
            causal_mask = causal_mask.repeat(1, height, width, 1, height, width)
 | 
						||
            causal_mask = causal_mask.reshape(frame_num * height * width, frame_num * height * width)
 | 
						||
            block_mask = causal_mask.unsqueeze(0).unsqueeze(0)
 | 
						||
            del causal_mask
 | 
						||
 | 
						||
        offload.shared_state["embed_sizes"] = embed_sizes 
 | 
						||
        offload.shared_state["step_no"] = current_step 
 | 
						||
        offload.shared_state["max_steps"] = max_steps
 | 
						||
 | 
						||
        x = [u.flatten(2).transpose(1, 2) for u in x]
 | 
						||
        x = x[0]
 | 
						||
        if x_neg !=None:
 | 
						||
            x_neg = [u.flatten(2).transpose(1, 2) for u in x_neg]
 | 
						||
            x_neg = x_neg[0]
 | 
						||
 | 
						||
        if t.dim() == 2:
 | 
						||
            b, f = t.shape
 | 
						||
            _flag_df = True
 | 
						||
        else:
 | 
						||
            _flag_df = False
 | 
						||
        e = self.time_embedding(
 | 
						||
            sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype)  # self.patch_embedding.weight.dtype)
 | 
						||
        )  # b, dim        
 | 
						||
        e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype)
 | 
						||
 | 
						||
        if self.inject_sample_info:
 | 
						||
            fps = torch.tensor(fps, dtype=torch.long, device=device)
 | 
						||
 | 
						||
            fps_emb = self.fps_embedding(fps).to(dtype) # float()
 | 
						||
            if _flag_df:
 | 
						||
                e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1)
 | 
						||
            else:
 | 
						||
                e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim))
 | 
						||
 | 
						||
        # context
 | 
						||
        context = [self.text_embedding( torch.cat( [u, u.new_zeros(self.text_len - u.size(0), u.size(1))] ).unsqueeze(0) ) for u in context  ] 
 | 
						||
        
 | 
						||
        if clip_fea is not None:
 | 
						||
            context_clip = self.img_emb(clip_fea)  # bs x 257 x dim
 | 
						||
            context = [ torch.cat( [context_clip, u ], dim=1 ) for u in context  ] 
 | 
						||
        
 | 
						||
        joint_pass = len(context) > 0 
 | 
						||
        x_list = [x]
 | 
						||
        if joint_pass:
 | 
						||
            if x_neg == None:
 | 
						||
                x_list +=  [x.clone() for i in  range(len(context) - 1) ]
 | 
						||
            else:
 | 
						||
                x_list +=  [x.clone() for i in  range(len(context) - 2) ] + [x_neg]
 | 
						||
            is_uncond = False
 | 
						||
        del x
 | 
						||
        context_list = context
 | 
						||
 | 
						||
            # arguments
 | 
						||
 | 
						||
        kwargs = dict(
 | 
						||
            grid_sizes=grid_sizes,
 | 
						||
            freqs=freqs,
 | 
						||
            cam_emb = cam_emb
 | 
						||
            )
 | 
						||
 | 
						||
        if vace_context == None:
 | 
						||
            hints_list = [None ] *len(x_list)
 | 
						||
        else:
 | 
						||
            # embeddings
 | 
						||
            c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
 | 
						||
            c = [u.flatten(2).transpose(1, 2) for u in c]
 | 
						||
            c = c[0]
 | 
						||
 
 | 
						||
            kwargs['context_scale'] = vace_context_scale
 | 
						||
            hints_list = [ [c] for _ in range(len(x_list)) ] 
 | 
						||
            del c
 | 
						||
 | 
						||
        should_calc = True
 | 
						||
        if self.enable_teacache: 
 | 
						||
            if is_uncond:
 | 
						||
                should_calc = self.should_calc
 | 
						||
            else:
 | 
						||
                if current_step <= self.teacache_start_step or current_step == self.num_steps-1:
 | 
						||
                    should_calc = True
 | 
						||
                    self.accumulated_rel_l1_distance = 0
 | 
						||
                else:
 | 
						||
                    rescale_func = np.poly1d(self.coefficients)
 | 
						||
                    self.accumulated_rel_l1_distance += abs(rescale_func(((e-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()))
 | 
						||
                    if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
 | 
						||
                        should_calc = False
 | 
						||
                        self.teacache_skipped_steps += 1
 | 
						||
                        # print(f"Teacache Skipped Step:{self.teacache_skipped_steps}/{current_step}" )
 | 
						||
                    else:
 | 
						||
                        should_calc = True
 | 
						||
                        self.accumulated_rel_l1_distance = 0
 | 
						||
                self.previous_modulated_input = e 
 | 
						||
                self.should_calc = should_calc                        
 | 
						||
 | 
						||
        if not should_calc:
 | 
						||
            for i, x in enumerate(x_list):
 | 
						||
                x += self.previous_residual_uncond if i==1 or is_uncond else self.previous_residual_cond                              
 | 
						||
        else:
 | 
						||
            if self.enable_teacache:
 | 
						||
                if joint_pass or is_uncond:
 | 
						||
                    self.previous_residual_uncond = None
 | 
						||
                if joint_pass or not is_uncond:
 | 
						||
                    self.previous_residual_cond = None
 | 
						||
                ori_hidden_states = x_list[0].clone()
 | 
						||
            
 | 
						||
            for block_idx, block in enumerate(self.blocks):
 | 
						||
                offload.shared_state["layer"] = block_idx
 | 
						||
                if callback != None:
 | 
						||
                    callback(-1, None, False, True)
 | 
						||
                if pipeline._interrupt:
 | 
						||
                    return [None] * len(x_list)
 | 
						||
 | 
						||
                if slg_layers is not None and block_idx in slg_layers:
 | 
						||
                    if is_uncond and not joint_pass:
 | 
						||
                        continue
 | 
						||
                    x_list[0] = block(x_list[0], context = context_list[0], e= e0, **kwargs)
 | 
						||
 | 
						||
                else:
 | 
						||
                    for i, (x, context, hints) in enumerate(zip(x_list, context_list, hints_list)):
 | 
						||
                        x_list[i] = block(x, context = context, hints= hints, e= e0, **kwargs)
 | 
						||
                        del x
 | 
						||
                    del context, hints
 | 
						||
 | 
						||
            if self.enable_teacache:
 | 
						||
                if joint_pass:
 | 
						||
                    self.previous_residual_cond = torch.sub(x_list[0], ori_hidden_states)
 | 
						||
                    self.previous_residual_uncond = ori_hidden_states
 | 
						||
                    torch.sub(x_list[1], ori_hidden_states, out=self.previous_residual_uncond)
 | 
						||
                else:
 | 
						||
                    residual = ori_hidden_states # just to have a readable code
 | 
						||
                    torch.sub(x_list[0], ori_hidden_states, out=residual)
 | 
						||
                    if i==1 or is_uncond:
 | 
						||
                        self.previous_residual_uncond = residual
 | 
						||
                    else:
 | 
						||
                        self.previous_residual_cond = residual
 | 
						||
                residual, ori_hidden_states = None, None
 | 
						||
 | 
						||
        for i, x in enumerate(x_list):
 | 
						||
            # head
 | 
						||
            x = self.head(x, e)
 | 
						||
 | 
						||
            # unpatchify
 | 
						||
            x_list[i] = self.unpatchify(x, grid_sizes)
 | 
						||
            del x
 | 
						||
 | 
						||
        return [x[0].float() for x in x_list]
 | 
						||
 | 
						||
    def unpatchify(self, x, grid_sizes):
 | 
						||
        r"""
 | 
						||
        Reconstruct video tensors from patch embeddings.
 | 
						||
 | 
						||
        Args:
 | 
						||
            x (List[Tensor]):
 | 
						||
                List of patchified features, each with shape [L, C_out * prod(patch_size)]
 | 
						||
            grid_sizes (Tensor):
 | 
						||
                Original spatial-temporal grid dimensions before patching,
 | 
						||
                    shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
 | 
						||
 | 
						||
        Returns:
 | 
						||
            List[Tensor]:
 | 
						||
                Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
 | 
						||
        """
 | 
						||
 | 
						||
        c = self.out_dim
 | 
						||
        out = []
 | 
						||
        for u, v in zip(x, grid_sizes):
 | 
						||
            u = u[:math.prod(v)].view(*v, *self.patch_size, c)
 | 
						||
            u = torch.einsum('fhwpqrc->cfphqwr', u)
 | 
						||
            u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
 | 
						||
            out.append(u)
 | 
						||
        return out
 | 
						||
 | 
						||
    def init_weights(self):
 | 
						||
        r"""
 | 
						||
        Initialize model parameters using Xavier initialization.
 | 
						||
        """
 | 
						||
 | 
						||
        # basic init
 | 
						||
        for m in self.modules():
 | 
						||
            if isinstance(m, nn.Linear):
 | 
						||
                nn.init.xavier_uniform_(m.weight)
 | 
						||
                if m.bias is not None:
 | 
						||
                    nn.init.zeros_(m.bias)
 | 
						||
 | 
						||
        # init embeddings
 | 
						||
        nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
 | 
						||
        for m in self.text_embedding.modules():
 | 
						||
            if isinstance(m, nn.Linear):
 | 
						||
                nn.init.normal_(m.weight, std=.02)
 | 
						||
        for m in self.time_embedding.modules():
 | 
						||
            if isinstance(m, nn.Linear):
 | 
						||
                nn.init.normal_(m.weight, std=.02)
 | 
						||
 | 
						||
        # init output layer
 | 
						||
        nn.init.zeros_(self.head.head.weight)
 |