mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			487 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			487 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import torch
 | 
						||
from typing import Union, Tuple, List, Optional
 | 
						||
import numpy as np
 | 
						||
 | 
						||
 | 
						||
###### Thanks to the RifleX project (https://github.com/thu-ml/RIFLEx/) for this alternative pos embed for long videos
 | 
						||
#  
 | 
						||
def get_1d_rotary_pos_embed_riflex(
 | 
						||
    dim: int,
 | 
						||
    pos: Union[np.ndarray, int],
 | 
						||
    theta: float = 10000.0,
 | 
						||
    use_real=False,
 | 
						||
    k: Optional[int] = None,
 | 
						||
    L_test: Optional[int] = None,
 | 
						||
):
 | 
						||
    """
 | 
						||
    RIFLEx: Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
 | 
						||
 | 
						||
    This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
 | 
						||
    index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
 | 
						||
    data type.
 | 
						||
 | 
						||
    Args:
 | 
						||
        dim (`int`): Dimension of the frequency tensor.
 | 
						||
        pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
 | 
						||
        theta (`float`, *optional*, defaults to 10000.0):
 | 
						||
            Scaling factor for frequency computation. Defaults to 10000.0.
 | 
						||
        use_real (`bool`, *optional*):
 | 
						||
            If True, return real part and imaginary part separately. Otherwise, return complex numbers.
 | 
						||
        k (`int`, *optional*, defaults to None): the index for the intrinsic frequency in RoPE
 | 
						||
        L_test (`int`, *optional*, defaults to None): the number of frames for inference
 | 
						||
    Returns:
 | 
						||
        `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
 | 
						||
    """
 | 
						||
    assert dim % 2 == 0
 | 
						||
 | 
						||
    if isinstance(pos, int):
 | 
						||
        pos = torch.arange(pos)
 | 
						||
    if isinstance(pos, np.ndarray):
 | 
						||
        pos = torch.from_numpy(pos)  # type: ignore  # [S]
 | 
						||
 | 
						||
    freqs = 1.0 / (
 | 
						||
            theta ** (torch.arange(0, dim, 2, device=pos.device)[: (dim // 2)].float() / dim)
 | 
						||
    )  # [D/2]
 | 
						||
 | 
						||
    # === Riflex modification start ===
 | 
						||
    # Reduce the intrinsic frequency to stay within a single period after extrapolation (see Eq. (8)).
 | 
						||
    # Empirical observations show that a few videos may exhibit repetition in the tail frames.
 | 
						||
    # To be conservative, we multiply by 0.9 to keep the extrapolated length below 90% of a single period.
 | 
						||
    if k is not None:
 | 
						||
        freqs[k-1] = 0.9 * 2 * torch.pi / L_test
 | 
						||
    # === Riflex modification end ===
 | 
						||
 | 
						||
    freqs = torch.outer(pos, freqs)  # type: ignore   # [S, D/2]
 | 
						||
    if use_real:
 | 
						||
        freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float()  # [S, D]
 | 
						||
        freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float()  # [S, D]
 | 
						||
        return freqs_cos, freqs_sin
 | 
						||
    else:
 | 
						||
        # lumina
 | 
						||
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64     # [S, D/2]
 | 
						||
        return freqs_cis
 | 
						||
 | 
						||
def identify_k( b: float, d: int, N: int):
 | 
						||
    """
 | 
						||
    This function identifies the index of the intrinsic frequency component in a RoPE-based pre-trained diffusion transformer.
 | 
						||
 | 
						||
    Args:
 | 
						||
        b (`float`): The base frequency for RoPE.
 | 
						||
        d (`int`): Dimension of the frequency tensor
 | 
						||
        N (`int`): the first observed repetition frame in latent space
 | 
						||
    Returns:
 | 
						||
        k (`int`): the index of intrinsic frequency component
 | 
						||
        N_k (`int`): the period of intrinsic frequency component in latent space
 | 
						||
    Example:
 | 
						||
        In HunyuanVideo, b=256 and d=16, the repetition occurs approximately 8s (N=48 in latent space).
 | 
						||
        k, N_k = identify_k(b=256, d=16, N=48)
 | 
						||
        In this case, the intrinsic frequency index k is 4, and the period N_k is 50.
 | 
						||
    """
 | 
						||
 | 
						||
    # Compute the period of each frequency in RoPE according to Eq.(4)
 | 
						||
    periods = []
 | 
						||
    for j in range(1, d // 2 + 1):
 | 
						||
        theta_j = 1.0 / (b ** (2 * (j - 1) / d))
 | 
						||
        N_j = round(2 * torch.pi / theta_j)
 | 
						||
        periods.append(N_j)
 | 
						||
 | 
						||
    # Identify the intrinsic frequency whose period is closed to N(see Eq.(7))
 | 
						||
    diffs = [abs(N_j - N) for N_j in periods]
 | 
						||
    k = diffs.index(min(diffs)) + 1
 | 
						||
    N_k = periods[k-1]
 | 
						||
    return k, N_k
 | 
						||
 | 
						||
def _to_tuple(x, dim=2):
 | 
						||
    if isinstance(x, int):
 | 
						||
        return (x,) * dim
 | 
						||
    elif len(x) == dim:
 | 
						||
        return x
 | 
						||
    else:
 | 
						||
        raise ValueError(f"Expected length {dim} or int, but got {x}")
 | 
						||
 | 
						||
 | 
						||
def get_meshgrid_nd(start, *args, dim=2):
 | 
						||
    """
 | 
						||
    Get n-D meshgrid with start, stop and num.
 | 
						||
 | 
						||
    Args:
 | 
						||
        start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
 | 
						||
            step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
 | 
						||
            should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
 | 
						||
            n-tuples.
 | 
						||
        *args: See above.
 | 
						||
        dim (int): Dimension of the meshgrid. Defaults to 2.
 | 
						||
 | 
						||
    Returns:
 | 
						||
        grid (np.ndarray): [dim, ...]
 | 
						||
    """
 | 
						||
    if len(args) == 0:
 | 
						||
        # start is grid_size
 | 
						||
        num = _to_tuple(start, dim=dim)
 | 
						||
        start = (0,) * dim
 | 
						||
        stop = num
 | 
						||
    elif len(args) == 1:
 | 
						||
        # start is start, args[0] is stop, step is 1
 | 
						||
        start = _to_tuple(start, dim=dim)
 | 
						||
        stop = _to_tuple(args[0], dim=dim)
 | 
						||
        num = [stop[i] - start[i] for i in range(dim)]
 | 
						||
    elif len(args) == 2:
 | 
						||
        # start is start, args[0] is stop, args[1] is num
 | 
						||
        start = _to_tuple(start, dim=dim)  # Left-Top       eg: 12,0
 | 
						||
        stop = _to_tuple(args[0], dim=dim)  # Right-Bottom   eg: 20,32
 | 
						||
        num = _to_tuple(args[1], dim=dim)  # Target Size    eg: 32,124
 | 
						||
    else:
 | 
						||
        raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
 | 
						||
 | 
						||
    # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
 | 
						||
    axis_grid = []
 | 
						||
    for i in range(dim):
 | 
						||
        a, b, n = start[i], stop[i], num[i]
 | 
						||
        g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
 | 
						||
        axis_grid.append(g)
 | 
						||
    grid = torch.meshgrid(*axis_grid, indexing="ij")  # dim x [W, H, D]
 | 
						||
    grid = torch.stack(grid, dim=0)  # [dim, W, H, D]
 | 
						||
 | 
						||
    return grid
 | 
						||
 | 
						||
 | 
						||
#################################################################################
 | 
						||
#                   Rotary Positional Embedding Functions                       #
 | 
						||
#################################################################################
 | 
						||
# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
 | 
						||
 | 
						||
 | 
						||
def reshape_for_broadcast(
 | 
						||
    freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
 | 
						||
    x: torch.Tensor,
 | 
						||
    head_first=False,
 | 
						||
):
 | 
						||
    """
 | 
						||
    Reshape frequency tensor for broadcasting it with another tensor.
 | 
						||
 | 
						||
    This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
 | 
						||
    for the purpose of broadcasting the frequency tensor during element-wise operations.
 | 
						||
 | 
						||
    Notes:
 | 
						||
        When using FlashMHAModified, head_first should be False.
 | 
						||
        When using Attention, head_first should be True.
 | 
						||
 | 
						||
    Args:
 | 
						||
        freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
 | 
						||
        x (torch.Tensor): Target tensor for broadcasting compatibility.
 | 
						||
        head_first (bool): head dimension first (except batch dim) or not.
 | 
						||
 | 
						||
    Returns:
 | 
						||
        torch.Tensor: Reshaped frequency tensor.
 | 
						||
 | 
						||
    Raises:
 | 
						||
        AssertionError: If the frequency tensor doesn't match the expected shape.
 | 
						||
        AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
 | 
						||
    """
 | 
						||
    ndim = x.ndim
 | 
						||
    assert 0 <= 1 < ndim
 | 
						||
 | 
						||
    if isinstance(freqs_cis, tuple):
 | 
						||
        # freqs_cis: (cos, sin) in real space
 | 
						||
        if head_first:
 | 
						||
            assert freqs_cis[0].shape == (
 | 
						||
                x.shape[-2],
 | 
						||
                x.shape[-1],
 | 
						||
            ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
 | 
						||
            shape = [
 | 
						||
                d if i == ndim - 2 or i == ndim - 1 else 1
 | 
						||
                for i, d in enumerate(x.shape)
 | 
						||
            ]
 | 
						||
        else:
 | 
						||
            assert freqs_cis[0].shape == (
 | 
						||
                x.shape[1],
 | 
						||
                x.shape[-1],
 | 
						||
            ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
 | 
						||
            shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
 | 
						||
        return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
 | 
						||
    else:
 | 
						||
        # freqs_cis: values in complex space
 | 
						||
        if head_first:
 | 
						||
            assert freqs_cis.shape == (
 | 
						||
                x.shape[-2],
 | 
						||
                x.shape[-1],
 | 
						||
            ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
 | 
						||
            shape = [
 | 
						||
                d if i == ndim - 2 or i == ndim - 1 else 1
 | 
						||
                for i, d in enumerate(x.shape)
 | 
						||
            ]
 | 
						||
        else:
 | 
						||
            assert freqs_cis.shape == (
 | 
						||
                x.shape[1],
 | 
						||
                x.shape[-1],
 | 
						||
            ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
 | 
						||
            shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
 | 
						||
        return freqs_cis.view(*shape)
 | 
						||
 | 
						||
 | 
						||
def rotate_half(x):
 | 
						||
    x_real, x_imag = (
 | 
						||
        x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
 | 
						||
    )  # [B, S, H, D//2]
 | 
						||
    return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
 | 
						||
 | 
						||
 | 
						||
def apply_rotary_emb( qklist,
 | 
						||
    freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
 | 
						||
    head_first: bool = False,
 | 
						||
) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
						||
    """
 | 
						||
    Apply rotary embeddings to input tensors using the given frequency tensor.
 | 
						||
 | 
						||
    This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
 | 
						||
    frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
 | 
						||
    is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
 | 
						||
    returned as real tensors.
 | 
						||
 | 
						||
    Args:
 | 
						||
        xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
 | 
						||
        xk (torch.Tensor): Key tensor to apply rotary embeddings.   [B, S, H, D]
 | 
						||
        freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
 | 
						||
        head_first (bool): head dimension first (except batch dim) or not.
 | 
						||
 | 
						||
    Returns:
 | 
						||
        Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
 | 
						||
 | 
						||
    """
 | 
						||
    xq, xk = qklist
 | 
						||
    qklist.clear()
 | 
						||
    xk_out = None
 | 
						||
    if isinstance(freqs_cis, tuple):
 | 
						||
        cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first)  # [S, D]
 | 
						||
        cos, sin = cos.to(xq.device), sin.to(xq.device)
 | 
						||
        # real * cos - imag * sin
 | 
						||
        # imag * cos + real * sin
 | 
						||
        xq_dtype = xq.dtype
 | 
						||
        xq_out = xq.to(torch.float)
 | 
						||
        xq = None        
 | 
						||
        xq_rot = rotate_half(xq_out)
 | 
						||
        xq_out *= cos
 | 
						||
        xq_rot *= sin
 | 
						||
        xq_out += xq_rot
 | 
						||
        del xq_rot
 | 
						||
        xq_out = xq_out.to(xq_dtype)
 | 
						||
 | 
						||
        xk_out = xk.to(torch.float)
 | 
						||
        xk = None
 | 
						||
        xk_rot = rotate_half(xk_out)
 | 
						||
        xk_out *= cos
 | 
						||
        xk_rot *= sin
 | 
						||
        xk_out += xk_rot
 | 
						||
        del xk_rot
 | 
						||
        xk_out = xk_out.to(xq_dtype)
 | 
						||
    else:
 | 
						||
        # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
 | 
						||
        xq_ = torch.view_as_complex(
 | 
						||
            xq.float().reshape(*xq.shape[:-1], -1, 2)
 | 
						||
        )  # [B, S, H, D//2]
 | 
						||
        freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
 | 
						||
            xq.device
 | 
						||
        )  # [S, D//2] --> [1, S, 1, D//2]
 | 
						||
        # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
 | 
						||
        # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
 | 
						||
        xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
 | 
						||
        xk_ = torch.view_as_complex(
 | 
						||
            xk.float().reshape(*xk.shape[:-1], -1, 2)
 | 
						||
        )  # [B, S, H, D//2]
 | 
						||
        xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
 | 
						||
 | 
						||
    return xq_out, xk_out
 | 
						||
 | 
						||
def get_nd_rotary_pos_embed_new(rope_dim_list, start, *args, theta=10000., use_real=False, 
 | 
						||
                            theta_rescale_factor: Union[float, List[float]]=1.0,
 | 
						||
                            interpolation_factor: Union[float, List[float]]=1.0,
 | 
						||
                            concat_dict={},
 | 
						||
                            k = 4,
 | 
						||
                            L_test = 66,
 | 
						||
                            enable_riflex = True
 | 
						||
                            ):
 | 
						||
 | 
						||
    grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list))   # [3, W, H, D] / [2, W, H]
 | 
						||
    if len(concat_dict)<1:
 | 
						||
        pass
 | 
						||
    else:
 | 
						||
        if concat_dict['mode']=='timecat':
 | 
						||
            bias = grid[:,:1].clone()
 | 
						||
            bias[0] = concat_dict['bias']*torch.ones_like(bias[0])
 | 
						||
            grid = torch.cat([bias, grid], dim=1)
 | 
						||
            
 | 
						||
        elif concat_dict['mode']=='timecat-w': 
 | 
						||
            bias = grid[:,:1].clone()
 | 
						||
            bias[0] = concat_dict['bias']*torch.ones_like(bias[0])
 | 
						||
            bias[2] += start[-1]    ## ref https://github.com/Yuanshi9815/OminiControl/blob/main/src/generate.py#L178
 | 
						||
            grid = torch.cat([bias, grid], dim=1)
 | 
						||
    if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
 | 
						||
        theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
 | 
						||
    elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
 | 
						||
        theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
 | 
						||
    assert len(theta_rescale_factor) == len(rope_dim_list), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
 | 
						||
 | 
						||
    if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
 | 
						||
        interpolation_factor = [interpolation_factor] * len(rope_dim_list)
 | 
						||
    elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
 | 
						||
        interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
 | 
						||
    assert len(interpolation_factor) == len(rope_dim_list), "len(interpolation_factor) should equal to len(rope_dim_list)"
 | 
						||
 | 
						||
    # use 1/ndim of dimensions to encode grid_axis
 | 
						||
    embs = []
 | 
						||
    for i in range(len(rope_dim_list)):
 | 
						||
        # === RIFLEx modification start ===
 | 
						||
        # apply RIFLEx for time dimension
 | 
						||
        if i == 0 and enable_riflex:
 | 
						||
            emb = get_1d_rotary_pos_embed_riflex(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=True, k=k, L_test=L_test)
 | 
						||
        # === RIFLEx modification end ===
 | 
						||
        else:
 | 
						||
            emb = get_1d_rotary_pos_embed(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=True, theta_rescale_factor=theta_rescale_factor[i],interpolation_factor=interpolation_factor[i],)
 | 
						||
 | 
						||
        # emb = get_1d_rotary_pos_embed(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=use_real,
 | 
						||
        #                               theta_rescale_factor=theta_rescale_factor[i],
 | 
						||
        #                        w       interpolation_factor=interpolation_factor[i])    # 2 x [WHD, rope_dim_list[i]]
 | 
						||
        
 | 
						||
        embs.append(emb)
 | 
						||
 | 
						||
    if use_real:
 | 
						||
        cos = torch.cat([emb[0] for emb in embs], dim=1)    # (WHD, D/2)
 | 
						||
        sin = torch.cat([emb[1] for emb in embs], dim=1)    # (WHD, D/2)
 | 
						||
        return cos, sin
 | 
						||
    else:
 | 
						||
        emb = torch.cat(embs, dim=1)    # (WHD, D/2)
 | 
						||
        return emb
 | 
						||
    
 | 
						||
def get_nd_rotary_pos_embed(
 | 
						||
    rope_dim_list,
 | 
						||
    start,
 | 
						||
    *args,
 | 
						||
    theta=10000.0,
 | 
						||
    use_real=False,
 | 
						||
    theta_rescale_factor: Union[float, List[float]] = 1.0,
 | 
						||
    interpolation_factor: Union[float, List[float]] = 1.0,
 | 
						||
    k = 4,
 | 
						||
    L_test = 66,
 | 
						||
    enable_riflex = True
 | 
						||
):
 | 
						||
    """
 | 
						||
    This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
 | 
						||
 | 
						||
    Args:
 | 
						||
        rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
 | 
						||
            sum(rope_dim_list) should equal to head_dim of attention layer.
 | 
						||
        start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
 | 
						||
            args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
 | 
						||
        *args: See above.
 | 
						||
        theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
 | 
						||
        use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
 | 
						||
            Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
 | 
						||
            part and an imaginary part separately.
 | 
						||
        theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
 | 
						||
 | 
						||
    Returns:
 | 
						||
        pos_embed (torch.Tensor): [HW, D/2]
 | 
						||
    """
 | 
						||
 | 
						||
    grid = get_meshgrid_nd(
 | 
						||
        start, *args, dim=len(rope_dim_list)
 | 
						||
    )  # [3, W, H, D] / [2, W, H]
 | 
						||
 | 
						||
    if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
 | 
						||
        theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
 | 
						||
    elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
 | 
						||
        theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
 | 
						||
    assert len(theta_rescale_factor) == len(
 | 
						||
        rope_dim_list
 | 
						||
    ), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
 | 
						||
 | 
						||
    if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
 | 
						||
        interpolation_factor = [interpolation_factor] * len(rope_dim_list)
 | 
						||
    elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
 | 
						||
        interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
 | 
						||
    assert len(interpolation_factor) == len(
 | 
						||
        rope_dim_list
 | 
						||
    ), "len(interpolation_factor) should equal to len(rope_dim_list)"
 | 
						||
 | 
						||
    # use 1/ndim of dimensions to encode grid_axis
 | 
						||
    embs = []
 | 
						||
    for i in range(len(rope_dim_list)):
 | 
						||
        # emb = get_1d_rotary_pos_embed(
 | 
						||
        #     rope_dim_list[i],
 | 
						||
        #     grid[i].reshape(-1),
 | 
						||
        #     theta,
 | 
						||
        #     use_real=use_real,
 | 
						||
        #     theta_rescale_factor=theta_rescale_factor[i],
 | 
						||
        #     interpolation_factor=interpolation_factor[i],
 | 
						||
        # )  # 2 x [WHD, rope_dim_list[i]]
 | 
						||
 | 
						||
 | 
						||
        # === RIFLEx modification start ===
 | 
						||
        # apply RIFLEx for time dimension
 | 
						||
        if i == 0 and enable_riflex:
 | 
						||
            emb = get_1d_rotary_pos_embed_riflex(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=True, k=k, L_test=L_test)
 | 
						||
        # === RIFLEx modification end ===
 | 
						||
        else:
 | 
						||
            emb = get_1d_rotary_pos_embed(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=True, theta_rescale_factor=theta_rescale_factor[i],interpolation_factor=interpolation_factor[i],)
 | 
						||
        embs.append(emb)
 | 
						||
 | 
						||
    if use_real:
 | 
						||
        cos = torch.cat([emb[0] for emb in embs], dim=1)  # (WHD, D/2)
 | 
						||
        sin = torch.cat([emb[1] for emb in embs], dim=1)  # (WHD, D/2)
 | 
						||
        return cos, sin
 | 
						||
    else:
 | 
						||
        emb = torch.cat(embs, dim=1)  # (WHD, D/2)
 | 
						||
        return emb
 | 
						||
 | 
						||
 | 
						||
def get_1d_rotary_pos_embed(
 | 
						||
    dim: int,
 | 
						||
    pos: Union[torch.FloatTensor, int],
 | 
						||
    theta: float = 10000.0,
 | 
						||
    use_real: bool = False,
 | 
						||
    theta_rescale_factor: float = 1.0,
 | 
						||
    interpolation_factor: float = 1.0,
 | 
						||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
 | 
						||
    """
 | 
						||
    Precompute the frequency tensor for complex exponential (cis) with given dimensions.
 | 
						||
    (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
 | 
						||
 | 
						||
    This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
 | 
						||
    and the end index 'end'. The 'theta' parameter scales the frequencies.
 | 
						||
    The returned tensor contains complex values in complex64 data type.
 | 
						||
 | 
						||
    Args:
 | 
						||
        dim (int): Dimension of the frequency tensor.
 | 
						||
        pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
 | 
						||
        theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
 | 
						||
        use_real (bool, optional): If True, return real part and imaginary part separately.
 | 
						||
                                   Otherwise, return complex numbers.
 | 
						||
        theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
 | 
						||
 | 
						||
    Returns:
 | 
						||
        freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
 | 
						||
        freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
 | 
						||
    """
 | 
						||
    if isinstance(pos, int):
 | 
						||
        pos = torch.arange(pos).float()
 | 
						||
 | 
						||
    # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
 | 
						||
    # has some connection to NTK literature
 | 
						||
    if theta_rescale_factor != 1.0:
 | 
						||
        theta *= theta_rescale_factor ** (dim / (dim - 2))
 | 
						||
 | 
						||
    freqs = 1.0 / (
 | 
						||
        theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
 | 
						||
    )  # [D/2]
 | 
						||
    # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
 | 
						||
    freqs = torch.outer(pos * interpolation_factor, freqs)  # [S, D/2]
 | 
						||
    if use_real:
 | 
						||
        freqs_cos = freqs.cos().repeat_interleave(2, dim=1)  # [S, D]
 | 
						||
        freqs_sin = freqs.sin().repeat_interleave(2, dim=1)  # [S, D]
 | 
						||
        return freqs_cos, freqs_sin
 | 
						||
    else:
 | 
						||
        freqs_cis = torch.polar(
 | 
						||
            torch.ones_like(freqs), freqs
 | 
						||
        )  # complex64     # [S, D/2]
 | 
						||
        return freqs_cis
 |