mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			353 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			353 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import os
 | 
						|
from einops import rearrange
 | 
						|
 | 
						|
import torch
 | 
						|
import torch.nn as nn
 | 
						|
 | 
						|
from einops import rearrange, repeat
 | 
						|
from functools import lru_cache
 | 
						|
import imageio
 | 
						|
import uuid
 | 
						|
from tqdm import tqdm
 | 
						|
import numpy as np
 | 
						|
import subprocess
 | 
						|
import soundfile as sf
 | 
						|
import torchvision
 | 
						|
import binascii
 | 
						|
import os.path as osp
 | 
						|
 | 
						|
 | 
						|
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
 | 
						|
ASPECT_RATIO_627 = {
 | 
						|
     '0.26': ([320, 1216], 1), '0.38': ([384, 1024], 1), '0.50': ([448, 896], 1), '0.67': ([512, 768], 1), 
 | 
						|
     '0.82': ([576, 704], 1),  '1.00': ([640, 640], 1),  '1.22': ([704, 576], 1), '1.50': ([768, 512], 1), 
 | 
						|
     '1.86': ([832, 448], 1),  '2.00': ([896, 448], 1),  '2.50': ([960, 384], 1), '2.83': ([1088, 384], 1), 
 | 
						|
     '3.60': ([1152, 320], 1), '3.80': ([1216, 320], 1), '4.00': ([1280, 320], 1)}
 | 
						|
 | 
						|
 | 
						|
ASPECT_RATIO_960 = {
 | 
						|
     '0.22': ([448, 2048], 1), '0.29': ([512, 1792], 1), '0.36': ([576, 1600], 1), '0.45': ([640, 1408], 1), 
 | 
						|
     '0.55': ([704, 1280], 1), '0.63': ([768, 1216], 1), '0.76': ([832, 1088], 1), '0.88': ([896, 1024], 1), 
 | 
						|
     '1.00': ([960, 960], 1), '1.14': ([1024, 896], 1), '1.31': ([1088, 832], 1), '1.50': ([1152, 768], 1), 
 | 
						|
     '1.58': ([1216, 768], 1), '1.82': ([1280, 704], 1), '1.91': ([1344, 704], 1), '2.20': ([1408, 640], 1), 
 | 
						|
     '2.30': ([1472, 640], 1), '2.67': ([1536, 576], 1), '2.89': ([1664, 576], 1), '3.62': ([1856, 512], 1), 
 | 
						|
     '3.75': ([1920, 512], 1)}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
def torch_gc():
 | 
						|
    torch.cuda.empty_cache()
 | 
						|
    torch.cuda.ipc_collect()
 | 
						|
 | 
						|
 | 
						|
 | 
						|
def split_token_counts_and_frame_ids(T, token_frame, world_size, rank):
 | 
						|
 | 
						|
    S = T * token_frame
 | 
						|
    split_sizes = [S // world_size + (1 if i < S % world_size else 0) for i in range(world_size)]
 | 
						|
    start = sum(split_sizes[:rank])
 | 
						|
    end = start + split_sizes[rank]
 | 
						|
    counts = [0] * T
 | 
						|
    for idx in range(start, end):
 | 
						|
        t = idx // token_frame
 | 
						|
        counts[t] += 1
 | 
						|
 | 
						|
    counts_filtered = []
 | 
						|
    frame_ids = []
 | 
						|
    for t, c in enumerate(counts):
 | 
						|
        if c > 0:
 | 
						|
            counts_filtered.append(c)
 | 
						|
            frame_ids.append(t)
 | 
						|
    return counts_filtered, frame_ids
 | 
						|
 | 
						|
 | 
						|
def normalize_and_scale(column, source_range, target_range, epsilon=1e-8):
 | 
						|
 | 
						|
    source_min, source_max = source_range
 | 
						|
    new_min, new_max = target_range
 | 
						|
 
 | 
						|
    normalized = (column - source_min) / (source_max - source_min + epsilon)
 | 
						|
    scaled = normalized * (new_max - new_min) + new_min
 | 
						|
    return scaled
 | 
						|
 | 
						|
 | 
						|
# @torch.compile
 | 
						|
def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, ref_images_count, mode='mean', attn_bias=None):
 | 
						|
    
 | 
						|
    ref_k = ref_k.to(visual_q.dtype).to(visual_q.device)
 | 
						|
    scale = 1.0 / visual_q.shape[-1] ** 0.5
 | 
						|
    visual_q = visual_q * scale
 | 
						|
    visual_q = visual_q.transpose(1, 2)
 | 
						|
    ref_k = ref_k.transpose(1, 2)
 | 
						|
    attn = visual_q @ ref_k.transpose(-2, -1)
 | 
						|
 | 
						|
    if attn_bias is not None: attn += attn_bias
 | 
						|
 | 
						|
    x_ref_attn_map_source = attn.softmax(-1) # B, H, x_seqlens, ref_seqlens
 | 
						|
 | 
						|
    x_ref_attn_maps = []
 | 
						|
    ref_target_masks = ref_target_masks.to(visual_q.dtype)
 | 
						|
    x_ref_attn_map_source = x_ref_attn_map_source.to(visual_q.dtype)
 | 
						|
 | 
						|
    for class_idx, ref_target_mask in enumerate(ref_target_masks):
 | 
						|
        ref_target_mask = ref_target_mask[None, None, None, ...]
 | 
						|
        x_ref_attnmap = x_ref_attn_map_source * ref_target_mask
 | 
						|
        x_ref_attnmap = x_ref_attnmap.sum(-1) / ref_target_mask.sum() # B, H, x_seqlens, ref_seqlens --> B, H, x_seqlens
 | 
						|
        x_ref_attnmap = x_ref_attnmap.permute(0, 2, 1) # B, x_seqlens, H
 | 
						|
       
 | 
						|
        if mode == 'mean':
 | 
						|
            x_ref_attnmap = x_ref_attnmap.mean(-1) # B, x_seqlens
 | 
						|
        elif mode == 'max':
 | 
						|
            x_ref_attnmap = x_ref_attnmap.max(-1) # B, x_seqlens
 | 
						|
        
 | 
						|
        x_ref_attn_maps.append(x_ref_attnmap)
 | 
						|
    
 | 
						|
    del attn
 | 
						|
    del x_ref_attn_map_source
 | 
						|
 | 
						|
    return torch.concat(x_ref_attn_maps, dim=0)
 | 
						|
 | 
						|
 | 
						|
def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=10, ref_images_count = 0):
 | 
						|
    """Args:
 | 
						|
        query (torch.tensor): B M H K
 | 
						|
        key (torch.tensor): B M H K
 | 
						|
        shape (tuple): (N_t, N_h, N_w)
 | 
						|
        ref_target_masks: [B, N_h * N_w]
 | 
						|
    """
 | 
						|
 | 
						|
    N_t, N_h, N_w = shape
 | 
						|
    
 | 
						|
    x_seqlens = N_h * N_w
 | 
						|
    ref_k     = ref_k[:, :x_seqlens]
 | 
						|
    if ref_images_count > 0 :
 | 
						|
        visual_q_shape = visual_q.shape 
 | 
						|
        visual_q = visual_q.reshape(visual_q_shape[0], N_t, -1)
 | 
						|
        visual_q = visual_q[:, ref_images_count:]
 | 
						|
        visual_q = visual_q.reshape(visual_q_shape[0], -1, *visual_q_shape[-2:])
 | 
						|
 | 
						|
    _, seq_lens, heads, _ = visual_q.shape
 | 
						|
    class_num, _ = ref_target_masks.shape
 | 
						|
    x_ref_attn_maps = torch.zeros(class_num, seq_lens).to(visual_q.device).to(visual_q.dtype)
 | 
						|
 | 
						|
    split_chunk = heads // split_num
 | 
						|
    
 | 
						|
    for i in range(split_num):
 | 
						|
        x_ref_attn_maps_perhead = calculate_x_ref_attn_map(visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_target_masks, ref_images_count)
 | 
						|
        x_ref_attn_maps += x_ref_attn_maps_perhead
 | 
						|
    
 | 
						|
    x_ref_attn_maps /= split_num
 | 
						|
    return x_ref_attn_maps
 | 
						|
 | 
						|
 | 
						|
def rotate_half(x):
 | 
						|
    x = rearrange(x, "... (d r) -> ... d r", r=2)
 | 
						|
    x1, x2 = x.unbind(dim=-1)
 | 
						|
    x = torch.stack((-x2, x1), dim=-1)
 | 
						|
    return rearrange(x, "... d r -> ... (d r)")
 | 
						|
 | 
						|
 | 
						|
class RotaryPositionalEmbedding1D(nn.Module):
 | 
						|
 | 
						|
    def __init__(self,
 | 
						|
                 head_dim,
 | 
						|
                 ):
 | 
						|
        super().__init__()
 | 
						|
        self.head_dim = head_dim
 | 
						|
        self.base = 10000
 | 
						|
 | 
						|
 | 
						|
    @lru_cache(maxsize=32)
 | 
						|
    def precompute_freqs_cis_1d(self, pos_indices):
 | 
						|
 | 
						|
        freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim))
 | 
						|
        freqs = freqs.to(pos_indices.device)
 | 
						|
        freqs = torch.einsum("..., f -> ... f", pos_indices.float(), freqs)
 | 
						|
        freqs = repeat(freqs, "... n -> ... (n r)", r=2)
 | 
						|
        return freqs
 | 
						|
 | 
						|
    def forward(self, x, pos_indices):
 | 
						|
        """1D RoPE.
 | 
						|
 | 
						|
        Args:
 | 
						|
            query (torch.tensor): [B, head, seq, head_dim]
 | 
						|
            pos_indices (torch.tensor): [seq,]
 | 
						|
        Returns:
 | 
						|
            query with the same shape as input.
 | 
						|
        """
 | 
						|
        freqs_cis = self.precompute_freqs_cis_1d(pos_indices)
 | 
						|
 | 
						|
        x_ = x.float()
 | 
						|
 | 
						|
        freqs_cis = freqs_cis.float().to(x.device)
 | 
						|
        cos, sin = freqs_cis.cos(), freqs_cis.sin()
 | 
						|
        cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
 | 
						|
        x_ = (x_ * cos) + (rotate_half(x_) * sin)
 | 
						|
 | 
						|
        return x_.type_as(x)
 | 
						|
    
 | 
						|
 | 
						|
 | 
						|
def rand_name(length=8, suffix=''):
 | 
						|
    name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
 | 
						|
    if suffix:
 | 
						|
        if not suffix.startswith('.'):
 | 
						|
            suffix = '.' + suffix
 | 
						|
        name += suffix
 | 
						|
    return name
 | 
						|
 | 
						|
def cache_video(tensor,
 | 
						|
                save_file=None,
 | 
						|
                fps=30,
 | 
						|
                suffix='.mp4',
 | 
						|
                nrow=8,
 | 
						|
                normalize=True,
 | 
						|
                value_range=(-1, 1),
 | 
						|
                retry=5):
 | 
						|
    
 | 
						|
    # cache file
 | 
						|
    cache_file = osp.join('/tmp', rand_name(
 | 
						|
        suffix=suffix)) if save_file is None else save_file
 | 
						|
 | 
						|
    # save to cache
 | 
						|
    error = None
 | 
						|
    for _ in range(retry):
 | 
						|
       
 | 
						|
        # preprocess
 | 
						|
        tensor = tensor.clamp(min(value_range), max(value_range))
 | 
						|
        tensor = torch.stack([
 | 
						|
                torchvision.utils.make_grid(
 | 
						|
                    u, nrow=nrow, normalize=normalize, value_range=value_range)
 | 
						|
                for u in tensor.unbind(2)
 | 
						|
            ],
 | 
						|
                                 dim=1).permute(1, 2, 3, 0)
 | 
						|
        tensor = (tensor * 255).type(torch.uint8).cpu()
 | 
						|
 | 
						|
        # write video
 | 
						|
        writer = imageio.get_writer(cache_file, fps=fps, codec='libx264', quality=10, ffmpeg_params=["-crf", "10"])
 | 
						|
        for frame in tensor.numpy():
 | 
						|
            writer.append_data(frame)
 | 
						|
        writer.close()
 | 
						|
        return cache_file
 | 
						|
 | 
						|
def save_video_ffmpeg(gen_video_samples, save_path, vocal_audio_list, fps=25, quality=5, high_quality_save=False):
 | 
						|
    
 | 
						|
    def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
 | 
						|
        writer = imageio.get_writer(
 | 
						|
            save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params
 | 
						|
        )
 | 
						|
        for frame in tqdm(frames, desc="Saving video"):
 | 
						|
            frame = np.array(frame)
 | 
						|
            writer.append_data(frame)
 | 
						|
        writer.close()
 | 
						|
    save_path_tmp = save_path + "-temp.mp4"
 | 
						|
 | 
						|
    if high_quality_save:
 | 
						|
        cache_video(
 | 
						|
                    tensor=gen_video_samples.unsqueeze(0),
 | 
						|
                    save_file=save_path_tmp,
 | 
						|
                    fps=fps,
 | 
						|
                    nrow=1,
 | 
						|
                    normalize=True,
 | 
						|
                    value_range=(-1, 1)
 | 
						|
                    )
 | 
						|
    else:
 | 
						|
        video_audio = (gen_video_samples+1)/2 # C T H W
 | 
						|
        video_audio = video_audio.permute(1, 2, 3, 0).cpu().numpy()
 | 
						|
        video_audio = np.clip(video_audio * 255, 0, 255).astype(np.uint8)  # to [0, 255]
 | 
						|
        save_video(video_audio, save_path_tmp, fps=fps, quality=quality)
 | 
						|
 | 
						|
 | 
						|
    # crop audio according to video length
 | 
						|
    _, T, _, _ = gen_video_samples.shape
 | 
						|
    duration = T / fps
 | 
						|
    save_path_crop_audio = save_path + "-cropaudio.wav"
 | 
						|
    final_command = [
 | 
						|
        "ffmpeg",
 | 
						|
        "-i",
 | 
						|
        vocal_audio_list[0],
 | 
						|
        "-t",
 | 
						|
        f'{duration}',
 | 
						|
        save_path_crop_audio,
 | 
						|
    ]
 | 
						|
    subprocess.run(final_command, check=True)
 | 
						|
 | 
						|
    save_path = save_path + ".mp4"
 | 
						|
    if high_quality_save:
 | 
						|
        final_command = [
 | 
						|
            "ffmpeg",
 | 
						|
            "-y",
 | 
						|
            "-i", save_path_tmp,
 | 
						|
            "-i", save_path_crop_audio,
 | 
						|
            "-c:v", "libx264",
 | 
						|
            "-crf", "0",
 | 
						|
            "-preset", "veryslow",
 | 
						|
            "-c:a", "aac", 
 | 
						|
            "-shortest",
 | 
						|
            save_path,
 | 
						|
        ]
 | 
						|
        subprocess.run(final_command, check=True)
 | 
						|
        os.remove(save_path_tmp)
 | 
						|
        os.remove(save_path_crop_audio)
 | 
						|
    else:
 | 
						|
        final_command = [
 | 
						|
            "ffmpeg",
 | 
						|
            "-y",
 | 
						|
            "-i",
 | 
						|
            save_path_tmp,
 | 
						|
            "-i",
 | 
						|
            save_path_crop_audio,
 | 
						|
            "-c:v",
 | 
						|
            "libx264",
 | 
						|
            "-c:a",
 | 
						|
            "aac",
 | 
						|
            "-shortest",
 | 
						|
            save_path,
 | 
						|
        ]
 | 
						|
        subprocess.run(final_command, check=True)
 | 
						|
        os.remove(save_path_tmp)
 | 
						|
        os.remove(save_path_crop_audio)
 | 
						|
 | 
						|
 | 
						|
class MomentumBuffer:
 | 
						|
    def __init__(self, momentum: float): 
 | 
						|
        self.momentum = momentum 
 | 
						|
        self.running_average = 0 
 | 
						|
    
 | 
						|
    def update(self, update_value: torch.Tensor): 
 | 
						|
        new_average = self.momentum * self.running_average 
 | 
						|
        self.running_average = update_value + new_average
 | 
						|
    
 | 
						|
 | 
						|
 | 
						|
def project( 
 | 
						|
        v0: torch.Tensor, # [B, C, T, H, W] 
 | 
						|
        v1: torch.Tensor, # [B, C, T, H, W] 
 | 
						|
        ): 
 | 
						|
    dtype = v0.dtype 
 | 
						|
    v0, v1 = v0.double(), v1.double() 
 | 
						|
    v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3, -4]) 
 | 
						|
    v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3, -4], keepdim=True) * v1 
 | 
						|
    v0_orthogonal = v0 - v0_parallel
 | 
						|
    return v0_parallel.to(dtype), v0_orthogonal.to(dtype)
 | 
						|
 | 
						|
 | 
						|
def adaptive_projected_guidance( 
 | 
						|
          diff: torch.Tensor, # [B, C, T, H, W] 
 | 
						|
          pred_cond: torch.Tensor, # [B, C, T, H, W] 
 | 
						|
          momentum_buffer: MomentumBuffer = None, 
 | 
						|
          eta: float = 0.0,
 | 
						|
          norm_threshold: float = 55,
 | 
						|
          ): 
 | 
						|
    if momentum_buffer is not None: 
 | 
						|
        momentum_buffer.update(diff) 
 | 
						|
        diff = momentum_buffer.running_average
 | 
						|
    if norm_threshold > 0: 
 | 
						|
        ones = torch.ones_like(diff) 
 | 
						|
        diff_norm = diff.norm(p=2, dim=[-1, -2, -3, -4], keepdim=True) 
 | 
						|
        print(f"diff_norm: {diff_norm}")
 | 
						|
        scale_factor = torch.minimum(ones, norm_threshold / diff_norm) 
 | 
						|
        diff = diff * scale_factor 
 | 
						|
    diff_parallel, diff_orthogonal = project(diff, pred_cond) 
 | 
						|
    normalized_update = diff_orthogonal + eta * diff_parallel
 | 
						|
    return normalized_update |