mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
oops
This commit is contained in:
parent
66ddadf0cc
commit
6b17c9fb6a
@ -113,6 +113,7 @@ class WanAny2V:
|
|||||||
self.vae = vae(
|
self.vae = vae(
|
||||||
vae_pth=os.path.join(checkpoint_dir, vae_checkpoint), dtype= VAE_dtype,
|
vae_pth=os.path.join(checkpoint_dir, vae_checkpoint), dtype= VAE_dtype,
|
||||||
device="cpu")
|
device="cpu")
|
||||||
|
self.vae.device = self.device
|
||||||
|
|
||||||
# config_filename= "configs/t2v_1.3B.json"
|
# config_filename= "configs/t2v_1.3B.json"
|
||||||
# import json
|
# import json
|
||||||
@ -467,7 +468,6 @@ class WanAny2V:
|
|||||||
color_reference_frame = None
|
color_reference_frame = None
|
||||||
if self._interrupt:
|
if self._interrupt:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Text Encoder
|
# Text Encoder
|
||||||
if n_prompt == "":
|
if n_prompt == "":
|
||||||
n_prompt = self.sample_neg_prompt
|
n_prompt = self.sample_neg_prompt
|
||||||
|
|||||||
@ -832,8 +832,7 @@ class WanVAE:
|
|||||||
"""
|
"""
|
||||||
videos: A list of videos each with shape [C, T, H, W].
|
videos: A list of videos each with shape [C, T, H, W].
|
||||||
"""
|
"""
|
||||||
original_device = videos[0].device
|
scale = [u.to(device = self.device) for u in self.scale]
|
||||||
scale = [u.to(device = original_device) for u in self.scale]
|
|
||||||
if tile_size > 0:
|
if tile_size > 0:
|
||||||
return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
|
return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
|
||||||
else:
|
else:
|
||||||
@ -841,8 +840,7 @@ class WanVAE:
|
|||||||
|
|
||||||
|
|
||||||
def decode(self, zs, tile_size, any_end_frame = False):
|
def decode(self, zs, tile_size, any_end_frame = False):
|
||||||
original_device = zs[0].device
|
scale = [u.to(device = self.device) for u in self.scale]
|
||||||
scale = [u.to(device = original_device) for u in self.scale]
|
|
||||||
if tile_size > 0:
|
if tile_size > 0:
|
||||||
return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
|
return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -750,7 +750,7 @@ def count_conv3d(model):
|
|||||||
|
|
||||||
|
|
||||||
class WanVAE_(nn.Module):
|
class WanVAE_(nn.Module):
|
||||||
|
_offload_hooks = ['encode', 'decode']
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim=160,
|
dim=160,
|
||||||
@ -1173,8 +1173,7 @@ class Wan2_2_VAE:
|
|||||||
"""
|
"""
|
||||||
videos: A list of videos each with shape [C, T, H, W].
|
videos: A list of videos each with shape [C, T, H, W].
|
||||||
"""
|
"""
|
||||||
original_device = videos[0].device
|
scale = [u.to(device = self.device) for u in self.scale]
|
||||||
scale = [u.to(device = original_device) for u in self.scale]
|
|
||||||
|
|
||||||
if tile_size > 0 and False:
|
if tile_size > 0 and False:
|
||||||
return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
|
return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
|
||||||
@ -1183,8 +1182,7 @@ class Wan2_2_VAE:
|
|||||||
|
|
||||||
|
|
||||||
def decode(self, zs, tile_size, any_end_frame = False):
|
def decode(self, zs, tile_size, any_end_frame = False):
|
||||||
original_device = zs[0].device
|
scale = [u.to(device = self.device) for u in self.scale]
|
||||||
scale = [u.to(device = original_device) for u in self.scale]
|
|
||||||
|
|
||||||
if tile_size > 0 and False:
|
if tile_size > 0 and False:
|
||||||
return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
|
return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user