mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
Implemented VAE tiling
This commit is contained in:
parent
79335caa99
commit
697cc2cce5
@ -20,6 +20,7 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
|
||||
|
||||
## 🔥 Latest News!!
|
||||
|
||||
* Mar 03, 2025: 👋 Wan2.1GP v1.2: Implementented tiling on VAE encoding and decoding. No more VRAM peaks at the beginning and at the end
|
||||
* Mar 03, 2025: 👋 Wan2.1GP v1.1: added Tea Cache support for faster generations: optimization of kijai's implementation (https://github.com/kijai/ComfyUI-WanVideoWrapper/) of teacache (https://github.com/ali-vilab/TeaCache)
|
||||
* Mar 02, 2025: 👋 Wan2.1GP by DeepBeepMeep v1 brings:
|
||||
- Support for all Wan including the Image to Video model
|
||||
@ -126,7 +127,6 @@ To run the application while loading entirely the diffusion model in VRAM (sligh
|
||||
```bash
|
||||
python gradio_server.py --profile 3
|
||||
```
|
||||
Please note that diffusion model of Wan2.1GP is extremely VRAM optimized and this will greatly benefit low VRAM systems since the diffusion / denoising step is the longest part of the generation process. However, the VAE encoder (at the beginning of a image 2 video process) and the VAE decoder (at the end of any video process) is still VRAM hungry after optimization and it will require temporarly 22 GB of VRAM for a 720p generation and 12 GB of VRAM for a 480p generation. Therefore if you have less than these numbers, you may experience slow downs at the beginning and at the end of the generation process due to pytorch VRAM offloading.
|
||||
|
||||
|
||||
### Loras support
|
||||
|
||||
@ -433,7 +433,7 @@ def load_models(i2v, lora_dir, lora_preselected_preset ):
|
||||
|
||||
kwargs = { "extraModelsToQuantize": None}
|
||||
if profile == 2 or profile == 4:
|
||||
kwargs["budgets"] = { "transformer" : 100, "*" : 3000 }
|
||||
kwargs["budgets"] = { "transformer" : 100, "text_encoder" : 100, "*" : 1000 }
|
||||
|
||||
loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets = setup_loras(pipe, lora_dir, lora_preselected_preset, None)
|
||||
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, **kwargs)
|
||||
@ -693,11 +693,29 @@ def generate_video(
|
||||
if "1.3B" in transformer_filename_t2v and width * height > 848*480:
|
||||
raise gr.Error("You must use the 14B text to video model to generate videos with a resolution equivalent to 720P")
|
||||
|
||||
offload.shared_state["_vae"] = vae_config
|
||||
offload.shared_state["_vae_threshold"] = 0.9* torch.cuda.get_device_properties(0).total_memory
|
||||
|
||||
offload.shared_state["_attention"] = attn
|
||||
|
||||
# VAE Tiling
|
||||
device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576
|
||||
if vae_config == 0:
|
||||
if device_mem_capacity >= 24000:
|
||||
use_vae_config = 1
|
||||
elif device_mem_capacity >= 8000:
|
||||
use_vae_config = 2
|
||||
else:
|
||||
use_vae_config = 3
|
||||
else:
|
||||
use_vae_config = vae_config
|
||||
|
||||
if use_vae_config == 1:
|
||||
VAE_tile_size = 0
|
||||
elif use_vae_config == 2:
|
||||
VAE_tile_size = 256
|
||||
else:
|
||||
VAE_tile_size = 128
|
||||
|
||||
|
||||
global gen_in_progress
|
||||
gen_in_progress = True
|
||||
temp_filename = None
|
||||
@ -818,7 +836,8 @@ def generate_video(
|
||||
seed=seed,
|
||||
offload_model=False,
|
||||
callback=callback,
|
||||
enable_RIFLEx = enable_RIFLEx
|
||||
enable_RIFLEx = enable_RIFLEx,
|
||||
VAE_tile_size = VAE_tile_size
|
||||
)
|
||||
|
||||
else:
|
||||
@ -833,7 +852,8 @@ def generate_video(
|
||||
seed=seed,
|
||||
offload_model=False,
|
||||
callback=callback,
|
||||
enable_RIFLEx = enable_RIFLEx
|
||||
enable_RIFLEx = enable_RIFLEx,
|
||||
VAE_tile_size = VAE_tile_size
|
||||
)
|
||||
except Exception as e:
|
||||
gen_in_progress = False
|
||||
@ -987,7 +1007,6 @@ def create_demo():
|
||||
gr.Markdown("- 848 x 480 with a 14B model: 80 frames (5s) : 8 GB of VRAM")
|
||||
gr.Markdown("- 848 x 480 with the 1.3B model: 80 frames (5s) : 5 GB of VRAM")
|
||||
gr.Markdown("- 1280 x 720 with a 14B model: 80 frames (5s): 11 GB of VRAM")
|
||||
gr.Markdown("Note that the VAE stages (encoding / decoding at image2video ) or just the decoding at text2video will create a temporary VRAM peaks (up to 12GB for 420P and 22 GB for 720P)")
|
||||
gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifacts may appear")
|
||||
gr.Markdown("Please note that if your turn on compilation, the first generation step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
|
||||
|
||||
@ -1076,8 +1095,9 @@ def create_demo():
|
||||
vae_config_choice = gr.Dropdown(
|
||||
choices=[
|
||||
("Auto", 0),
|
||||
("Disabled (faster but may require up to 24 GB of VRAM)", 1),
|
||||
("Enabled (2x slower and up to 50% VRAM reduction)", 2),
|
||||
("Disabled (faster but may require up to 22 GB of VRAM)", 1),
|
||||
("256 x 256 : If at least 8 GB of VRAM", 2),
|
||||
("128 x 128 : If at least 6 GB of VRAM", 3),
|
||||
],
|
||||
value= vae_config,
|
||||
label="VAE optimisations - reduce the VRAM requirements for VAE decoding and VAE encoding"
|
||||
|
||||
@ -144,7 +144,8 @@ class WanI2V:
|
||||
seed=-1,
|
||||
offload_model=True,
|
||||
callback = None,
|
||||
enable_RIFLEx = False
|
||||
enable_RIFLEx = False,
|
||||
VAE_tile_size= 0,
|
||||
|
||||
):
|
||||
r"""
|
||||
@ -254,7 +255,7 @@ class WanI2V:
|
||||
], dim=1).to(self.device)
|
||||
# enc = None
|
||||
|
||||
y = self.vae.encode([enc])[0]
|
||||
y = self.vae.encode([enc], VAE_tile_size)[0]
|
||||
y = torch.concat([msk, y])
|
||||
|
||||
@contextmanager
|
||||
@ -363,7 +364,7 @@ class WanI2V:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if self.rank == 0:
|
||||
videos = self.vae.decode(x0)
|
||||
videos = self.vae.decode(x0, VAE_tile_size)
|
||||
|
||||
del noise, latent
|
||||
del sample_scheduler
|
||||
|
||||
@ -35,11 +35,6 @@ class CausalConv3d(nn.Conv3d):
|
||||
x = F.pad(x, padding)
|
||||
x = super().forward(x)
|
||||
|
||||
mem_threshold = offload.shared_state.get("_vae_threshold",0)
|
||||
vae_config = offload.shared_state.get("_vae",1)
|
||||
|
||||
if vae_config == 0 and torch.cuda.memory_reserved() > mem_threshold or vae_config == 2:
|
||||
torch.cuda.empty_cache()
|
||||
return x
|
||||
|
||||
|
||||
@ -346,8 +341,6 @@ class Encoder3d(nn.Module):
|
||||
x = self.conv1(x)
|
||||
|
||||
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
## downsamples
|
||||
for layer in self.downsamples:
|
||||
if feat_cache is not None:
|
||||
@ -355,7 +348,6 @@ class Encoder3d(nn.Module):
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
@ -364,7 +356,6 @@ class Encoder3d(nn.Module):
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
## head
|
||||
for layer in self.head:
|
||||
@ -385,7 +376,6 @@ class Encoder3d(nn.Module):
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
return x
|
||||
|
||||
@ -540,7 +530,7 @@ class WanVAE_(nn.Module):
|
||||
x_recon = self.decode(z)
|
||||
return x_recon, mu, log_var
|
||||
|
||||
def encode(self, x, scale):
|
||||
def encode(self, x, scale = None):
|
||||
self.clear_cache()
|
||||
## cache
|
||||
t = x.shape[2]
|
||||
@ -562,22 +552,25 @@ class WanVAE_(nn.Module):
|
||||
|
||||
|
||||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||
if isinstance(scale[0], torch.Tensor):
|
||||
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
|
||||
1, self.z_dim, 1, 1, 1)
|
||||
else:
|
||||
mu = (mu - scale[0]) * scale[1]
|
||||
if scale != None:
|
||||
if isinstance(scale[0], torch.Tensor):
|
||||
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
|
||||
1, self.z_dim, 1, 1, 1)
|
||||
else:
|
||||
mu = (mu - scale[0]) * scale[1]
|
||||
self.clear_cache()
|
||||
return mu
|
||||
|
||||
def decode(self, z, scale):
|
||||
|
||||
def decode(self, z, scale=None):
|
||||
self.clear_cache()
|
||||
# z: [b,c,t,h,w]
|
||||
if isinstance(scale[0], torch.Tensor):
|
||||
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
||||
1, self.z_dim, 1, 1, 1)
|
||||
else:
|
||||
z = z / scale[1] + scale[0]
|
||||
if scale != None:
|
||||
if isinstance(scale[0], torch.Tensor):
|
||||
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
||||
1, self.z_dim, 1, 1, 1)
|
||||
else:
|
||||
z = z / scale[1] + scale[0]
|
||||
iter_ = z.shape[2]
|
||||
x = self.conv2(z)
|
||||
for i in range(iter_):
|
||||
@ -595,6 +588,104 @@ class WanVAE_(nn.Module):
|
||||
out = torch.cat([out, out_], 2)
|
||||
self.clear_cache()
|
||||
return out
|
||||
|
||||
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
|
||||
for y in range(blend_extent):
|
||||
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
|
||||
return b
|
||||
|
||||
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
|
||||
for x in range(blend_extent):
|
||||
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
|
||||
return b
|
||||
|
||||
def spatial_tiled_decode(self, z, scale, tile_size):
|
||||
tile_sample_min_size = tile_size
|
||||
tile_latent_min_size = int(tile_sample_min_size / 8)
|
||||
tile_overlap_factor = 0.25
|
||||
|
||||
# z: [b,c,t,h,w]
|
||||
|
||||
if isinstance(scale[0], torch.Tensor):
|
||||
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
||||
1, self.z_dim, 1, 1, 1)
|
||||
else:
|
||||
z = z / scale[1] + scale[0]
|
||||
|
||||
|
||||
overlap_size = int(tile_latent_min_size * (1 - tile_overlap_factor)) #8 0.75
|
||||
blend_extent = int(tile_sample_min_size * tile_overlap_factor) #256 0.25
|
||||
row_limit = tile_sample_min_size - blend_extent
|
||||
|
||||
# Split z into overlapping tiles and decode them separately.
|
||||
# The tiles have an overlap to avoid seams between tiles.
|
||||
rows = []
|
||||
for i in range(0, z.shape[-2], overlap_size):
|
||||
row = []
|
||||
for j in range(0, z.shape[-1], overlap_size):
|
||||
tile = z[:, :, :, i: i + tile_latent_min_size, j: j + tile_latent_min_size]
|
||||
decoded = self.decode(tile)
|
||||
row.append(decoded)
|
||||
rows.append(row)
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
||||
result_rows.append(torch.cat(result_row, dim=-1))
|
||||
|
||||
return torch.cat(result_rows, dim=-2)
|
||||
|
||||
|
||||
def spatial_tiled_encode(self, x, scale, tile_size) :
|
||||
tile_sample_min_size = tile_size
|
||||
tile_latent_min_size = int(tile_sample_min_size / 8)
|
||||
tile_overlap_factor = 0.25
|
||||
|
||||
overlap_size = int(tile_sample_min_size * (1 - tile_overlap_factor))
|
||||
blend_extent = int(tile_latent_min_size * tile_overlap_factor)
|
||||
row_limit = tile_latent_min_size - blend_extent
|
||||
|
||||
# Split video into tiles and encode them separately.
|
||||
rows = []
|
||||
for i in range(0, x.shape[-2], overlap_size):
|
||||
row = []
|
||||
for j in range(0, x.shape[-1], overlap_size):
|
||||
tile = x[:, :, :, i: i + tile_sample_min_size, j: j + tile_sample_min_size]
|
||||
tile = self.encode(tile)
|
||||
row.append(tile)
|
||||
rows.append(row)
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
||||
result_rows.append(torch.cat(result_row, dim=-1))
|
||||
|
||||
mu = torch.cat(result_rows, dim=-2)
|
||||
|
||||
if isinstance(scale[0], torch.Tensor):
|
||||
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
|
||||
1, self.z_dim, 1, 1, 1)
|
||||
else:
|
||||
mu = (mu - scale[0]) * scale[1]
|
||||
|
||||
return mu
|
||||
|
||||
|
||||
def reparameterize(self, mu, log_var):
|
||||
std = torch.exp(0.5 * log_var)
|
||||
@ -673,18 +764,18 @@ class WanVAE:
|
||||
z_dim=z_dim,
|
||||
).eval().requires_grad_(False).to(device)
|
||||
|
||||
def encode(self, videos):
|
||||
def encode(self, videos, tile_size = 256):
|
||||
"""
|
||||
videos: A list of videos each with shape [C, T, H, W].
|
||||
"""
|
||||
return [
|
||||
self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
|
||||
for u in videos
|
||||
]
|
||||
if tile_size > 0:
|
||||
return [ self.model.spatial_tiled_encode(u.unsqueeze(0), self.scale, tile_size).float().squeeze(0) for u in videos ]
|
||||
else:
|
||||
return [ self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) for u in videos ]
|
||||
|
||||
def decode(self, zs):
|
||||
return [
|
||||
self.model.decode(u.unsqueeze(0),
|
||||
self.scale).float().clamp_(-1, 1).squeeze(0)
|
||||
for u in zs
|
||||
]
|
||||
|
||||
def decode(self, zs, tile_size):
|
||||
if tile_size > 0:
|
||||
return [ self.model.spatial_tiled_decode(u.unsqueeze(0), self.scale, tile_size).float().clamp_(-1, 1).squeeze(0) for u in zs ]
|
||||
else:
|
||||
return [ self.model.decode(u.unsqueeze(0), self.scale).float().clamp_(-1, 1).squeeze(0) for u in zs ]
|
||||
|
||||
@ -129,7 +129,8 @@ class WanT2V:
|
||||
seed=-1,
|
||||
offload_model=True,
|
||||
callback = None,
|
||||
enable_RIFLEx = None
|
||||
enable_RIFLEx = None,
|
||||
VAE_tile_size = 0
|
||||
):
|
||||
r"""
|
||||
Generates video frames from text prompt using diffusion process.
|
||||
@ -286,7 +287,7 @@ class WanT2V:
|
||||
self.model.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
if self.rank == 0:
|
||||
videos = self.vae.decode(x0)
|
||||
videos = self.vae.decode(x0, VAE_tile_size)
|
||||
|
||||
|
||||
del noise, latents
|
||||
|
||||
Loading…
Reference in New Issue
Block a user