mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	vace block update
This commit is contained in:
		
							parent
							
								
									96c5ec7605
								
							
						
					
					
						commit
						5bfce89227
					
				@ -32,15 +32,10 @@ class VaceWanAttentionBlock(WanAttentionBlock):
 | 
				
			|||||||
    def forward(self, c, x, **kwargs):
 | 
					    def forward(self, c, x, **kwargs):
 | 
				
			||||||
        if self.block_id == 0:
 | 
					        if self.block_id == 0:
 | 
				
			||||||
            c = self.before_proj(c) + x
 | 
					            c = self.before_proj(c) + x
 | 
				
			||||||
            all_c = []
 | 
					
 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            all_c = list(torch.unbind(c))
 | 
					 | 
				
			||||||
            c = all_c.pop(-1)
 | 
					 | 
				
			||||||
        c = super().forward(c, **kwargs)
 | 
					        c = super().forward(c, **kwargs)
 | 
				
			||||||
        c_skip = self.after_proj(c)
 | 
					        c_skip = self.after_proj(c)
 | 
				
			||||||
        all_c += [c_skip, c]
 | 
					        return c, c_skip
 | 
				
			||||||
        c = torch.stack(all_c)
 | 
					 | 
				
			||||||
        return c
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class BaseWanAttentionBlock(WanAttentionBlock):
 | 
					class BaseWanAttentionBlock(WanAttentionBlock):
 | 
				
			||||||
@ -134,9 +129,10 @@ class VaceWanModel(WanModel):
 | 
				
			|||||||
        new_kwargs = dict(x=x)
 | 
					        new_kwargs = dict(x=x)
 | 
				
			||||||
        new_kwargs.update(kwargs)
 | 
					        new_kwargs.update(kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        hints = []
 | 
				
			||||||
        for block in self.vace_blocks:
 | 
					        for block in self.vace_blocks:
 | 
				
			||||||
            c = block(c, **new_kwargs)
 | 
					            c, c_skip = block(c, **new_kwargs)
 | 
				
			||||||
        hints = torch.unbind(c)[:-1]
 | 
					            hints.append(c_skip)
 | 
				
			||||||
        return hints
 | 
					        return hints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(
 | 
					    def forward(
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user