mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
If an error occurs because AMD is asked to VAE Decode without tiling, warn and use CPU decoding.
This commit is contained in:
parent
18bf248fb8
commit
4ca0666aa5
@ -35,32 +35,27 @@ class CausalConv3d(nn.Conv3d):
|
|||||||
x = F.pad(x, padding)
|
x = F.pad(x, padding)
|
||||||
try:
|
try:
|
||||||
out = super().forward(x)
|
out = super().forward(x)
|
||||||
print("(ran fine)")
|
|
||||||
return out
|
return out
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
if "miopenStatus" in str(e):
|
if "miopenStatus" in str(e):
|
||||||
print("⚠️ MIOpen fallback: running Conv3d on CPU")
|
print("⚠️ MIOpen fallback: AMD gets upset when trying to work with large areas, and so CPU will be "
|
||||||
|
"used for this decoding (which is very slow). Consider using tiled VAE Decoding.")
|
||||||
x_cpu = x.float().cpu()
|
x_cpu = x.float().cpu()
|
||||||
weight_cpu = self.weight.float().cpu()
|
weight_cpu = self.weight.float().cpu()
|
||||||
bias_cpu = self.bias.float().cpu() if self.bias is not None else None
|
bias_cpu = self.bias.float().cpu() if self.bias is not None else None
|
||||||
|
|
||||||
print(f"[Fallback] x shape: {x_cpu.shape}, weight shape: {weight_cpu.shape}")
|
print(f"[Fallback] x shape: {x_cpu.shape}, weight shape: {weight_cpu.shape}")
|
||||||
out = F.conv3d(x_cpu, weight_cpu, bias_cpu,
|
out = F.conv3d(x_cpu, weight_cpu, bias_cpu,
|
||||||
self.stride, (0, 0, 0), # <-- FIX: no padding here
|
self.stride, (0, 0, 0), # avoid double padding here
|
||||||
self.dilation, self.groups)
|
self.dilation, self.groups)
|
||||||
|
|
||||||
out = out.to(x.device)
|
out = out.to(x.device)
|
||||||
if x.dtype in (torch.float16, torch.bfloat16):
|
if x.dtype in (torch.float16, torch.bfloat16):
|
||||||
out = out.half()
|
out = out.half()
|
||||||
if x.dtype != out.dtype:
|
if x.dtype != out.dtype:
|
||||||
out = out.to(x.dtype)
|
out = out.to(x.dtype)
|
||||||
print("... returned (from CPU fallback)")
|
|
||||||
return out
|
return out
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class RMS_norm(nn.Module):
|
class RMS_norm(nn.Module):
|
||||||
|
|
||||||
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user