diff --git a/models/qwen/autoencoder_kl_qwenimage.py b/models/qwen/autoencoder_kl_qwenimage.py index ce5b1df..d144284 100644 --- a/models/qwen/autoencoder_kl_qwenimage.py +++ b/models/qwen/autoencoder_kl_qwenimage.py @@ -674,6 +674,32 @@ class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin): _supports_gradient_checkpointing = False + @staticmethod + def get_VAE_tile_size(vae_config, device_mem_capacity, mixed_precision): + + # VAE Tiling + if vae_config == 0: + if device_mem_capacity >= 24000: + use_vae_config = 1 + elif device_mem_capacity >= 8000: + use_vae_config = 2 + else: + use_vae_config = 3 + else: + use_vae_config = vae_config + + use_tiling = False + tile_sample_min_width = 256 + + if use_vae_config == 1: + use_tiling = False + elif use_vae_config == 2: + use_tiling = True + tile_sample_min_width = 256 + + return (use_tiling, tile_sample_min_width) + + # fmt: off @register_to_config def __init__( diff --git a/models/qwen/qwen_main.py b/models/qwen/qwen_main.py index 246777e..fd41e94 100644 --- a/models/qwen/qwen_main.py +++ b/models/qwen/qwen_main.py @@ -83,6 +83,7 @@ class model_factory(): loras_slists = None, batch_size = 1, video_prompt_type = "", + VAE_tile_size = None, **bbargs ): # Generate with different aspect ratios @@ -94,6 +95,11 @@ class model_factory(): "3:4": (1140, 1472) } + if VAE_tile_size is not None: + self.vae.use_tiling = VAE_tile_size[0] + self.vae.tile_latent_min_height = VAE_tile_size[1] + self.vae.tile_latent_min_width = VAE_tile_size[1] + # width, height = aspect_ratios["16:9"] image = self.pipeline( diff --git a/models/wan/any2video.py b/models/wan/any2video.py index 83038bf..a91cb73 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -112,7 +112,7 @@ class WanAny2V: self.vae = vae( vae_pth=os.path.join(checkpoint_dir, vae_checkpoint), dtype= VAE_dtype, - device=self.device) + device="cpu") # config_filename= "configs/t2v_1.3B.json" # import json diff --git a/models/wan/modules/vae.py b/models/wan/modules/vae.py index ed6a4ab..0d1ab7a 100644 --- a/models/wan/modules/vae.py +++ b/models/wan/modules/vae.py @@ -832,16 +832,18 @@ class WanVAE: """ videos: A list of videos each with shape [C, T, H, W]. """ - original_dtype = videos[0].dtype - + original_device = videos[0].device + scale = [u.to(device = original_device) for u in self.scale] if tile_size > 0: - return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ] + return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ] else: - return [ self.model.encode(u.to(self.dtype).unsqueeze(0), self.scale, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ] + return [ self.model.encode(u.to(self.dtype).unsqueeze(0), scale, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ] def decode(self, zs, tile_size, any_end_frame = False): + original_device = zs[0].device + scale = [u.to(device = original_device) for u in self.scale] if tile_size > 0: - return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ] + return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ] else: - return [ self.model.decode(u.to(self.dtype).unsqueeze(0), self.scale, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ] + return [ self.model.decode(u.to(self.dtype).unsqueeze(0), scale, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ] diff --git a/models/wan/modules/vae2_2.py b/models/wan/modules/vae2_2.py index 3e67f54..c4d1e46 100644 --- a/models/wan/modules/vae2_2.py +++ b/models/wan/modules/vae2_2.py @@ -1173,19 +1173,23 @@ class Wan2_2_VAE: """ videos: A list of videos each with shape [C, T, H, W]. """ - original_dtype = videos[0].dtype + original_device = videos[0].device + scale = [u.to(device = original_device) for u in self.scale] if tile_size > 0 and False: - return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ] + return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ] else: - return [ self.model.encode(u.to(self.dtype).unsqueeze(0), self.scale, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ] + return [ self.model.encode(u.to(self.dtype).unsqueeze(0), scale, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ] def decode(self, zs, tile_size, any_end_frame = False): + original_device = zs[0].device + scale = [u.to(device = original_device) for u in self.scale] + if tile_size > 0 and False: - return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ] + return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ] else: - return [ self.model.decode(u.to(self.dtype).unsqueeze(0), self.scale, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ] + return [ self.model.decode(u.to(self.dtype).unsqueeze(0), scale, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ] # def encode(self, videos, VAE_tile_size = 0, any_end_frame = False ): diff --git a/wgp.py b/wgp.py index 14f4c07..c932141 100644 --- a/wgp.py +++ b/wgp.py @@ -1662,8 +1662,8 @@ attention_modes_installed = get_attention_modes() attention_modes_supported = get_supported_attention_modes() args = _parse_args() -major, minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None) -if major < 8: +gpu_major, gpu_minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None) +if gpu_major < 8: print("Switching to FP16 models when possible as GPU architecture doesn't support optimed BF16 Kernels") bfloat16_supported = False else: @@ -3971,7 +3971,13 @@ def generate_video( wan_model, offloadobj = load_models(model_type) send_cmd("status", "Model loaded") reload_needed= False - + override_attention = model_def.get("attention", None) + if override_attention is not None: + if isinstance(override_attention, dict): + override_attention = override_attention.get(gpu_major, None) + if override_attention is not None and override_attention not in attention_modes_supported: override_attention = None + if override_attention != attention_mode: print(f"Attention mode has been overriden to {override_attention} for model type '{model_type}'") + attn = override_attention if override_attention is not None else attention_mode if attention_mode == "auto": attn = get_auto_attention() elif attention_mode in attention_modes_supported: