vace block update

This commit is contained in:
hanzhn 2025-05-12 19:39:56 +08:00
parent 96c5ec7605
commit 5bfce89227

View File

@ -32,15 +32,10 @@ class VaceWanAttentionBlock(WanAttentionBlock):
def forward(self, c, x, **kwargs):
if self.block_id == 0:
c = self.before_proj(c) + x
all_c = []
else:
all_c = list(torch.unbind(c))
c = all_c.pop(-1)
c = super().forward(c, **kwargs)
c_skip = self.after_proj(c)
all_c += [c_skip, c]
c = torch.stack(all_c)
return c
return c, c_skip
class BaseWanAttentionBlock(WanAttentionBlock):
@ -134,9 +129,10 @@ class VaceWanModel(WanModel):
new_kwargs = dict(x=x)
new_kwargs.update(kwargs)
hints = []
for block in self.vace_blocks:
c = block(c, **new_kwargs)
hints = torch.unbind(c)[:-1]
c, c_skip = block(c, **new_kwargs)
hints.append(c_skip)
return hints
def forward(