mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-05 14:33:15 +00:00
21 lines
600 B
Python
21 lines
600 B
Python
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def get_mask_from_lengths(lengths, max_len=None):
|
|
lengths = lengths.to(torch.long)
|
|
if max_len is None:
|
|
max_len = torch.max(lengths).item()
|
|
|
|
ids = torch.arange(0, max_len).unsqueeze(0).expand(lengths.shape[0], -1).to(lengths.device)
|
|
mask = ids < lengths.unsqueeze(1).expand(-1, max_len)
|
|
|
|
return mask
|
|
|
|
|
|
def linear_interpolation(features, seq_len):
|
|
features = features.transpose(1, 2)
|
|
output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
|
|
return output_features.transpose(1, 2)
|
|
|