mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			21 lines
		
	
	
		
			600 B
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			21 lines
		
	
	
		
			600 B
		
	
	
	
		
			Python
		
	
	
	
	
	
import torch
 | 
						|
import torch.nn.functional as F
 | 
						|
 | 
						|
 | 
						|
def get_mask_from_lengths(lengths, max_len=None):
 | 
						|
    lengths = lengths.to(torch.long)
 | 
						|
    if max_len is None:
 | 
						|
        max_len = torch.max(lengths).item()
 | 
						|
 | 
						|
    ids = torch.arange(0, max_len).unsqueeze(0).expand(lengths.shape[0], -1).to(lengths.device)
 | 
						|
    mask = ids < lengths.unsqueeze(1).expand(-1, max_len)
 | 
						|
 | 
						|
    return mask
 | 
						|
 | 
						|
 | 
						|
def linear_interpolation(features, seq_len):
 | 
						|
    features = features.transpose(1, 2)
 | 
						|
    output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
 | 
						|
    return output_features.transpose(1, 2)
 | 
						|
 |