mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
Optimized Vace RAM usage
This commit is contained in:
parent
f5bb9d5972
commit
8d998d645f
@ -447,6 +447,21 @@ class WanAttentionBlock(nn.Module):
|
|||||||
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
||||||
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
||||||
"""
|
"""
|
||||||
|
hint = None
|
||||||
|
if self.block_id is not None and hints is not None:
|
||||||
|
kwargs = {
|
||||||
|
"seq_lens" : seq_lens,
|
||||||
|
"grid_sizes" : grid_sizes,
|
||||||
|
"freqs" :freqs,
|
||||||
|
"context" : context,
|
||||||
|
"context_lens" : context_lens,
|
||||||
|
"e" : e,
|
||||||
|
}
|
||||||
|
if self.block_id == 0:
|
||||||
|
hint = self.vace(hints, x, **kwargs)
|
||||||
|
else:
|
||||||
|
hint = self.vace(hints, None, **kwargs)
|
||||||
|
|
||||||
e = (self.modulation + e).chunk(6, dim=1)
|
e = (self.modulation + e).chunk(6, dim=1)
|
||||||
|
|
||||||
# self-attention
|
# self-attention
|
||||||
@ -486,13 +501,16 @@ class WanAttentionBlock(nn.Module):
|
|||||||
x.addcmul_(y, e[5])
|
x.addcmul_(y, e[5])
|
||||||
|
|
||||||
|
|
||||||
if self.block_id is not None and hints != None:
|
|
||||||
|
if hint is not None:
|
||||||
if context_scale == 1:
|
if context_scale == 1:
|
||||||
x.add_(hints[self.block_id])
|
x.add_(hint)
|
||||||
else:
|
else:
|
||||||
x.add_(hints[self.block_id], alpha =context_scale)
|
x.add_(hint, alpha= context_scale)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class VaceWanAttentionBlock(WanAttentionBlock):
|
class VaceWanAttentionBlock(WanAttentionBlock):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -516,18 +534,29 @@ class VaceWanAttentionBlock(WanAttentionBlock):
|
|||||||
nn.init.zeros_(self.after_proj.weight)
|
nn.init.zeros_(self.after_proj.weight)
|
||||||
nn.init.zeros_(self.after_proj.bias)
|
nn.init.zeros_(self.after_proj.bias)
|
||||||
|
|
||||||
def forward(self, c, x, **kwargs):
|
def forward(self, hints, x, **kwargs):
|
||||||
# behold dbm magic !
|
# behold dbm magic !
|
||||||
|
c = hints[0]
|
||||||
|
hints[0] = None
|
||||||
if self.block_id == 0:
|
if self.block_id == 0:
|
||||||
c = self.before_proj(c) + x
|
c = self.before_proj(c) + x
|
||||||
all_c = []
|
|
||||||
else:
|
|
||||||
all_c = c
|
|
||||||
c = all_c.pop(-1)
|
|
||||||
c = super().forward(c, **kwargs)
|
c = super().forward(c, **kwargs)
|
||||||
c_skip = self.after_proj(c)
|
c_skip = self.after_proj(c)
|
||||||
all_c += [c_skip, c]
|
hints[0] = c
|
||||||
return all_c
|
return c_skip
|
||||||
|
|
||||||
|
# def forward(self, c, x, **kwargs):
|
||||||
|
# # behold dbm magic !
|
||||||
|
# if self.block_id == 0:
|
||||||
|
# c = self.before_proj(c) + x
|
||||||
|
# all_c = []
|
||||||
|
# else:
|
||||||
|
# all_c = c
|
||||||
|
# c = all_c.pop(-1)
|
||||||
|
# c = super().forward(c, **kwargs)
|
||||||
|
# c_skip = self.after_proj(c)
|
||||||
|
# all_c += [c_skip, c]
|
||||||
|
# return all_c
|
||||||
|
|
||||||
class Head(nn.Module):
|
class Head(nn.Module):
|
||||||
|
|
||||||
@ -764,35 +793,37 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}")
|
print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}")
|
||||||
return best_threshold
|
return best_threshold
|
||||||
|
|
||||||
def forward_vace(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
vace_context,
|
|
||||||
seq_len,
|
|
||||||
context,
|
|
||||||
e,
|
|
||||||
kwargs
|
|
||||||
):
|
|
||||||
# embeddings
|
|
||||||
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
|
||||||
c = [u.flatten(2).transpose(1, 2) for u in c]
|
|
||||||
if (len(c) == 1 and seq_len == c[0].size(1)):
|
|
||||||
c = c[0]
|
|
||||||
else:
|
|
||||||
c = torch.cat([
|
|
||||||
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
|
||||||
dim=1) for u in c
|
|
||||||
])
|
|
||||||
|
|
||||||
# arguments
|
|
||||||
new_kwargs = dict(x=x)
|
|
||||||
new_kwargs.update(kwargs)
|
|
||||||
|
|
||||||
for block in self.vace_blocks:
|
# def forward_vace(
|
||||||
c = block(c, context= context, e= e, **new_kwargs)
|
# self,
|
||||||
hints = c[:-1]
|
# x,
|
||||||
|
# vace_context,
|
||||||
|
# seq_len,
|
||||||
|
# context,
|
||||||
|
# e,
|
||||||
|
# kwargs
|
||||||
|
# ):
|
||||||
|
# # embeddings
|
||||||
|
# c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
||||||
|
# c = [u.flatten(2).transpose(1, 2) for u in c]
|
||||||
|
# if (len(c) == 1 and seq_len == c[0].size(1)):
|
||||||
|
# c = c[0]
|
||||||
|
# else:
|
||||||
|
# c = torch.cat([
|
||||||
|
# torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
||||||
|
# dim=1) for u in c
|
||||||
|
# ])
|
||||||
|
|
||||||
return hints
|
# # arguments
|
||||||
|
# new_kwargs = dict(x=x)
|
||||||
|
# new_kwargs.update(kwargs)
|
||||||
|
|
||||||
|
# for block in self.vace_blocks:
|
||||||
|
# c = block(c, context= context, e= e, **new_kwargs)
|
||||||
|
# hints = c[:-1]
|
||||||
|
|
||||||
|
# return hints
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -904,6 +935,34 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
x_list = [x]
|
x_list = [x]
|
||||||
context_list = [context]
|
context_list = [context]
|
||||||
del x
|
del x
|
||||||
|
|
||||||
|
# arguments
|
||||||
|
|
||||||
|
kwargs = dict(
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
grid_sizes=grid_sizes,
|
||||||
|
freqs=freqs,
|
||||||
|
context_lens=context_lens,
|
||||||
|
)
|
||||||
|
|
||||||
|
if vace_context == None:
|
||||||
|
hints_list = [None ] *len(x_list)
|
||||||
|
else:
|
||||||
|
# embeddings
|
||||||
|
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
||||||
|
c = [u.flatten(2).transpose(1, 2) for u in c]
|
||||||
|
if (len(c) == 1 and seq_len == c[0].size(1)):
|
||||||
|
c = c[0]
|
||||||
|
else:
|
||||||
|
c = torch.cat([
|
||||||
|
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
||||||
|
dim=1) for u in c
|
||||||
|
])
|
||||||
|
|
||||||
|
kwargs['context_scale'] = vace_context_scale
|
||||||
|
hints_list = [ [c] if i==0 else [c.clone()] for i in range(len(x_list)) ]
|
||||||
|
del c
|
||||||
|
|
||||||
should_calc = True
|
should_calc = True
|
||||||
if self.enable_teacache:
|
if self.enable_teacache:
|
||||||
if is_uncond:
|
if is_uncond:
|
||||||
@ -935,23 +994,6 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
if joint_pass or not is_uncond:
|
if joint_pass or not is_uncond:
|
||||||
self.previous_residual_cond = None
|
self.previous_residual_cond = None
|
||||||
ori_hidden_states = x_list[0].clone()
|
ori_hidden_states = x_list[0].clone()
|
||||||
# arguments
|
|
||||||
|
|
||||||
kwargs = dict(
|
|
||||||
seq_lens=seq_lens,
|
|
||||||
grid_sizes=grid_sizes,
|
|
||||||
freqs=freqs,
|
|
||||||
context_lens=context_lens)
|
|
||||||
|
|
||||||
if vace_context == None:
|
|
||||||
hints_list = [None ] *len(x_list)
|
|
||||||
else:
|
|
||||||
hints_list = []
|
|
||||||
for x, context in zip(x_list, context_list) :
|
|
||||||
hints_list.append( self.forward_vace(x, vace_context, seq_len, context= context, e= e0, kwargs= kwargs))
|
|
||||||
del x, context
|
|
||||||
kwargs['context_scale'] = vace_context_scale
|
|
||||||
|
|
||||||
|
|
||||||
for block_idx, block in enumerate(self.blocks):
|
for block_idx, block in enumerate(self.blocks):
|
||||||
offload.shared_state["layer"] = block_idx
|
offload.shared_state["layer"] = block_idx
|
||||||
|
|||||||
@ -143,6 +143,8 @@ class WanT2V:
|
|||||||
seq_len=32760,
|
seq_len=32760,
|
||||||
keep_last=True)
|
keep_last=True)
|
||||||
|
|
||||||
|
self.adapt_vace_model()
|
||||||
|
|
||||||
def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0):
|
def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0):
|
||||||
if ref_images is None:
|
if ref_images is None:
|
||||||
ref_images = [None] * len(frames)
|
ref_images = [None] * len(frames)
|
||||||
@ -505,3 +507,14 @@ class WanT2V:
|
|||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
return videos[0] if self.rank == 0 else None
|
return videos[0] if self.rank == 0 else None
|
||||||
|
|
||||||
|
def adapt_vace_model(self):
|
||||||
|
model = self.model
|
||||||
|
modules_dict= { k: m for k, m in model.named_modules()}
|
||||||
|
for num in range(15):
|
||||||
|
module = modules_dict[f"vace_blocks.{num}"]
|
||||||
|
target = modules_dict[f"blocks.{2*num}"]
|
||||||
|
setattr(target, "vace", module )
|
||||||
|
delattr(model, "vace_blocks")
|
||||||
|
|
||||||
|
|
||||||
23
wgp.py
23
wgp.py
@ -910,14 +910,6 @@ def get_queue_table(queue):
|
|||||||
if len(queue) == 1:
|
if len(queue) == 1:
|
||||||
return data
|
return data
|
||||||
|
|
||||||
# def td(l, content, width =None):
|
|
||||||
# if width !=None:
|
|
||||||
# l.append("<TD WIDTH="+ str(width) + "px>" + content + "</TD>")
|
|
||||||
# else:
|
|
||||||
# l.append("<TD>" + content + "</TD>")
|
|
||||||
|
|
||||||
# data.append("<STYLE> .TB, .TB th, .TB td {border: 1px solid #CCCCCC};></STYLE><TABLE CLASS=TB><TR BGCOLOR=#F2F2F2><TD Style='Bold'>Qty</TD><TD>Prompt</TD><TD>Steps</TD><TD></TD><TD><TD></TD><TD></TD><TD></TD></TR>")
|
|
||||||
|
|
||||||
for i, item in enumerate(queue):
|
for i, item in enumerate(queue):
|
||||||
if i==0:
|
if i==0:
|
||||||
continue
|
continue
|
||||||
@ -937,22 +929,7 @@ def get_queue_table(queue):
|
|||||||
start_img_md = f'<img src="{start_img_uri}" alt="Start" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
|
start_img_md = f'<img src="{start_img_uri}" alt="Start" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
|
||||||
if end_img_uri:
|
if end_img_uri:
|
||||||
end_img_md = f'<img src="{end_img_uri}" alt="End" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
|
end_img_md = f'<img src="{end_img_uri}" alt="End" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
|
||||||
# if i % 2 == 1:
|
|
||||||
# data.append("<TR>")
|
|
||||||
# else:
|
|
||||||
# data.append("<TR BGCOLOR=#F2F2F2>")
|
|
||||||
|
|
||||||
# td(data,str(item.get('repeats', "1")) )
|
|
||||||
# td(data, prompt_cell, "100%")
|
|
||||||
# td(data, num_steps, "100%")
|
|
||||||
# td(data, start_img_md)
|
|
||||||
# td(data, end_img_md)
|
|
||||||
# td(data, "↑")
|
|
||||||
# td(data, "↓")
|
|
||||||
# td(data, "✖")
|
|
||||||
# data.append("</TR>")
|
|
||||||
# data.append("</TABLE>")
|
|
||||||
# return ''.join(data)
|
|
||||||
|
|
||||||
data.append([item.get('repeats', "1"),
|
data.append([item.get('repeats', "1"),
|
||||||
prompt_cell,
|
prompt_cell,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user