diff --git a/wan/modules/vace_model.py b/wan/modules/vace_model.py index 344f547..60178a9 100644 --- a/wan/modules/vace_model.py +++ b/wan/modules/vace_model.py @@ -32,15 +32,10 @@ class VaceWanAttentionBlock(WanAttentionBlock): def forward(self, c, x, **kwargs): if self.block_id == 0: c = self.before_proj(c) + x - all_c = [] - else: - all_c = list(torch.unbind(c)) - c = all_c.pop(-1) + c = super().forward(c, **kwargs) c_skip = self.after_proj(c) - all_c += [c_skip, c] - c = torch.stack(all_c) - return c + return c, c_skip class BaseWanAttentionBlock(WanAttentionBlock): @@ -134,9 +129,10 @@ class VaceWanModel(WanModel): new_kwargs = dict(x=x) new_kwargs.update(kwargs) + hints = [] for block in self.vace_blocks: - c = block(c, **new_kwargs) - hints = torch.unbind(c)[:-1] + c, c_skip = block(c, **new_kwargs) + hints.append(c_skip) return hints def forward(