a few more fixes

This commit is contained in:
DeepBeepMeep 2025-06-22 02:17:05 +02:00
parent 743f6911a1
commit 95727e618f
4 changed files with 79 additions and 53 deletions

View File

@ -9,8 +9,7 @@
"URLs": [ "URLs": [
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_fp16.safetensors", "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_fp16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_bf16_int8.safetensors", "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_bf16_int8.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_fp16_int8.safetensors", "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_fp16_int8.safetensors"
"ckpts/Wan14BT2VFusioniX_quanto_fp16_int8.safetensors"
], ],
"auto_quantize": true "auto_quantize": true
}, },

View File

@ -11,8 +11,6 @@ from typing import Union,Optional
from mmgp import offload from mmgp import offload
from .attention import pay_attention from .attention import pay_attention
from torch.backends.cuda import sdp_kernel from torch.backends.cuda import sdp_kernel
# from src.chipmunk.modules import SparseDiffMlp, SparseDiffAttn
# from src.chipmunk.util import LayerCounter, GLOBAL_CONFIG
__all__ = ['WanModel'] __all__ = ['WanModel']
@ -156,7 +154,8 @@ class WanSelfAttention(nn.Module):
num_heads, num_heads,
window_size=(-1, -1), window_size=(-1, -1),
qk_norm=True, qk_norm=True,
eps=1e-6): eps=1e-6,
block_no=0):
assert dim % num_heads == 0 assert dim % num_heads == 0
super().__init__() super().__init__()
self.dim = dim self.dim = dim
@ -165,6 +164,7 @@ class WanSelfAttention(nn.Module):
self.window_size = window_size self.window_size = window_size
self.qk_norm = qk_norm self.qk_norm = qk_norm
self.eps = eps self.eps = eps
self.block_no = block_no
# layers # layers
self.q = nn.Linear(dim, dim) self.q = nn.Linear(dim, dim)
@ -174,10 +174,6 @@ class WanSelfAttention(nn.Module):
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
# Only initialize SparseDiffAttn if this is not a subclass initialization
# if self.__class__ == WanSelfAttention:
# layer_num, layer_counter = LayerCounter.build_for_layer(is_attn_sparse=True, is_mlp_sparse=False)
# self.attn = SparseDiffAttn(layer_num, layer_counter)
def forward(self, xlist, grid_sizes, freqs, block_mask = None): def forward(self, xlist, grid_sizes, freqs, block_mask = None):
r""" r"""
@ -204,9 +200,14 @@ class WanSelfAttention(nn.Module):
del q,k del q,k
q,k = apply_rotary_emb(qklist, freqs, head_first=False) q,k = apply_rotary_emb(qklist, freqs, head_first=False)
chipmunk = offload.shared_state["_chipmunk"] chipmunk = offload.shared_state.get("_chipmunk", False)
if chipmunk: if chipmunk and self.__class__ == WanSelfAttention:
x = self.attn(q, k, v) q = q.transpose(1,2)
k = k.transpose(1,2)
v = v.transpose(1,2)
attn_layers = offload.shared_state["_chipmunk_layers"]
x = attn_layers[self.block_no](q, k, v)
x = x.transpose(1,2)
elif block_mask == None: elif block_mask == None:
qkv_list = [q,k,v] qkv_list = [q,k,v]
del q,k,v del q,k,v
@ -372,7 +373,8 @@ class WanAttentionBlock(nn.Module):
qk_norm=True, qk_norm=True,
cross_attn_norm=False, cross_attn_norm=False,
eps=1e-6, eps=1e-6,
block_id=None block_id=None,
block_no = 0
): ):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
@ -382,11 +384,12 @@ class WanAttentionBlock(nn.Module):
self.qk_norm = qk_norm self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm self.cross_attn_norm = cross_attn_norm
self.eps = eps self.eps = eps
self.block_no = block_no
# layers # layers
self.norm1 = WanLayerNorm(dim, eps) self.norm1 = WanLayerNorm(dim, eps)
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
eps) eps, block_no= block_no)
self.norm3 = WanLayerNorm( self.norm3 = WanLayerNorm(
dim, eps, dim, eps,
elementwise_affine=True) if cross_attn_norm else nn.Identity() elementwise_affine=True) if cross_attn_norm else nn.Identity()
@ -394,7 +397,8 @@ class WanAttentionBlock(nn.Module):
num_heads, num_heads,
(-1, -1), (-1, -1),
qk_norm, qk_norm,
eps) eps,
block_no)
self.norm2 = WanLayerNorm(dim, eps) self.norm2 = WanLayerNorm(dim, eps)
self.ffn = nn.Sequential( self.ffn = nn.Sequential(
nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'), nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
@ -599,6 +603,27 @@ class MLPProj(torch.nn.Module):
class WanModel(ModelMixin, ConfigMixin): class WanModel(ModelMixin, ConfigMixin):
def setup_chipmunk(self):
from chipmunk.util import LayerCounter
from chipmunk.modules import SparseDiffMlp, SparseDiffAttn
seq_shape = (21, 45, 80)
chipmunk_layers =[]
for i in range(self.num_layers):
layer_num, layer_counter = LayerCounter.build_for_layer(is_attn_sparse=True, is_mlp_sparse=False)
chipmunk_layers.append( SparseDiffAttn(layer_num, layer_counter))
offload.shared_state["_chipmunk_layers"] = chipmunk_layers
chipmunk_layers[0].initialize_static_mask(
seq_shape=seq_shape,
txt_len=0,
local_heads_num=self.num_heads,
device='cuda'
)
chipmunk_layers[0].layer_counter.reset()
def release_chipmunk(self):
offload.shared_state["_chipmunk_layers"] = None
def preprocess_loras(self, model_type, sd): def preprocess_loras(self, model_type, sd):
first = next(iter(sd), None) first = next(iter(sd), None)
@ -766,8 +791,8 @@ class WanModel(ModelMixin, ConfigMixin):
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList([
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
window_size, qk_norm, cross_attn_norm, eps) window_size, qk_norm, cross_attn_norm, eps, block_no =i)
for _ in range(num_layers) for i in range(num_layers)
]) ])
# head # head
@ -791,7 +816,7 @@ class WanModel(ModelMixin, ConfigMixin):
# blocks # blocks
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList([
WanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm, WanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
self.cross_attn_norm, self.eps, self.cross_attn_norm, self.eps, block_no =i,
block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None) block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None)
for i in range(self.num_layers) for i in range(self.num_layers)
]) ])
@ -944,6 +969,10 @@ class WanModel(ModelMixin, ConfigMixin):
if torch.is_tensor(freqs) and freqs.device != device: if torch.is_tensor(freqs) and freqs.device != device:
freqs = freqs.to(device) freqs = freqs.to(device)
chipmunk = offload.shared_state.get("_chipmunk", False)
if chipmunk:
from src.chipmunk.ops.voxel import voxel_chunk_no_padding, reverse_voxel_chunk_no_padding
voxel_shape = (4, 6, 8)
x_list = x x_list = x
joint_pass = len(x_list) > 1 joint_pass = len(x_list) > 1
@ -960,20 +989,15 @@ class WanModel(ModelMixin, ConfigMixin):
# embeddings # embeddings
x = self.patch_embedding(x.unsqueeze(0)).to(modulation_dtype) x = self.patch_embedding(x.unsqueeze(0)).to(modulation_dtype)
grid_sizes = x.shape[2:] grid_sizes = x.shape[2:]
x = x.flatten(2).transpose(1, 2) if chipmunk:
x = x.unsqueeze(-1)
x_og_shape = x.shape
x = voxel_chunk_no_padding(x, voxel_shape).squeeze(-1).transpose(1, 2)
else:
x = x.flatten(2).transpose(1, 2)
x_list[i] = x x_list[i] = x
x, y = None, None x, y = None, None
offload.shared_state["_chipmunk"] = False
chipmunk = offload.shared_state["_chipmunk"]
if chipmunk:
voxel_shape = (4, 6, 8)
for x in x_list:
from src.chipmunk.ops.voxel import voxel_chunk_no_padding, reverse_voxel_chunk_no_padding
x = x.unsqueeze(-1)
x_og_shape = x.shape
x = voxel_chunk_no_padding(x, voxel_shape).squeeze(-1).transpose(1, 2)
x = None
block_mask = None block_mask = None
if causal_attention and causal_block_size > 0 and False: # NEVER WORKED if causal_attention and causal_block_size > 0 and False: # NEVER WORKED

View File

@ -108,7 +108,7 @@ class WanT2V:
self.sample_neg_prompt = config.sample_neg_prompt self.sample_neg_prompt = config.sample_neg_prompt
if "Vace" in model_filename[-1]: if base_model_type in ["vace_14B", "vace_1.3B"]:
self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]), self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]),
min_area=480*832, min_area=480*832,
max_area=480*832, max_area=480*832,
@ -492,16 +492,10 @@ class WanT2V:
if callback != None: if callback != None:
callback(-1, None, True) callback(-1, None, True)
# seq_shape = (21, 45, 80) offload.shared_state["_chipmunk"] = False
# local_heads_num = 40 #12 for 1.3B chipmunk = offload.shared_state.get("_chipmunk", False)
if chipmunk:
# self.model.blocks[0].self_attn.attn.initialize_static_mask( self.model.setup_chipmunk()
# seq_shape=seq_shape,
# txt_len=0,
# local_heads_num=local_heads_num,
# device='cuda'
# )
# self.model.blocks[0].self_attn.attn.layer_counter.reset()
for i, t in enumerate(tqdm(timesteps)): for i, t in enumerate(tqdm(timesteps)):
@ -614,6 +608,9 @@ class WanT2V:
x0 = [latents] x0 = [latents]
if chipmunk:
self.model.release_chipmunk() # need to add it at every exit when in prof
if return_latent_slice != None: if return_latent_slice != None:
if overlapped_latents != None: if overlapped_latents != None:
# latents [:, 1:] = self.toto # latents [:, 1:] = self.toto
@ -640,6 +637,4 @@ class WanT2V:
module = modules_dict[f"vace_blocks.{vace_layer}"] module = modules_dict[f"vace_blocks.{vace_layer}"]
target = modules_dict[f"blocks.{model_layer}"] target = modules_dict[f"blocks.{model_layer}"]
setattr(target, "vace", module ) setattr(target, "vace", module )
delattr(model, "vace_blocks") delattr(model, "vace_blocks")

28
wgp.py
View File

@ -1772,7 +1772,7 @@ def fix_settings(model_type, ui_defaults):
if not "I" in video_prompt_type: # workaround for settings corruption if not "I" in video_prompt_type: # workaround for settings corruption
video_prompt_type += "I" video_prompt_type += "I"
if model_type in ["hunyuan"]: if model_type in ["hunyuan"]:
del_in_sequence(video_prompt_type, "I") video_prompt_type = video_prompt_type.replace("I", "")
ui_defaults["video_prompt_type"] = video_prompt_type ui_defaults["video_prompt_type"] = video_prompt_type
def get_default_settings(model_type): def get_default_settings(model_type):
@ -1990,7 +1990,9 @@ def save_quantized_model(model, model_type, model_filename, dtype, config_file)
pos = model_filename.rfind(".") pos = model_filename.rfind(".")
model_filename = model_filename[:pos] + "_quanto_int8" + model_filename[pos+1:] model_filename = model_filename[:pos] + "_quanto_int8" + model_filename[pos+1:]
if not os.path.isfile(model_filename): if os.path.isfile(model_filename):
print(f"There isn't any model to quantize as quantized model '{model_filename}' aready exists")
else:
offload.save_model(model, model_filename, do_quantize= True, config_file_path=config_file) offload.save_model(model, model_filename, do_quantize= True, config_file_path=config_file)
print(f"New quantized file '{model_filename}' had been created for finetune Id '{model_type}'.") print(f"New quantized file '{model_filename}' had been created for finetune Id '{model_type}'.")
finetune_def = get_model_finetune_def(model_type) finetune_def = get_model_finetune_def(model_type)
@ -2360,10 +2362,18 @@ def load_models(model_type):
preload =int(args.preload) preload =int(args.preload)
save_quantized = args.save_quantized and finetune_def != None save_quantized = args.save_quantized and finetune_def != None
model_filename = get_model_filename(model_type=model_type, quantization= "" if save_quantized else transformer_quantization, dtype_policy = transformer_dtype_policy) model_filename = get_model_filename(model_type=model_type, quantization= "" if save_quantized else transformer_quantization, dtype_policy = transformer_dtype_policy)
modules = finetune_def.get("modules", []) if finetune_def != None else []
if save_quantized and "quanto" in model_filename: if save_quantized and "quanto" in model_filename:
save_quantized = False save_quantized = False
print("Need to provide a non quantized model to create a quantized model to be saved") print("Need to provide a non quantized model to create a quantized model to be saved")
quantizeTransformer = not save_quantized and finetune_def !=None and transformer_quantization in ("int8", "fp8") and finetune_def.get("auto_quantize", False) and not "quanto" in model_filename if save_quantized and len(modules) > 0:
_, model_types_no_module = dependent_models_types = get_dependent_models(base_model_type, transformer_quantization, transformer_dtype_policy)
print(f"Unable to create a finetune quantized model as some modules are declared in the finetune definition. If your finetune includes already the module weights you can remove the 'modules' entry and try again. If not you will need also to change temporarly the model 'architecture' to an architecture that wont require the modules part ('{model_types_no_module[0] if len(model_types_no_module)>0 else ''}' ?) to quantize and then add back the original 'modules' and 'architecture' entries.")
save_quantized = False
quantizeTransformer = not save_quantized and finetune_def !=None and transformer_quantization in ("int8", "fp8") and finetune_def.get("auto_quantize", False) and not "quanto" in model_filename
if quantizeTransformer and len(modules) > 0:
print(f"Autoquantize is not yet supported if some modules are declared")
quantizeTransformer = False
model_family = get_model_family(model_type) model_family = get_model_family(model_type)
transformer_dtype = get_transformer_dtype(model_family, transformer_dtype_policy) transformer_dtype = get_transformer_dtype(model_family, transformer_dtype_policy)
if quantizeTransformer or "quanto" in model_filename: if quantizeTransformer or "quanto" in model_filename:
@ -2379,16 +2389,14 @@ def load_models(model_type):
model_file_list = dependent_models + [model_filename] model_file_list = dependent_models + [model_filename]
model_type_list = dependent_models_types + [model_type] model_type_list = dependent_models_types + [model_type]
new_transformer_filename = model_file_list[-1] new_transformer_filename = model_file_list[-1]
if finetune_def != None: for module_type in modules:
for module_type in finetune_def.get("modules", []): model_file_list.append(get_model_filename(module_type, transformer_quantization, transformer_dtype))
model_file_list.append(get_model_filename(module_type, transformer_quantization, transformer_dtype)) model_type_list.append(module_type)
model_type_list.append(module_type)
for filename, file_model_type in zip(model_file_list, model_type_list): for filename, file_model_type in zip(model_file_list, model_type_list):
download_models(filename, file_model_type) download_models(filename, file_model_type)
VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float
mixed_precision_transformer = server_config.get("mixed_precision","0") == "1" mixed_precision_transformer = server_config.get("mixed_precision","0") == "1"
transformer_filename = None
transformer_loras_filenames = None transformer_loras_filenames = None
transformer_type = None transformer_type = None
for i, filename in enumerate(model_file_list): for i, filename in enumerate(model_file_list):
@ -5317,9 +5325,9 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
video_guide_outpainting_left = gr.Slider(0, 100, value= video_guide_outpainting_list[2], step=5, label="Left %", show_reset_button= False) video_guide_outpainting_left = gr.Slider(0, 100, value= video_guide_outpainting_list[2], step=5, label="Left %", show_reset_button= False)
video_guide_outpainting_right = gr.Slider(0, 100, value= video_guide_outpainting_list[3], step=5, label="Right %", show_reset_button= False) video_guide_outpainting_right = gr.Slider(0, 100, value= video_guide_outpainting_list[3], step=5, label="Right %", show_reset_button= False)
video_mask = gr.Video(label= "Video Mask Area (for Inpainting, white = Control Area, black = Unchanged)", visible= "A" in video_prompt_type_value and not "U" in video_prompt_type_value , value= ui_defaults.get("video_mask", None)) video_mask = gr.Video(label= "Video Mask Area (for Inpainting, white = Control Area, black = Unchanged)", visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value , value= ui_defaults.get("video_mask", None))
mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "A" in video_prompt_type_value and not "U" in video_prompt_type_value ) mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value )
image_refs = gr.Gallery( label ="Start Image" if hunyuan_video_avatar else "Reference Images", image_refs = gr.Gallery( label ="Start Image" if hunyuan_video_avatar else "Reference Images",
type ="pil", show_label= True, type ="pil", show_label= True,