VAE Tiling Support - credits to deepbeepmeep's WanGP

This commit is contained in:
pftq 2025-03-18 22:16:33 -07:00 committed by GitHub
parent a0de59e928
commit 09ba994635
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,6 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import logging import logging
from mmgp import offload
import torch import torch
import torch.cuda.amp as amp import torch.cuda.amp as amp
import torch.nn as nn import torch.nn as nn
@ -15,10 +15,6 @@ CACHE_T = 2
class CausalConv3d(nn.Conv3d): class CausalConv3d(nn.Conv3d):
"""
Causal 3d convolusion.
"""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._padding = (self.padding[2], self.padding[2], self.padding[1], self._padding = (self.padding[2], self.padding[2], self.padding[1],
@ -31,48 +27,38 @@ class CausalConv3d(nn.Conv3d):
cache_x = cache_x.to(x.device) cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2) x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2] padding[4] -= cache_x.shape[2]
cache_x = None
x = F.pad(x, padding) x = F.pad(x, padding)
x = super().forward(x)
return super().forward(x) return x
class RMS_norm(nn.Module): class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=True, images=True, bias=False): def __init__(self, dim, channel_first=True, images=True, bias=False):
super().__init__() super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1) broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,) shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first self.channel_first = channel_first
self.scale = dim**0.5 self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape)) self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0. self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
def forward(self, x): def forward(self, x):
return F.normalize( x = F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
x, dim=(1 if self.channel_first else return x
-1)) * self.scale * self.gamma + self.bias
class Upsample(nn.Upsample): class Upsample(nn.Upsample):
def forward(self, x): def forward(self, x):
"""
Fix bfloat16 support for nearest neighbor interpolation.
"""
return super().forward(x.float()).type_as(x) return super().forward(x.float()).type_as(x)
class Resample(nn.Module): class Resample(nn.Module):
def __init__(self, dim, mode): def __init__(self, dim, mode):
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d', assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d', 'downsample3d')
'downsample3d')
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.mode = mode self.mode = mode
# layers
if mode == 'upsample2d': if mode == 'upsample2d':
self.resample = nn.Sequential( self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'), Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
@ -81,9 +67,7 @@ class Resample(nn.Module):
self.resample = nn.Sequential( self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'), Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Conv2d(dim, dim // 2, 3, padding=1)) nn.Conv2d(dim, dim // 2, 3, padding=1))
self.time_conv = CausalConv3d( self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == 'downsample2d': elif mode == 'downsample2d':
self.resample = nn.Sequential( self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)), nn.ZeroPad2d((0, 1, 0, 1)),
@ -92,9 +76,7 @@ class Resample(nn.Module):
self.resample = nn.Sequential( self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)), nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(dim, dim, 3, stride=(2, 2))) nn.Conv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = CausalConv3d( self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else: else:
self.resample = nn.Identity() self.resample = nn.Identity()
@ -107,54 +89,38 @@ class Resample(nn.Module):
feat_cache[idx] = 'Rep' feat_cache[idx] = 'Rep'
feat_idx[0] += 1 feat_idx[0] += 1
else: else:
clone = True
cache_x = x[:, :, -CACHE_T:, :, :].clone() cache_x = x[:, :, -CACHE_T:, :, :]
if cache_x.shape[2] < 2 and feat_cache[ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != 'Rep':
idx] is not None and feat_cache[idx] != 'Rep': clone = False
# cache last frame of last two chunk cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
cache_x = torch.cat([ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == 'Rep':
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( clone = False
cache_x.device), cache_x cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
], if clone:
dim=2) cache_x = cache_x.clone()
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] == 'Rep':
cache_x = torch.cat([
torch.zeros_like(cache_x).to(cache_x.device),
cache_x
],
dim=2)
if feat_cache[idx] == 'Rep': if feat_cache[idx] == 'Rep':
x = self.time_conv(x) x = self.time_conv(x)
else: else:
x = self.time_conv(x, feat_cache[idx]) x = self.time_conv(x, feat_cache[idx])
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
feat_idx[0] += 1 feat_idx[0] += 1
x = x.reshape(b, 2, c, t, h, w) x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
3)
x = x.reshape(b, c, t * 2, h, w) x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2] t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w') x = rearrange(x, 'b c t h w -> (b t) c h w')
x = self.resample(x) x = self.resample(x)
x = rearrange(x, '(b t) c h w -> b c t h w', t=t) x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
if self.mode == 'downsample3d': if self.mode == 'downsample3d':
if feat_cache is not None: if feat_cache is not None:
idx = feat_idx[0] idx = feat_idx[0]
if feat_cache[idx] is None: if feat_cache[idx] is None:
feat_cache[idx] = x.clone() feat_cache[idx] = x
feat_idx[0] += 1 feat_idx[0] += 1
else: else:
cache_x = x[:, :, -1:, :, :].clone() cache_x = x[:, :, -1:, :, :].clone()
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep': x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
# # cache last frame of last two chunk
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
feat_idx[0] += 1 feat_idx[0] += 1
return x return x
@ -164,10 +130,8 @@ class Resample(nn.Module):
nn.init.zeros_(conv_weight) nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size() c1, c2, t, h, w = conv_weight.size()
one_matrix = torch.eye(c1, c2) one_matrix = torch.eye(c1, c2)
init_matrix = one_matrix
nn.init.zeros_(conv_weight) nn.init.zeros_(conv_weight)
#conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5 conv_weight.data[:, :, 1, 0, 0] = one_matrix
conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
conv.weight.data.copy_(conv_weight) conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data) nn.init.zeros_(conv.bias.data)
@ -176,7 +140,6 @@ class Resample(nn.Module):
nn.init.zeros_(conv_weight) nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size() c1, c2, t, h, w = conv_weight.size()
init_matrix = torch.eye(c1 // 2, c2) init_matrix = torch.eye(c1 // 2, c2)
#init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
conv.weight.data.copy_(conv_weight) conv.weight.data.copy_(conv_weight)
@ -184,20 +147,16 @@ class Resample(nn.Module):
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0.0): def __init__(self, in_dim, out_dim, dropout=0.0):
super().__init__() super().__init__()
self.in_dim = in_dim self.in_dim = in_dim
self.out_dim = out_dim self.out_dim = out_dim
# layers
self.residual = nn.Sequential( self.residual = nn.Sequential(
RMS_norm(in_dim, images=False), nn.SiLU(), RMS_norm(in_dim, images=False), nn.SiLU(),
CausalConv3d(in_dim, out_dim, 3, padding=1), CausalConv3d(in_dim, out_dim, 3, padding=1),
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout), RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
CausalConv3d(out_dim, out_dim, 3, padding=1)) CausalConv3d(out_dim, out_dim, 3, padding=1))
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]): def forward(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x) h = self.shortcut(x)
@ -206,12 +165,7 @@ class ResidualBlock(nn.Module):
idx = feat_idx[0] idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone() cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None: if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(x.device), cache_x], dim=2)
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx]) x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
feat_idx[0] += 1 feat_idx[0] += 1
@ -221,20 +175,12 @@ class ResidualBlock(nn.Module):
class AttentionBlock(nn.Module): class AttentionBlock(nn.Module):
"""
Causal self-attention with a single head.
"""
def __init__(self, dim): def __init__(self, dim):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
# layers
self.norm = RMS_norm(dim) self.norm = RMS_norm(dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1) self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
self.proj = nn.Conv2d(dim, dim, 1) self.proj = nn.Conv2d(dim, dim, 1)
# zero out the last layer params
nn.init.zeros_(self.proj.weight) nn.init.zeros_(self.proj.weight)
def forward(self, x): def forward(self, x):
@ -242,36 +188,17 @@ class AttentionBlock(nn.Module):
b, c, t, h, w = x.size() b, c, t, h, w = x.size()
x = rearrange(x, 'b c t h w -> (b t) c h w') x = rearrange(x, 'b c t h w -> (b t) c h w')
x = self.norm(x) x = self.norm(x)
# compute query, key, value q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1)
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, x = F.scaled_dot_product_attention(q, k, v)
-1).permute(0, 1, 3,
2).contiguous().chunk(
3, dim=-1)
# apply attention
x = F.scaled_dot_product_attention(
q,
k,
v,
)
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
# output
x = self.proj(x) x = self.proj(x)
x = rearrange(x, '(b t) c h w-> b c t h w', t=t) x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
return x + identity return x + identity
class Encoder3d(nn.Module): class Encoder3d(nn.Module):
def __init__(self, dim=128, z_dim=4, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[],
def __init__(self, temperal_downsample=[True, True, False], dropout=0.0):
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.z_dim = z_dim self.z_dim = z_dim
@ -279,38 +206,24 @@ class Encoder3d(nn.Module):
self.num_res_blocks = num_res_blocks self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample self.temperal_downsample = temperal_downsample
# dimensions
dims = [dim * u for u in [1] + dim_mult] dims = [dim * u for u in [1] + dim_mult]
scale = 1.0 scale = 1.0
# init block
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
# downsample blocks
downsamples = [] downsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
for _ in range(num_res_blocks): for _ in range(num_res_blocks):
downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales: if scale in attn_scales:
downsamples.append(AttentionBlock(out_dim)) downsamples.append(AttentionBlock(out_dim))
in_dim = out_dim in_dim = out_dim
# downsample block
if i != len(dim_mult) - 1: if i != len(dim_mult) - 1:
mode = 'downsample3d' if temperal_downsample[ mode = 'downsample3d' if temperal_downsample[i] else 'downsample2d'
i] else 'downsample2d'
downsamples.append(Resample(out_dim, mode=mode)) downsamples.append(Resample(out_dim, mode=mode))
scale /= 2.0 scale /= 2.0
self.downsamples = nn.Sequential(*downsamples) self.downsamples = nn.Sequential(*downsamples)
# middle blocks
self.middle = nn.Sequential( self.middle = nn.Sequential(
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
ResidualBlock(out_dim, out_dim, dropout)) ResidualBlock(out_dim, out_dim, dropout))
# output blocks
self.head = nn.Sequential( self.head = nn.Sequential(
RMS_norm(out_dim, images=False), nn.SiLU(), RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, z_dim, 3, padding=1)) CausalConv3d(out_dim, z_dim, 3, padding=1))
@ -320,46 +233,32 @@ class Encoder3d(nn.Module):
idx = feat_idx[0] idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone() cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None: if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(x.device), cache_x], dim=2)
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx]) x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
del cache_x
feat_idx[0] += 1 feat_idx[0] += 1
else: else:
x = self.conv1(x) x = self.conv1(x)
## downsamples
for layer in self.downsamples: for layer in self.downsamples:
if feat_cache is not None: if feat_cache is not None:
x = layer(x, feat_cache, feat_idx) x = layer(x, feat_cache, feat_idx)
else: else:
x = layer(x) x = layer(x)
## middle
for layer in self.middle: for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None: if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx) x = layer(x, feat_cache, feat_idx)
else: else:
x = layer(x) x = layer(x)
## head
for layer in self.head: for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None: if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0] idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone() cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None: if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(x.device), cache_x], dim=2)
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx]) x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
del cache_x
feat_idx[0] += 1 feat_idx[0] += 1
else: else:
x = layer(x) x = layer(x)
@ -367,15 +266,8 @@ class Encoder3d(nn.Module):
class Decoder3d(nn.Module): class Decoder3d(nn.Module):
def __init__(self, dim=128, z_dim=4, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[],
def __init__(self, temperal_upsample=[False, True, True], dropout=0.0):
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.z_dim = z_dim self.z_dim = z_dim
@ -383,23 +275,14 @@ class Decoder3d(nn.Module):
self.num_res_blocks = num_res_blocks self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales self.attn_scales = attn_scales
self.temperal_upsample = temperal_upsample self.temperal_upsample = temperal_upsample
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2**(len(dim_mult) - 2) scale = 1.0 / 2**(len(dim_mult) - 2)
# init block
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.middle = nn.Sequential( self.middle = nn.Sequential(
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0], dropout)) ResidualBlock(dims[0], dims[0], dropout))
# upsample blocks
upsamples = [] upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
if i == 1 or i == 2 or i == 3: if i == 1 or i == 2 or i == 3:
in_dim = in_dim // 2 in_dim = in_dim // 2
for _ in range(num_res_blocks + 1): for _ in range(num_res_blocks + 1):
@ -407,65 +290,46 @@ class Decoder3d(nn.Module):
if scale in attn_scales: if scale in attn_scales:
upsamples.append(AttentionBlock(out_dim)) upsamples.append(AttentionBlock(out_dim))
in_dim = out_dim in_dim = out_dim
# upsample block
if i != len(dim_mult) - 1: if i != len(dim_mult) - 1:
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d' mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
upsamples.append(Resample(out_dim, mode=mode)) upsamples.append(Resample(out_dim, mode=mode))
scale *= 2.0 scale *= 2.0
self.upsamples = nn.Sequential(*upsamples) self.upsamples = nn.Sequential(*upsamples)
# output blocks
self.head = nn.Sequential( self.head = nn.Sequential(
RMS_norm(out_dim, images=False), nn.SiLU(), RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, 3, 3, padding=1)) CausalConv3d(out_dim, 3, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]): def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1
if feat_cache is not None: if feat_cache is not None:
idx = feat_idx[0] idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone() cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None: if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(x.device), cache_x], dim=2)
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx]) x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
del cache_x
feat_idx[0] += 1 feat_idx[0] += 1
else: else:
x = self.conv1(x) x = self.conv1(x)
## middle
for layer in self.middle: for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None: if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx) x = layer(x, feat_cache, feat_idx)
else: else:
x = layer(x) x = layer(x)
## upsamples
for layer in self.upsamples: for layer in self.upsamples:
if feat_cache is not None: if feat_cache is not None:
x = layer(x, feat_cache, feat_idx) x = layer(x, feat_cache, feat_idx)
else: else:
x = layer(x) x = layer(x)
## head
for layer in self.head: for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None: if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0] idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone() cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None: if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(x.device), cache_x], dim=2)
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx]) x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
del cache_x
feat_idx[0] += 1 feat_idx[0] += 1
else: else:
x = layer(x) x = layer(x)
@ -481,15 +345,8 @@ def count_conv3d(model):
class WanVAE_(nn.Module): class WanVAE_(nn.Module):
def __init__(self, dim=128, z_dim=4, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[],
def __init__(self, temperal_downsample=[True, True, False], dropout=0.0):
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.z_dim = z_dim self.z_dim = z_dim
@ -498,14 +355,10 @@ class WanVAE_(nn.Module):
self.attn_scales = attn_scales self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1] self.temperal_upsample = temperal_downsample[::-1]
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout)
# modules
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
attn_scales, self.temperal_downsample, dropout)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1) self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout)
attn_scales, self.temperal_upsample, dropout)
def forward(self, x): def forward(self, x):
mu, log_var = self.encode(x) mu, log_var = self.encode(x)
@ -513,60 +366,120 @@ class WanVAE_(nn.Module):
x_recon = self.decode(z) x_recon = self.decode(z)
return x_recon, mu, log_var return x_recon, mu, log_var
def encode(self, x, scale): def encode(self, x, scale=None):
self.clear_cache() self.clear_cache()
## cache
t = x.shape[2] t = x.shape[2]
iter_ = 1 + (t - 1) // 4 iter_ = 1 + (t - 1) // 4
## 对encode输入的x按时间拆分为1、4、4、4....
for i in range(iter_): for i in range(iter_):
self._enc_conv_idx = [0] self._enc_conv_idx = [0]
if i == 0: if i == 0:
out = self.encoder( out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
else: else:
out_ = self.encoder( out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
out = torch.cat([out, out_], 2) out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1) mu, log_var = self.conv1(out).chunk(2, dim=1)
if isinstance(scale[0], torch.Tensor): if scale is not None:
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( if isinstance(scale[0], torch.Tensor):
1, self.z_dim, 1, 1, 1) mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
else: else:
mu = (mu - scale[0]) * scale[1] mu = (mu - scale[0]) * scale[1]
self.clear_cache() self.clear_cache()
return mu return mu
def decode(self, z, scale): def decode(self, z, scale=None):
self.clear_cache() self.clear_cache()
# z: [b,c,t,h,w] if scale is not None:
if isinstance(scale[0], torch.Tensor): if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
1, self.z_dim, 1, 1, 1) else:
else: z = z / scale[1] + scale[0]
z = z / scale[1] + scale[0]
iter_ = z.shape[2] iter_ = z.shape[2]
x = self.conv2(z) x = self.conv2(z)
for i in range(iter_): for i in range(iter_):
self._conv_idx = [0] self._conv_idx = [0]
if i == 0: if i == 0:
out = self.decoder( out = self.decoder(x[:, :, i:i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
else: else:
out_ = self.decoder( out_ = self.decoder(x[:, :, i:i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2) out = torch.cat([out, out_], 2)
self.clear_cache() self.clear_cache()
return out return out
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
return b
def spatial_tiled_decode(self, z, scale, tile_size):
tile_sample_min_size = tile_size
tile_latent_min_size = int(tile_sample_min_size / 8)
tile_overlap_factor = 0.25
if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
else:
z = z / scale[1] + scale[0]
overlap_size = int(tile_latent_min_size * (1 - tile_overlap_factor))
blend_extent = int(tile_sample_min_size * tile_overlap_factor)
row_limit = tile_sample_min_size - blend_extent
rows = []
for i in range(0, z.shape[-2], overlap_size):
row = []
for j in range(0, z.shape[-1], overlap_size):
tile = z[:, :, :, i: i + tile_latent_min_size, j: j + tile_latent_min_size]
decoded = self.decode(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=-1))
return torch.cat(result_rows, dim=-2)
def spatial_tiled_encode(self, x, scale, tile_size):
tile_sample_min_size = tile_size
tile_latent_min_size = int(tile_sample_min_size / 8)
tile_overlap_factor = 0.25
overlap_size = int(tile_sample_min_size * (1 - tile_overlap_factor))
blend_extent = int(tile_latent_min_size * tile_overlap_factor)
row_limit = tile_latent_min_size - blend_extent
rows = []
for i in range(0, x.shape[-2], overlap_size):
row = []
for j in range(0, x.shape[-1], overlap_size):
tile = x[:, :, :, i: i + tile_sample_min_size, j: j + tile_sample_min_size]
tile = self.encode(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=-1))
mu = torch.cat(result_rows, dim=-2)
if isinstance(scale[0], torch.Tensor):
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
else:
mu = (mu - scale[0]) * scale[1]
return mu
def reparameterize(self, mu, log_var): def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var) std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std) eps = torch.randn_like(std)
@ -583,81 +496,43 @@ class WanVAE_(nn.Module):
self._conv_num = count_conv3d(self.decoder) self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0] self._conv_idx = [0]
self._feat_map = [None] * self._conv_num self._feat_map = [None] * self._conv_num
#cache encode
self._enc_conv_num = count_conv3d(self.encoder) self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0] self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num self._enc_feat_map = [None] * self._enc_conv_num
def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs): def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
""" cfg = dict(dim=96, z_dim=z_dim, dim_mult=[1, 2, 4, 4], num_res_blocks=2,
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL. attn_scales=[], temperal_downsample=[False, True, True], dropout=0.0)
"""
# params
cfg = dict(
dim=96,
z_dim=z_dim,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[False, True, True],
dropout=0.0)
cfg.update(**kwargs) cfg.update(**kwargs)
model = WanVAE_(**cfg)
# init model
with torch.device('meta'):
model = WanVAE_(**cfg)
# load checkpoint
logging.info(f'loading {pretrained_path}') logging.info(f'loading {pretrained_path}')
model.load_state_dict( model.load_state_dict(torch.load(pretrained_path, map_location=device), assign=True)
torch.load(pretrained_path, map_location=device), assign=True)
return model return model
class WanVAE: class WanVAE:
def __init__(self, z_dim=16, vae_pth='cache/vae_step_411000.pth', dtype=torch.float, device="cuda"):
def __init__(self,
z_dim=16,
vae_pth='cache/vae_step_411000.pth',
dtype=torch.float,
device="cuda"):
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
mean = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
mean = [ 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921]
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, std = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160]
]
std = [
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
]
self.mean = torch.tensor(mean, dtype=dtype, device=device) self.mean = torch.tensor(mean, dtype=dtype, device=device)
self.std = torch.tensor(std, dtype=dtype, device=device) self.std = torch.tensor(std, dtype=dtype, device=device)
self.scale = [self.mean, 1.0 / self.std] self.scale = [self.mean, 1.0 / self.std]
self.model = _video_vae(pretrained_path=vae_pth, z_dim=z_dim, device=device)
self.model = self.model.eval().requires_grad_(False).to(device)
# init model def encode(self, videos, tile_size=256):
self.model = _video_vae( if tile_size > 0:
pretrained_path=vae_pth, return [self.model.spatial_tiled_encode(u.unsqueeze(0), self.scale, tile_size).float().squeeze(0) for u in videos]
z_dim=z_dim, else:
).eval().requires_grad_(False).to(device) return [self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) for u in videos]
def encode(self, videos): def decode(self, zs, tile_size):
""" if tile_size > 0:
videos: A list of videos each with shape [C, T, H, W]. return [self.model.spatial_tiled_decode(u.unsqueeze(0), self.scale, tile_size).float().clamp_(-1, 1).squeeze(0) for u in zs]
""" else:
with amp.autocast(dtype=self.dtype): return [self.model.decode(u.unsqueeze(0), self.scale).float().clamp_(-1, 1).squeeze(0) for u in zs]
return [
self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
for u in videos
]
def decode(self, zs):
with amp.autocast(dtype=self.dtype):
return [
self.model.decode(u.unsqueeze(0),
self.scale).float().clamp_(-1, 1).squeeze(0)
for u in zs
]