diff --git a/models/wan/any2video.py b/models/wan/any2video.py index a91cb73..b409279 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -113,6 +113,7 @@ class WanAny2V: self.vae = vae( vae_pth=os.path.join(checkpoint_dir, vae_checkpoint), dtype= VAE_dtype, device="cpu") + self.vae.device = self.device # config_filename= "configs/t2v_1.3B.json" # import json @@ -467,7 +468,6 @@ class WanAny2V: color_reference_frame = None if self._interrupt: return None - # Text Encoder if n_prompt == "": n_prompt = self.sample_neg_prompt diff --git a/models/wan/modules/vae.py b/models/wan/modules/vae.py index 0d1ab7a..feb0d2e 100644 --- a/models/wan/modules/vae.py +++ b/models/wan/modules/vae.py @@ -832,8 +832,7 @@ class WanVAE: """ videos: A list of videos each with shape [C, T, H, W]. """ - original_device = videos[0].device - scale = [u.to(device = original_device) for u in self.scale] + scale = [u.to(device = self.device) for u in self.scale] if tile_size > 0: return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ] else: @@ -841,8 +840,7 @@ class WanVAE: def decode(self, zs, tile_size, any_end_frame = False): - original_device = zs[0].device - scale = [u.to(device = original_device) for u in self.scale] + scale = [u.to(device = self.device) for u in self.scale] if tile_size > 0: return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ] else: diff --git a/models/wan/modules/vae2_2.py b/models/wan/modules/vae2_2.py index c4d1e46..220979b 100644 --- a/models/wan/modules/vae2_2.py +++ b/models/wan/modules/vae2_2.py @@ -750,7 +750,7 @@ def count_conv3d(model): class WanVAE_(nn.Module): - + _offload_hooks = ['encode', 'decode'] def __init__( self, dim=160, @@ -1173,8 +1173,7 @@ class Wan2_2_VAE: """ videos: A list of videos each with shape [C, T, H, W]. """ - original_device = videos[0].device - scale = [u.to(device = original_device) for u in self.scale] + scale = [u.to(device = self.device) for u in self.scale] if tile_size > 0 and False: return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ] @@ -1183,8 +1182,7 @@ class Wan2_2_VAE: def decode(self, zs, tile_size, any_end_frame = False): - original_device = zs[0].device - scale = [u.to(device = original_device) for u in self.scale] + scale = [u.to(device = self.device) for u in self.scale] if tile_size > 0 and False: return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]