If an error occurs because AMD is asked to VAE Decode without tiling, warn and use CPU decoding.

This commit is contained in:
Christopher Anderson 2025-06-15 10:12:03 +10:00
parent 18bf248fb8
commit 4ca0666aa5

View File

@ -35,32 +35,27 @@ class CausalConv3d(nn.Conv3d):
x = F.pad(x, padding) x = F.pad(x, padding)
try: try:
out = super().forward(x) out = super().forward(x)
print("(ran fine)")
return out return out
except RuntimeError as e: except RuntimeError as e:
if "miopenStatus" in str(e): if "miopenStatus" in str(e):
print("⚠️ MIOpen fallback: running Conv3d on CPU") print("⚠️ MIOpen fallback: AMD gets upset when trying to work with large areas, and so CPU will be "
"used for this decoding (which is very slow). Consider using tiled VAE Decoding.")
x_cpu = x.float().cpu() x_cpu = x.float().cpu()
weight_cpu = self.weight.float().cpu() weight_cpu = self.weight.float().cpu()
bias_cpu = self.bias.float().cpu() if self.bias is not None else None bias_cpu = self.bias.float().cpu() if self.bias is not None else None
print(f"[Fallback] x shape: {x_cpu.shape}, weight shape: {weight_cpu.shape}") print(f"[Fallback] x shape: {x_cpu.shape}, weight shape: {weight_cpu.shape}")
out = F.conv3d(x_cpu, weight_cpu, bias_cpu, out = F.conv3d(x_cpu, weight_cpu, bias_cpu,
self.stride, (0, 0, 0), # <-- FIX: no padding here self.stride, (0, 0, 0), # avoid double padding here
self.dilation, self.groups) self.dilation, self.groups)
out = out.to(x.device) out = out.to(x.device)
if x.dtype in (torch.float16, torch.bfloat16): if x.dtype in (torch.float16, torch.bfloat16):
out = out.half() out = out.half()
if x.dtype != out.dtype: if x.dtype != out.dtype:
out = out.to(x.dtype) out = out.to(x.dtype)
print("... returned (from CPU fallback)")
return out return out
raise raise
class RMS_norm(nn.Module): class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=True, images=True, bias=False): def __init__(self, dim, channel_first=True, images=True, bias=False):