mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	fixed Wan 2.2 VAE crash
This commit is contained in:
		
							parent
							
								
									0e13d6b2a1
								
							
						
					
					
						commit
						8dc2e1c3d9
					
				@ -37,11 +37,29 @@ class CausalConv3d(nn.Conv3d):
 | 
			
		||||
            cache_x = cache_x.to(x.device)
 | 
			
		||||
            x = torch.cat([cache_x, x], dim=2)
 | 
			
		||||
            padding[4] -= cache_x.shape[2]
 | 
			
		||||
            cache_x = None
 | 
			
		||||
        x = F.pad(x, padding)
 | 
			
		||||
 | 
			
		||||
        return super().forward(x)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            out = super().forward(x)
 | 
			
		||||
            return out
 | 
			
		||||
        except RuntimeError as e:
 | 
			
		||||
            if "miopenStatus" in str(e):
 | 
			
		||||
                print("⚠️ MIOpen fallback: AMD gets upset when trying to work with large areas, and so CPU will be "
 | 
			
		||||
                      "used for this decoding (which is very slow). Consider using tiled VAE Decoding.")
 | 
			
		||||
                x_cpu = x.float().cpu()
 | 
			
		||||
                weight_cpu = self.weight.float().cpu()
 | 
			
		||||
                bias_cpu = self.bias.float().cpu() if self.bias is not None else None
 | 
			
		||||
                print(f"[Fallback] x shape: {x_cpu.shape}, weight shape: {weight_cpu.shape}")
 | 
			
		||||
                out = F.conv3d(x_cpu, weight_cpu, bias_cpu,
 | 
			
		||||
                               self.stride, (0, 0, 0),  # avoid double padding here
 | 
			
		||||
                               self.dilation, self.groups)
 | 
			
		||||
                out = out.to(x.device)
 | 
			
		||||
                if x.dtype in (torch.float16, torch.bfloat16):
 | 
			
		||||
                    out = out.half()
 | 
			
		||||
                if x.dtype != out.dtype:
 | 
			
		||||
                    out = out.to(x.dtype)
 | 
			
		||||
                return out
 | 
			
		||||
            raise
 | 
			
		||||
class RMS_norm(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, dim, channel_first=True, images=True, bias=False):
 | 
			
		||||
@ -1124,6 +1142,8 @@ class Wan2_2_VAE:
 | 
			
		||||
                temperal_downsample=temperal_downsample,
 | 
			
		||||
            ).eval().requires_grad_(False).to(device))
 | 
			
		||||
        
 | 
			
		||||
        self.model._model_dtype = dtype
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_VAE_tile_size(vae_config, device_mem_capacity, mixed_precision):
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user