mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	oops
This commit is contained in:
		
							parent
							
								
									66ddadf0cc
								
							
						
					
					
						commit
						6b17c9fb6a
					
				@ -113,6 +113,7 @@ class WanAny2V:
 | 
			
		||||
        self.vae = vae(
 | 
			
		||||
            vae_pth=os.path.join(checkpoint_dir, vae_checkpoint), dtype= VAE_dtype,
 | 
			
		||||
            device="cpu")
 | 
			
		||||
        self.vae.device = self.device
 | 
			
		||||
        
 | 
			
		||||
        # config_filename= "configs/t2v_1.3B.json"
 | 
			
		||||
        # import json
 | 
			
		||||
@ -467,7 +468,6 @@ class WanAny2V:
 | 
			
		||||
        color_reference_frame = None
 | 
			
		||||
        if self._interrupt:
 | 
			
		||||
            return None
 | 
			
		||||
        
 | 
			
		||||
        # Text Encoder
 | 
			
		||||
        if n_prompt == "":
 | 
			
		||||
            n_prompt = self.sample_neg_prompt
 | 
			
		||||
 | 
			
		||||
@ -832,8 +832,7 @@ class WanVAE:
 | 
			
		||||
        """
 | 
			
		||||
        videos: A list of videos each with shape [C, T, H, W].
 | 
			
		||||
        """
 | 
			
		||||
        original_device = videos[0].device
 | 
			
		||||
        scale = [u.to(device = original_device) for u in self.scale]  
 | 
			
		||||
        scale = [u.to(device = self.device) for u in self.scale]  
 | 
			
		||||
        if tile_size > 0:
 | 
			
		||||
            return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
 | 
			
		||||
        else:
 | 
			
		||||
@ -841,8 +840,7 @@ class WanVAE:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def decode(self, zs, tile_size, any_end_frame = False):
 | 
			
		||||
        original_device = zs[0].device
 | 
			
		||||
        scale = [u.to(device = original_device) for u in self.scale]  
 | 
			
		||||
        scale = [u.to(device = self.device) for u in self.scale]  
 | 
			
		||||
        if tile_size > 0:
 | 
			
		||||
            return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
 | 
			
		||||
        else:
 | 
			
		||||
 | 
			
		||||
@ -750,7 +750,7 @@ def count_conv3d(model):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WanVAE_(nn.Module):
 | 
			
		||||
 | 
			
		||||
    _offload_hooks = ['encode', 'decode']
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        dim=160,
 | 
			
		||||
@ -1173,8 +1173,7 @@ class Wan2_2_VAE:
 | 
			
		||||
        """
 | 
			
		||||
        videos: A list of videos each with shape [C, T, H, W].
 | 
			
		||||
        """
 | 
			
		||||
        original_device = videos[0].device
 | 
			
		||||
        scale = [u.to(device = original_device) for u in self.scale]  
 | 
			
		||||
        scale = [u.to(device = self.device) for u in self.scale]  
 | 
			
		||||
        
 | 
			
		||||
        if tile_size > 0 and False:
 | 
			
		||||
            return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
 | 
			
		||||
@ -1183,8 +1182,7 @@ class Wan2_2_VAE:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def decode(self, zs, tile_size, any_end_frame = False):
 | 
			
		||||
        original_device = zs[0].device
 | 
			
		||||
        scale = [u.to(device = original_device) for u in self.scale]  
 | 
			
		||||
        scale = [u.to(device = self.device) for u in self.scale]  
 | 
			
		||||
 | 
			
		||||
        if tile_size > 0 and False:
 | 
			
		||||
            return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user