mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-06-07 15:54:52 +00:00
vace block update
This commit is contained in:
parent
96c5ec7605
commit
5bfce89227
@ -32,15 +32,10 @@ class VaceWanAttentionBlock(WanAttentionBlock):
|
|||||||
def forward(self, c, x, **kwargs):
|
def forward(self, c, x, **kwargs):
|
||||||
if self.block_id == 0:
|
if self.block_id == 0:
|
||||||
c = self.before_proj(c) + x
|
c = self.before_proj(c) + x
|
||||||
all_c = []
|
|
||||||
else:
|
|
||||||
all_c = list(torch.unbind(c))
|
|
||||||
c = all_c.pop(-1)
|
|
||||||
c = super().forward(c, **kwargs)
|
c = super().forward(c, **kwargs)
|
||||||
c_skip = self.after_proj(c)
|
c_skip = self.after_proj(c)
|
||||||
all_c += [c_skip, c]
|
return c, c_skip
|
||||||
c = torch.stack(all_c)
|
|
||||||
return c
|
|
||||||
|
|
||||||
|
|
||||||
class BaseWanAttentionBlock(WanAttentionBlock):
|
class BaseWanAttentionBlock(WanAttentionBlock):
|
||||||
@ -134,9 +129,10 @@ class VaceWanModel(WanModel):
|
|||||||
new_kwargs = dict(x=x)
|
new_kwargs = dict(x=x)
|
||||||
new_kwargs.update(kwargs)
|
new_kwargs.update(kwargs)
|
||||||
|
|
||||||
|
hints = []
|
||||||
for block in self.vace_blocks:
|
for block in self.vace_blocks:
|
||||||
c = block(c, **new_kwargs)
|
c, c_skip = block(c, **new_kwargs)
|
||||||
hints = torch.unbind(c)[:-1]
|
hints.append(c_skip)
|
||||||
return hints
|
return hints
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
Loading…
Reference in New Issue
Block a user