Compare commits

..

1 Commits

Author SHA1 Message Date
Ang Wang
d68ac2e68d
Merge 96c5ec7605 into 204f899b64 2025-05-08 20:00:14 +08:00

View File

@ -32,10 +32,15 @@ class VaceWanAttentionBlock(WanAttentionBlock):
def forward(self, c, x, **kwargs): def forward(self, c, x, **kwargs):
if self.block_id == 0: if self.block_id == 0:
c = self.before_proj(c) + x c = self.before_proj(c) + x
all_c = []
else:
all_c = list(torch.unbind(c))
c = all_c.pop(-1)
c = super().forward(c, **kwargs) c = super().forward(c, **kwargs)
c_skip = self.after_proj(c) c_skip = self.after_proj(c)
return c, c_skip all_c += [c_skip, c]
c = torch.stack(all_c)
return c
class BaseWanAttentionBlock(WanAttentionBlock): class BaseWanAttentionBlock(WanAttentionBlock):
@ -129,10 +134,9 @@ class VaceWanModel(WanModel):
new_kwargs = dict(x=x) new_kwargs = dict(x=x)
new_kwargs.update(kwargs) new_kwargs.update(kwargs)
hints = []
for block in self.vace_blocks: for block in self.vace_blocks:
c, c_skip = block(c, **new_kwargs) c = block(c, **new_kwargs)
hints.append(c_skip) hints = torch.unbind(c)[:-1]
return hints return hints
def forward( def forward(