This commit is contained in:
deepbeepmeep 2025-08-08 01:40:48 +02:00
parent 66ddadf0cc
commit 6b17c9fb6a
3 changed files with 6 additions and 10 deletions

View File

@ -113,6 +113,7 @@ class WanAny2V:
self.vae = vae(
vae_pth=os.path.join(checkpoint_dir, vae_checkpoint), dtype= VAE_dtype,
device="cpu")
self.vae.device = self.device
# config_filename= "configs/t2v_1.3B.json"
# import json
@ -467,7 +468,6 @@ class WanAny2V:
color_reference_frame = None
if self._interrupt:
return None
# Text Encoder
if n_prompt == "":
n_prompt = self.sample_neg_prompt

View File

@ -832,8 +832,7 @@ class WanVAE:
"""
videos: A list of videos each with shape [C, T, H, W].
"""
original_device = videos[0].device
scale = [u.to(device = original_device) for u in self.scale]
scale = [u.to(device = self.device) for u in self.scale]
if tile_size > 0:
return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
else:
@ -841,8 +840,7 @@ class WanVAE:
def decode(self, zs, tile_size, any_end_frame = False):
original_device = zs[0].device
scale = [u.to(device = original_device) for u in self.scale]
scale = [u.to(device = self.device) for u in self.scale]
if tile_size > 0:
return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
else:

View File

@ -750,7 +750,7 @@ def count_conv3d(model):
class WanVAE_(nn.Module):
_offload_hooks = ['encode', 'decode']
def __init__(
self,
dim=160,
@ -1173,8 +1173,7 @@ class Wan2_2_VAE:
"""
videos: A list of videos each with shape [C, T, H, W].
"""
original_device = videos[0].device
scale = [u.to(device = original_device) for u in self.scale]
scale = [u.to(device = self.device) for u in self.scale]
if tile_size > 0 and False:
return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
@ -1183,8 +1182,7 @@ class Wan2_2_VAE:
def decode(self, zs, tile_size, any_end_frame = False):
original_device = zs[0].device
scale = [u.to(device = original_device) for u in self.scale]
scale = [u.to(device = self.device) for u in self.scale]
if tile_size > 0 and False:
return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]