Implemented VAE tiling

This commit is contained in:
DeepBeepMeep 2025-03-04 02:39:44 +01:00
parent 79335caa99
commit 697cc2cce5
5 changed files with 160 additions and 47 deletions

View File

@ -20,6 +20,7 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
## 🔥 Latest News!!
* Mar 03, 2025: 👋 Wan2.1GP v1.2: Implementented tiling on VAE encoding and decoding. No more VRAM peaks at the beginning and at the end
* Mar 03, 2025: 👋 Wan2.1GP v1.1: added Tea Cache support for faster generations: optimization of kijai's implementation (https://github.com/kijai/ComfyUI-WanVideoWrapper/) of teacache (https://github.com/ali-vilab/TeaCache)
* Mar 02, 2025: 👋 Wan2.1GP by DeepBeepMeep v1 brings:
- Support for all Wan including the Image to Video model
@ -126,7 +127,6 @@ To run the application while loading entirely the diffusion model in VRAM (sligh
```bash
python gradio_server.py --profile 3
```
Please note that diffusion model of Wan2.1GP is extremely VRAM optimized and this will greatly benefit low VRAM systems since the diffusion / denoising step is the longest part of the generation process. However, the VAE encoder (at the beginning of a image 2 video process) and the VAE decoder (at the end of any video process) is still VRAM hungry after optimization and it will require temporarly 22 GB of VRAM for a 720p generation and 12 GB of VRAM for a 480p generation. Therefore if you have less than these numbers, you may experience slow downs at the beginning and at the end of the generation process due to pytorch VRAM offloading.
### Loras support

View File

@ -433,7 +433,7 @@ def load_models(i2v, lora_dir, lora_preselected_preset ):
kwargs = { "extraModelsToQuantize": None}
if profile == 2 or profile == 4:
kwargs["budgets"] = { "transformer" : 100, "*" : 3000 }
kwargs["budgets"] = { "transformer" : 100, "text_encoder" : 100, "*" : 1000 }
loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets = setup_loras(pipe, lora_dir, lora_preselected_preset, None)
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, **kwargs)
@ -693,11 +693,29 @@ def generate_video(
if "1.3B" in transformer_filename_t2v and width * height > 848*480:
raise gr.Error("You must use the 14B text to video model to generate videos with a resolution equivalent to 720P")
offload.shared_state["_vae"] = vae_config
offload.shared_state["_vae_threshold"] = 0.9* torch.cuda.get_device_properties(0).total_memory
offload.shared_state["_attention"] = attn
# VAE Tiling
device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576
if vae_config == 0:
if device_mem_capacity >= 24000:
use_vae_config = 1
elif device_mem_capacity >= 8000:
use_vae_config = 2
else:
use_vae_config = 3
else:
use_vae_config = vae_config
if use_vae_config == 1:
VAE_tile_size = 0
elif use_vae_config == 2:
VAE_tile_size = 256
else:
VAE_tile_size = 128
global gen_in_progress
gen_in_progress = True
temp_filename = None
@ -818,7 +836,8 @@ def generate_video(
seed=seed,
offload_model=False,
callback=callback,
enable_RIFLEx = enable_RIFLEx
enable_RIFLEx = enable_RIFLEx,
VAE_tile_size = VAE_tile_size
)
else:
@ -833,7 +852,8 @@ def generate_video(
seed=seed,
offload_model=False,
callback=callback,
enable_RIFLEx = enable_RIFLEx
enable_RIFLEx = enable_RIFLEx,
VAE_tile_size = VAE_tile_size
)
except Exception as e:
gen_in_progress = False
@ -987,7 +1007,6 @@ def create_demo():
gr.Markdown("- 848 x 480 with a 14B model: 80 frames (5s) : 8 GB of VRAM")
gr.Markdown("- 848 x 480 with the 1.3B model: 80 frames (5s) : 5 GB of VRAM")
gr.Markdown("- 1280 x 720 with a 14B model: 80 frames (5s): 11 GB of VRAM")
gr.Markdown("Note that the VAE stages (encoding / decoding at image2video ) or just the decoding at text2video will create a temporary VRAM peaks (up to 12GB for 420P and 22 GB for 720P)")
gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifacts may appear")
gr.Markdown("Please note that if your turn on compilation, the first generation step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
@ -1076,8 +1095,9 @@ def create_demo():
vae_config_choice = gr.Dropdown(
choices=[
("Auto", 0),
("Disabled (faster but may require up to 24 GB of VRAM)", 1),
("Enabled (2x slower and up to 50% VRAM reduction)", 2),
("Disabled (faster but may require up to 22 GB of VRAM)", 1),
("256 x 256 : If at least 8 GB of VRAM", 2),
("128 x 128 : If at least 6 GB of VRAM", 3),
],
value= vae_config,
label="VAE optimisations - reduce the VRAM requirements for VAE decoding and VAE encoding"

View File

@ -144,7 +144,8 @@ class WanI2V:
seed=-1,
offload_model=True,
callback = None,
enable_RIFLEx = False
enable_RIFLEx = False,
VAE_tile_size= 0,
):
r"""
@ -254,7 +255,7 @@ class WanI2V:
], dim=1).to(self.device)
# enc = None
y = self.vae.encode([enc])[0]
y = self.vae.encode([enc], VAE_tile_size)[0]
y = torch.concat([msk, y])
@contextmanager
@ -363,7 +364,7 @@ class WanI2V:
torch.cuda.empty_cache()
if self.rank == 0:
videos = self.vae.decode(x0)
videos = self.vae.decode(x0, VAE_tile_size)
del noise, latent
del sample_scheduler

View File

@ -35,11 +35,6 @@ class CausalConv3d(nn.Conv3d):
x = F.pad(x, padding)
x = super().forward(x)
mem_threshold = offload.shared_state.get("_vae_threshold",0)
vae_config = offload.shared_state.get("_vae",1)
if vae_config == 0 and torch.cuda.memory_reserved() > mem_threshold or vae_config == 2:
torch.cuda.empty_cache()
return x
@ -346,8 +341,6 @@ class Encoder3d(nn.Module):
x = self.conv1(x)
# torch.cuda.empty_cache()
## downsamples
for layer in self.downsamples:
if feat_cache is not None:
@ -355,7 +348,6 @@ class Encoder3d(nn.Module):
else:
x = layer(x)
# torch.cuda.empty_cache()
## middle
for layer in self.middle:
@ -364,7 +356,6 @@ class Encoder3d(nn.Module):
else:
x = layer(x)
# torch.cuda.empty_cache()
## head
for layer in self.head:
@ -385,7 +376,6 @@ class Encoder3d(nn.Module):
else:
x = layer(x)
# torch.cuda.empty_cache()
return x
@ -540,7 +530,7 @@ class WanVAE_(nn.Module):
x_recon = self.decode(z)
return x_recon, mu, log_var
def encode(self, x, scale):
def encode(self, x, scale = None):
self.clear_cache()
## cache
t = x.shape[2]
@ -562,22 +552,25 @@ class WanVAE_(nn.Module):
mu, log_var = self.conv1(out).chunk(2, dim=1)
if isinstance(scale[0], torch.Tensor):
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
1, self.z_dim, 1, 1, 1)
else:
mu = (mu - scale[0]) * scale[1]
if scale != None:
if isinstance(scale[0], torch.Tensor):
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
1, self.z_dim, 1, 1, 1)
else:
mu = (mu - scale[0]) * scale[1]
self.clear_cache()
return mu
def decode(self, z, scale):
def decode(self, z, scale=None):
self.clear_cache()
# z: [b,c,t,h,w]
if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
1, self.z_dim, 1, 1, 1)
else:
z = z / scale[1] + scale[0]
if scale != None:
if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
1, self.z_dim, 1, 1, 1)
else:
z = z / scale[1] + scale[0]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
@ -595,6 +588,104 @@ class WanVAE_(nn.Module):
out = torch.cat([out, out_], 2)
self.clear_cache()
return out
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
return b
def spatial_tiled_decode(self, z, scale, tile_size):
tile_sample_min_size = tile_size
tile_latent_min_size = int(tile_sample_min_size / 8)
tile_overlap_factor = 0.25
# z: [b,c,t,h,w]
if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
1, self.z_dim, 1, 1, 1)
else:
z = z / scale[1] + scale[0]
overlap_size = int(tile_latent_min_size * (1 - tile_overlap_factor)) #8 0.75
blend_extent = int(tile_sample_min_size * tile_overlap_factor) #256 0.25
row_limit = tile_sample_min_size - blend_extent
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, z.shape[-2], overlap_size):
row = []
for j in range(0, z.shape[-1], overlap_size):
tile = z[:, :, :, i: i + tile_latent_min_size, j: j + tile_latent_min_size]
decoded = self.decode(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=-1))
return torch.cat(result_rows, dim=-2)
def spatial_tiled_encode(self, x, scale, tile_size) :
tile_sample_min_size = tile_size
tile_latent_min_size = int(tile_sample_min_size / 8)
tile_overlap_factor = 0.25
overlap_size = int(tile_sample_min_size * (1 - tile_overlap_factor))
blend_extent = int(tile_latent_min_size * tile_overlap_factor)
row_limit = tile_latent_min_size - blend_extent
# Split video into tiles and encode them separately.
rows = []
for i in range(0, x.shape[-2], overlap_size):
row = []
for j in range(0, x.shape[-1], overlap_size):
tile = x[:, :, :, i: i + tile_sample_min_size, j: j + tile_sample_min_size]
tile = self.encode(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=-1))
mu = torch.cat(result_rows, dim=-2)
if isinstance(scale[0], torch.Tensor):
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
1, self.z_dim, 1, 1, 1)
else:
mu = (mu - scale[0]) * scale[1]
return mu
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
@ -673,18 +764,18 @@ class WanVAE:
z_dim=z_dim,
).eval().requires_grad_(False).to(device)
def encode(self, videos):
def encode(self, videos, tile_size = 256):
"""
videos: A list of videos each with shape [C, T, H, W].
"""
return [
self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
for u in videos
]
if tile_size > 0:
return [ self.model.spatial_tiled_encode(u.unsqueeze(0), self.scale, tile_size).float().squeeze(0) for u in videos ]
else:
return [ self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) for u in videos ]
def decode(self, zs):
return [
self.model.decode(u.unsqueeze(0),
self.scale).float().clamp_(-1, 1).squeeze(0)
for u in zs
]
def decode(self, zs, tile_size):
if tile_size > 0:
return [ self.model.spatial_tiled_decode(u.unsqueeze(0), self.scale, tile_size).float().clamp_(-1, 1).squeeze(0) for u in zs ]
else:
return [ self.model.decode(u.unsqueeze(0), self.scale).float().clamp_(-1, 1).squeeze(0) for u in zs ]

View File

@ -129,7 +129,8 @@ class WanT2V:
seed=-1,
offload_model=True,
callback = None,
enable_RIFLEx = None
enable_RIFLEx = None,
VAE_tile_size = 0
):
r"""
Generates video frames from text prompt using diffusion process.
@ -286,7 +287,7 @@ class WanT2V:
self.model.cpu()
torch.cuda.empty_cache()
if self.rank == 0:
videos = self.vae.decode(x0)
videos = self.vae.decode(x0, VAE_tile_size)
del noise, latents