From 8dc2e1c3d981b73db084e7c045afb338727875f9 Mon Sep 17 00:00:00 2001 From: deepbeepmeep Date: Wed, 6 Aug 2025 22:07:48 +0200 Subject: [PATCH] fixed Wan 2.2 VAE crash --- models/wan/modules/vae2_2.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/models/wan/modules/vae2_2.py b/models/wan/modules/vae2_2.py index 1592e13..3e67f54 100644 --- a/models/wan/modules/vae2_2.py +++ b/models/wan/modules/vae2_2.py @@ -37,11 +37,29 @@ class CausalConv3d(nn.Conv3d): cache_x = cache_x.to(x.device) x = torch.cat([cache_x, x], dim=2) padding[4] -= cache_x.shape[2] + cache_x = None x = F.pad(x, padding) - - return super().forward(x) - - + try: + out = super().forward(x) + return out + except RuntimeError as e: + if "miopenStatus" in str(e): + print("⚠️ MIOpen fallback: AMD gets upset when trying to work with large areas, and so CPU will be " + "used for this decoding (which is very slow). Consider using tiled VAE Decoding.") + x_cpu = x.float().cpu() + weight_cpu = self.weight.float().cpu() + bias_cpu = self.bias.float().cpu() if self.bias is not None else None + print(f"[Fallback] x shape: {x_cpu.shape}, weight shape: {weight_cpu.shape}") + out = F.conv3d(x_cpu, weight_cpu, bias_cpu, + self.stride, (0, 0, 0), # avoid double padding here + self.dilation, self.groups) + out = out.to(x.device) + if x.dtype in (torch.float16, torch.bfloat16): + out = out.half() + if x.dtype != out.dtype: + out = out.to(x.dtype) + return out + raise class RMS_norm(nn.Module): def __init__(self, dim, channel_first=True, images=True, bias=False): @@ -1124,6 +1142,8 @@ class Wan2_2_VAE: temperal_downsample=temperal_downsample, ).eval().requires_grad_(False).to(device)) + self.model._model_dtype = dtype + @staticmethod def get_VAE_tile_size(vae_config, device_mem_capacity, mixed_precision):