mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-12-16 12:13:27 +00:00
added qwen vae tiling
This commit is contained in:
parent
3c185ae69a
commit
66ddadf0cc
@ -674,6 +674,32 @@ class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
|
||||
_supports_gradient_checkpointing = False
|
||||
|
||||
@staticmethod
|
||||
def get_VAE_tile_size(vae_config, device_mem_capacity, mixed_precision):
|
||||
|
||||
# VAE Tiling
|
||||
if vae_config == 0:
|
||||
if device_mem_capacity >= 24000:
|
||||
use_vae_config = 1
|
||||
elif device_mem_capacity >= 8000:
|
||||
use_vae_config = 2
|
||||
else:
|
||||
use_vae_config = 3
|
||||
else:
|
||||
use_vae_config = vae_config
|
||||
|
||||
use_tiling = False
|
||||
tile_sample_min_width = 256
|
||||
|
||||
if use_vae_config == 1:
|
||||
use_tiling = False
|
||||
elif use_vae_config == 2:
|
||||
use_tiling = True
|
||||
tile_sample_min_width = 256
|
||||
|
||||
return (use_tiling, tile_sample_min_width)
|
||||
|
||||
|
||||
# fmt: off
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@ -83,6 +83,7 @@ class model_factory():
|
||||
loras_slists = None,
|
||||
batch_size = 1,
|
||||
video_prompt_type = "",
|
||||
VAE_tile_size = None,
|
||||
**bbargs
|
||||
):
|
||||
# Generate with different aspect ratios
|
||||
@ -94,6 +95,11 @@ class model_factory():
|
||||
"3:4": (1140, 1472)
|
||||
}
|
||||
|
||||
if VAE_tile_size is not None:
|
||||
self.vae.use_tiling = VAE_tile_size[0]
|
||||
self.vae.tile_latent_min_height = VAE_tile_size[1]
|
||||
self.vae.tile_latent_min_width = VAE_tile_size[1]
|
||||
|
||||
# width, height = aspect_ratios["16:9"]
|
||||
|
||||
image = self.pipeline(
|
||||
|
||||
@ -112,7 +112,7 @@ class WanAny2V:
|
||||
|
||||
self.vae = vae(
|
||||
vae_pth=os.path.join(checkpoint_dir, vae_checkpoint), dtype= VAE_dtype,
|
||||
device=self.device)
|
||||
device="cpu")
|
||||
|
||||
# config_filename= "configs/t2v_1.3B.json"
|
||||
# import json
|
||||
|
||||
@ -832,16 +832,18 @@ class WanVAE:
|
||||
"""
|
||||
videos: A list of videos each with shape [C, T, H, W].
|
||||
"""
|
||||
original_dtype = videos[0].dtype
|
||||
|
||||
original_device = videos[0].device
|
||||
scale = [u.to(device = original_device) for u in self.scale]
|
||||
if tile_size > 0:
|
||||
return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
|
||||
return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
|
||||
else:
|
||||
return [ self.model.encode(u.to(self.dtype).unsqueeze(0), self.scale, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
|
||||
return [ self.model.encode(u.to(self.dtype).unsqueeze(0), scale, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
|
||||
|
||||
|
||||
def decode(self, zs, tile_size, any_end_frame = False):
|
||||
original_device = zs[0].device
|
||||
scale = [u.to(device = original_device) for u in self.scale]
|
||||
if tile_size > 0:
|
||||
return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
|
||||
return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
|
||||
else:
|
||||
return [ self.model.decode(u.to(self.dtype).unsqueeze(0), self.scale, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
|
||||
return [ self.model.decode(u.to(self.dtype).unsqueeze(0), scale, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
|
||||
|
||||
@ -1173,19 +1173,23 @@ class Wan2_2_VAE:
|
||||
"""
|
||||
videos: A list of videos each with shape [C, T, H, W].
|
||||
"""
|
||||
original_dtype = videos[0].dtype
|
||||
original_device = videos[0].device
|
||||
scale = [u.to(device = original_device) for u in self.scale]
|
||||
|
||||
if tile_size > 0 and False:
|
||||
return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
|
||||
return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
|
||||
else:
|
||||
return [ self.model.encode(u.to(self.dtype).unsqueeze(0), self.scale, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
|
||||
return [ self.model.encode(u.to(self.dtype).unsqueeze(0), scale, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
|
||||
|
||||
|
||||
def decode(self, zs, tile_size, any_end_frame = False):
|
||||
original_device = zs[0].device
|
||||
scale = [u.to(device = original_device) for u in self.scale]
|
||||
|
||||
if tile_size > 0 and False:
|
||||
return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
|
||||
return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
|
||||
else:
|
||||
return [ self.model.decode(u.to(self.dtype).unsqueeze(0), self.scale, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
|
||||
return [ self.model.decode(u.to(self.dtype).unsqueeze(0), scale, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
|
||||
|
||||
|
||||
# def encode(self, videos, VAE_tile_size = 0, any_end_frame = False ):
|
||||
|
||||
12
wgp.py
12
wgp.py
@ -1662,8 +1662,8 @@ attention_modes_installed = get_attention_modes()
|
||||
attention_modes_supported = get_supported_attention_modes()
|
||||
args = _parse_args()
|
||||
|
||||
major, minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None)
|
||||
if major < 8:
|
||||
gpu_major, gpu_minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None)
|
||||
if gpu_major < 8:
|
||||
print("Switching to FP16 models when possible as GPU architecture doesn't support optimed BF16 Kernels")
|
||||
bfloat16_supported = False
|
||||
else:
|
||||
@ -3971,7 +3971,13 @@ def generate_video(
|
||||
wan_model, offloadobj = load_models(model_type)
|
||||
send_cmd("status", "Model loaded")
|
||||
reload_needed= False
|
||||
|
||||
override_attention = model_def.get("attention", None)
|
||||
if override_attention is not None:
|
||||
if isinstance(override_attention, dict):
|
||||
override_attention = override_attention.get(gpu_major, None)
|
||||
if override_attention is not None and override_attention not in attention_modes_supported: override_attention = None
|
||||
if override_attention != attention_mode: print(f"Attention mode has been overriden to {override_attention} for model type '{model_type}'")
|
||||
attn = override_attention if override_attention is not None else attention_mode
|
||||
if attention_mode == "auto":
|
||||
attn = get_auto_attention()
|
||||
elif attention_mode in attention_modes_supported:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user