From 09ba994635350c746bdd81094212793cd05003b3 Mon Sep 17 00:00:00 2001 From: pftq Date: Tue, 18 Mar 2025 22:16:33 -0700 Subject: [PATCH] VAE Tiling Support - credits to deepbeepmeep's WanGP --- wan/modules/vae.py | 439 ++++++++++++++++----------------------------- 1 file changed, 157 insertions(+), 282 deletions(-) diff --git a/wan/modules/vae.py b/wan/modules/vae.py index 5c6da57..c70a066 100644 --- a/wan/modules/vae.py +++ b/wan/modules/vae.py @@ -1,6 +1,6 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import logging - +from mmgp import offload import torch import torch.cuda.amp as amp import torch.nn as nn @@ -15,10 +15,6 @@ CACHE_T = 2 class CausalConv3d(nn.Conv3d): - """ - Causal 3d convolusion. - """ - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._padding = (self.padding[2], self.padding[2], self.padding[1], @@ -31,48 +27,38 @@ class CausalConv3d(nn.Conv3d): cache_x = cache_x.to(x.device) x = torch.cat([cache_x, x], dim=2) padding[4] -= cache_x.shape[2] + cache_x = None x = F.pad(x, padding) - - return super().forward(x) + x = super().forward(x) + return x class RMS_norm(nn.Module): - def __init__(self, dim, channel_first=True, images=True, bias=False): super().__init__() broadcastable_dims = (1, 1, 1) if not images else (1, 1) shape = (dim, *broadcastable_dims) if channel_first else (dim,) - self.channel_first = channel_first self.scale = dim**0.5 self.gamma = nn.Parameter(torch.ones(shape)) self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0. def forward(self, x): - return F.normalize( - x, dim=(1 if self.channel_first else - -1)) * self.scale * self.gamma + self.bias + x = F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + return x class Upsample(nn.Upsample): - def forward(self, x): - """ - Fix bfloat16 support for nearest neighbor interpolation. - """ return super().forward(x.float()).type_as(x) class Resample(nn.Module): - def __init__(self, dim, mode): - assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d', - 'downsample3d') + assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d', 'downsample3d') super().__init__() self.dim = dim self.mode = mode - - # layers if mode == 'upsample2d': self.resample = nn.Sequential( Upsample(scale_factor=(2., 2.), mode='nearest-exact'), @@ -81,9 +67,7 @@ class Resample(nn.Module): self.resample = nn.Sequential( Upsample(scale_factor=(2., 2.), mode='nearest-exact'), nn.Conv2d(dim, dim // 2, 3, padding=1)) - self.time_conv = CausalConv3d( - dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) - + self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) elif mode == 'downsample2d': self.resample = nn.Sequential( nn.ZeroPad2d((0, 1, 0, 1)), @@ -92,9 +76,7 @@ class Resample(nn.Module): self.resample = nn.Sequential( nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) - self.time_conv = CausalConv3d( - dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) - + self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) else: self.resample = nn.Identity() @@ -107,54 +89,38 @@ class Resample(nn.Module): feat_cache[idx] = 'Rep' feat_idx[0] += 1 else: - - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[ - idx] is not None and feat_cache[idx] != 'Rep': - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) - if cache_x.shape[2] < 2 and feat_cache[ - idx] is not None and feat_cache[idx] == 'Rep': - cache_x = torch.cat([ - torch.zeros_like(cache_x).to(cache_x.device), - cache_x - ], - dim=2) + clone = True + cache_x = x[:, :, -CACHE_T:, :, :] + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != 'Rep': + clone = False + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == 'Rep': + clone = False + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + if clone: + cache_x = cache_x.clone() if feat_cache[idx] == 'Rep': x = self.time_conv(x) else: x = self.time_conv(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 - x = x.reshape(b, 2, c, t, h, w) - x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), - 3) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) x = x.reshape(b, c, t * 2, h, w) t = x.shape[2] x = rearrange(x, 'b c t h w -> (b t) c h w') x = self.resample(x) x = rearrange(x, '(b t) c h w -> b c t h w', t=t) - if self.mode == 'downsample3d': if feat_cache is not None: idx = feat_idx[0] if feat_cache[idx] is None: - feat_cache[idx] = x.clone() + feat_cache[idx] = x feat_idx[0] += 1 else: - cache_x = x[:, :, -1:, :, :].clone() - # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep': - # # cache last frame of last two chunk - # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) - - x = self.time_conv( - torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) feat_cache[idx] = cache_x feat_idx[0] += 1 return x @@ -164,10 +130,8 @@ class Resample(nn.Module): nn.init.zeros_(conv_weight) c1, c2, t, h, w = conv_weight.size() one_matrix = torch.eye(c1, c2) - init_matrix = one_matrix nn.init.zeros_(conv_weight) - #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5 - conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5 + conv_weight.data[:, :, 1, 0, 0] = one_matrix conv.weight.data.copy_(conv_weight) nn.init.zeros_(conv.bias.data) @@ -176,7 +140,6 @@ class Resample(nn.Module): nn.init.zeros_(conv_weight) c1, c2, t, h, w = conv_weight.size() init_matrix = torch.eye(c1 // 2, c2) - #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2) conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix conv.weight.data.copy_(conv_weight) @@ -184,20 +147,16 @@ class Resample(nn.Module): class ResidualBlock(nn.Module): - def __init__(self, in_dim, out_dim, dropout=0.0): super().__init__() self.in_dim = in_dim self.out_dim = out_dim - - # layers self.residual = nn.Sequential( RMS_norm(in_dim, images=False), nn.SiLU(), CausalConv3d(in_dim, out_dim, 3, padding=1), RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout), CausalConv3d(out_dim, out_dim, 3, padding=1)) - self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ - if in_dim != out_dim else nn.Identity() + self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() def forward(self, x, feat_cache=None, feat_idx=[0]): h = self.shortcut(x) @@ -206,12 +165,7 @@ class ResidualBlock(nn.Module): idx = feat_idx[0] cache_x = x[:, :, -CACHE_T:, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(x.device), cache_x], dim=2) x = layer(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -221,20 +175,12 @@ class ResidualBlock(nn.Module): class AttentionBlock(nn.Module): - """ - Causal self-attention with a single head. - """ - def __init__(self, dim): super().__init__() self.dim = dim - - # layers self.norm = RMS_norm(dim) self.to_qkv = nn.Conv2d(dim, dim * 3, 1) self.proj = nn.Conv2d(dim, dim, 1) - - # zero out the last layer params nn.init.zeros_(self.proj.weight) def forward(self, x): @@ -242,36 +188,17 @@ class AttentionBlock(nn.Module): b, c, t, h, w = x.size() x = rearrange(x, 'b c t h w -> (b t) c h w') x = self.norm(x) - # compute query, key, value - q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, - -1).permute(0, 1, 3, - 2).contiguous().chunk( - 3, dim=-1) - - # apply attention - x = F.scaled_dot_product_attention( - q, - k, - v, - ) + q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1) + x = F.scaled_dot_product_attention(q, k, v) x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) - - # output x = self.proj(x) - x = rearrange(x, '(b t) c h w-> b c t h w', t=t) + x = rearrange(x, '(b t) c h w -> b c t h w', t=t) return x + identity class Encoder3d(nn.Module): - - def __init__(self, - dim=128, - z_dim=4, - dim_mult=[1, 2, 4, 4], - num_res_blocks=2, - attn_scales=[], - temperal_downsample=[True, True, False], - dropout=0.0): + def __init__(self, dim=128, z_dim=4, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], + temperal_downsample=[True, True, False], dropout=0.0): super().__init__() self.dim = dim self.z_dim = z_dim @@ -279,38 +206,24 @@ class Encoder3d(nn.Module): self.num_res_blocks = num_res_blocks self.attn_scales = attn_scales self.temperal_downsample = temperal_downsample - - # dimensions dims = [dim * u for u in [1] + dim_mult] scale = 1.0 - - # init block self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) - - # downsample blocks downsamples = [] for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): - # residual (+attention) blocks for _ in range(num_res_blocks): downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) if scale in attn_scales: downsamples.append(AttentionBlock(out_dim)) in_dim = out_dim - - # downsample block if i != len(dim_mult) - 1: - mode = 'downsample3d' if temperal_downsample[ - i] else 'downsample2d' + mode = 'downsample3d' if temperal_downsample[i] else 'downsample2d' downsamples.append(Resample(out_dim, mode=mode)) scale /= 2.0 self.downsamples = nn.Sequential(*downsamples) - - # middle blocks self.middle = nn.Sequential( ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), ResidualBlock(out_dim, out_dim, dropout)) - - # output blocks self.head = nn.Sequential( RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, z_dim, 3, padding=1)) @@ -320,46 +233,32 @@ class Encoder3d(nn.Module): idx = feat_idx[0] cache_x = x[:, :, -CACHE_T:, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(x.device), cache_x], dim=2) x = self.conv1(x, feat_cache[idx]) feat_cache[idx] = cache_x + del cache_x feat_idx[0] += 1 else: x = self.conv1(x) - - ## downsamples for layer in self.downsamples: if feat_cache is not None: x = layer(x, feat_cache, feat_idx) else: x = layer(x) - - ## middle for layer in self.middle: if isinstance(layer, ResidualBlock) and feat_cache is not None: x = layer(x, feat_cache, feat_idx) else: x = layer(x) - - ## head for layer in self.head: if isinstance(layer, CausalConv3d) and feat_cache is not None: idx = feat_idx[0] cache_x = x[:, :, -CACHE_T:, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(x.device), cache_x], dim=2) x = layer(x, feat_cache[idx]) feat_cache[idx] = cache_x + del cache_x feat_idx[0] += 1 else: x = layer(x) @@ -367,15 +266,8 @@ class Encoder3d(nn.Module): class Decoder3d(nn.Module): - - def __init__(self, - dim=128, - z_dim=4, - dim_mult=[1, 2, 4, 4], - num_res_blocks=2, - attn_scales=[], - temperal_upsample=[False, True, True], - dropout=0.0): + def __init__(self, dim=128, z_dim=4, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], + temperal_upsample=[False, True, True], dropout=0.0): super().__init__() self.dim = dim self.z_dim = z_dim @@ -383,23 +275,14 @@ class Decoder3d(nn.Module): self.num_res_blocks = num_res_blocks self.attn_scales = attn_scales self.temperal_upsample = temperal_upsample - - # dimensions dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] scale = 1.0 / 2**(len(dim_mult) - 2) - - # init block self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) - - # middle blocks self.middle = nn.Sequential( ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), ResidualBlock(dims[0], dims[0], dropout)) - - # upsample blocks upsamples = [] for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): - # residual (+attention) blocks if i == 1 or i == 2 or i == 3: in_dim = in_dim // 2 for _ in range(num_res_blocks + 1): @@ -407,65 +290,46 @@ class Decoder3d(nn.Module): if scale in attn_scales: upsamples.append(AttentionBlock(out_dim)) in_dim = out_dim - - # upsample block if i != len(dim_mult) - 1: mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d' upsamples.append(Resample(out_dim, mode=mode)) scale *= 2.0 self.upsamples = nn.Sequential(*upsamples) - - # output blocks self.head = nn.Sequential( RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, 3, 3, padding=1)) def forward(self, x, feat_cache=None, feat_idx=[0]): - ## conv1 if feat_cache is not None: idx = feat_idx[0] cache_x = x[:, :, -CACHE_T:, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(x.device), cache_x], dim=2) x = self.conv1(x, feat_cache[idx]) feat_cache[idx] = cache_x + del cache_x feat_idx[0] += 1 else: x = self.conv1(x) - - ## middle for layer in self.middle: if isinstance(layer, ResidualBlock) and feat_cache is not None: x = layer(x, feat_cache, feat_idx) else: x = layer(x) - - ## upsamples for layer in self.upsamples: if feat_cache is not None: x = layer(x, feat_cache, feat_idx) else: x = layer(x) - - ## head for layer in self.head: if isinstance(layer, CausalConv3d) and feat_cache is not None: idx = feat_idx[0] cache_x = x[:, :, -CACHE_T:, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(x.device), cache_x], dim=2) x = layer(x, feat_cache[idx]) feat_cache[idx] = cache_x + del cache_x feat_idx[0] += 1 else: x = layer(x) @@ -481,15 +345,8 @@ def count_conv3d(model): class WanVAE_(nn.Module): - - def __init__(self, - dim=128, - z_dim=4, - dim_mult=[1, 2, 4, 4], - num_res_blocks=2, - attn_scales=[], - temperal_downsample=[True, True, False], - dropout=0.0): + def __init__(self, dim=128, z_dim=4, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], + temperal_downsample=[True, True, False], dropout=0.0): super().__init__() self.dim = dim self.z_dim = z_dim @@ -498,14 +355,10 @@ class WanVAE_(nn.Module): self.attn_scales = attn_scales self.temperal_downsample = temperal_downsample self.temperal_upsample = temperal_downsample[::-1] - - # modules - self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, - attn_scales, self.temperal_downsample, dropout) + self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout) self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) self.conv2 = CausalConv3d(z_dim, z_dim, 1) - self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, - attn_scales, self.temperal_upsample, dropout) + self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout) def forward(self, x): mu, log_var = self.encode(x) @@ -513,60 +366,120 @@ class WanVAE_(nn.Module): x_recon = self.decode(z) return x_recon, mu, log_var - def encode(self, x, scale): + def encode(self, x, scale=None): self.clear_cache() - ## cache t = x.shape[2] iter_ = 1 + (t - 1) // 4 - ## 对encode输入的x,按时间拆分为1、4、4、4.... for i in range(iter_): self._enc_conv_idx = [0] if i == 0: - out = self.encoder( - x[:, :, :1, :, :], - feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx) + out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) else: - out_ = self.encoder( - x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], - feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx) + out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) out = torch.cat([out, out_], 2) mu, log_var = self.conv1(out).chunk(2, dim=1) - if isinstance(scale[0], torch.Tensor): - mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( - 1, self.z_dim, 1, 1, 1) - else: - mu = (mu - scale[0]) * scale[1] + if scale is not None: + if isinstance(scale[0], torch.Tensor): + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1) + else: + mu = (mu - scale[0]) * scale[1] self.clear_cache() return mu - def decode(self, z, scale): + def decode(self, z, scale=None): self.clear_cache() - # z: [b,c,t,h,w] - if isinstance(scale[0], torch.Tensor): - z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( - 1, self.z_dim, 1, 1, 1) - else: - z = z / scale[1] + scale[0] + if scale is not None: + if isinstance(scale[0], torch.Tensor): + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1) + else: + z = z / scale[1] + scale[0] iter_ = z.shape[2] x = self.conv2(z) for i in range(iter_): self._conv_idx = [0] if i == 0: - out = self.decoder( - x[:, :, i:i + 1, :, :], - feat_cache=self._feat_map, - feat_idx=self._conv_idx) + out = self.decoder(x[:, :, i:i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) else: - out_ = self.decoder( - x[:, :, i:i + 1, :, :], - feat_cache=self._feat_map, - feat_idx=self._conv_idx) + out_ = self.decoder(x[:, :, i:i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) out = torch.cat([out, out_], 2) self.clear_cache() return out + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent) + return b + + def spatial_tiled_decode(self, z, scale, tile_size): + tile_sample_min_size = tile_size + tile_latent_min_size = int(tile_sample_min_size / 8) + tile_overlap_factor = 0.25 + if isinstance(scale[0], torch.Tensor): + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1) + else: + z = z / scale[1] + scale[0] + overlap_size = int(tile_latent_min_size * (1 - tile_overlap_factor)) + blend_extent = int(tile_sample_min_size * tile_overlap_factor) + row_limit = tile_sample_min_size - blend_extent + rows = [] + for i in range(0, z.shape[-2], overlap_size): + row = [] + for j in range(0, z.shape[-1], overlap_size): + tile = z[:, :, :, i: i + tile_latent_min_size, j: j + tile_latent_min_size] + decoded = self.decode(tile) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=-1)) + return torch.cat(result_rows, dim=-2) + + def spatial_tiled_encode(self, x, scale, tile_size): + tile_sample_min_size = tile_size + tile_latent_min_size = int(tile_sample_min_size / 8) + tile_overlap_factor = 0.25 + overlap_size = int(tile_sample_min_size * (1 - tile_overlap_factor)) + blend_extent = int(tile_latent_min_size * tile_overlap_factor) + row_limit = tile_latent_min_size - blend_extent + rows = [] + for i in range(0, x.shape[-2], overlap_size): + row = [] + for j in range(0, x.shape[-1], overlap_size): + tile = x[:, :, :, i: i + tile_sample_min_size, j: j + tile_sample_min_size] + tile = self.encode(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=-1)) + mu = torch.cat(result_rows, dim=-2) + if isinstance(scale[0], torch.Tensor): + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1) + else: + mu = (mu - scale[0]) * scale[1] + return mu + def reparameterize(self, mu, log_var): std = torch.exp(0.5 * log_var) eps = torch.randn_like(std) @@ -583,81 +496,43 @@ class WanVAE_(nn.Module): self._conv_num = count_conv3d(self.decoder) self._conv_idx = [0] self._feat_map = [None] * self._conv_num - #cache encode self._enc_conv_num = count_conv3d(self.encoder) self._enc_conv_idx = [0] self._enc_feat_map = [None] * self._enc_conv_num def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs): - """ - Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL. - """ - # params - cfg = dict( - dim=96, - z_dim=z_dim, - dim_mult=[1, 2, 4, 4], - num_res_blocks=2, - attn_scales=[], - temperal_downsample=[False, True, True], - dropout=0.0) + cfg = dict(dim=96, z_dim=z_dim, dim_mult=[1, 2, 4, 4], num_res_blocks=2, + attn_scales=[], temperal_downsample=[False, True, True], dropout=0.0) cfg.update(**kwargs) - - # init model - with torch.device('meta'): - model = WanVAE_(**cfg) - - # load checkpoint + model = WanVAE_(**cfg) logging.info(f'loading {pretrained_path}') - model.load_state_dict( - torch.load(pretrained_path, map_location=device), assign=True) - + model.load_state_dict(torch.load(pretrained_path, map_location=device), assign=True) return model class WanVAE: - - def __init__(self, - z_dim=16, - vae_pth='cache/vae_step_411000.pth', - dtype=torch.float, - device="cuda"): + def __init__(self, z_dim=16, vae_pth='cache/vae_step_411000.pth', dtype=torch.float, device="cuda"): self.dtype = dtype self.device = device - - mean = [ - -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, - 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921 - ] - std = [ - 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, - 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160 - ] + mean = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, + 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921] + std = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, + 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160] self.mean = torch.tensor(mean, dtype=dtype, device=device) self.std = torch.tensor(std, dtype=dtype, device=device) self.scale = [self.mean, 1.0 / self.std] + self.model = _video_vae(pretrained_path=vae_pth, z_dim=z_dim, device=device) + self.model = self.model.eval().requires_grad_(False).to(device) - # init model - self.model = _video_vae( - pretrained_path=vae_pth, - z_dim=z_dim, - ).eval().requires_grad_(False).to(device) + def encode(self, videos, tile_size=256): + if tile_size > 0: + return [self.model.spatial_tiled_encode(u.unsqueeze(0), self.scale, tile_size).float().squeeze(0) for u in videos] + else: + return [self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) for u in videos] - def encode(self, videos): - """ - videos: A list of videos each with shape [C, T, H, W]. - """ - with amp.autocast(dtype=self.dtype): - return [ - self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) - for u in videos - ] - - def decode(self, zs): - with amp.autocast(dtype=self.dtype): - return [ - self.model.decode(u.unsqueeze(0), - self.scale).float().clamp_(-1, 1).squeeze(0) - for u in zs - ] + def decode(self, zs, tile_size): + if tile_size > 0: + return [self.model.spatial_tiled_decode(u.unsqueeze(0), self.scale, tile_size).float().clamp_(-1, 1).squeeze(0) for u in zs] + else: + return [self.model.decode(u.unsqueeze(0), self.scale).float().clamp_(-1, 1).squeeze(0) for u in zs]