mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			32 lines
		
	
	
		
			1.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			32 lines
		
	
	
		
			1.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 | 
						|
import math
 | 
						|
import types
 | 
						|
from copy import deepcopy
 | 
						|
from einops import  rearrange
 | 
						|
from typing import List
 | 
						|
import numpy as np
 | 
						|
import torch
 | 
						|
import torch.cuda.amp as amp
 | 
						|
import torch.nn as nn
 | 
						|
 | 
						|
def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values):
 | 
						|
    pose_latents = self.pose_patch_embedding(pose_latents)
 | 
						|
    x[:, :, 1:] += pose_latents
 | 
						|
    
 | 
						|
    b,c,T,h,w = face_pixel_values.shape
 | 
						|
    face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w")
 | 
						|
    encode_bs = 8
 | 
						|
    face_pixel_values_tmp = []
 | 
						|
    for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)):
 | 
						|
        face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs]))
 | 
						|
 | 
						|
    motion_vec = torch.cat(face_pixel_values_tmp)
 | 
						|
    
 | 
						|
    motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T)
 | 
						|
    motion_vec = self.face_encoder(motion_vec)
 | 
						|
 | 
						|
    B, L, H, C = motion_vec.shape
 | 
						|
    pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec)
 | 
						|
    motion_vec = torch.cat([pad_face, motion_vec], dim=1)
 | 
						|
    return x, motion_vec
 |