mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
Optimized Vace RAM usage
This commit is contained in:
parent
f5bb9d5972
commit
8d998d645f
@ -447,6 +447,21 @@ class WanAttentionBlock(nn.Module):
|
||||
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
||||
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
||||
"""
|
||||
hint = None
|
||||
if self.block_id is not None and hints is not None:
|
||||
kwargs = {
|
||||
"seq_lens" : seq_lens,
|
||||
"grid_sizes" : grid_sizes,
|
||||
"freqs" :freqs,
|
||||
"context" : context,
|
||||
"context_lens" : context_lens,
|
||||
"e" : e,
|
||||
}
|
||||
if self.block_id == 0:
|
||||
hint = self.vace(hints, x, **kwargs)
|
||||
else:
|
||||
hint = self.vace(hints, None, **kwargs)
|
||||
|
||||
e = (self.modulation + e).chunk(6, dim=1)
|
||||
|
||||
# self-attention
|
||||
@ -485,13 +500,16 @@ class WanAttentionBlock(nn.Module):
|
||||
|
||||
x.addcmul_(y, e[5])
|
||||
|
||||
|
||||
if self.block_id is not None and hints != None:
|
||||
|
||||
|
||||
if hint is not None:
|
||||
if context_scale == 1:
|
||||
x.add_(hints[self.block_id])
|
||||
x.add_(hint)
|
||||
else:
|
||||
x.add_(hints[self.block_id], alpha =context_scale)
|
||||
return x
|
||||
x.add_(hint, alpha= context_scale)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class VaceWanAttentionBlock(WanAttentionBlock):
|
||||
def __init__(
|
||||
@ -516,18 +534,29 @@ class VaceWanAttentionBlock(WanAttentionBlock):
|
||||
nn.init.zeros_(self.after_proj.weight)
|
||||
nn.init.zeros_(self.after_proj.bias)
|
||||
|
||||
def forward(self, c, x, **kwargs):
|
||||
def forward(self, hints, x, **kwargs):
|
||||
# behold dbm magic !
|
||||
c = hints[0]
|
||||
hints[0] = None
|
||||
if self.block_id == 0:
|
||||
c = self.before_proj(c) + x
|
||||
all_c = []
|
||||
else:
|
||||
all_c = c
|
||||
c = all_c.pop(-1)
|
||||
c = super().forward(c, **kwargs)
|
||||
c_skip = self.after_proj(c)
|
||||
all_c += [c_skip, c]
|
||||
return all_c
|
||||
hints[0] = c
|
||||
return c_skip
|
||||
|
||||
# def forward(self, c, x, **kwargs):
|
||||
# # behold dbm magic !
|
||||
# if self.block_id == 0:
|
||||
# c = self.before_proj(c) + x
|
||||
# all_c = []
|
||||
# else:
|
||||
# all_c = c
|
||||
# c = all_c.pop(-1)
|
||||
# c = super().forward(c, **kwargs)
|
||||
# c_skip = self.after_proj(c)
|
||||
# all_c += [c_skip, c]
|
||||
# return all_c
|
||||
|
||||
class Head(nn.Module):
|
||||
|
||||
@ -764,35 +793,37 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}")
|
||||
return best_threshold
|
||||
|
||||
def forward_vace(
|
||||
self,
|
||||
x,
|
||||
vace_context,
|
||||
seq_len,
|
||||
context,
|
||||
e,
|
||||
kwargs
|
||||
):
|
||||
# embeddings
|
||||
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
||||
c = [u.flatten(2).transpose(1, 2) for u in c]
|
||||
if (len(c) == 1 and seq_len == c[0].size(1)):
|
||||
c = c[0]
|
||||
else:
|
||||
c = torch.cat([
|
||||
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
||||
dim=1) for u in c
|
||||
])
|
||||
|
||||
# arguments
|
||||
new_kwargs = dict(x=x)
|
||||
new_kwargs.update(kwargs)
|
||||
|
||||
for block in self.vace_blocks:
|
||||
c = block(c, context= context, e= e, **new_kwargs)
|
||||
hints = c[:-1]
|
||||
# def forward_vace(
|
||||
# self,
|
||||
# x,
|
||||
# vace_context,
|
||||
# seq_len,
|
||||
# context,
|
||||
# e,
|
||||
# kwargs
|
||||
# ):
|
||||
# # embeddings
|
||||
# c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
||||
# c = [u.flatten(2).transpose(1, 2) for u in c]
|
||||
# if (len(c) == 1 and seq_len == c[0].size(1)):
|
||||
# c = c[0]
|
||||
# else:
|
||||
# c = torch.cat([
|
||||
# torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
||||
# dim=1) for u in c
|
||||
# ])
|
||||
|
||||
return hints
|
||||
# # arguments
|
||||
# new_kwargs = dict(x=x)
|
||||
# new_kwargs.update(kwargs)
|
||||
|
||||
# for block in self.vace_blocks:
|
||||
# c = block(c, context= context, e= e, **new_kwargs)
|
||||
# hints = c[:-1]
|
||||
|
||||
# return hints
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -904,6 +935,34 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
x_list = [x]
|
||||
context_list = [context]
|
||||
del x
|
||||
|
||||
# arguments
|
||||
|
||||
kwargs = dict(
|
||||
seq_lens=seq_lens,
|
||||
grid_sizes=grid_sizes,
|
||||
freqs=freqs,
|
||||
context_lens=context_lens,
|
||||
)
|
||||
|
||||
if vace_context == None:
|
||||
hints_list = [None ] *len(x_list)
|
||||
else:
|
||||
# embeddings
|
||||
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
||||
c = [u.flatten(2).transpose(1, 2) for u in c]
|
||||
if (len(c) == 1 and seq_len == c[0].size(1)):
|
||||
c = c[0]
|
||||
else:
|
||||
c = torch.cat([
|
||||
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
||||
dim=1) for u in c
|
||||
])
|
||||
|
||||
kwargs['context_scale'] = vace_context_scale
|
||||
hints_list = [ [c] if i==0 else [c.clone()] for i in range(len(x_list)) ]
|
||||
del c
|
||||
|
||||
should_calc = True
|
||||
if self.enable_teacache:
|
||||
if is_uncond:
|
||||
@ -935,23 +994,6 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
if joint_pass or not is_uncond:
|
||||
self.previous_residual_cond = None
|
||||
ori_hidden_states = x_list[0].clone()
|
||||
# arguments
|
||||
|
||||
kwargs = dict(
|
||||
seq_lens=seq_lens,
|
||||
grid_sizes=grid_sizes,
|
||||
freqs=freqs,
|
||||
context_lens=context_lens)
|
||||
|
||||
if vace_context == None:
|
||||
hints_list = [None ] *len(x_list)
|
||||
else:
|
||||
hints_list = []
|
||||
for x, context in zip(x_list, context_list) :
|
||||
hints_list.append( self.forward_vace(x, vace_context, seq_len, context= context, e= e0, kwargs= kwargs))
|
||||
del x, context
|
||||
kwargs['context_scale'] = vace_context_scale
|
||||
|
||||
|
||||
for block_idx, block in enumerate(self.blocks):
|
||||
offload.shared_state["layer"] = block_idx
|
||||
|
||||
@ -143,6 +143,8 @@ class WanT2V:
|
||||
seq_len=32760,
|
||||
keep_last=True)
|
||||
|
||||
self.adapt_vace_model()
|
||||
|
||||
def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0):
|
||||
if ref_images is None:
|
||||
ref_images = [None] * len(frames)
|
||||
@ -505,3 +507,14 @@ class WanT2V:
|
||||
dist.barrier()
|
||||
|
||||
return videos[0] if self.rank == 0 else None
|
||||
|
||||
def adapt_vace_model(self):
|
||||
model = self.model
|
||||
modules_dict= { k: m for k, m in model.named_modules()}
|
||||
for num in range(15):
|
||||
module = modules_dict[f"vace_blocks.{num}"]
|
||||
target = modules_dict[f"blocks.{2*num}"]
|
||||
setattr(target, "vace", module )
|
||||
delattr(model, "vace_blocks")
|
||||
|
||||
|
||||
23
wgp.py
23
wgp.py
@ -910,14 +910,6 @@ def get_queue_table(queue):
|
||||
if len(queue) == 1:
|
||||
return data
|
||||
|
||||
# def td(l, content, width =None):
|
||||
# if width !=None:
|
||||
# l.append("<TD WIDTH="+ str(width) + "px>" + content + "</TD>")
|
||||
# else:
|
||||
# l.append("<TD>" + content + "</TD>")
|
||||
|
||||
# data.append("<STYLE> .TB, .TB th, .TB td {border: 1px solid #CCCCCC};></STYLE><TABLE CLASS=TB><TR BGCOLOR=#F2F2F2><TD Style='Bold'>Qty</TD><TD>Prompt</TD><TD>Steps</TD><TD></TD><TD><TD></TD><TD></TD><TD></TD></TR>")
|
||||
|
||||
for i, item in enumerate(queue):
|
||||
if i==0:
|
||||
continue
|
||||
@ -937,22 +929,7 @@ def get_queue_table(queue):
|
||||
start_img_md = f'<img src="{start_img_uri}" alt="Start" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
|
||||
if end_img_uri:
|
||||
end_img_md = f'<img src="{end_img_uri}" alt="End" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
|
||||
# if i % 2 == 1:
|
||||
# data.append("<TR>")
|
||||
# else:
|
||||
# data.append("<TR BGCOLOR=#F2F2F2>")
|
||||
|
||||
# td(data,str(item.get('repeats', "1")) )
|
||||
# td(data, prompt_cell, "100%")
|
||||
# td(data, num_steps, "100%")
|
||||
# td(data, start_img_md)
|
||||
# td(data, end_img_md)
|
||||
# td(data, "↑")
|
||||
# td(data, "↓")
|
||||
# td(data, "✖")
|
||||
# data.append("</TR>")
|
||||
# data.append("</TABLE>")
|
||||
# return ''.join(data)
|
||||
|
||||
data.append([item.get('repeats', "1"),
|
||||
prompt_cell,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user