mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	Optimized Vace RAM usage
This commit is contained in:
		
							parent
							
								
									f5bb9d5972
								
							
						
					
					
						commit
						8d998d645f
					
				@ -447,6 +447,21 @@ class WanAttentionBlock(nn.Module):
 | 
			
		||||
            grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
 | 
			
		||||
            freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
 | 
			
		||||
        """
 | 
			
		||||
        hint = None
 | 
			
		||||
        if self.block_id is not None and hints is not None:
 | 
			
		||||
            kwargs = { 
 | 
			
		||||
                "seq_lens" : seq_lens,
 | 
			
		||||
                "grid_sizes" : grid_sizes,
 | 
			
		||||
                "freqs" :freqs, 
 | 
			
		||||
                "context" : context,
 | 
			
		||||
                "context_lens" : context_lens,
 | 
			
		||||
                "e" : e,
 | 
			
		||||
            }
 | 
			
		||||
            if self.block_id == 0:
 | 
			
		||||
                hint = self.vace(hints, x, **kwargs)
 | 
			
		||||
            else:
 | 
			
		||||
                hint = self.vace(hints, None, **kwargs)
 | 
			
		||||
 | 
			
		||||
        e = (self.modulation + e).chunk(6, dim=1)
 | 
			
		||||
 
 | 
			
		||||
        # self-attention
 | 
			
		||||
@ -485,13 +500,16 @@ class WanAttentionBlock(nn.Module):
 | 
			
		||||
 | 
			
		||||
        x.addcmul_(y, e[5])
 | 
			
		||||
 | 
			
		||||
       
 | 
			
		||||
        if self.block_id is not None and hints != None:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        if hint is not None:
 | 
			
		||||
            if context_scale == 1:
 | 
			
		||||
                x.add_(hints[self.block_id])
 | 
			
		||||
                x.add_(hint)
 | 
			
		||||
            else:
 | 
			
		||||
                x.add_(hints[self.block_id], alpha =context_scale)
 | 
			
		||||
        return x    
 | 
			
		||||
                x.add_(hint, alpha= context_scale)
 | 
			
		||||
        return x 
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class VaceWanAttentionBlock(WanAttentionBlock):
 | 
			
		||||
    def __init__(
 | 
			
		||||
@ -516,18 +534,29 @@ class VaceWanAttentionBlock(WanAttentionBlock):
 | 
			
		||||
        nn.init.zeros_(self.after_proj.weight)
 | 
			
		||||
        nn.init.zeros_(self.after_proj.bias)
 | 
			
		||||
 | 
			
		||||
    def forward(self, c, x, **kwargs):
 | 
			
		||||
    def forward(self, hints, x, **kwargs):
 | 
			
		||||
        # behold dbm magic !
 | 
			
		||||
        c = hints[0]
 | 
			
		||||
        hints[0] = None
 | 
			
		||||
        if self.block_id == 0:
 | 
			
		||||
            c = self.before_proj(c) + x
 | 
			
		||||
            all_c = []
 | 
			
		||||
        else:
 | 
			
		||||
            all_c = c
 | 
			
		||||
            c = all_c.pop(-1)
 | 
			
		||||
        c = super().forward(c, **kwargs)
 | 
			
		||||
        c_skip = self.after_proj(c)
 | 
			
		||||
        all_c += [c_skip, c]
 | 
			
		||||
        return all_c
 | 
			
		||||
        hints[0] = c
 | 
			
		||||
        return c_skip
 | 
			
		||||
 | 
			
		||||
    # def forward(self, c, x, **kwargs):
 | 
			
		||||
    #     # behold dbm magic !
 | 
			
		||||
    #     if self.block_id == 0:
 | 
			
		||||
    #         c = self.before_proj(c) + x
 | 
			
		||||
    #         all_c = []
 | 
			
		||||
    #     else:
 | 
			
		||||
    #         all_c = c
 | 
			
		||||
    #         c = all_c.pop(-1)
 | 
			
		||||
    #     c = super().forward(c, **kwargs)
 | 
			
		||||
    #     c_skip = self.after_proj(c)
 | 
			
		||||
    #     all_c += [c_skip, c]
 | 
			
		||||
    #     return all_c
 | 
			
		||||
    
 | 
			
		||||
class Head(nn.Module):
 | 
			
		||||
 | 
			
		||||
@ -764,35 +793,37 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
			
		||||
        print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}")
 | 
			
		||||
        return best_threshold
 | 
			
		||||
 | 
			
		||||
    def forward_vace(
 | 
			
		||||
        self,
 | 
			
		||||
        x,
 | 
			
		||||
        vace_context,
 | 
			
		||||
        seq_len,
 | 
			
		||||
        context,
 | 
			
		||||
        e,
 | 
			
		||||
        kwargs
 | 
			
		||||
    ):
 | 
			
		||||
        # embeddings
 | 
			
		||||
        c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
 | 
			
		||||
        c = [u.flatten(2).transpose(1, 2) for u in c]
 | 
			
		||||
        if (len(c) == 1 and seq_len == c[0].size(1)):
 | 
			
		||||
            c = c[0]
 | 
			
		||||
        else:
 | 
			
		||||
            c = torch.cat([
 | 
			
		||||
                torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
 | 
			
		||||
                        dim=1) for u in c
 | 
			
		||||
            ])
 | 
			
		||||
 | 
			
		||||
        # arguments
 | 
			
		||||
        new_kwargs = dict(x=x)
 | 
			
		||||
        new_kwargs.update(kwargs)
 | 
			
		||||
 | 
			
		||||
        for block in self.vace_blocks:
 | 
			
		||||
            c = block(c, context= context, e= e, **new_kwargs)
 | 
			
		||||
        hints = c[:-1]
 | 
			
		||||
    # def forward_vace(
 | 
			
		||||
    #     self,
 | 
			
		||||
    #     x,
 | 
			
		||||
    #     vace_context,
 | 
			
		||||
    #     seq_len,
 | 
			
		||||
    #     context,
 | 
			
		||||
    #     e,
 | 
			
		||||
    #     kwargs
 | 
			
		||||
    # ):
 | 
			
		||||
    #     # embeddings
 | 
			
		||||
    #     c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
 | 
			
		||||
    #     c = [u.flatten(2).transpose(1, 2) for u in c]
 | 
			
		||||
    #     if (len(c) == 1 and seq_len == c[0].size(1)):
 | 
			
		||||
    #         c = c[0]
 | 
			
		||||
    #     else:
 | 
			
		||||
    #         c = torch.cat([
 | 
			
		||||
    #             torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
 | 
			
		||||
    #                     dim=1) for u in c
 | 
			
		||||
    #         ])
 | 
			
		||||
 | 
			
		||||
        return hints
 | 
			
		||||
    #     # arguments
 | 
			
		||||
    #     new_kwargs = dict(x=x)
 | 
			
		||||
    #     new_kwargs.update(kwargs)
 | 
			
		||||
 | 
			
		||||
    #     for block in self.vace_blocks:
 | 
			
		||||
    #         c = block(c, context= context, e= e, **new_kwargs)
 | 
			
		||||
    #     hints = c[:-1]
 | 
			
		||||
 | 
			
		||||
    #     return hints
 | 
			
		||||
    
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
@ -904,6 +935,34 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
			
		||||
            x_list = [x]
 | 
			
		||||
            context_list = [context]
 | 
			
		||||
        del x
 | 
			
		||||
 | 
			
		||||
            # arguments
 | 
			
		||||
 | 
			
		||||
        kwargs = dict(
 | 
			
		||||
            seq_lens=seq_lens,
 | 
			
		||||
            grid_sizes=grid_sizes,
 | 
			
		||||
            freqs=freqs,
 | 
			
		||||
            context_lens=context_lens,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if vace_context == None:
 | 
			
		||||
            hints_list = [None ] *len(x_list)
 | 
			
		||||
        else:
 | 
			
		||||
            # embeddings
 | 
			
		||||
            c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
 | 
			
		||||
            c = [u.flatten(2).transpose(1, 2) for u in c]
 | 
			
		||||
            if (len(c) == 1 and seq_len == c[0].size(1)):
 | 
			
		||||
                c = c[0]
 | 
			
		||||
            else:
 | 
			
		||||
                c = torch.cat([
 | 
			
		||||
                    torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
 | 
			
		||||
                            dim=1) for u in c
 | 
			
		||||
                ])
 | 
			
		||||
 
 | 
			
		||||
            kwargs['context_scale'] = vace_context_scale
 | 
			
		||||
            hints_list = [ [c] if i==0 else [c.clone()] for i in range(len(x_list)) ] 
 | 
			
		||||
            del c
 | 
			
		||||
 | 
			
		||||
        should_calc = True
 | 
			
		||||
        if self.enable_teacache: 
 | 
			
		||||
            if is_uncond:
 | 
			
		||||
@ -935,23 +994,6 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
			
		||||
                if joint_pass or not is_uncond:
 | 
			
		||||
                    self.previous_residual_cond = None
 | 
			
		||||
                ori_hidden_states = x_list[0].clone()
 | 
			
		||||
            # arguments
 | 
			
		||||
 | 
			
		||||
            kwargs = dict(
 | 
			
		||||
                seq_lens=seq_lens,
 | 
			
		||||
                grid_sizes=grid_sizes,
 | 
			
		||||
                freqs=freqs,
 | 
			
		||||
                context_lens=context_lens)
 | 
			
		||||
 | 
			
		||||
            if vace_context == None:
 | 
			
		||||
                hints_list = [None ] *len(x_list)
 | 
			
		||||
            else:
 | 
			
		||||
                hints_list = []
 | 
			
		||||
                for x, context in  zip(x_list,  context_list) :
 | 
			
		||||
                    hints_list.append( self.forward_vace(x, vace_context, seq_len, context= context, e= e0,  kwargs= kwargs))
 | 
			
		||||
                del x, context
 | 
			
		||||
                kwargs['context_scale'] = vace_context_scale                
 | 
			
		||||
 | 
			
		||||
            
 | 
			
		||||
            for block_idx, block in enumerate(self.blocks):
 | 
			
		||||
                offload.shared_state["layer"] = block_idx
 | 
			
		||||
 | 
			
		||||
@ -143,6 +143,8 @@ class WanT2V:
 | 
			
		||||
                                            seq_len=32760,
 | 
			
		||||
                                            keep_last=True)
 | 
			
		||||
 | 
			
		||||
            self.adapt_vace_model()
 | 
			
		||||
 | 
			
		||||
    def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0):
 | 
			
		||||
        if ref_images is None:
 | 
			
		||||
            ref_images = [None] * len(frames)
 | 
			
		||||
@ -505,3 +507,14 @@ class WanT2V:
 | 
			
		||||
            dist.barrier()
 | 
			
		||||
 | 
			
		||||
        return videos[0] if self.rank == 0 else None
 | 
			
		||||
 | 
			
		||||
    def adapt_vace_model(self):
 | 
			
		||||
        model = self.model
 | 
			
		||||
        modules_dict= { k: m for k, m in model.named_modules()}
 | 
			
		||||
        for num in range(15):
 | 
			
		||||
            module = modules_dict[f"vace_blocks.{num}"]
 | 
			
		||||
            target = modules_dict[f"blocks.{2*num}"]
 | 
			
		||||
            setattr(target, "vace", module )
 | 
			
		||||
        delattr(model, "vace_blocks")
 | 
			
		||||
                    
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										23
									
								
								wgp.py
									
									
									
									
									
								
							
							
						
						
									
										23
									
								
								wgp.py
									
									
									
									
									
								
							@ -910,14 +910,6 @@ def get_queue_table(queue):
 | 
			
		||||
    if len(queue) == 1:
 | 
			
		||||
        return data 
 | 
			
		||||
 | 
			
		||||
    # def td(l, content, width =None):
 | 
			
		||||
    #     if width !=None:
 | 
			
		||||
    #         l.append("<TD WIDTH="+ str(width) + "px>" + content + "</TD>")
 | 
			
		||||
    #     else:
 | 
			
		||||
    #         l.append("<TD>" + content + "</TD>")
 | 
			
		||||
 | 
			
		||||
    # data.append("<STYLE> .TB, .TB  th, .TB td {border: 1px solid #CCCCCC};></STYLE><TABLE CLASS=TB><TR BGCOLOR=#F2F2F2><TD Style='Bold'>Qty</TD><TD>Prompt</TD><TD>Steps</TD><TD></TD><TD><TD></TD><TD></TD><TD></TD></TR>")
 | 
			
		||||
 | 
			
		||||
    for i, item in enumerate(queue):
 | 
			
		||||
        if i==0:
 | 
			
		||||
            continue
 | 
			
		||||
@ -937,22 +929,7 @@ def get_queue_table(queue):
 | 
			
		||||
            start_img_md = f'<img src="{start_img_uri}" alt="Start" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
 | 
			
		||||
        if end_img_uri:
 | 
			
		||||
            end_img_md = f'<img src="{end_img_uri}" alt="End" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
 | 
			
		||||
    #     if i % 2 == 1:
 | 
			
		||||
    #         data.append("<TR>")
 | 
			
		||||
    #     else:
 | 
			
		||||
    #         data.append("<TR BGCOLOR=#F2F2F2>")
 | 
			
		||||
 | 
			
		||||
    #     td(data,str(item.get('repeats', "1")) )
 | 
			
		||||
    #     td(data, prompt_cell, "100%")
 | 
			
		||||
    #     td(data, num_steps, "100%")
 | 
			
		||||
    #     td(data, start_img_md)
 | 
			
		||||
    #     td(data, end_img_md)
 | 
			
		||||
    #     td(data, "↑")
 | 
			
		||||
    #     td(data, "↓")
 | 
			
		||||
    #     td(data, "✖")
 | 
			
		||||
    #     data.append("</TR>")
 | 
			
		||||
    # data.append("</TABLE>")
 | 
			
		||||
    # return ''.join(data)
 | 
			
		||||
 | 
			
		||||
        data.append([item.get('repeats', "1"),
 | 
			
		||||
                    prompt_cell,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user