added qwen vae tiling

This commit is contained in:
DeepBeepMeep 2025-08-08 00:47:01 +02:00
parent 3c185ae69a
commit 66ddadf0cc
6 changed files with 59 additions and 15 deletions

View File

@ -674,6 +674,32 @@ class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
_supports_gradient_checkpointing = False _supports_gradient_checkpointing = False
@staticmethod
def get_VAE_tile_size(vae_config, device_mem_capacity, mixed_precision):
# VAE Tiling
if vae_config == 0:
if device_mem_capacity >= 24000:
use_vae_config = 1
elif device_mem_capacity >= 8000:
use_vae_config = 2
else:
use_vae_config = 3
else:
use_vae_config = vae_config
use_tiling = False
tile_sample_min_width = 256
if use_vae_config == 1:
use_tiling = False
elif use_vae_config == 2:
use_tiling = True
tile_sample_min_width = 256
return (use_tiling, tile_sample_min_width)
# fmt: off # fmt: off
@register_to_config @register_to_config
def __init__( def __init__(

View File

@ -83,6 +83,7 @@ class model_factory():
loras_slists = None, loras_slists = None,
batch_size = 1, batch_size = 1,
video_prompt_type = "", video_prompt_type = "",
VAE_tile_size = None,
**bbargs **bbargs
): ):
# Generate with different aspect ratios # Generate with different aspect ratios
@ -94,6 +95,11 @@ class model_factory():
"3:4": (1140, 1472) "3:4": (1140, 1472)
} }
if VAE_tile_size is not None:
self.vae.use_tiling = VAE_tile_size[0]
self.vae.tile_latent_min_height = VAE_tile_size[1]
self.vae.tile_latent_min_width = VAE_tile_size[1]
# width, height = aspect_ratios["16:9"] # width, height = aspect_ratios["16:9"]
image = self.pipeline( image = self.pipeline(

View File

@ -112,7 +112,7 @@ class WanAny2V:
self.vae = vae( self.vae = vae(
vae_pth=os.path.join(checkpoint_dir, vae_checkpoint), dtype= VAE_dtype, vae_pth=os.path.join(checkpoint_dir, vae_checkpoint), dtype= VAE_dtype,
device=self.device) device="cpu")
# config_filename= "configs/t2v_1.3B.json" # config_filename= "configs/t2v_1.3B.json"
# import json # import json

View File

@ -832,16 +832,18 @@ class WanVAE:
""" """
videos: A list of videos each with shape [C, T, H, W]. videos: A list of videos each with shape [C, T, H, W].
""" """
original_dtype = videos[0].dtype original_device = videos[0].device
scale = [u.to(device = original_device) for u in self.scale]
if tile_size > 0: if tile_size > 0:
return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ] return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
else: else:
return [ self.model.encode(u.to(self.dtype).unsqueeze(0), self.scale, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ] return [ self.model.encode(u.to(self.dtype).unsqueeze(0), scale, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
def decode(self, zs, tile_size, any_end_frame = False): def decode(self, zs, tile_size, any_end_frame = False):
original_device = zs[0].device
scale = [u.to(device = original_device) for u in self.scale]
if tile_size > 0: if tile_size > 0:
return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ] return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
else: else:
return [ self.model.decode(u.to(self.dtype).unsqueeze(0), self.scale, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ] return [ self.model.decode(u.to(self.dtype).unsqueeze(0), scale, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]

View File

@ -1173,19 +1173,23 @@ class Wan2_2_VAE:
""" """
videos: A list of videos each with shape [C, T, H, W]. videos: A list of videos each with shape [C, T, H, W].
""" """
original_dtype = videos[0].dtype original_device = videos[0].device
scale = [u.to(device = original_device) for u in self.scale]
if tile_size > 0 and False: if tile_size > 0 and False:
return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ] return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
else: else:
return [ self.model.encode(u.to(self.dtype).unsqueeze(0), self.scale, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ] return [ self.model.encode(u.to(self.dtype).unsqueeze(0), scale, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
def decode(self, zs, tile_size, any_end_frame = False): def decode(self, zs, tile_size, any_end_frame = False):
original_device = zs[0].device
scale = [u.to(device = original_device) for u in self.scale]
if tile_size > 0 and False: if tile_size > 0 and False:
return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ] return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
else: else:
return [ self.model.decode(u.to(self.dtype).unsqueeze(0), self.scale, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ] return [ self.model.decode(u.to(self.dtype).unsqueeze(0), scale, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
# def encode(self, videos, VAE_tile_size = 0, any_end_frame = False ): # def encode(self, videos, VAE_tile_size = 0, any_end_frame = False ):

12
wgp.py
View File

@ -1662,8 +1662,8 @@ attention_modes_installed = get_attention_modes()
attention_modes_supported = get_supported_attention_modes() attention_modes_supported = get_supported_attention_modes()
args = _parse_args() args = _parse_args()
major, minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None) gpu_major, gpu_minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None)
if major < 8: if gpu_major < 8:
print("Switching to FP16 models when possible as GPU architecture doesn't support optimed BF16 Kernels") print("Switching to FP16 models when possible as GPU architecture doesn't support optimed BF16 Kernels")
bfloat16_supported = False bfloat16_supported = False
else: else:
@ -3971,7 +3971,13 @@ def generate_video(
wan_model, offloadobj = load_models(model_type) wan_model, offloadobj = load_models(model_type)
send_cmd("status", "Model loaded") send_cmd("status", "Model loaded")
reload_needed= False reload_needed= False
override_attention = model_def.get("attention", None)
if override_attention is not None:
if isinstance(override_attention, dict):
override_attention = override_attention.get(gpu_major, None)
if override_attention is not None and override_attention not in attention_modes_supported: override_attention = None
if override_attention != attention_mode: print(f"Attention mode has been overriden to {override_attention} for model type '{model_type}'")
attn = override_attention if override_attention is not None else attention_mode
if attention_mode == "auto": if attention_mode == "auto":
attn = get_auto_attention() attn = get_auto_attention()
elif attention_mode in attention_modes_supported: elif attention_mode in attention_modes_supported: