mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
added VAE tiling for ti2v 5B
This commit is contained in:
parent
6b17c9fb6a
commit
e13206b583
@ -20,15 +20,15 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
|
||||
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
|
||||
|
||||
## 🔥 Latest Updates :
|
||||
### August 6 2025: WanGP v7.7 - Picky, picky
|
||||
### August 6 2025: WanGP v7.71 - Picky, picky
|
||||
|
||||
This release comes with two new models :
|
||||
- Qwen Image: a Commercial grade Image generator capable to inject full sentences in the generated Image while still offering incredible visuals
|
||||
- Wan 2.2 TextImage to Video 5B: the last Wan 2.2 needed if you want to complete your Wan 2.2 collection (loras for this folder can be stored in "\loras\5B" )
|
||||
|
||||
There is catch though, they are very picky if you want to get good generations: first they both need lots of steps (50 ?) to show what they have to offer. Then for Qwen Image I had to hardcode the supported resolutions, because if you try anything else, you will get garbage. Likiwise Wan 2.2 5B will remind you of Wan 1.0 if you don't ask for at least 720p.
|
||||
There is catch though, they are very picky if you want to get good generations: first they both need lots of steps (50 ?) to show what they have to offer. Then for Qwen Image I had to hardcode the supported resolutions, because if you try anything else, you will get garbage. Likewise Wan 2.2 5B will remind you of Wan 1.0 if you don't ask for at least 720p.
|
||||
|
||||
Please note that the VAE decoding of Wan 2.2 TextImage is not tiled yet and it may produce VRAM consumption peaks (this doens't mix well with the 720p requirement).
|
||||
*7.71 update: Added VAE Tiling for both Qwen Image and Wan 2.2 TextImage to Video 5B, for low VRAM during a whole gen.*
|
||||
|
||||
|
||||
### August 4 2025: WanGP v7.6 - Remuxed
|
||||
|
||||
@ -656,7 +656,7 @@ class WanAny2V:
|
||||
height, width = (height // 32) * 32, (width // 32) * 32
|
||||
else:
|
||||
height, width = input_video.shape[-2:]
|
||||
source_latents = self.vae.encode([input_video])[0].unsqueeze(0)
|
||||
source_latents = self.vae.encode([input_video], tile_size = VAE_tile_size)[0].unsqueeze(0)
|
||||
timestep_injection = True
|
||||
|
||||
# Vace
|
||||
|
||||
@ -520,6 +520,8 @@ def count_conv3d(model):
|
||||
|
||||
class WanVAE_(nn.Module):
|
||||
|
||||
_offload_hooks = ['encode', 'decode']
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
|
||||
@ -798,62 +798,83 @@ class WanVAE_(nn.Module):
|
||||
x_recon = self.decode(mu, scale)
|
||||
return x_recon, mu
|
||||
|
||||
def encode(self, x, scale, any_end_frame = False):
|
||||
def encode(self, x, scale = None, any_end_frame = False):
|
||||
self.clear_cache()
|
||||
x = patchify(x, patch_size=2)
|
||||
## cache
|
||||
t = x.shape[2]
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
if any_end_frame:
|
||||
iter_ = 2 + (t - 2) // 4
|
||||
else:
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
## 对encode输入的x,按时间拆分为1、4、4、4....
|
||||
out_list = []
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.encoder(
|
||||
out_list.append(self.encoder(
|
||||
x[:, :, :1, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx,
|
||||
)
|
||||
feat_idx=self._enc_conv_idx))
|
||||
elif any_end_frame and i== iter_ -1:
|
||||
out_list.append(self.encoder(
|
||||
x[:, :, -1:, :, :],
|
||||
feat_cache= None,
|
||||
feat_idx=self._enc_conv_idx))
|
||||
else:
|
||||
out_ = self.encoder(
|
||||
out_list.append(self.encoder(
|
||||
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx,
|
||||
)
|
||||
out = torch.cat([out, out_], 2)
|
||||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||
if isinstance(scale[0], torch.Tensor):
|
||||
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
|
||||
1, self.z_dim, 1, 1, 1)
|
||||
else:
|
||||
mu = (mu - scale[0]) * scale[1]
|
||||
feat_idx=self._enc_conv_idx))
|
||||
|
||||
self.clear_cache()
|
||||
out = torch.cat(out_list, 2)
|
||||
out_list = None
|
||||
|
||||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||
if scale != None:
|
||||
if isinstance(scale[0], torch.Tensor):
|
||||
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
|
||||
1, self.z_dim, 1, 1, 1)
|
||||
else:
|
||||
mu = (mu - scale[0]) * scale[1]
|
||||
return mu
|
||||
|
||||
def decode(self, z, scale,any_end_frame = False):
|
||||
|
||||
def decode(self, z, scale=None, any_end_frame = False):
|
||||
self.clear_cache()
|
||||
if isinstance(scale[0], torch.Tensor):
|
||||
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
||||
1, self.z_dim, 1, 1, 1)
|
||||
else:
|
||||
z = z / scale[1] + scale[0]
|
||||
# z: [b,c,t,h,w]
|
||||
if scale != None:
|
||||
if isinstance(scale[0], torch.Tensor):
|
||||
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
||||
1, self.z_dim, 1, 1, 1)
|
||||
else:
|
||||
z = z / scale[1] + scale[0]
|
||||
iter_ = z.shape[2]
|
||||
x = self.conv2(z)
|
||||
out_list = []
|
||||
for i in range(iter_):
|
||||
self._conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.decoder(
|
||||
out_list.append(self.decoder(
|
||||
x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx,
|
||||
first_chunk=True,
|
||||
)
|
||||
first_chunk = True)
|
||||
)
|
||||
elif any_end_frame and i==iter_-1:
|
||||
out_list.append(self.decoder(
|
||||
x[:, :, -1:, :, :],
|
||||
feat_cache=None ,
|
||||
feat_idx=self._conv_idx))
|
||||
else:
|
||||
out_ = self.decoder(
|
||||
out_list.append(self.decoder(
|
||||
x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx,
|
||||
)
|
||||
out = torch.cat([out, out_], 2)
|
||||
out = unpatchify(out, patch_size=2)
|
||||
feat_idx=self._conv_idx))
|
||||
self.clear_cache()
|
||||
out = torch.cat(out_list, 2)
|
||||
out = unpatchify(out, patch_size=2)
|
||||
return out
|
||||
|
||||
|
||||
@ -894,7 +915,7 @@ class WanVAE_(nn.Module):
|
||||
row = []
|
||||
for j in range(0, z.shape[-1], overlap_size):
|
||||
tile = z[:, :, :, i: i + tile_latent_min_size, j: j + tile_latent_min_size]
|
||||
decoded = self.decode(tile, scale, any_end_frame= any_end_frame)
|
||||
decoded = self.decode(tile, any_end_frame= any_end_frame)
|
||||
row.append(decoded)
|
||||
rows.append(row)
|
||||
result_rows = []
|
||||
@ -928,7 +949,7 @@ class WanVAE_(nn.Module):
|
||||
row = []
|
||||
for j in range(0, x.shape[-1], overlap_size):
|
||||
tile = x[:, :, :, i: i + tile_sample_min_size, j: j + tile_sample_min_size]
|
||||
tile = self.encode(tile, scale, any_end_frame= any_end_frame)
|
||||
tile = self.encode(tile, any_end_frame= any_end_frame)
|
||||
row.append(tile)
|
||||
rows.append(row)
|
||||
result_rows = []
|
||||
@ -1175,34 +1196,16 @@ class Wan2_2_VAE:
|
||||
"""
|
||||
scale = [u.to(device = self.device) for u in self.scale]
|
||||
|
||||
if tile_size > 0 and False:
|
||||
if tile_size > 0 and False :
|
||||
return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
|
||||
else:
|
||||
return [ self.model.encode(u.to(self.dtype).unsqueeze(0), scale, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
|
||||
|
||||
|
||||
def decode(self, zs, tile_size, any_end_frame = False):
|
||||
def decode(self, zs, tile_size = 256, any_end_frame = False):
|
||||
scale = [u.to(device = self.device) for u in self.scale]
|
||||
|
||||
if tile_size > 0 and False:
|
||||
if tile_size > 0 :
|
||||
return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
|
||||
else:
|
||||
return [ self.model.decode(u.to(self.dtype).unsqueeze(0), scale, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
|
||||
|
||||
|
||||
# def encode(self, videos, VAE_tile_size = 0, any_end_frame = False ):
|
||||
# with amp.autocast(dtype=self.dtype):
|
||||
# return [
|
||||
# self.model.encode(u.unsqueeze(0),
|
||||
# self.scale).float().squeeze(0)
|
||||
# for u in videos
|
||||
# ]
|
||||
|
||||
# def decode(self, zs, VAE_tile_size = 0, any_end_frame = False):
|
||||
# with amp.autocast(dtype=self.dtype):
|
||||
# return [
|
||||
# self.model.decode(u.unsqueeze(0),
|
||||
# self.scale).float().clamp_(-1,
|
||||
# 1).squeeze(0)
|
||||
# for u in zs
|
||||
# ]
|
||||
|
||||
2
wgp.py
2
wgp.py
@ -50,7 +50,7 @@ AUTOSAVE_FILENAME = "queue.zip"
|
||||
PROMPT_VARS_MAX = 10
|
||||
|
||||
target_mmgp_version = "3.5.6"
|
||||
WanGP_version = "7.7"
|
||||
WanGP_version = "7.71"
|
||||
settings_version = 2.23
|
||||
max_source_video_frames = 3000
|
||||
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
|
||||
|
||||
Loading…
Reference in New Issue
Block a user