mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
oops
This commit is contained in:
parent
66ddadf0cc
commit
6b17c9fb6a
@ -113,6 +113,7 @@ class WanAny2V:
|
||||
self.vae = vae(
|
||||
vae_pth=os.path.join(checkpoint_dir, vae_checkpoint), dtype= VAE_dtype,
|
||||
device="cpu")
|
||||
self.vae.device = self.device
|
||||
|
||||
# config_filename= "configs/t2v_1.3B.json"
|
||||
# import json
|
||||
@ -467,7 +468,6 @@ class WanAny2V:
|
||||
color_reference_frame = None
|
||||
if self._interrupt:
|
||||
return None
|
||||
|
||||
# Text Encoder
|
||||
if n_prompt == "":
|
||||
n_prompt = self.sample_neg_prompt
|
||||
|
||||
@ -832,8 +832,7 @@ class WanVAE:
|
||||
"""
|
||||
videos: A list of videos each with shape [C, T, H, W].
|
||||
"""
|
||||
original_device = videos[0].device
|
||||
scale = [u.to(device = original_device) for u in self.scale]
|
||||
scale = [u.to(device = self.device) for u in self.scale]
|
||||
if tile_size > 0:
|
||||
return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
|
||||
else:
|
||||
@ -841,8 +840,7 @@ class WanVAE:
|
||||
|
||||
|
||||
def decode(self, zs, tile_size, any_end_frame = False):
|
||||
original_device = zs[0].device
|
||||
scale = [u.to(device = original_device) for u in self.scale]
|
||||
scale = [u.to(device = self.device) for u in self.scale]
|
||||
if tile_size > 0:
|
||||
return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
|
||||
else:
|
||||
|
||||
@ -750,7 +750,7 @@ def count_conv3d(model):
|
||||
|
||||
|
||||
class WanVAE_(nn.Module):
|
||||
|
||||
_offload_hooks = ['encode', 'decode']
|
||||
def __init__(
|
||||
self,
|
||||
dim=160,
|
||||
@ -1173,8 +1173,7 @@ class Wan2_2_VAE:
|
||||
"""
|
||||
videos: A list of videos each with shape [C, T, H, W].
|
||||
"""
|
||||
original_device = videos[0].device
|
||||
scale = [u.to(device = original_device) for u in self.scale]
|
||||
scale = [u.to(device = self.device) for u in self.scale]
|
||||
|
||||
if tile_size > 0 and False:
|
||||
return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
|
||||
@ -1183,8 +1182,7 @@ class Wan2_2_VAE:
|
||||
|
||||
|
||||
def decode(self, zs, tile_size, any_end_frame = False):
|
||||
original_device = zs[0].device
|
||||
scale = [u.to(device = original_device) for u in self.scale]
|
||||
scale = [u.to(device = self.device) for u in self.scale]
|
||||
|
||||
if tile_size > 0 and False:
|
||||
return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user