mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-06-07 07:44:53 +00:00
VAE Tiling Support - credits to deepbeepmeep's WanGP
This commit is contained in:
parent
a0de59e928
commit
09ba994635
@ -1,6 +1,6 @@
|
|||||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
import logging
|
import logging
|
||||||
|
from mmgp import offload
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda.amp as amp
|
import torch.cuda.amp as amp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -15,10 +15,6 @@ CACHE_T = 2
|
|||||||
|
|
||||||
|
|
||||||
class CausalConv3d(nn.Conv3d):
|
class CausalConv3d(nn.Conv3d):
|
||||||
"""
|
|
||||||
Causal 3d convolusion.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
||||||
@ -31,48 +27,38 @@ class CausalConv3d(nn.Conv3d):
|
|||||||
cache_x = cache_x.to(x.device)
|
cache_x = cache_x.to(x.device)
|
||||||
x = torch.cat([cache_x, x], dim=2)
|
x = torch.cat([cache_x, x], dim=2)
|
||||||
padding[4] -= cache_x.shape[2]
|
padding[4] -= cache_x.shape[2]
|
||||||
|
cache_x = None
|
||||||
x = F.pad(x, padding)
|
x = F.pad(x, padding)
|
||||||
|
x = super().forward(x)
|
||||||
return super().forward(x)
|
return x
|
||||||
|
|
||||||
|
|
||||||
class RMS_norm(nn.Module):
|
class RMS_norm(nn.Module):
|
||||||
|
|
||||||
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
||||||
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
||||||
|
|
||||||
self.channel_first = channel_first
|
self.channel_first = channel_first
|
||||||
self.scale = dim**0.5
|
self.scale = dim**0.5
|
||||||
self.gamma = nn.Parameter(torch.ones(shape))
|
self.gamma = nn.Parameter(torch.ones(shape))
|
||||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
|
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return F.normalize(
|
x = F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
||||||
x, dim=(1 if self.channel_first else
|
return x
|
||||||
-1)) * self.scale * self.gamma + self.bias
|
|
||||||
|
|
||||||
|
|
||||||
class Upsample(nn.Upsample):
|
class Upsample(nn.Upsample):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""
|
|
||||||
Fix bfloat16 support for nearest neighbor interpolation.
|
|
||||||
"""
|
|
||||||
return super().forward(x.float()).type_as(x)
|
return super().forward(x.float()).type_as(x)
|
||||||
|
|
||||||
|
|
||||||
class Resample(nn.Module):
|
class Resample(nn.Module):
|
||||||
|
|
||||||
def __init__(self, dim, mode):
|
def __init__(self, dim, mode):
|
||||||
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
|
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d', 'downsample3d')
|
||||||
'downsample3d')
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
|
|
||||||
# layers
|
|
||||||
if mode == 'upsample2d':
|
if mode == 'upsample2d':
|
||||||
self.resample = nn.Sequential(
|
self.resample = nn.Sequential(
|
||||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||||
@ -81,9 +67,7 @@ class Resample(nn.Module):
|
|||||||
self.resample = nn.Sequential(
|
self.resample = nn.Sequential(
|
||||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||||
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
||||||
self.time_conv = CausalConv3d(
|
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||||
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
|
||||||
|
|
||||||
elif mode == 'downsample2d':
|
elif mode == 'downsample2d':
|
||||||
self.resample = nn.Sequential(
|
self.resample = nn.Sequential(
|
||||||
nn.ZeroPad2d((0, 1, 0, 1)),
|
nn.ZeroPad2d((0, 1, 0, 1)),
|
||||||
@ -92,9 +76,7 @@ class Resample(nn.Module):
|
|||||||
self.resample = nn.Sequential(
|
self.resample = nn.Sequential(
|
||||||
nn.ZeroPad2d((0, 1, 0, 1)),
|
nn.ZeroPad2d((0, 1, 0, 1)),
|
||||||
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||||
self.time_conv = CausalConv3d(
|
self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
||||||
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.resample = nn.Identity()
|
self.resample = nn.Identity()
|
||||||
|
|
||||||
@ -107,54 +89,38 @@ class Resample(nn.Module):
|
|||||||
feat_cache[idx] = 'Rep'
|
feat_cache[idx] = 'Rep'
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
else:
|
else:
|
||||||
|
clone = True
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||||
if cache_x.shape[2] < 2 and feat_cache[
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != 'Rep':
|
||||||
idx] is not None and feat_cache[idx] != 'Rep':
|
clone = False
|
||||||
# cache last frame of last two chunk
|
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||||
cache_x = torch.cat([
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == 'Rep':
|
||||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
clone = False
|
||||||
cache_x.device), cache_x
|
cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
|
||||||
],
|
if clone:
|
||||||
dim=2)
|
cache_x = cache_x.clone()
|
||||||
if cache_x.shape[2] < 2 and feat_cache[
|
|
||||||
idx] is not None and feat_cache[idx] == 'Rep':
|
|
||||||
cache_x = torch.cat([
|
|
||||||
torch.zeros_like(cache_x).to(cache_x.device),
|
|
||||||
cache_x
|
|
||||||
],
|
|
||||||
dim=2)
|
|
||||||
if feat_cache[idx] == 'Rep':
|
if feat_cache[idx] == 'Rep':
|
||||||
x = self.time_conv(x)
|
x = self.time_conv(x)
|
||||||
else:
|
else:
|
||||||
x = self.time_conv(x, feat_cache[idx])
|
x = self.time_conv(x, feat_cache[idx])
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
|
|
||||||
x = x.reshape(b, 2, c, t, h, w)
|
x = x.reshape(b, 2, c, t, h, w)
|
||||||
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
|
||||||
3)
|
|
||||||
x = x.reshape(b, c, t * 2, h, w)
|
x = x.reshape(b, c, t * 2, h, w)
|
||||||
t = x.shape[2]
|
t = x.shape[2]
|
||||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||||
x = self.resample(x)
|
x = self.resample(x)
|
||||||
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
||||||
|
|
||||||
if self.mode == 'downsample3d':
|
if self.mode == 'downsample3d':
|
||||||
if feat_cache is not None:
|
if feat_cache is not None:
|
||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
if feat_cache[idx] is None:
|
if feat_cache[idx] is None:
|
||||||
feat_cache[idx] = x.clone()
|
feat_cache[idx] = x
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
else:
|
else:
|
||||||
|
|
||||||
cache_x = x[:, :, -1:, :, :].clone()
|
cache_x = x[:, :, -1:, :, :].clone()
|
||||||
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
|
x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||||||
# # cache last frame of last two chunk
|
|
||||||
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
|
||||||
|
|
||||||
x = self.time_conv(
|
|
||||||
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
return x
|
return x
|
||||||
@ -164,10 +130,8 @@ class Resample(nn.Module):
|
|||||||
nn.init.zeros_(conv_weight)
|
nn.init.zeros_(conv_weight)
|
||||||
c1, c2, t, h, w = conv_weight.size()
|
c1, c2, t, h, w = conv_weight.size()
|
||||||
one_matrix = torch.eye(c1, c2)
|
one_matrix = torch.eye(c1, c2)
|
||||||
init_matrix = one_matrix
|
|
||||||
nn.init.zeros_(conv_weight)
|
nn.init.zeros_(conv_weight)
|
||||||
#conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
|
conv_weight.data[:, :, 1, 0, 0] = one_matrix
|
||||||
conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
|
|
||||||
conv.weight.data.copy_(conv_weight)
|
conv.weight.data.copy_(conv_weight)
|
||||||
nn.init.zeros_(conv.bias.data)
|
nn.init.zeros_(conv.bias.data)
|
||||||
|
|
||||||
@ -176,7 +140,6 @@ class Resample(nn.Module):
|
|||||||
nn.init.zeros_(conv_weight)
|
nn.init.zeros_(conv_weight)
|
||||||
c1, c2, t, h, w = conv_weight.size()
|
c1, c2, t, h, w = conv_weight.size()
|
||||||
init_matrix = torch.eye(c1 // 2, c2)
|
init_matrix = torch.eye(c1 // 2, c2)
|
||||||
#init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
|
|
||||||
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
|
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
|
||||||
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
|
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
|
||||||
conv.weight.data.copy_(conv_weight)
|
conv.weight.data.copy_(conv_weight)
|
||||||
@ -184,20 +147,16 @@ class Resample(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class ResidualBlock(nn.Module):
|
class ResidualBlock(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_dim, out_dim, dropout=0.0):
|
def __init__(self, in_dim, out_dim, dropout=0.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_dim = in_dim
|
self.in_dim = in_dim
|
||||||
self.out_dim = out_dim
|
self.out_dim = out_dim
|
||||||
|
|
||||||
# layers
|
|
||||||
self.residual = nn.Sequential(
|
self.residual = nn.Sequential(
|
||||||
RMS_norm(in_dim, images=False), nn.SiLU(),
|
RMS_norm(in_dim, images=False), nn.SiLU(),
|
||||||
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
||||||
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
|
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
|
||||||
CausalConv3d(out_dim, out_dim, 3, padding=1))
|
CausalConv3d(out_dim, out_dim, 3, padding=1))
|
||||||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
|
||||||
if in_dim != out_dim else nn.Identity()
|
|
||||||
|
|
||||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
h = self.shortcut(x)
|
h = self.shortcut(x)
|
||||||
@ -206,12 +165,7 @@ class ResidualBlock(nn.Module):
|
|||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||||
# cache last frame of last two chunk
|
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(x.device), cache_x], dim=2)
|
||||||
cache_x = torch.cat([
|
|
||||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
|
||||||
cache_x.device), cache_x
|
|
||||||
],
|
|
||||||
dim=2)
|
|
||||||
x = layer(x, feat_cache[idx])
|
x = layer(x, feat_cache[idx])
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
@ -221,20 +175,12 @@ class ResidualBlock(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class AttentionBlock(nn.Module):
|
class AttentionBlock(nn.Module):
|
||||||
"""
|
|
||||||
Causal self-attention with a single head.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, dim):
|
def __init__(self, dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
|
|
||||||
# layers
|
|
||||||
self.norm = RMS_norm(dim)
|
self.norm = RMS_norm(dim)
|
||||||
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
||||||
self.proj = nn.Conv2d(dim, dim, 1)
|
self.proj = nn.Conv2d(dim, dim, 1)
|
||||||
|
|
||||||
# zero out the last layer params
|
|
||||||
nn.init.zeros_(self.proj.weight)
|
nn.init.zeros_(self.proj.weight)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -242,36 +188,17 @@ class AttentionBlock(nn.Module):
|
|||||||
b, c, t, h, w = x.size()
|
b, c, t, h, w = x.size()
|
||||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
# compute query, key, value
|
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1)
|
||||||
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3,
|
x = F.scaled_dot_product_attention(q, k, v)
|
||||||
-1).permute(0, 1, 3,
|
|
||||||
2).contiguous().chunk(
|
|
||||||
3, dim=-1)
|
|
||||||
|
|
||||||
# apply attention
|
|
||||||
x = F.scaled_dot_product_attention(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
)
|
|
||||||
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
|
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
|
||||||
|
|
||||||
# output
|
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
|
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
||||||
return x + identity
|
return x + identity
|
||||||
|
|
||||||
|
|
||||||
class Encoder3d(nn.Module):
|
class Encoder3d(nn.Module):
|
||||||
|
def __init__(self, dim=128, z_dim=4, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[],
|
||||||
def __init__(self,
|
temperal_downsample=[True, True, False], dropout=0.0):
|
||||||
dim=128,
|
|
||||||
z_dim=4,
|
|
||||||
dim_mult=[1, 2, 4, 4],
|
|
||||||
num_res_blocks=2,
|
|
||||||
attn_scales=[],
|
|
||||||
temperal_downsample=[True, True, False],
|
|
||||||
dropout=0.0):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.z_dim = z_dim
|
self.z_dim = z_dim
|
||||||
@ -279,38 +206,24 @@ class Encoder3d(nn.Module):
|
|||||||
self.num_res_blocks = num_res_blocks
|
self.num_res_blocks = num_res_blocks
|
||||||
self.attn_scales = attn_scales
|
self.attn_scales = attn_scales
|
||||||
self.temperal_downsample = temperal_downsample
|
self.temperal_downsample = temperal_downsample
|
||||||
|
|
||||||
# dimensions
|
|
||||||
dims = [dim * u for u in [1] + dim_mult]
|
dims = [dim * u for u in [1] + dim_mult]
|
||||||
scale = 1.0
|
scale = 1.0
|
||||||
|
|
||||||
# init block
|
|
||||||
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
||||||
|
|
||||||
# downsample blocks
|
|
||||||
downsamples = []
|
downsamples = []
|
||||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||||
# residual (+attention) blocks
|
|
||||||
for _ in range(num_res_blocks):
|
for _ in range(num_res_blocks):
|
||||||
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||||
if scale in attn_scales:
|
if scale in attn_scales:
|
||||||
downsamples.append(AttentionBlock(out_dim))
|
downsamples.append(AttentionBlock(out_dim))
|
||||||
in_dim = out_dim
|
in_dim = out_dim
|
||||||
|
|
||||||
# downsample block
|
|
||||||
if i != len(dim_mult) - 1:
|
if i != len(dim_mult) - 1:
|
||||||
mode = 'downsample3d' if temperal_downsample[
|
mode = 'downsample3d' if temperal_downsample[i] else 'downsample2d'
|
||||||
i] else 'downsample2d'
|
|
||||||
downsamples.append(Resample(out_dim, mode=mode))
|
downsamples.append(Resample(out_dim, mode=mode))
|
||||||
scale /= 2.0
|
scale /= 2.0
|
||||||
self.downsamples = nn.Sequential(*downsamples)
|
self.downsamples = nn.Sequential(*downsamples)
|
||||||
|
|
||||||
# middle blocks
|
|
||||||
self.middle = nn.Sequential(
|
self.middle = nn.Sequential(
|
||||||
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
|
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
|
||||||
ResidualBlock(out_dim, out_dim, dropout))
|
ResidualBlock(out_dim, out_dim, dropout))
|
||||||
|
|
||||||
# output blocks
|
|
||||||
self.head = nn.Sequential(
|
self.head = nn.Sequential(
|
||||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||||
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
||||||
@ -320,46 +233,32 @@ class Encoder3d(nn.Module):
|
|||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||||
# cache last frame of last two chunk
|
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(x.device), cache_x], dim=2)
|
||||||
cache_x = torch.cat([
|
|
||||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
|
||||||
cache_x.device), cache_x
|
|
||||||
],
|
|
||||||
dim=2)
|
|
||||||
x = self.conv1(x, feat_cache[idx])
|
x = self.conv1(x, feat_cache[idx])
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
|
del cache_x
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
else:
|
else:
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
|
|
||||||
## downsamples
|
|
||||||
for layer in self.downsamples:
|
for layer in self.downsamples:
|
||||||
if feat_cache is not None:
|
if feat_cache is not None:
|
||||||
x = layer(x, feat_cache, feat_idx)
|
x = layer(x, feat_cache, feat_idx)
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
|
|
||||||
## middle
|
|
||||||
for layer in self.middle:
|
for layer in self.middle:
|
||||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||||
x = layer(x, feat_cache, feat_idx)
|
x = layer(x, feat_cache, feat_idx)
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
|
|
||||||
## head
|
|
||||||
for layer in self.head:
|
for layer in self.head:
|
||||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||||
# cache last frame of last two chunk
|
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(x.device), cache_x], dim=2)
|
||||||
cache_x = torch.cat([
|
|
||||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
|
||||||
cache_x.device), cache_x
|
|
||||||
],
|
|
||||||
dim=2)
|
|
||||||
x = layer(x, feat_cache[idx])
|
x = layer(x, feat_cache[idx])
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
|
del cache_x
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
@ -367,15 +266,8 @@ class Encoder3d(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Decoder3d(nn.Module):
|
class Decoder3d(nn.Module):
|
||||||
|
def __init__(self, dim=128, z_dim=4, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[],
|
||||||
def __init__(self,
|
temperal_upsample=[False, True, True], dropout=0.0):
|
||||||
dim=128,
|
|
||||||
z_dim=4,
|
|
||||||
dim_mult=[1, 2, 4, 4],
|
|
||||||
num_res_blocks=2,
|
|
||||||
attn_scales=[],
|
|
||||||
temperal_upsample=[False, True, True],
|
|
||||||
dropout=0.0):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.z_dim = z_dim
|
self.z_dim = z_dim
|
||||||
@ -383,23 +275,14 @@ class Decoder3d(nn.Module):
|
|||||||
self.num_res_blocks = num_res_blocks
|
self.num_res_blocks = num_res_blocks
|
||||||
self.attn_scales = attn_scales
|
self.attn_scales = attn_scales
|
||||||
self.temperal_upsample = temperal_upsample
|
self.temperal_upsample = temperal_upsample
|
||||||
|
|
||||||
# dimensions
|
|
||||||
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||||
scale = 1.0 / 2**(len(dim_mult) - 2)
|
scale = 1.0 / 2**(len(dim_mult) - 2)
|
||||||
|
|
||||||
# init block
|
|
||||||
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||||
|
|
||||||
# middle blocks
|
|
||||||
self.middle = nn.Sequential(
|
self.middle = nn.Sequential(
|
||||||
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
|
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
|
||||||
ResidualBlock(dims[0], dims[0], dropout))
|
ResidualBlock(dims[0], dims[0], dropout))
|
||||||
|
|
||||||
# upsample blocks
|
|
||||||
upsamples = []
|
upsamples = []
|
||||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||||
# residual (+attention) blocks
|
|
||||||
if i == 1 or i == 2 or i == 3:
|
if i == 1 or i == 2 or i == 3:
|
||||||
in_dim = in_dim // 2
|
in_dim = in_dim // 2
|
||||||
for _ in range(num_res_blocks + 1):
|
for _ in range(num_res_blocks + 1):
|
||||||
@ -407,65 +290,46 @@ class Decoder3d(nn.Module):
|
|||||||
if scale in attn_scales:
|
if scale in attn_scales:
|
||||||
upsamples.append(AttentionBlock(out_dim))
|
upsamples.append(AttentionBlock(out_dim))
|
||||||
in_dim = out_dim
|
in_dim = out_dim
|
||||||
|
|
||||||
# upsample block
|
|
||||||
if i != len(dim_mult) - 1:
|
if i != len(dim_mult) - 1:
|
||||||
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
|
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
|
||||||
upsamples.append(Resample(out_dim, mode=mode))
|
upsamples.append(Resample(out_dim, mode=mode))
|
||||||
scale *= 2.0
|
scale *= 2.0
|
||||||
self.upsamples = nn.Sequential(*upsamples)
|
self.upsamples = nn.Sequential(*upsamples)
|
||||||
|
|
||||||
# output blocks
|
|
||||||
self.head = nn.Sequential(
|
self.head = nn.Sequential(
|
||||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||||
CausalConv3d(out_dim, 3, 3, padding=1))
|
CausalConv3d(out_dim, 3, 3, padding=1))
|
||||||
|
|
||||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
## conv1
|
|
||||||
if feat_cache is not None:
|
if feat_cache is not None:
|
||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||||
# cache last frame of last two chunk
|
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(x.device), cache_x], dim=2)
|
||||||
cache_x = torch.cat([
|
|
||||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
|
||||||
cache_x.device), cache_x
|
|
||||||
],
|
|
||||||
dim=2)
|
|
||||||
x = self.conv1(x, feat_cache[idx])
|
x = self.conv1(x, feat_cache[idx])
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
|
del cache_x
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
else:
|
else:
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
|
|
||||||
## middle
|
|
||||||
for layer in self.middle:
|
for layer in self.middle:
|
||||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||||
x = layer(x, feat_cache, feat_idx)
|
x = layer(x, feat_cache, feat_idx)
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
|
|
||||||
## upsamples
|
|
||||||
for layer in self.upsamples:
|
for layer in self.upsamples:
|
||||||
if feat_cache is not None:
|
if feat_cache is not None:
|
||||||
x = layer(x, feat_cache, feat_idx)
|
x = layer(x, feat_cache, feat_idx)
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
|
|
||||||
## head
|
|
||||||
for layer in self.head:
|
for layer in self.head:
|
||||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||||
# cache last frame of last two chunk
|
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(x.device), cache_x], dim=2)
|
||||||
cache_x = torch.cat([
|
|
||||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
|
||||||
cache_x.device), cache_x
|
|
||||||
],
|
|
||||||
dim=2)
|
|
||||||
x = layer(x, feat_cache[idx])
|
x = layer(x, feat_cache[idx])
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
|
del cache_x
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
@ -481,15 +345,8 @@ def count_conv3d(model):
|
|||||||
|
|
||||||
|
|
||||||
class WanVAE_(nn.Module):
|
class WanVAE_(nn.Module):
|
||||||
|
def __init__(self, dim=128, z_dim=4, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[],
|
||||||
def __init__(self,
|
temperal_downsample=[True, True, False], dropout=0.0):
|
||||||
dim=128,
|
|
||||||
z_dim=4,
|
|
||||||
dim_mult=[1, 2, 4, 4],
|
|
||||||
num_res_blocks=2,
|
|
||||||
attn_scales=[],
|
|
||||||
temperal_downsample=[True, True, False],
|
|
||||||
dropout=0.0):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.z_dim = z_dim
|
self.z_dim = z_dim
|
||||||
@ -498,14 +355,10 @@ class WanVAE_(nn.Module):
|
|||||||
self.attn_scales = attn_scales
|
self.attn_scales = attn_scales
|
||||||
self.temperal_downsample = temperal_downsample
|
self.temperal_downsample = temperal_downsample
|
||||||
self.temperal_upsample = temperal_downsample[::-1]
|
self.temperal_upsample = temperal_downsample[::-1]
|
||||||
|
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout)
|
||||||
# modules
|
|
||||||
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
|
|
||||||
attn_scales, self.temperal_downsample, dropout)
|
|
||||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||||
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout)
|
||||||
attn_scales, self.temperal_upsample, dropout)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
mu, log_var = self.encode(x)
|
mu, log_var = self.encode(x)
|
||||||
@ -513,40 +366,31 @@ class WanVAE_(nn.Module):
|
|||||||
x_recon = self.decode(z)
|
x_recon = self.decode(z)
|
||||||
return x_recon, mu, log_var
|
return x_recon, mu, log_var
|
||||||
|
|
||||||
def encode(self, x, scale):
|
def encode(self, x, scale=None):
|
||||||
self.clear_cache()
|
self.clear_cache()
|
||||||
## cache
|
|
||||||
t = x.shape[2]
|
t = x.shape[2]
|
||||||
iter_ = 1 + (t - 1) // 4
|
iter_ = 1 + (t - 1) // 4
|
||||||
## 对encode输入的x,按时间拆分为1、4、4、4....
|
|
||||||
for i in range(iter_):
|
for i in range(iter_):
|
||||||
self._enc_conv_idx = [0]
|
self._enc_conv_idx = [0]
|
||||||
if i == 0:
|
if i == 0:
|
||||||
out = self.encoder(
|
out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
|
||||||
x[:, :, :1, :, :],
|
|
||||||
feat_cache=self._enc_feat_map,
|
|
||||||
feat_idx=self._enc_conv_idx)
|
|
||||||
else:
|
else:
|
||||||
out_ = self.encoder(
|
out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
|
||||||
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
|
||||||
feat_cache=self._enc_feat_map,
|
|
||||||
feat_idx=self._enc_conv_idx)
|
|
||||||
out = torch.cat([out, out_], 2)
|
out = torch.cat([out, out_], 2)
|
||||||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||||
|
if scale is not None:
|
||||||
if isinstance(scale[0], torch.Tensor):
|
if isinstance(scale[0], torch.Tensor):
|
||||||
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
|
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
|
||||||
1, self.z_dim, 1, 1, 1)
|
|
||||||
else:
|
else:
|
||||||
mu = (mu - scale[0]) * scale[1]
|
mu = (mu - scale[0]) * scale[1]
|
||||||
self.clear_cache()
|
self.clear_cache()
|
||||||
return mu
|
return mu
|
||||||
|
|
||||||
def decode(self, z, scale):
|
def decode(self, z, scale=None):
|
||||||
self.clear_cache()
|
self.clear_cache()
|
||||||
# z: [b,c,t,h,w]
|
if scale is not None:
|
||||||
if isinstance(scale[0], torch.Tensor):
|
if isinstance(scale[0], torch.Tensor):
|
||||||
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
|
||||||
1, self.z_dim, 1, 1, 1)
|
|
||||||
else:
|
else:
|
||||||
z = z / scale[1] + scale[0]
|
z = z / scale[1] + scale[0]
|
||||||
iter_ = z.shape[2]
|
iter_ = z.shape[2]
|
||||||
@ -554,19 +398,88 @@ class WanVAE_(nn.Module):
|
|||||||
for i in range(iter_):
|
for i in range(iter_):
|
||||||
self._conv_idx = [0]
|
self._conv_idx = [0]
|
||||||
if i == 0:
|
if i == 0:
|
||||||
out = self.decoder(
|
out = self.decoder(x[:, :, i:i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
||||||
x[:, :, i:i + 1, :, :],
|
|
||||||
feat_cache=self._feat_map,
|
|
||||||
feat_idx=self._conv_idx)
|
|
||||||
else:
|
else:
|
||||||
out_ = self.decoder(
|
out_ = self.decoder(x[:, :, i:i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
||||||
x[:, :, i:i + 1, :, :],
|
|
||||||
feat_cache=self._feat_map,
|
|
||||||
feat_idx=self._conv_idx)
|
|
||||||
out = torch.cat([out, out_], 2)
|
out = torch.cat([out, out_], 2)
|
||||||
self.clear_cache()
|
self.clear_cache()
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||||
|
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
|
||||||
|
for y in range(blend_extent):
|
||||||
|
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
|
||||||
|
return b
|
||||||
|
|
||||||
|
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||||
|
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
|
||||||
|
for x in range(blend_extent):
|
||||||
|
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
|
||||||
|
return b
|
||||||
|
|
||||||
|
def spatial_tiled_decode(self, z, scale, tile_size):
|
||||||
|
tile_sample_min_size = tile_size
|
||||||
|
tile_latent_min_size = int(tile_sample_min_size / 8)
|
||||||
|
tile_overlap_factor = 0.25
|
||||||
|
if isinstance(scale[0], torch.Tensor):
|
||||||
|
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
|
||||||
|
else:
|
||||||
|
z = z / scale[1] + scale[0]
|
||||||
|
overlap_size = int(tile_latent_min_size * (1 - tile_overlap_factor))
|
||||||
|
blend_extent = int(tile_sample_min_size * tile_overlap_factor)
|
||||||
|
row_limit = tile_sample_min_size - blend_extent
|
||||||
|
rows = []
|
||||||
|
for i in range(0, z.shape[-2], overlap_size):
|
||||||
|
row = []
|
||||||
|
for j in range(0, z.shape[-1], overlap_size):
|
||||||
|
tile = z[:, :, :, i: i + tile_latent_min_size, j: j + tile_latent_min_size]
|
||||||
|
decoded = self.decode(tile)
|
||||||
|
row.append(decoded)
|
||||||
|
rows.append(row)
|
||||||
|
result_rows = []
|
||||||
|
for i, row in enumerate(rows):
|
||||||
|
result_row = []
|
||||||
|
for j, tile in enumerate(row):
|
||||||
|
if i > 0:
|
||||||
|
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||||
|
if j > 0:
|
||||||
|
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||||
|
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
||||||
|
result_rows.append(torch.cat(result_row, dim=-1))
|
||||||
|
return torch.cat(result_rows, dim=-2)
|
||||||
|
|
||||||
|
def spatial_tiled_encode(self, x, scale, tile_size):
|
||||||
|
tile_sample_min_size = tile_size
|
||||||
|
tile_latent_min_size = int(tile_sample_min_size / 8)
|
||||||
|
tile_overlap_factor = 0.25
|
||||||
|
overlap_size = int(tile_sample_min_size * (1 - tile_overlap_factor))
|
||||||
|
blend_extent = int(tile_latent_min_size * tile_overlap_factor)
|
||||||
|
row_limit = tile_latent_min_size - blend_extent
|
||||||
|
rows = []
|
||||||
|
for i in range(0, x.shape[-2], overlap_size):
|
||||||
|
row = []
|
||||||
|
for j in range(0, x.shape[-1], overlap_size):
|
||||||
|
tile = x[:, :, :, i: i + tile_sample_min_size, j: j + tile_sample_min_size]
|
||||||
|
tile = self.encode(tile)
|
||||||
|
row.append(tile)
|
||||||
|
rows.append(row)
|
||||||
|
result_rows = []
|
||||||
|
for i, row in enumerate(rows):
|
||||||
|
result_row = []
|
||||||
|
for j, tile in enumerate(row):
|
||||||
|
if i > 0:
|
||||||
|
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||||
|
if j > 0:
|
||||||
|
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||||
|
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
||||||
|
result_rows.append(torch.cat(result_row, dim=-1))
|
||||||
|
mu = torch.cat(result_rows, dim=-2)
|
||||||
|
if isinstance(scale[0], torch.Tensor):
|
||||||
|
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
|
||||||
|
else:
|
||||||
|
mu = (mu - scale[0]) * scale[1]
|
||||||
|
return mu
|
||||||
|
|
||||||
def reparameterize(self, mu, log_var):
|
def reparameterize(self, mu, log_var):
|
||||||
std = torch.exp(0.5 * log_var)
|
std = torch.exp(0.5 * log_var)
|
||||||
eps = torch.randn_like(std)
|
eps = torch.randn_like(std)
|
||||||
@ -583,81 +496,43 @@ class WanVAE_(nn.Module):
|
|||||||
self._conv_num = count_conv3d(self.decoder)
|
self._conv_num = count_conv3d(self.decoder)
|
||||||
self._conv_idx = [0]
|
self._conv_idx = [0]
|
||||||
self._feat_map = [None] * self._conv_num
|
self._feat_map = [None] * self._conv_num
|
||||||
#cache encode
|
|
||||||
self._enc_conv_num = count_conv3d(self.encoder)
|
self._enc_conv_num = count_conv3d(self.encoder)
|
||||||
self._enc_conv_idx = [0]
|
self._enc_conv_idx = [0]
|
||||||
self._enc_feat_map = [None] * self._enc_conv_num
|
self._enc_feat_map = [None] * self._enc_conv_num
|
||||||
|
|
||||||
|
|
||||||
def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
|
def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
|
||||||
"""
|
cfg = dict(dim=96, z_dim=z_dim, dim_mult=[1, 2, 4, 4], num_res_blocks=2,
|
||||||
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
|
attn_scales=[], temperal_downsample=[False, True, True], dropout=0.0)
|
||||||
"""
|
|
||||||
# params
|
|
||||||
cfg = dict(
|
|
||||||
dim=96,
|
|
||||||
z_dim=z_dim,
|
|
||||||
dim_mult=[1, 2, 4, 4],
|
|
||||||
num_res_blocks=2,
|
|
||||||
attn_scales=[],
|
|
||||||
temperal_downsample=[False, True, True],
|
|
||||||
dropout=0.0)
|
|
||||||
cfg.update(**kwargs)
|
cfg.update(**kwargs)
|
||||||
|
|
||||||
# init model
|
|
||||||
with torch.device('meta'):
|
|
||||||
model = WanVAE_(**cfg)
|
model = WanVAE_(**cfg)
|
||||||
|
|
||||||
# load checkpoint
|
|
||||||
logging.info(f'loading {pretrained_path}')
|
logging.info(f'loading {pretrained_path}')
|
||||||
model.load_state_dict(
|
model.load_state_dict(torch.load(pretrained_path, map_location=device), assign=True)
|
||||||
torch.load(pretrained_path, map_location=device), assign=True)
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
class WanVAE:
|
class WanVAE:
|
||||||
|
def __init__(self, z_dim=16, vae_pth='cache/vae_step_411000.pth', dtype=torch.float, device="cuda"):
|
||||||
def __init__(self,
|
|
||||||
z_dim=16,
|
|
||||||
vae_pth='cache/vae_step_411000.pth',
|
|
||||||
dtype=torch.float,
|
|
||||||
device="cuda"):
|
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
|
mean = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
||||||
mean = [
|
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921]
|
||||||
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
std = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
||||||
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
|
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160]
|
||||||
]
|
|
||||||
std = [
|
|
||||||
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
|
||||||
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
|
|
||||||
]
|
|
||||||
self.mean = torch.tensor(mean, dtype=dtype, device=device)
|
self.mean = torch.tensor(mean, dtype=dtype, device=device)
|
||||||
self.std = torch.tensor(std, dtype=dtype, device=device)
|
self.std = torch.tensor(std, dtype=dtype, device=device)
|
||||||
self.scale = [self.mean, 1.0 / self.std]
|
self.scale = [self.mean, 1.0 / self.std]
|
||||||
|
self.model = _video_vae(pretrained_path=vae_pth, z_dim=z_dim, device=device)
|
||||||
|
self.model = self.model.eval().requires_grad_(False).to(device)
|
||||||
|
|
||||||
# init model
|
def encode(self, videos, tile_size=256):
|
||||||
self.model = _video_vae(
|
if tile_size > 0:
|
||||||
pretrained_path=vae_pth,
|
return [self.model.spatial_tiled_encode(u.unsqueeze(0), self.scale, tile_size).float().squeeze(0) for u in videos]
|
||||||
z_dim=z_dim,
|
else:
|
||||||
).eval().requires_grad_(False).to(device)
|
return [self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) for u in videos]
|
||||||
|
|
||||||
def encode(self, videos):
|
def decode(self, zs, tile_size):
|
||||||
"""
|
if tile_size > 0:
|
||||||
videos: A list of videos each with shape [C, T, H, W].
|
return [self.model.spatial_tiled_decode(u.unsqueeze(0), self.scale, tile_size).float().clamp_(-1, 1).squeeze(0) for u in zs]
|
||||||
"""
|
else:
|
||||||
with amp.autocast(dtype=self.dtype):
|
return [self.model.decode(u.unsqueeze(0), self.scale).float().clamp_(-1, 1).squeeze(0) for u in zs]
|
||||||
return [
|
|
||||||
self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
|
|
||||||
for u in videos
|
|
||||||
]
|
|
||||||
|
|
||||||
def decode(self, zs):
|
|
||||||
with amp.autocast(dtype=self.dtype):
|
|
||||||
return [
|
|
||||||
self.model.decode(u.unsqueeze(0),
|
|
||||||
self.scale).float().clamp_(-1, 1).squeeze(0)
|
|
||||||
for u in zs
|
|
||||||
]
|
|
||||||
|
Loading…
Reference in New Issue
Block a user