Optimized Vace RAM usage

This commit is contained in:
DeepBeepMeep 2025-04-11 00:10:03 +02:00
parent f5bb9d5972
commit 8d998d645f
3 changed files with 110 additions and 78 deletions

View File

@ -447,6 +447,21 @@ class WanAttentionBlock(nn.Module):
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
hint = None
if self.block_id is not None and hints is not None:
kwargs = {
"seq_lens" : seq_lens,
"grid_sizes" : grid_sizes,
"freqs" :freqs,
"context" : context,
"context_lens" : context_lens,
"e" : e,
}
if self.block_id == 0:
hint = self.vace(hints, x, **kwargs)
else:
hint = self.vace(hints, None, **kwargs)
e = (self.modulation + e).chunk(6, dim=1)
# self-attention
@ -485,13 +500,16 @@ class WanAttentionBlock(nn.Module):
x.addcmul_(y, e[5])
if self.block_id is not None and hints != None:
if hint is not None:
if context_scale == 1:
x.add_(hints[self.block_id])
x.add_(hint)
else:
x.add_(hints[self.block_id], alpha =context_scale)
return x
x.add_(hint, alpha= context_scale)
return x
class VaceWanAttentionBlock(WanAttentionBlock):
def __init__(
@ -516,18 +534,29 @@ class VaceWanAttentionBlock(WanAttentionBlock):
nn.init.zeros_(self.after_proj.weight)
nn.init.zeros_(self.after_proj.bias)
def forward(self, c, x, **kwargs):
def forward(self, hints, x, **kwargs):
# behold dbm magic !
c = hints[0]
hints[0] = None
if self.block_id == 0:
c = self.before_proj(c) + x
all_c = []
else:
all_c = c
c = all_c.pop(-1)
c = super().forward(c, **kwargs)
c_skip = self.after_proj(c)
all_c += [c_skip, c]
return all_c
hints[0] = c
return c_skip
# def forward(self, c, x, **kwargs):
# # behold dbm magic !
# if self.block_id == 0:
# c = self.before_proj(c) + x
# all_c = []
# else:
# all_c = c
# c = all_c.pop(-1)
# c = super().forward(c, **kwargs)
# c_skip = self.after_proj(c)
# all_c += [c_skip, c]
# return all_c
class Head(nn.Module):
@ -764,35 +793,37 @@ class WanModel(ModelMixin, ConfigMixin):
print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}")
return best_threshold
def forward_vace(
self,
x,
vace_context,
seq_len,
context,
e,
kwargs
):
# embeddings
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c]
if (len(c) == 1 and seq_len == c[0].size(1)):
c = c[0]
else:
c = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in c
])
# arguments
new_kwargs = dict(x=x)
new_kwargs.update(kwargs)
for block in self.vace_blocks:
c = block(c, context= context, e= e, **new_kwargs)
hints = c[:-1]
# def forward_vace(
# self,
# x,
# vace_context,
# seq_len,
# context,
# e,
# kwargs
# ):
# # embeddings
# c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
# c = [u.flatten(2).transpose(1, 2) for u in c]
# if (len(c) == 1 and seq_len == c[0].size(1)):
# c = c[0]
# else:
# c = torch.cat([
# torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
# dim=1) for u in c
# ])
return hints
# # arguments
# new_kwargs = dict(x=x)
# new_kwargs.update(kwargs)
# for block in self.vace_blocks:
# c = block(c, context= context, e= e, **new_kwargs)
# hints = c[:-1]
# return hints
def forward(
self,
@ -904,6 +935,34 @@ class WanModel(ModelMixin, ConfigMixin):
x_list = [x]
context_list = [context]
del x
# arguments
kwargs = dict(
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=freqs,
context_lens=context_lens,
)
if vace_context == None:
hints_list = [None ] *len(x_list)
else:
# embeddings
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c]
if (len(c) == 1 and seq_len == c[0].size(1)):
c = c[0]
else:
c = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in c
])
kwargs['context_scale'] = vace_context_scale
hints_list = [ [c] if i==0 else [c.clone()] for i in range(len(x_list)) ]
del c
should_calc = True
if self.enable_teacache:
if is_uncond:
@ -935,23 +994,6 @@ class WanModel(ModelMixin, ConfigMixin):
if joint_pass or not is_uncond:
self.previous_residual_cond = None
ori_hidden_states = x_list[0].clone()
# arguments
kwargs = dict(
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=freqs,
context_lens=context_lens)
if vace_context == None:
hints_list = [None ] *len(x_list)
else:
hints_list = []
for x, context in zip(x_list, context_list) :
hints_list.append( self.forward_vace(x, vace_context, seq_len, context= context, e= e0, kwargs= kwargs))
del x, context
kwargs['context_scale'] = vace_context_scale
for block_idx, block in enumerate(self.blocks):
offload.shared_state["layer"] = block_idx

View File

@ -143,6 +143,8 @@ class WanT2V:
seq_len=32760,
keep_last=True)
self.adapt_vace_model()
def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0):
if ref_images is None:
ref_images = [None] * len(frames)
@ -505,3 +507,14 @@ class WanT2V:
dist.barrier()
return videos[0] if self.rank == 0 else None
def adapt_vace_model(self):
model = self.model
modules_dict= { k: m for k, m in model.named_modules()}
for num in range(15):
module = modules_dict[f"vace_blocks.{num}"]
target = modules_dict[f"blocks.{2*num}"]
setattr(target, "vace", module )
delattr(model, "vace_blocks")

23
wgp.py
View File

@ -910,14 +910,6 @@ def get_queue_table(queue):
if len(queue) == 1:
return data
# def td(l, content, width =None):
# if width !=None:
# l.append("<TD WIDTH="+ str(width) + "px>" + content + "</TD>")
# else:
# l.append("<TD>" + content + "</TD>")
# data.append("<STYLE> .TB, .TB th, .TB td {border: 1px solid #CCCCCC};></STYLE><TABLE CLASS=TB><TR BGCOLOR=#F2F2F2><TD Style='Bold'>Qty</TD><TD>Prompt</TD><TD>Steps</TD><TD></TD><TD><TD></TD><TD></TD><TD></TD></TR>")
for i, item in enumerate(queue):
if i==0:
continue
@ -937,22 +929,7 @@ def get_queue_table(queue):
start_img_md = f'<img src="{start_img_uri}" alt="Start" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
if end_img_uri:
end_img_md = f'<img src="{end_img_uri}" alt="End" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
# if i % 2 == 1:
# data.append("<TR>")
# else:
# data.append("<TR BGCOLOR=#F2F2F2>")
# td(data,str(item.get('repeats', "1")) )
# td(data, prompt_cell, "100%")
# td(data, num_steps, "100%")
# td(data, start_img_md)
# td(data, end_img_md)
# td(data, "↑")
# td(data, "↓")
# td(data, "✖")
# data.append("</TR>")
# data.append("</TABLE>")
# return ''.join(data)
data.append([item.get('repeats', "1"),
prompt_cell,