diff --git a/README.md b/README.md index ceeb7a4..77750ff 100644 --- a/README.md +++ b/README.md @@ -20,15 +20,15 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep ## 🔥 Latest Updates : -### August 6 2025: WanGP v7.7 - Picky, picky +### August 6 2025: WanGP v7.71 - Picky, picky This release comes with two new models : - Qwen Image: a Commercial grade Image generator capable to inject full sentences in the generated Image while still offering incredible visuals - Wan 2.2 TextImage to Video 5B: the last Wan 2.2 needed if you want to complete your Wan 2.2 collection (loras for this folder can be stored in "\loras\5B" ) -There is catch though, they are very picky if you want to get good generations: first they both need lots of steps (50 ?) to show what they have to offer. Then for Qwen Image I had to hardcode the supported resolutions, because if you try anything else, you will get garbage. Likiwise Wan 2.2 5B will remind you of Wan 1.0 if you don't ask for at least 720p. +There is catch though, they are very picky if you want to get good generations: first they both need lots of steps (50 ?) to show what they have to offer. Then for Qwen Image I had to hardcode the supported resolutions, because if you try anything else, you will get garbage. Likewise Wan 2.2 5B will remind you of Wan 1.0 if you don't ask for at least 720p. -Please note that the VAE decoding of Wan 2.2 TextImage is not tiled yet and it may produce VRAM consumption peaks (this doens't mix well with the 720p requirement). +*7.71 update: Added VAE Tiling for both Qwen Image and Wan 2.2 TextImage to Video 5B, for low VRAM during a whole gen.* ### August 4 2025: WanGP v7.6 - Remuxed diff --git a/models/wan/any2video.py b/models/wan/any2video.py index b409279..4f8d7f1 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -656,7 +656,7 @@ class WanAny2V: height, width = (height // 32) * 32, (width // 32) * 32 else: height, width = input_video.shape[-2:] - source_latents = self.vae.encode([input_video])[0].unsqueeze(0) + source_latents = self.vae.encode([input_video], tile_size = VAE_tile_size)[0].unsqueeze(0) timestep_injection = True # Vace diff --git a/models/wan/modules/vae.py b/models/wan/modules/vae.py index feb0d2e..8290d94 100644 --- a/models/wan/modules/vae.py +++ b/models/wan/modules/vae.py @@ -520,6 +520,8 @@ def count_conv3d(model): class WanVAE_(nn.Module): + _offload_hooks = ['encode', 'decode'] + def __init__(self, dim=128, z_dim=4, diff --git a/models/wan/modules/vae2_2.py b/models/wan/modules/vae2_2.py index 220979b..c1a88f5 100644 --- a/models/wan/modules/vae2_2.py +++ b/models/wan/modules/vae2_2.py @@ -798,62 +798,83 @@ class WanVAE_(nn.Module): x_recon = self.decode(mu, scale) return x_recon, mu - def encode(self, x, scale, any_end_frame = False): + def encode(self, x, scale = None, any_end_frame = False): self.clear_cache() - x = patchify(x, patch_size=2) + x = patchify(x, patch_size=2) + ## cache t = x.shape[2] - iter_ = 1 + (t - 1) // 4 + if any_end_frame: + iter_ = 2 + (t - 2) // 4 + else: + iter_ = 1 + (t - 1) // 4 + ## 对encode输入的x,按时间拆分为1、4、4、4.... + out_list = [] for i in range(iter_): self._enc_conv_idx = [0] if i == 0: - out = self.encoder( + out_list.append(self.encoder( x[:, :, :1, :, :], feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx, - ) + feat_idx=self._enc_conv_idx)) + elif any_end_frame and i== iter_ -1: + out_list.append(self.encoder( + x[:, :, -1:, :, :], + feat_cache= None, + feat_idx=self._enc_conv_idx)) else: - out_ = self.encoder( + out_list.append(self.encoder( x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx, - ) - out = torch.cat([out, out_], 2) - mu, log_var = self.conv1(out).chunk(2, dim=1) - if isinstance(scale[0], torch.Tensor): - mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( - 1, self.z_dim, 1, 1, 1) - else: - mu = (mu - scale[0]) * scale[1] + feat_idx=self._enc_conv_idx)) + self.clear_cache() + out = torch.cat(out_list, 2) + out_list = None + + mu, log_var = self.conv1(out).chunk(2, dim=1) + if scale != None: + if isinstance(scale[0], torch.Tensor): + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( + 1, self.z_dim, 1, 1, 1) + else: + mu = (mu - scale[0]) * scale[1] return mu - def decode(self, z, scale,any_end_frame = False): + + def decode(self, z, scale=None, any_end_frame = False): self.clear_cache() - if isinstance(scale[0], torch.Tensor): - z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( - 1, self.z_dim, 1, 1, 1) - else: - z = z / scale[1] + scale[0] + # z: [b,c,t,h,w] + if scale != None: + if isinstance(scale[0], torch.Tensor): + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( + 1, self.z_dim, 1, 1, 1) + else: + z = z / scale[1] + scale[0] iter_ = z.shape[2] x = self.conv2(z) + out_list = [] for i in range(iter_): self._conv_idx = [0] if i == 0: - out = self.decoder( + out_list.append(self.decoder( x[:, :, i:i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, - first_chunk=True, - ) + first_chunk = True) + ) + elif any_end_frame and i==iter_-1: + out_list.append(self.decoder( + x[:, :, -1:, :, :], + feat_cache=None , + feat_idx=self._conv_idx)) else: - out_ = self.decoder( + out_list.append(self.decoder( x[:, :, i:i + 1, :, :], feat_cache=self._feat_map, - feat_idx=self._conv_idx, - ) - out = torch.cat([out, out_], 2) - out = unpatchify(out, patch_size=2) + feat_idx=self._conv_idx)) self.clear_cache() + out = torch.cat(out_list, 2) + out = unpatchify(out, patch_size=2) return out @@ -894,7 +915,7 @@ class WanVAE_(nn.Module): row = [] for j in range(0, z.shape[-1], overlap_size): tile = z[:, :, :, i: i + tile_latent_min_size, j: j + tile_latent_min_size] - decoded = self.decode(tile, scale, any_end_frame= any_end_frame) + decoded = self.decode(tile, any_end_frame= any_end_frame) row.append(decoded) rows.append(row) result_rows = [] @@ -928,7 +949,7 @@ class WanVAE_(nn.Module): row = [] for j in range(0, x.shape[-1], overlap_size): tile = x[:, :, :, i: i + tile_sample_min_size, j: j + tile_sample_min_size] - tile = self.encode(tile, scale, any_end_frame= any_end_frame) + tile = self.encode(tile, any_end_frame= any_end_frame) row.append(tile) rows.append(row) result_rows = [] @@ -1175,34 +1196,16 @@ class Wan2_2_VAE: """ scale = [u.to(device = self.device) for u in self.scale] - if tile_size > 0 and False: + if tile_size > 0 and False : return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ] else: return [ self.model.encode(u.to(self.dtype).unsqueeze(0), scale, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ] - def decode(self, zs, tile_size, any_end_frame = False): + def decode(self, zs, tile_size = 256, any_end_frame = False): scale = [u.to(device = self.device) for u in self.scale] - - if tile_size > 0 and False: + if tile_size > 0 : return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ] else: return [ self.model.decode(u.to(self.dtype).unsqueeze(0), scale, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ] - - # def encode(self, videos, VAE_tile_size = 0, any_end_frame = False ): - # with amp.autocast(dtype=self.dtype): - # return [ - # self.model.encode(u.unsqueeze(0), - # self.scale).float().squeeze(0) - # for u in videos - # ] - - # def decode(self, zs, VAE_tile_size = 0, any_end_frame = False): - # with amp.autocast(dtype=self.dtype): - # return [ - # self.model.decode(u.unsqueeze(0), - # self.scale).float().clamp_(-1, - # 1).squeeze(0) - # for u in zs - # ] diff --git a/wgp.py b/wgp.py index c932141..5c8f9f1 100644 --- a/wgp.py +++ b/wgp.py @@ -50,7 +50,7 @@ AUTOSAVE_FILENAME = "queue.zip" PROMPT_VARS_MAX = 10 target_mmgp_version = "3.5.6" -WanGP_version = "7.7" +WanGP_version = "7.71" settings_version = 2.23 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None