From f9ce97a1ba63a43e1e3c5e96fe5bdec2111d128a Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Sat, 8 Mar 2025 16:37:21 +0100 Subject: [PATCH] Fixed pytorch compilation --- README.md | 1 + gradio_server.py | 2 +- wan/image2video.py | 5 +- wan/modules/model.py | 38 +-- wan/modules/posemb_layers.py | 474 +++++++++++++++++++++++++++++++++++ wan/text2video.py | 13 +- 6 files changed, 502 insertions(+), 31 deletions(-) create mode 100644 wan/modules/posemb_layers.py diff --git a/README.md b/README.md index 1bede98..c086297 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid ## 🔥 Latest News!! +* Mar 03, 2025: 👋 Wan2.1GP v1.4: Fix Pytorch compilation, now it is really 20% faster when activated * Mar 03, 2025: 👋 Wan2.1GP v1.3: Support for Image to Video with multiples images for different images / prompts combinations (requires *--multiple-images* switch), and added command line *--preload x* to preload in VRAM x MB of the main diffusion model if you find there is too much unused VRAM and you want to (slightly) accelerate the generation process. If you upgrade you will need to do a 'pip install -r requirements.txt' again. * Mar 03, 2025: 👋 Wan2.1GP v1.2: Implemented tiling on VAE encoding and decoding. No more VRAM peaks at the beginning and at the end diff --git a/gradio_server.py b/gradio_server.py index eec6d14..8055f52 100644 --- a/gradio_server.py +++ b/gradio_server.py @@ -853,7 +853,7 @@ def generate_video( if use_image2video: samples = wan_model.generate( prompt, - image_to_continue[video_no-1], + image_to_continue[ (video_no-1) % len(image_to_continue)], frame_num=(video_length // 4)* 4 + 1, max_area=MAX_AREA_CONFIGS[resolution], shift=flow_shift, diff --git a/wan/image2video.py b/wan/image2video.py index f295141..5428c68 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -24,7 +24,7 @@ from .modules.vae import WanVAE from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps) from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler - +from wan.modules.posemb_layers import get_rotary_pos_embed class WanI2V: @@ -290,7 +290,7 @@ class WanI2V: # sample videos latent = noise - freqs = self.model.get_rope_freqs(nb_latent_frames = int((frame_num - 1)/4 + 1), RIFLEx_k = 6 if enable_RIFLEx else None ) + freqs = get_rotary_pos_embed(frame_num, h, w, enable_RIFLEx= enable_RIFLEx ) arg_c = { 'context': [context[0]], @@ -318,6 +318,7 @@ class WanI2V: callback(-1, None) for i, t in enumerate(tqdm(timesteps)): + offload.set_step_no_for_lora(i) latent_model_input = [latent.to(self.device)] timestep = [t] diff --git a/wan/modules/model.py b/wan/modules/model.py index e0afa99..11b57ab 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -67,17 +67,15 @@ def rope_params_riflex(max_seq_len, dim, theta=10000, L_test=30, k=6): inv_theta_pow[k-1] = 0.9 * 2 * torch.pi / L_test freqs = torch.outer(torch.arange(max_seq_len), inv_theta_pow) - freqs = torch.polar(torch.ones_like(freqs), freqs) + if True: + freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] + return (freqs_cos, freqs_sin) + else: + freqs = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] return freqs -def rope_params(max_seq_len, dim, theta=10000): - assert dim % 2 == 0 - freqs = torch.outer( - torch.arange(max_seq_len), - 1.0 / torch.pow(theta, - torch.arange(0, dim, 2).to(torch.float32).div(dim))) - freqs = torch.polar(torch.ones_like(freqs), freqs) - return freqs + def rope_apply_(x, grid_sizes, freqs): @@ -209,6 +207,7 @@ class WanLayerNorm(nn.LayerNorm): return x # return super().forward(x).type_as(x) +from wan.modules.posemb_layers import apply_rotary_emb class WanSelfAttention(nn.Module): @@ -257,8 +256,11 @@ class WanSelfAttention(nn.Module): k = k.view(b, s, n, d) v = self.v(x).view(b, s, n, d) del x - rope_apply_(q, grid_sizes, freqs) - rope_apply_(k, grid_sizes, freqs) + # rope_apply_(q, grid_sizes, freqs) + # rope_apply_(k, grid_sizes, freqs) + qklist = [q,k] + del q,k + q,k = apply_rotary_emb(qklist, freqs, head_first=False) qkv_list = [q,k,v] del q,k,v x = pay_attention( @@ -652,20 +654,18 @@ class WanModel(ModelMixin, ConfigMixin): # ],dim=1) - def get_rope_freqs(self, nb_latent_frames, RIFLEx_k = None): + def get_rope_freqs(self, nb_latent_frames, RIFLEx_k = None, device = "cuda"): dim = self.dim num_heads = self.num_heads d = dim // num_heads assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 - freqs = torch.cat([ - rope_params_riflex(1024, dim= d - 4 * (d // 6), L_test=nb_latent_frames, k = RIFLEx_k ) if RIFLEx_k != None else rope_params(1024, dim= d - 4 * (d // 6)), #44 - rope_params(1024, 2 * (d // 6)), #42 - rope_params(1024, 2 * (d // 6)) #42 - ],dim=1) + c1, s1 = rope_params_riflex(1024, dim= d - 4 * (d // 6), L_test=nb_latent_frames, k = RIFLEx_k ) if RIFLEx_k != None else rope_params(1024, dim= d - 4 * (d // 6)) #44 + c2, s2 = rope_params(1024, 2 * (d // 6)) #42 + c3, s3 = rope_params(1024, 2 * (d // 6)) #42 - return freqs + return (torch.cat([c1,c2,c3],dim=1).to(device) , torch.cat([s1,s2,s3],dim=1).to(device)) def forward( @@ -706,7 +706,7 @@ class WanModel(ModelMixin, ConfigMixin): assert clip_fea is not None and y is not None # params device = self.patch_embedding.weight.device - if freqs.device != device: + if torch.is_tensor(freqs) and freqs.device != device: freqs = freqs.to(device) if y is not None: diff --git a/wan/modules/posemb_layers.py b/wan/modules/posemb_layers.py new file mode 100644 index 0000000..86b8078 --- /dev/null +++ b/wan/modules/posemb_layers.py @@ -0,0 +1,474 @@ +import torch +from typing import Union, Tuple, List, Optional +import numpy as np + + +###### Thanks to the RifleX project (https://github.com/thu-ml/RIFLEx/) for this alternative pos embed for long videos +# +def get_1d_rotary_pos_embed_riflex( + dim: int, + pos: Union[np.ndarray, int], + theta: float = 10000.0, + use_real=False, + k: Optional[int] = None, + L_test: Optional[int] = None, +): + """ + RIFLEx: Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end + index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 + data type. + + Args: + dim (`int`): Dimension of the frequency tensor. + pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar + theta (`float`, *optional*, defaults to 10000.0): + Scaling factor for frequency computation. Defaults to 10000.0. + use_real (`bool`, *optional*): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + k (`int`, *optional*, defaults to None): the index for the intrinsic frequency in RoPE + L_test (`int`, *optional*, defaults to None): the number of frames for inference + Returns: + `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] + """ + assert dim % 2 == 0 + + if isinstance(pos, int): + pos = torch.arange(pos) + if isinstance(pos, np.ndarray): + pos = torch.from_numpy(pos) # type: ignore # [S] + + freqs = 1.0 / ( + theta ** (torch.arange(0, dim, 2, device=pos.device)[: (dim // 2)].float() / dim) + ) # [D/2] + + # === Riflex modification start === + # Reduce the intrinsic frequency to stay within a single period after extrapolation (see Eq. (8)). + # Empirical observations show that a few videos may exhibit repetition in the tail frames. + # To be conservative, we multiply by 0.9 to keep the extrapolated length below 90% of a single period. + if k is not None: + freqs[k-1] = 0.9 * 2 * torch.pi / L_test + # === Riflex modification end === + freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] + + if use_real: + freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] + return freqs_cos, freqs_sin + else: + # lumina + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] + return freqs_cis + +def identify_k( b: float, d: int, N: int): + """ + This function identifies the index of the intrinsic frequency component in a RoPE-based pre-trained diffusion transformer. + + Args: + b (`float`): The base frequency for RoPE. + d (`int`): Dimension of the frequency tensor + N (`int`): the first observed repetition frame in latent space + Returns: + k (`int`): the index of intrinsic frequency component + N_k (`int`): the period of intrinsic frequency component in latent space + Example: + In HunyuanVideo, b=256 and d=16, the repetition occurs approximately 8s (N=48 in latent space). + k, N_k = identify_k(b=256, d=16, N=48) + In this case, the intrinsic frequency index k is 4, and the period N_k is 50. + """ + + # Compute the period of each frequency in RoPE according to Eq.(4) + periods = [] + for j in range(1, d // 2 + 1): + theta_j = 1.0 / (b ** (2 * (j - 1) / d)) + N_j = round(2 * torch.pi / theta_j) + periods.append(N_j) + + # Identify the intrinsic frequency whose period is closed to N(see Eq.(7)) + diffs = [abs(N_j - N) for N_j in periods] + k = diffs.index(min(diffs)) + 1 + N_k = periods[k-1] + return k, N_k + +def _to_tuple(x, dim=2): + if isinstance(x, int): + return (x,) * dim + elif len(x) == dim: + return x + else: + raise ValueError(f"Expected length {dim} or int, but got {x}") + + +def get_meshgrid_nd(start, *args, dim=2): + """ + Get n-D meshgrid with start, stop and num. + + Args: + start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, + step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num + should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in + n-tuples. + *args: See above. + dim (int): Dimension of the meshgrid. Defaults to 2. + + Returns: + grid (np.ndarray): [dim, ...] + """ + if len(args) == 0: + # start is grid_size + num = _to_tuple(start, dim=dim) + start = (0,) * dim + stop = num + elif len(args) == 1: + # start is start, args[0] is stop, step is 1 + start = _to_tuple(start, dim=dim) + stop = _to_tuple(args[0], dim=dim) + num = [stop[i] - start[i] for i in range(dim)] + elif len(args) == 2: + # start is start, args[0] is stop, args[1] is num + start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0 + stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32 + num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124 + else: + raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") + + # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False) + axis_grid = [] + for i in range(dim): + a, b, n = start[i], stop[i], num[i] + g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n] + axis_grid.append(g) + grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D] + grid = torch.stack(grid, dim=0) # [dim, W, H, D] + + return grid + + +################################################################################# +# Rotary Positional Embedding Functions # +################################################################################# +# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80 + + +def reshape_for_broadcast( + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + x: torch.Tensor, + head_first=False, +): + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + Notes: + When using FlashMHAModified, head_first should be False. + When using Attention, head_first should be True. + + Args: + freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + head_first (bool): head dimension first (except batch dim) or not. + + Returns: + torch.Tensor: Reshaped frequency tensor. + + Raises: + AssertionError: If the frequency tensor doesn't match the expected shape. + AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions. + """ + ndim = x.ndim + assert 0 <= 1 < ndim + + if isinstance(freqs_cis, tuple): + # freqs_cis: (cos, sin) in real space + if head_first: + assert freqs_cis[0].shape == ( + x.shape[-2], + x.shape[-1], + ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}" + shape = [ + d if i == ndim - 2 or i == ndim - 1 else 1 + for i, d in enumerate(x.shape) + ] + else: + assert freqs_cis[0].shape == ( + x.shape[1], + x.shape[-1], + ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}" + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape) + else: + # freqs_cis: values in complex space + if head_first: + assert freqs_cis.shape == ( + x.shape[-2], + x.shape[-1], + ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}" + shape = [ + d if i == ndim - 2 or i == ndim - 1 else 1 + for i, d in enumerate(x.shape) + ] + else: + assert freqs_cis.shape == ( + x.shape[1], + x.shape[-1], + ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}" + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def rotate_half(x): + x_real, x_imag = ( + x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) + ) # [B, S, H, D//2] + return torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + +def apply_rotary_emb( qklist, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + head_first: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D] + xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D] + freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential. + head_first (bool): head dimension first (except batch dim) or not. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + + """ + xq, xk = qklist + qklist.clear() + xk_out = None + if isinstance(freqs_cis, tuple): + cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D] + cos, sin = cos.to(xq.device), sin.to(xq.device) + # real * cos - imag * sin + # imag * cos + real * sin + xq_dtype = xq.dtype + xq_out = xq.to(torch.float) + xq = None + xq_rot = rotate_half(xq_out) + xq_out *= cos + xq_rot *= sin + xq_out += xq_rot + del xq_rot + xq_out = xq_out.to(xq_dtype) + + xk_out = xk.to(torch.float) + xk = None + xk_rot = rotate_half(xk_out) + xk_out *= cos + xk_rot *= sin + xk_out += xk_rot + del xk_rot + xk_out = xk_out.to(xq_dtype) + else: + # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex) + xq_ = torch.view_as_complex( + xq.float().reshape(*xq.shape[:-1], -1, 2) + ) # [B, S, H, D//2] + freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to( + xq.device + ) # [S, D//2] --> [1, S, 1, D//2] + # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin) + # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq) + xk_ = torch.view_as_complex( + xk.float().reshape(*xk.shape[:-1], -1, 2) + ) # [B, S, H, D//2] + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk) + + return xq_out, xk_out + + + + + return xq_out, xk_out +def get_nd_rotary_pos_embed( + rope_dim_list, + start, + *args, + theta=10000.0, + use_real=False, + theta_rescale_factor: Union[float, List[float]] = 1.0, + interpolation_factor: Union[float, List[float]] = 1.0, + k = 6, + L_test = 66, + enable_riflex = True +): + """ + This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure. + + Args: + rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n. + sum(rope_dim_list) should equal to head_dim of attention layer. + start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start, + args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. + *args: See above. + theta (float): Scaling factor for frequency computation. Defaults to 10000.0. + use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers. + Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real + part and an imaginary part separately. + theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0. + + Returns: + pos_embed (torch.Tensor): [HW, D/2] + """ + + grid = get_meshgrid_nd( + start, *args, dim=len(rope_dim_list) + ) # [3, W, H, D] / [2, W, H] + + if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float): + theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list) + elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1: + theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list) + assert len(theta_rescale_factor) == len( + rope_dim_list + ), "len(theta_rescale_factor) should equal to len(rope_dim_list)" + + if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float): + interpolation_factor = [interpolation_factor] * len(rope_dim_list) + elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1: + interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list) + assert len(interpolation_factor) == len( + rope_dim_list + ), "len(interpolation_factor) should equal to len(rope_dim_list)" + + # use 1/ndim of dimensions to encode grid_axis + embs = [] + for i in range(len(rope_dim_list)): + # emb = get_1d_rotary_pos_embed( + # rope_dim_list[i], + # grid[i].reshape(-1), + # theta, + # use_real=use_real, + # theta_rescale_factor=theta_rescale_factor[i], + # interpolation_factor=interpolation_factor[i], + # ) # 2 x [WHD, rope_dim_list[i]] + + + # === RIFLEx modification start === + # apply RIFLEx for time dimension + if i == 0 and enable_riflex: + emb = get_1d_rotary_pos_embed_riflex(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=True, k=k, L_test=L_test) + # === RIFLEx modification end === + else: + emb = get_1d_rotary_pos_embed(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=True, theta_rescale_factor=theta_rescale_factor[i],interpolation_factor=interpolation_factor[i],) + embs.append(emb) + + if use_real: + cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2) + sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2) + return cos, sin + else: + emb = torch.cat(embs, dim=1) # (WHD, D/2) + return emb + + +def get_1d_rotary_pos_embed( + dim: int, + pos: Union[torch.FloatTensor, int], + theta: float = 10000.0, + use_real: bool = False, + theta_rescale_factor: float = 1.0, + interpolation_factor: float = 1.0, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Precompute the frequency tensor for complex exponential (cis) with given dimensions. + (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.) + + This function calculates a frequency tensor with complex exponential using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + use_real (bool, optional): If True, return real part and imaginary part separately. + Otherwise, return complex numbers. + theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0. + + Returns: + freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2] + freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D] + """ + if isinstance(pos, int): + pos = torch.arange(pos).float() + + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + if theta_rescale_factor != 1.0: + theta *= theta_rescale_factor ** (dim / (dim - 2)) + + freqs = 1.0 / ( + theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) + ) # [D/2] + # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}" + freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2] + if use_real: + freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] + return freqs_cos, freqs_sin + else: + freqs_cis = torch.polar( + torch.ones_like(freqs), freqs + ) # complex64 # [S, D/2] + return freqs_cis + +def get_rotary_pos_embed(video_length, height, width, enable_RIFLEx = False): + target_ndim = 3 + ndim = 5 - 2 + + latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8] + patch_size = [1, 2, 2] + if isinstance(patch_size, int): + assert all(s % patch_size == 0 for s in latents_size), ( + f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), " + f"but got {latents_size}." + ) + rope_sizes = [s // patch_size for s in latents_size] + elif isinstance(patch_size, list): + assert all( + s % patch_size[idx] == 0 + for idx, s in enumerate(latents_size) + ), ( + f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), " + f"but got {latents_size}." + ) + rope_sizes = [ + s // patch_size[idx] for idx, s in enumerate(latents_size) + ] + + if len(rope_sizes) != target_ndim: + rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis + head_dim = 128 + rope_dim_list = [44, 42, 42] + if rope_dim_list is None: + rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] + assert ( + sum(rope_dim_list) == head_dim + ), "sum(rope_dim_list) should equal to head_dim of attention layer" + freqs_cos, freqs_sin = get_nd_rotary_pos_embed( + rope_dim_list, + rope_sizes, + theta=10000, + use_real=True, + theta_rescale_factor=1, + L_test = (video_length - 1) // 4 + 1, + enable_riflex = enable_RIFLEx + ) + return (freqs_cos, freqs_sin) \ No newline at end of file diff --git a/wan/text2video.py b/wan/text2video.py index c7efb92..0b3bc2d 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -8,7 +8,7 @@ import sys import types from contextlib import contextmanager from functools import partial - +from mmgp import offload import torch import torch.cuda.amp as amp import torch.distributed as dist @@ -21,6 +21,7 @@ from .modules.vae import WanVAE from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps) from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from wan.modules.posemb_layers import get_rotary_pos_embed class WanT2V: @@ -236,13 +237,7 @@ class WanT2V: # sample videos latents = noise - # from .modules.model import identify_k - # for nf in range(20, 50): - # k, N_k = identify_k(10000, 44, 26) - # print(f"value nb latent frames={nf}, k={k}, n_k={N_k}") - - freqs = self.model.get_rope_freqs(nb_latent_frames = int((frame_num - 1)/4 + 1), RIFLEx_k = 6 if enable_RIFLEx else None ) - + freqs = get_rotary_pos_embed(frame_num, size[1], size[0], enable_RIFLEx= enable_RIFLEx) arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self} arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self} @@ -252,7 +247,7 @@ class WanT2V: for i, t in enumerate(tqdm(timesteps)): latent_model_input = latents timestep = [t] - + offload.set_step_no_for_lora(i) timestep = torch.stack(timestep) # self.model.to(self.device)