mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
53 lines
1.7 KiB
Python
53 lines
1.7 KiB
Python
from typing import Literal, Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from ..autoencoder.vae import VAE, get_my_vae
|
|
from ..bigvgan import BigVGAN
|
|
from ..bigvgan_v2.bigvgan import BigVGAN as BigVGANv2
|
|
from ...model.utils.distributions import DiagonalGaussianDistribution
|
|
|
|
|
|
class AutoEncoderModule(nn.Module):
|
|
|
|
def __init__(self,
|
|
*,
|
|
vae_ckpt_path,
|
|
vocoder_ckpt_path: Optional[str] = None,
|
|
mode: Literal['16k', '44k'],
|
|
need_vae_encoder: bool = True):
|
|
super().__init__()
|
|
self.vae: VAE = get_my_vae(mode).eval()
|
|
vae_state_dict = torch.load(vae_ckpt_path, weights_only=True, map_location='cpu')
|
|
self.vae.load_state_dict(vae_state_dict)
|
|
self.vae.remove_weight_norm()
|
|
|
|
if mode == '16k':
|
|
assert vocoder_ckpt_path is not None
|
|
self.vocoder = BigVGAN(vocoder_ckpt_path).eval()
|
|
elif mode == '44k':
|
|
self.vocoder = BigVGANv2.from_pretrained('nvidia/bigvgan_v2_44khz_128band_512x',
|
|
use_cuda_kernel=False)
|
|
self.vocoder.remove_weight_norm()
|
|
else:
|
|
raise ValueError(f'Unknown mode: {mode}')
|
|
|
|
for param in self.parameters():
|
|
param.requires_grad = False
|
|
|
|
if not need_vae_encoder:
|
|
del self.vae.encoder
|
|
|
|
@torch.inference_mode()
|
|
def encode(self, x: torch.Tensor) -> DiagonalGaussianDistribution:
|
|
return self.vae.encode(x)
|
|
|
|
@torch.inference_mode()
|
|
def decode(self, z: torch.Tensor) -> torch.Tensor:
|
|
return self.vae.decode(z)
|
|
|
|
@torch.inference_mode()
|
|
def vocode(self, spec: torch.Tensor) -> torch.Tensor:
|
|
return self.vocoder(spec)
|