mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-12-16 12:13:27 +00:00
added qwen vae tiling
This commit is contained in:
parent
3c185ae69a
commit
66ddadf0cc
@ -674,6 +674,32 @@ class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|||||||
|
|
||||||
_supports_gradient_checkpointing = False
|
_supports_gradient_checkpointing = False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_VAE_tile_size(vae_config, device_mem_capacity, mixed_precision):
|
||||||
|
|
||||||
|
# VAE Tiling
|
||||||
|
if vae_config == 0:
|
||||||
|
if device_mem_capacity >= 24000:
|
||||||
|
use_vae_config = 1
|
||||||
|
elif device_mem_capacity >= 8000:
|
||||||
|
use_vae_config = 2
|
||||||
|
else:
|
||||||
|
use_vae_config = 3
|
||||||
|
else:
|
||||||
|
use_vae_config = vae_config
|
||||||
|
|
||||||
|
use_tiling = False
|
||||||
|
tile_sample_min_width = 256
|
||||||
|
|
||||||
|
if use_vae_config == 1:
|
||||||
|
use_tiling = False
|
||||||
|
elif use_vae_config == 2:
|
||||||
|
use_tiling = True
|
||||||
|
tile_sample_min_width = 256
|
||||||
|
|
||||||
|
return (use_tiling, tile_sample_min_width)
|
||||||
|
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@register_to_config
|
@register_to_config
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@ -83,6 +83,7 @@ class model_factory():
|
|||||||
loras_slists = None,
|
loras_slists = None,
|
||||||
batch_size = 1,
|
batch_size = 1,
|
||||||
video_prompt_type = "",
|
video_prompt_type = "",
|
||||||
|
VAE_tile_size = None,
|
||||||
**bbargs
|
**bbargs
|
||||||
):
|
):
|
||||||
# Generate with different aspect ratios
|
# Generate with different aspect ratios
|
||||||
@ -94,6 +95,11 @@ class model_factory():
|
|||||||
"3:4": (1140, 1472)
|
"3:4": (1140, 1472)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if VAE_tile_size is not None:
|
||||||
|
self.vae.use_tiling = VAE_tile_size[0]
|
||||||
|
self.vae.tile_latent_min_height = VAE_tile_size[1]
|
||||||
|
self.vae.tile_latent_min_width = VAE_tile_size[1]
|
||||||
|
|
||||||
# width, height = aspect_ratios["16:9"]
|
# width, height = aspect_ratios["16:9"]
|
||||||
|
|
||||||
image = self.pipeline(
|
image = self.pipeline(
|
||||||
|
|||||||
@ -112,7 +112,7 @@ class WanAny2V:
|
|||||||
|
|
||||||
self.vae = vae(
|
self.vae = vae(
|
||||||
vae_pth=os.path.join(checkpoint_dir, vae_checkpoint), dtype= VAE_dtype,
|
vae_pth=os.path.join(checkpoint_dir, vae_checkpoint), dtype= VAE_dtype,
|
||||||
device=self.device)
|
device="cpu")
|
||||||
|
|
||||||
# config_filename= "configs/t2v_1.3B.json"
|
# config_filename= "configs/t2v_1.3B.json"
|
||||||
# import json
|
# import json
|
||||||
|
|||||||
@ -832,16 +832,18 @@ class WanVAE:
|
|||||||
"""
|
"""
|
||||||
videos: A list of videos each with shape [C, T, H, W].
|
videos: A list of videos each with shape [C, T, H, W].
|
||||||
"""
|
"""
|
||||||
original_dtype = videos[0].dtype
|
original_device = videos[0].device
|
||||||
|
scale = [u.to(device = original_device) for u in self.scale]
|
||||||
if tile_size > 0:
|
if tile_size > 0:
|
||||||
return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
|
return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
|
||||||
else:
|
else:
|
||||||
return [ self.model.encode(u.to(self.dtype).unsqueeze(0), self.scale, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
|
return [ self.model.encode(u.to(self.dtype).unsqueeze(0), scale, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
|
||||||
|
|
||||||
|
|
||||||
def decode(self, zs, tile_size, any_end_frame = False):
|
def decode(self, zs, tile_size, any_end_frame = False):
|
||||||
|
original_device = zs[0].device
|
||||||
|
scale = [u.to(device = original_device) for u in self.scale]
|
||||||
if tile_size > 0:
|
if tile_size > 0:
|
||||||
return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
|
return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
|
||||||
else:
|
else:
|
||||||
return [ self.model.decode(u.to(self.dtype).unsqueeze(0), self.scale, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
|
return [ self.model.decode(u.to(self.dtype).unsqueeze(0), scale, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
|
||||||
|
|||||||
@ -1173,19 +1173,23 @@ class Wan2_2_VAE:
|
|||||||
"""
|
"""
|
||||||
videos: A list of videos each with shape [C, T, H, W].
|
videos: A list of videos each with shape [C, T, H, W].
|
||||||
"""
|
"""
|
||||||
original_dtype = videos[0].dtype
|
original_device = videos[0].device
|
||||||
|
scale = [u.to(device = original_device) for u in self.scale]
|
||||||
|
|
||||||
if tile_size > 0 and False:
|
if tile_size > 0 and False:
|
||||||
return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
|
return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
|
||||||
else:
|
else:
|
||||||
return [ self.model.encode(u.to(self.dtype).unsqueeze(0), self.scale, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
|
return [ self.model.encode(u.to(self.dtype).unsqueeze(0), scale, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
|
||||||
|
|
||||||
|
|
||||||
def decode(self, zs, tile_size, any_end_frame = False):
|
def decode(self, zs, tile_size, any_end_frame = False):
|
||||||
|
original_device = zs[0].device
|
||||||
|
scale = [u.to(device = original_device) for u in self.scale]
|
||||||
|
|
||||||
if tile_size > 0 and False:
|
if tile_size > 0 and False:
|
||||||
return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
|
return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
|
||||||
else:
|
else:
|
||||||
return [ self.model.decode(u.to(self.dtype).unsqueeze(0), self.scale, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
|
return [ self.model.decode(u.to(self.dtype).unsqueeze(0), scale, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
|
||||||
|
|
||||||
|
|
||||||
# def encode(self, videos, VAE_tile_size = 0, any_end_frame = False ):
|
# def encode(self, videos, VAE_tile_size = 0, any_end_frame = False ):
|
||||||
|
|||||||
12
wgp.py
12
wgp.py
@ -1662,8 +1662,8 @@ attention_modes_installed = get_attention_modes()
|
|||||||
attention_modes_supported = get_supported_attention_modes()
|
attention_modes_supported = get_supported_attention_modes()
|
||||||
args = _parse_args()
|
args = _parse_args()
|
||||||
|
|
||||||
major, minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None)
|
gpu_major, gpu_minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None)
|
||||||
if major < 8:
|
if gpu_major < 8:
|
||||||
print("Switching to FP16 models when possible as GPU architecture doesn't support optimed BF16 Kernels")
|
print("Switching to FP16 models when possible as GPU architecture doesn't support optimed BF16 Kernels")
|
||||||
bfloat16_supported = False
|
bfloat16_supported = False
|
||||||
else:
|
else:
|
||||||
@ -3971,7 +3971,13 @@ def generate_video(
|
|||||||
wan_model, offloadobj = load_models(model_type)
|
wan_model, offloadobj = load_models(model_type)
|
||||||
send_cmd("status", "Model loaded")
|
send_cmd("status", "Model loaded")
|
||||||
reload_needed= False
|
reload_needed= False
|
||||||
|
override_attention = model_def.get("attention", None)
|
||||||
|
if override_attention is not None:
|
||||||
|
if isinstance(override_attention, dict):
|
||||||
|
override_attention = override_attention.get(gpu_major, None)
|
||||||
|
if override_attention is not None and override_attention not in attention_modes_supported: override_attention = None
|
||||||
|
if override_attention != attention_mode: print(f"Attention mode has been overriden to {override_attention} for model type '{model_type}'")
|
||||||
|
attn = override_attention if override_attention is not None else attention_mode
|
||||||
if attention_mode == "auto":
|
if attention_mode == "auto":
|
||||||
attn = get_auto_attention()
|
attn = get_auto_attention()
|
||||||
elif attention_mode in attention_modes_supported:
|
elif attention_mode in attention_modes_supported:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user