mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			253 lines
		
	
	
		
			9.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			253 lines
		
	
	
		
			9.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
						|
# All rights reserved.
 | 
						|
#
 | 
						|
# This source code is licensed under the license found in the
 | 
						|
# LICENSE file in the root directory of this source tree.
 | 
						|
 | 
						|
# References:
 | 
						|
#   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
 | 
						|
#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
 | 
						|
 | 
						|
import logging
 | 
						|
from typing import Callable, List, Any, Tuple, Dict
 | 
						|
 | 
						|
import torch
 | 
						|
from torch import nn, Tensor
 | 
						|
 | 
						|
from .attention import Attention, MemEffAttention
 | 
						|
from .drop_path import DropPath
 | 
						|
from .layer_scale import LayerScale
 | 
						|
from .mlp import Mlp
 | 
						|
 | 
						|
 | 
						|
logger = logging.getLogger("dinov2")
 | 
						|
 | 
						|
 | 
						|
try:
 | 
						|
    from xformers.ops import fmha
 | 
						|
    from xformers.ops import scaled_index_add, index_select_cat
 | 
						|
 | 
						|
    XFORMERS_AVAILABLE = True
 | 
						|
except ImportError:
 | 
						|
    # logger.warning("xFormers not available")
 | 
						|
    XFORMERS_AVAILABLE = False
 | 
						|
 | 
						|
 | 
						|
class Block(nn.Module):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        dim: int,
 | 
						|
        num_heads: int,
 | 
						|
        mlp_ratio: float = 4.0,
 | 
						|
        qkv_bias: bool = False,
 | 
						|
        proj_bias: bool = True,
 | 
						|
        ffn_bias: bool = True,
 | 
						|
        drop: float = 0.0,
 | 
						|
        attn_drop: float = 0.0,
 | 
						|
        init_values=None,
 | 
						|
        drop_path: float = 0.0,
 | 
						|
        act_layer: Callable[..., nn.Module] = nn.GELU,
 | 
						|
        norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
 | 
						|
        attn_class: Callable[..., nn.Module] = Attention,
 | 
						|
        ffn_layer: Callable[..., nn.Module] = Mlp,
 | 
						|
    ) -> None:
 | 
						|
        super().__init__()
 | 
						|
        # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
 | 
						|
        self.norm1 = norm_layer(dim)
 | 
						|
        self.attn = attn_class(
 | 
						|
            dim,
 | 
						|
            num_heads=num_heads,
 | 
						|
            qkv_bias=qkv_bias,
 | 
						|
            proj_bias=proj_bias,
 | 
						|
            attn_drop=attn_drop,
 | 
						|
            proj_drop=drop,
 | 
						|
        )
 | 
						|
        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
 | 
						|
        self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
 | 
						|
 | 
						|
        self.norm2 = norm_layer(dim)
 | 
						|
        mlp_hidden_dim = int(dim * mlp_ratio)
 | 
						|
        self.mlp = ffn_layer(
 | 
						|
            in_features=dim,
 | 
						|
            hidden_features=mlp_hidden_dim,
 | 
						|
            act_layer=act_layer,
 | 
						|
            drop=drop,
 | 
						|
            bias=ffn_bias,
 | 
						|
        )
 | 
						|
        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
 | 
						|
        self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
 | 
						|
 | 
						|
        self.sample_drop_ratio = drop_path
 | 
						|
 | 
						|
    def forward(self, x: Tensor) -> Tensor:
 | 
						|
        def attn_residual_func(x: Tensor) -> Tensor:
 | 
						|
            return self.ls1(self.attn(self.norm1(x)))
 | 
						|
 | 
						|
        def ffn_residual_func(x: Tensor) -> Tensor:
 | 
						|
            return self.ls2(self.mlp(self.norm2(x)))
 | 
						|
 | 
						|
        if self.training and self.sample_drop_ratio > 0.1:
 | 
						|
            # the overhead is compensated only for a drop path rate larger than 0.1
 | 
						|
            x = drop_add_residual_stochastic_depth(
 | 
						|
                x,
 | 
						|
                residual_func=attn_residual_func,
 | 
						|
                sample_drop_ratio=self.sample_drop_ratio,
 | 
						|
            )
 | 
						|
            x = drop_add_residual_stochastic_depth(
 | 
						|
                x,
 | 
						|
                residual_func=ffn_residual_func,
 | 
						|
                sample_drop_ratio=self.sample_drop_ratio,
 | 
						|
            )
 | 
						|
        elif self.training and self.sample_drop_ratio > 0.0:
 | 
						|
            x = x + self.drop_path1(attn_residual_func(x))
 | 
						|
            x = x + self.drop_path1(ffn_residual_func(x))  # FIXME: drop_path2
 | 
						|
        else:
 | 
						|
            x = x + attn_residual_func(x)
 | 
						|
            x = x + ffn_residual_func(x)
 | 
						|
        return x
 | 
						|
 | 
						|
 | 
						|
def drop_add_residual_stochastic_depth(
 | 
						|
    x: Tensor,
 | 
						|
    residual_func: Callable[[Tensor], Tensor],
 | 
						|
    sample_drop_ratio: float = 0.0,
 | 
						|
) -> Tensor:
 | 
						|
    # 1) extract subset using permutation
 | 
						|
    b, n, d = x.shape
 | 
						|
    sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
 | 
						|
    brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
 | 
						|
    x_subset = x[brange]
 | 
						|
 | 
						|
    # 2) apply residual_func to get residual
 | 
						|
    residual = residual_func(x_subset)
 | 
						|
 | 
						|
    x_flat = x.flatten(1)
 | 
						|
    residual = residual.flatten(1)
 | 
						|
 | 
						|
    residual_scale_factor = b / sample_subset_size
 | 
						|
 | 
						|
    # 3) add the residual
 | 
						|
    x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
 | 
						|
    return x_plus_residual.view_as(x)
 | 
						|
 | 
						|
 | 
						|
def get_branges_scales(x, sample_drop_ratio=0.0):
 | 
						|
    b, n, d = x.shape
 | 
						|
    sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
 | 
						|
    brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
 | 
						|
    residual_scale_factor = b / sample_subset_size
 | 
						|
    return brange, residual_scale_factor
 | 
						|
 | 
						|
 | 
						|
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
 | 
						|
    if scaling_vector is None:
 | 
						|
        x_flat = x.flatten(1)
 | 
						|
        residual = residual.flatten(1)
 | 
						|
        x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
 | 
						|
    else:
 | 
						|
        x_plus_residual = scaled_index_add(
 | 
						|
            x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
 | 
						|
        )
 | 
						|
    return x_plus_residual
 | 
						|
 | 
						|
 | 
						|
attn_bias_cache: Dict[Tuple, Any] = {}
 | 
						|
 | 
						|
 | 
						|
def get_attn_bias_and_cat(x_list, branges=None):
 | 
						|
    """
 | 
						|
    this will perform the index select, cat the tensors, and provide the attn_bias from cache
 | 
						|
    """
 | 
						|
    batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
 | 
						|
    all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
 | 
						|
    if all_shapes not in attn_bias_cache.keys():
 | 
						|
        seqlens = []
 | 
						|
        for b, x in zip(batch_sizes, x_list):
 | 
						|
            for _ in range(b):
 | 
						|
                seqlens.append(x.shape[1])
 | 
						|
        attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
 | 
						|
        attn_bias._batch_sizes = batch_sizes
 | 
						|
        attn_bias_cache[all_shapes] = attn_bias
 | 
						|
 | 
						|
    if branges is not None:
 | 
						|
        cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
 | 
						|
    else:
 | 
						|
        tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
 | 
						|
        cat_tensors = torch.cat(tensors_bs1, dim=1)
 | 
						|
 | 
						|
    return attn_bias_cache[all_shapes], cat_tensors
 | 
						|
 | 
						|
 | 
						|
def drop_add_residual_stochastic_depth_list(
 | 
						|
    x_list: List[Tensor],
 | 
						|
    residual_func: Callable[[Tensor, Any], Tensor],
 | 
						|
    sample_drop_ratio: float = 0.0,
 | 
						|
    scaling_vector=None,
 | 
						|
) -> Tensor:
 | 
						|
    # 1) generate random set of indices for dropping samples in the batch
 | 
						|
    branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
 | 
						|
    branges = [s[0] for s in branges_scales]
 | 
						|
    residual_scale_factors = [s[1] for s in branges_scales]
 | 
						|
 | 
						|
    # 2) get attention bias and index+concat the tensors
 | 
						|
    attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
 | 
						|
 | 
						|
    # 3) apply residual_func to get residual, and split the result
 | 
						|
    residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias))  # type: ignore
 | 
						|
 | 
						|
    outputs = []
 | 
						|
    for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
 | 
						|
        outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
 | 
						|
    return outputs
 | 
						|
 | 
						|
 | 
						|
class NestedTensorBlock(Block):
 | 
						|
    def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
 | 
						|
        """
 | 
						|
        x_list contains a list of tensors to nest together and run
 | 
						|
        """
 | 
						|
        assert isinstance(self.attn, MemEffAttention)
 | 
						|
 | 
						|
        if self.training and self.sample_drop_ratio > 0.0:
 | 
						|
 | 
						|
            def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
 | 
						|
                return self.attn(self.norm1(x), attn_bias=attn_bias)
 | 
						|
 | 
						|
            def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
 | 
						|
                return self.mlp(self.norm2(x))
 | 
						|
 | 
						|
            x_list = drop_add_residual_stochastic_depth_list(
 | 
						|
                x_list,
 | 
						|
                residual_func=attn_residual_func,
 | 
						|
                sample_drop_ratio=self.sample_drop_ratio,
 | 
						|
                scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
 | 
						|
            )
 | 
						|
            x_list = drop_add_residual_stochastic_depth_list(
 | 
						|
                x_list,
 | 
						|
                residual_func=ffn_residual_func,
 | 
						|
                sample_drop_ratio=self.sample_drop_ratio,
 | 
						|
                scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
 | 
						|
            )
 | 
						|
            return x_list
 | 
						|
        else:
 | 
						|
 | 
						|
            def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
 | 
						|
                return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
 | 
						|
 | 
						|
            def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
 | 
						|
                return self.ls2(self.mlp(self.norm2(x)))
 | 
						|
 | 
						|
            attn_bias, x = get_attn_bias_and_cat(x_list)
 | 
						|
            x = x + attn_residual_func(x, attn_bias=attn_bias)
 | 
						|
            x = x + ffn_residual_func(x)
 | 
						|
            return attn_bias.split(x)
 | 
						|
 | 
						|
    def forward(self, x_or_x_list):
 | 
						|
        if isinstance(x_or_x_list, Tensor):
 | 
						|
            return super().forward(x_or_x_list)
 | 
						|
        elif isinstance(x_or_x_list, list):
 | 
						|
            assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
 | 
						|
            return self.forward_nested(x_or_x_list)
 | 
						|
        else:
 | 
						|
            raise AssertionError
 |