added VAE tiling for ti2v 5B

This commit is contained in:
deepbeepmeep 2025-08-08 03:59:20 +02:00
parent 6b17c9fb6a
commit e13206b583
5 changed files with 63 additions and 58 deletions

View File

@ -20,15 +20,15 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
## 🔥 Latest Updates :
### August 6 2025: WanGP v7.7 - Picky, picky
### August 6 2025: WanGP v7.71 - Picky, picky
This release comes with two new models :
- Qwen Image: a Commercial grade Image generator capable to inject full sentences in the generated Image while still offering incredible visuals
- Wan 2.2 TextImage to Video 5B: the last Wan 2.2 needed if you want to complete your Wan 2.2 collection (loras for this folder can be stored in "\loras\5B" )
There is catch though, they are very picky if you want to get good generations: first they both need lots of steps (50 ?) to show what they have to offer. Then for Qwen Image I had to hardcode the supported resolutions, because if you try anything else, you will get garbage. Likiwise Wan 2.2 5B will remind you of Wan 1.0 if you don't ask for at least 720p.
There is catch though, they are very picky if you want to get good generations: first they both need lots of steps (50 ?) to show what they have to offer. Then for Qwen Image I had to hardcode the supported resolutions, because if you try anything else, you will get garbage. Likewise Wan 2.2 5B will remind you of Wan 1.0 if you don't ask for at least 720p.
Please note that the VAE decoding of Wan 2.2 TextImage is not tiled yet and it may produce VRAM consumption peaks (this doens't mix well with the 720p requirement).
*7.71 update: Added VAE Tiling for both Qwen Image and Wan 2.2 TextImage to Video 5B, for low VRAM during a whole gen.*
### August 4 2025: WanGP v7.6 - Remuxed

View File

@ -656,7 +656,7 @@ class WanAny2V:
height, width = (height // 32) * 32, (width // 32) * 32
else:
height, width = input_video.shape[-2:]
source_latents = self.vae.encode([input_video])[0].unsqueeze(0)
source_latents = self.vae.encode([input_video], tile_size = VAE_tile_size)[0].unsqueeze(0)
timestep_injection = True
# Vace

View File

@ -520,6 +520,8 @@ def count_conv3d(model):
class WanVAE_(nn.Module):
_offload_hooks = ['encode', 'decode']
def __init__(self,
dim=128,
z_dim=4,

View File

@ -798,62 +798,83 @@ class WanVAE_(nn.Module):
x_recon = self.decode(mu, scale)
return x_recon, mu
def encode(self, x, scale, any_end_frame = False):
def encode(self, x, scale = None, any_end_frame = False):
self.clear_cache()
x = patchify(x, patch_size=2)
## cache
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
if any_end_frame:
iter_ = 2 + (t - 2) // 4
else:
iter_ = 1 + (t - 1) // 4
## 对encode输入的x按时间拆分为1、4、4、4....
out_list = []
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(
out_list.append(self.encoder(
x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
)
feat_idx=self._enc_conv_idx))
elif any_end_frame and i== iter_ -1:
out_list.append(self.encoder(
x[:, :, -1:, :, :],
feat_cache= None,
feat_idx=self._enc_conv_idx))
else:
out_ = self.encoder(
out_list.append(self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
if isinstance(scale[0], torch.Tensor):
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
1, self.z_dim, 1, 1, 1)
else:
mu = (mu - scale[0]) * scale[1]
feat_idx=self._enc_conv_idx))
self.clear_cache()
out = torch.cat(out_list, 2)
out_list = None
mu, log_var = self.conv1(out).chunk(2, dim=1)
if scale != None:
if isinstance(scale[0], torch.Tensor):
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
1, self.z_dim, 1, 1, 1)
else:
mu = (mu - scale[0]) * scale[1]
return mu
def decode(self, z, scale,any_end_frame = False):
def decode(self, z, scale=None, any_end_frame = False):
self.clear_cache()
if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
1, self.z_dim, 1, 1, 1)
else:
z = z / scale[1] + scale[0]
# z: [b,c,t,h,w]
if scale != None:
if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
1, self.z_dim, 1, 1, 1)
else:
z = z / scale[1] + scale[0]
iter_ = z.shape[2]
x = self.conv2(z)
out_list = []
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(
out_list.append(self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
first_chunk=True,
)
first_chunk = True)
)
elif any_end_frame and i==iter_-1:
out_list.append(self.decoder(
x[:, :, -1:, :, :],
feat_cache=None ,
feat_idx=self._conv_idx))
else:
out_ = self.decoder(
out_list.append(self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
)
out = torch.cat([out, out_], 2)
out = unpatchify(out, patch_size=2)
feat_idx=self._conv_idx))
self.clear_cache()
out = torch.cat(out_list, 2)
out = unpatchify(out, patch_size=2)
return out
@ -894,7 +915,7 @@ class WanVAE_(nn.Module):
row = []
for j in range(0, z.shape[-1], overlap_size):
tile = z[:, :, :, i: i + tile_latent_min_size, j: j + tile_latent_min_size]
decoded = self.decode(tile, scale, any_end_frame= any_end_frame)
decoded = self.decode(tile, any_end_frame= any_end_frame)
row.append(decoded)
rows.append(row)
result_rows = []
@ -928,7 +949,7 @@ class WanVAE_(nn.Module):
row = []
for j in range(0, x.shape[-1], overlap_size):
tile = x[:, :, :, i: i + tile_sample_min_size, j: j + tile_sample_min_size]
tile = self.encode(tile, scale, any_end_frame= any_end_frame)
tile = self.encode(tile, any_end_frame= any_end_frame)
row.append(tile)
rows.append(row)
result_rows = []
@ -1175,34 +1196,16 @@ class Wan2_2_VAE:
"""
scale = [u.to(device = self.device) for u in self.scale]
if tile_size > 0 and False:
if tile_size > 0 and False :
return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
else:
return [ self.model.encode(u.to(self.dtype).unsqueeze(0), scale, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
def decode(self, zs, tile_size, any_end_frame = False):
def decode(self, zs, tile_size = 256, any_end_frame = False):
scale = [u.to(device = self.device) for u in self.scale]
if tile_size > 0 and False:
if tile_size > 0 :
return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
else:
return [ self.model.decode(u.to(self.dtype).unsqueeze(0), scale, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
# def encode(self, videos, VAE_tile_size = 0, any_end_frame = False ):
# with amp.autocast(dtype=self.dtype):
# return [
# self.model.encode(u.unsqueeze(0),
# self.scale).float().squeeze(0)
# for u in videos
# ]
# def decode(self, zs, VAE_tile_size = 0, any_end_frame = False):
# with amp.autocast(dtype=self.dtype):
# return [
# self.model.decode(u.unsqueeze(0),
# self.scale).float().clamp_(-1,
# 1).squeeze(0)
# for u in zs
# ]

2
wgp.py
View File

@ -50,7 +50,7 @@ AUTOSAVE_FILENAME = "queue.zip"
PROMPT_VARS_MAX = 10
target_mmgp_version = "3.5.6"
WanGP_version = "7.7"
WanGP_version = "7.71"
settings_version = 2.23
max_source_video_frames = 3000
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None