mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
fixed Wan 2.2 VAE crash
This commit is contained in:
parent
0e13d6b2a1
commit
8dc2e1c3d9
@ -37,11 +37,29 @@ class CausalConv3d(nn.Conv3d):
|
||||
cache_x = cache_x.to(x.device)
|
||||
x = torch.cat([cache_x, x], dim=2)
|
||||
padding[4] -= cache_x.shape[2]
|
||||
cache_x = None
|
||||
x = F.pad(x, padding)
|
||||
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
try:
|
||||
out = super().forward(x)
|
||||
return out
|
||||
except RuntimeError as e:
|
||||
if "miopenStatus" in str(e):
|
||||
print("⚠️ MIOpen fallback: AMD gets upset when trying to work with large areas, and so CPU will be "
|
||||
"used for this decoding (which is very slow). Consider using tiled VAE Decoding.")
|
||||
x_cpu = x.float().cpu()
|
||||
weight_cpu = self.weight.float().cpu()
|
||||
bias_cpu = self.bias.float().cpu() if self.bias is not None else None
|
||||
print(f"[Fallback] x shape: {x_cpu.shape}, weight shape: {weight_cpu.shape}")
|
||||
out = F.conv3d(x_cpu, weight_cpu, bias_cpu,
|
||||
self.stride, (0, 0, 0), # avoid double padding here
|
||||
self.dilation, self.groups)
|
||||
out = out.to(x.device)
|
||||
if x.dtype in (torch.float16, torch.bfloat16):
|
||||
out = out.half()
|
||||
if x.dtype != out.dtype:
|
||||
out = out.to(x.dtype)
|
||||
return out
|
||||
raise
|
||||
class RMS_norm(nn.Module):
|
||||
|
||||
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
||||
@ -1124,6 +1142,8 @@ class Wan2_2_VAE:
|
||||
temperal_downsample=temperal_downsample,
|
||||
).eval().requires_grad_(False).to(device))
|
||||
|
||||
self.model._model_dtype = dtype
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_VAE_tile_size(vae_config, device_mem_capacity, mixed_precision):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user