From 697cc2cce5801482685816b8bd32db5a97698b6c Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Tue, 4 Mar 2025 02:39:44 +0100 Subject: [PATCH] Implemented VAE tiling --- README.md | 2 +- gradio_server.py | 36 ++++++++--- wan/image2video.py | 7 +- wan/modules/vae.py | 157 +++++++++++++++++++++++++++++++++++---------- wan/text2video.py | 5 +- 5 files changed, 160 insertions(+), 47 deletions(-) diff --git a/README.md b/README.md index c4d2f61..b912ed6 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid ## 🔥 Latest News!! +* Mar 03, 2025: 👋 Wan2.1GP v1.2: Implementented tiling on VAE encoding and decoding. No more VRAM peaks at the beginning and at the end * Mar 03, 2025: 👋 Wan2.1GP v1.1: added Tea Cache support for faster generations: optimization of kijai's implementation (https://github.com/kijai/ComfyUI-WanVideoWrapper/) of teacache (https://github.com/ali-vilab/TeaCache) * Mar 02, 2025: 👋 Wan2.1GP by DeepBeepMeep v1 brings: - Support for all Wan including the Image to Video model @@ -126,7 +127,6 @@ To run the application while loading entirely the diffusion model in VRAM (sligh ```bash python gradio_server.py --profile 3 ``` -Please note that diffusion model of Wan2.1GP is extremely VRAM optimized and this will greatly benefit low VRAM systems since the diffusion / denoising step is the longest part of the generation process. However, the VAE encoder (at the beginning of a image 2 video process) and the VAE decoder (at the end of any video process) is still VRAM hungry after optimization and it will require temporarly 22 GB of VRAM for a 720p generation and 12 GB of VRAM for a 480p generation. Therefore if you have less than these numbers, you may experience slow downs at the beginning and at the end of the generation process due to pytorch VRAM offloading. ### Loras support diff --git a/gradio_server.py b/gradio_server.py index 4c34ee0..bf0996b 100644 --- a/gradio_server.py +++ b/gradio_server.py @@ -433,7 +433,7 @@ def load_models(i2v, lora_dir, lora_preselected_preset ): kwargs = { "extraModelsToQuantize": None} if profile == 2 or profile == 4: - kwargs["budgets"] = { "transformer" : 100, "*" : 3000 } + kwargs["budgets"] = { "transformer" : 100, "text_encoder" : 100, "*" : 1000 } loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets = setup_loras(pipe, lora_dir, lora_preselected_preset, None) offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, **kwargs) @@ -693,11 +693,29 @@ def generate_video( if "1.3B" in transformer_filename_t2v and width * height > 848*480: raise gr.Error("You must use the 14B text to video model to generate videos with a resolution equivalent to 720P") - offload.shared_state["_vae"] = vae_config - offload.shared_state["_vae_threshold"] = 0.9* torch.cuda.get_device_properties(0).total_memory offload.shared_state["_attention"] = attn + # VAE Tiling + device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576 + if vae_config == 0: + if device_mem_capacity >= 24000: + use_vae_config = 1 + elif device_mem_capacity >= 8000: + use_vae_config = 2 + else: + use_vae_config = 3 + else: + use_vae_config = vae_config + + if use_vae_config == 1: + VAE_tile_size = 0 + elif use_vae_config == 2: + VAE_tile_size = 256 + else: + VAE_tile_size = 128 + + global gen_in_progress gen_in_progress = True temp_filename = None @@ -818,7 +836,8 @@ def generate_video( seed=seed, offload_model=False, callback=callback, - enable_RIFLEx = enable_RIFLEx + enable_RIFLEx = enable_RIFLEx, + VAE_tile_size = VAE_tile_size ) else: @@ -833,7 +852,8 @@ def generate_video( seed=seed, offload_model=False, callback=callback, - enable_RIFLEx = enable_RIFLEx + enable_RIFLEx = enable_RIFLEx, + VAE_tile_size = VAE_tile_size ) except Exception as e: gen_in_progress = False @@ -987,7 +1007,6 @@ def create_demo(): gr.Markdown("- 848 x 480 with a 14B model: 80 frames (5s) : 8 GB of VRAM") gr.Markdown("- 848 x 480 with the 1.3B model: 80 frames (5s) : 5 GB of VRAM") gr.Markdown("- 1280 x 720 with a 14B model: 80 frames (5s): 11 GB of VRAM") - gr.Markdown("Note that the VAE stages (encoding / decoding at image2video ) or just the decoding at text2video will create a temporary VRAM peaks (up to 12GB for 420P and 22 GB for 720P)") gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifacts may appear") gr.Markdown("Please note that if your turn on compilation, the first generation step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.") @@ -1076,8 +1095,9 @@ def create_demo(): vae_config_choice = gr.Dropdown( choices=[ ("Auto", 0), - ("Disabled (faster but may require up to 24 GB of VRAM)", 1), - ("Enabled (2x slower and up to 50% VRAM reduction)", 2), + ("Disabled (faster but may require up to 22 GB of VRAM)", 1), + ("256 x 256 : If at least 8 GB of VRAM", 2), + ("128 x 128 : If at least 6 GB of VRAM", 3), ], value= vae_config, label="VAE optimisations - reduce the VRAM requirements for VAE decoding and VAE encoding" diff --git a/wan/image2video.py b/wan/image2video.py index 2f9aeea..f295141 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -144,7 +144,8 @@ class WanI2V: seed=-1, offload_model=True, callback = None, - enable_RIFLEx = False + enable_RIFLEx = False, + VAE_tile_size= 0, ): r""" @@ -254,7 +255,7 @@ class WanI2V: ], dim=1).to(self.device) # enc = None - y = self.vae.encode([enc])[0] + y = self.vae.encode([enc], VAE_tile_size)[0] y = torch.concat([msk, y]) @contextmanager @@ -363,7 +364,7 @@ class WanI2V: torch.cuda.empty_cache() if self.rank == 0: - videos = self.vae.decode(x0) + videos = self.vae.decode(x0, VAE_tile_size) del noise, latent del sample_scheduler diff --git a/wan/modules/vae.py b/wan/modules/vae.py index ded9f92..fac20dd 100644 --- a/wan/modules/vae.py +++ b/wan/modules/vae.py @@ -35,11 +35,6 @@ class CausalConv3d(nn.Conv3d): x = F.pad(x, padding) x = super().forward(x) - mem_threshold = offload.shared_state.get("_vae_threshold",0) - vae_config = offload.shared_state.get("_vae",1) - - if vae_config == 0 and torch.cuda.memory_reserved() > mem_threshold or vae_config == 2: - torch.cuda.empty_cache() return x @@ -346,8 +341,6 @@ class Encoder3d(nn.Module): x = self.conv1(x) - # torch.cuda.empty_cache() - ## downsamples for layer in self.downsamples: if feat_cache is not None: @@ -355,7 +348,6 @@ class Encoder3d(nn.Module): else: x = layer(x) - # torch.cuda.empty_cache() ## middle for layer in self.middle: @@ -364,7 +356,6 @@ class Encoder3d(nn.Module): else: x = layer(x) - # torch.cuda.empty_cache() ## head for layer in self.head: @@ -385,7 +376,6 @@ class Encoder3d(nn.Module): else: x = layer(x) - # torch.cuda.empty_cache() return x @@ -540,7 +530,7 @@ class WanVAE_(nn.Module): x_recon = self.decode(z) return x_recon, mu, log_var - def encode(self, x, scale): + def encode(self, x, scale = None): self.clear_cache() ## cache t = x.shape[2] @@ -562,22 +552,25 @@ class WanVAE_(nn.Module): mu, log_var = self.conv1(out).chunk(2, dim=1) - if isinstance(scale[0], torch.Tensor): - mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( - 1, self.z_dim, 1, 1, 1) - else: - mu = (mu - scale[0]) * scale[1] + if scale != None: + if isinstance(scale[0], torch.Tensor): + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( + 1, self.z_dim, 1, 1, 1) + else: + mu = (mu - scale[0]) * scale[1] self.clear_cache() return mu - def decode(self, z, scale): + + def decode(self, z, scale=None): self.clear_cache() # z: [b,c,t,h,w] - if isinstance(scale[0], torch.Tensor): - z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( - 1, self.z_dim, 1, 1, 1) - else: - z = z / scale[1] + scale[0] + if scale != None: + if isinstance(scale[0], torch.Tensor): + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( + 1, self.z_dim, 1, 1, 1) + else: + z = z / scale[1] + scale[0] iter_ = z.shape[2] x = self.conv2(z) for i in range(iter_): @@ -595,6 +588,104 @@ class WanVAE_(nn.Module): out = torch.cat([out, out_], 2) self.clear_cache() return out + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent) + return b + + def spatial_tiled_decode(self, z, scale, tile_size): + tile_sample_min_size = tile_size + tile_latent_min_size = int(tile_sample_min_size / 8) + tile_overlap_factor = 0.25 + + # z: [b,c,t,h,w] + + if isinstance(scale[0], torch.Tensor): + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( + 1, self.z_dim, 1, 1, 1) + else: + z = z / scale[1] + scale[0] + + + overlap_size = int(tile_latent_min_size * (1 - tile_overlap_factor)) #8 0.75 + blend_extent = int(tile_sample_min_size * tile_overlap_factor) #256 0.25 + row_limit = tile_sample_min_size - blend_extent + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[-2], overlap_size): + row = [] + for j in range(0, z.shape[-1], overlap_size): + tile = z[:, :, :, i: i + tile_latent_min_size, j: j + tile_latent_min_size] + decoded = self.decode(tile) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=-1)) + + return torch.cat(result_rows, dim=-2) + + + def spatial_tiled_encode(self, x, scale, tile_size) : + tile_sample_min_size = tile_size + tile_latent_min_size = int(tile_sample_min_size / 8) + tile_overlap_factor = 0.25 + + overlap_size = int(tile_sample_min_size * (1 - tile_overlap_factor)) + blend_extent = int(tile_latent_min_size * tile_overlap_factor) + row_limit = tile_latent_min_size - blend_extent + + # Split video into tiles and encode them separately. + rows = [] + for i in range(0, x.shape[-2], overlap_size): + row = [] + for j in range(0, x.shape[-1], overlap_size): + tile = x[:, :, :, i: i + tile_sample_min_size, j: j + tile_sample_min_size] + tile = self.encode(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=-1)) + + mu = torch.cat(result_rows, dim=-2) + + if isinstance(scale[0], torch.Tensor): + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( + 1, self.z_dim, 1, 1, 1) + else: + mu = (mu - scale[0]) * scale[1] + + return mu + def reparameterize(self, mu, log_var): std = torch.exp(0.5 * log_var) @@ -673,18 +764,18 @@ class WanVAE: z_dim=z_dim, ).eval().requires_grad_(False).to(device) - def encode(self, videos): + def encode(self, videos, tile_size = 256): """ videos: A list of videos each with shape [C, T, H, W]. """ - return [ - self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) - for u in videos - ] + if tile_size > 0: + return [ self.model.spatial_tiled_encode(u.unsqueeze(0), self.scale, tile_size).float().squeeze(0) for u in videos ] + else: + return [ self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) for u in videos ] - def decode(self, zs): - return [ - self.model.decode(u.unsqueeze(0), - self.scale).float().clamp_(-1, 1).squeeze(0) - for u in zs - ] + + def decode(self, zs, tile_size): + if tile_size > 0: + return [ self.model.spatial_tiled_decode(u.unsqueeze(0), self.scale, tile_size).float().clamp_(-1, 1).squeeze(0) for u in zs ] + else: + return [ self.model.decode(u.unsqueeze(0), self.scale).float().clamp_(-1, 1).squeeze(0) for u in zs ] diff --git a/wan/text2video.py b/wan/text2video.py index 2f1d9d0..c7efb92 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -129,7 +129,8 @@ class WanT2V: seed=-1, offload_model=True, callback = None, - enable_RIFLEx = None + enable_RIFLEx = None, + VAE_tile_size = 0 ): r""" Generates video frames from text prompt using diffusion process. @@ -286,7 +287,7 @@ class WanT2V: self.model.cpu() torch.cuda.empty_cache() if self.rank == 0: - videos = self.vae.decode(x0) + videos = self.vae.decode(x0, VAE_tile_size) del noise, latents