mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	Support training with multiple gpus
This commit is contained in:
		
							parent
							
								
									d372c7eb6b
								
							
						
					
					
						commit
						1cfa46a3cd
					
				@ -63,12 +63,44 @@ def rope_apply(x, grid_sizes, freqs):
 | 
			
		||||
    return torch.stack(output).float()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def usp_dit_forward_vace(
 | 
			
		||||
    self,
 | 
			
		||||
    x,
 | 
			
		||||
    vace_context,
 | 
			
		||||
    seq_len,
 | 
			
		||||
    kwargs
 | 
			
		||||
):
 | 
			
		||||
    # embeddings
 | 
			
		||||
    c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
 | 
			
		||||
    c = [u.flatten(2).transpose(1, 2) for u in c]
 | 
			
		||||
    c = torch.cat([
 | 
			
		||||
        torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
 | 
			
		||||
                  dim=1) for u in c
 | 
			
		||||
    ])
 | 
			
		||||
 | 
			
		||||
    # arguments
 | 
			
		||||
    new_kwargs = dict(x=x)
 | 
			
		||||
    new_kwargs.update(kwargs)
 | 
			
		||||
 | 
			
		||||
    # Context Parallel
 | 
			
		||||
    c = torch.chunk(
 | 
			
		||||
        c, get_sequence_parallel_world_size(),
 | 
			
		||||
        dim=1)[get_sequence_parallel_rank()]
 | 
			
		||||
 | 
			
		||||
    for block in self.vace_blocks:
 | 
			
		||||
        c = block(c, **new_kwargs)
 | 
			
		||||
    hints = torch.unbind(c)[:-1]
 | 
			
		||||
    return hints
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def usp_dit_forward(
 | 
			
		||||
    self,
 | 
			
		||||
    x,
 | 
			
		||||
    t,
 | 
			
		||||
    context,
 | 
			
		||||
    seq_len,
 | 
			
		||||
    vace_context=None,
 | 
			
		||||
    vace_context_scale=1.0,
 | 
			
		||||
    clip_fea=None,
 | 
			
		||||
    y=None,
 | 
			
		||||
):
 | 
			
		||||
@ -77,14 +109,14 @@ def usp_dit_forward(
 | 
			
		||||
    t:              [B].
 | 
			
		||||
    context:        A list of text embeddings each with shape [L, C].
 | 
			
		||||
    """
 | 
			
		||||
    if self.model_type == 'i2v':
 | 
			
		||||
    if self.model_type != 'vace':
 | 
			
		||||
        assert clip_fea is not None and y is not None
 | 
			
		||||
    # params
 | 
			
		||||
    device = self.patch_embedding.weight.device
 | 
			
		||||
    if self.freqs.device != device:
 | 
			
		||||
        self.freqs = self.freqs.to(device)
 | 
			
		||||
 | 
			
		||||
    if y is not None:
 | 
			
		||||
    if self.model_type != 'vace' and y is not None:
 | 
			
		||||
        x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
 | 
			
		||||
 | 
			
		||||
    # embeddings
 | 
			
		||||
@ -114,7 +146,7 @@ def usp_dit_forward(
 | 
			
		||||
            for u in context
 | 
			
		||||
        ]))
 | 
			
		||||
 | 
			
		||||
    if clip_fea is not None:
 | 
			
		||||
    if self.model_type != 'vace' and clip_fea is not None:
 | 
			
		||||
        context_clip = self.img_emb(clip_fea)  # bs x 257 x dim
 | 
			
		||||
        context = torch.concat([context_clip, context], dim=1)
 | 
			
		||||
 | 
			
		||||
@ -132,6 +164,11 @@ def usp_dit_forward(
 | 
			
		||||
        x, get_sequence_parallel_world_size(),
 | 
			
		||||
        dim=1)[get_sequence_parallel_rank()]
 | 
			
		||||
 | 
			
		||||
    if self.model_type == 'vace':
 | 
			
		||||
        hints = self.forward_vace(x, vace_context, seq_len, kwargs)
 | 
			
		||||
        kwargs['hints'] = hints
 | 
			
		||||
        kwargs['context_scale'] = vace_context_scale
 | 
			
		||||
 | 
			
		||||
    for block in self.blocks:
 | 
			
		||||
        x = block(x, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -400,7 +400,7 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            model_type (`str`, *optional*, defaults to 't2v'):
 | 
			
		||||
                Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video)
 | 
			
		||||
                Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video) or 'vace'
 | 
			
		||||
            patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
 | 
			
		||||
                3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
 | 
			
		||||
            text_len (`int`, *optional*, defaults to 512):
 | 
			
		||||
@ -433,7 +433,7 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
			
		||||
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        assert model_type in ['t2v', 'i2v', 'flf2v']
 | 
			
		||||
        assert model_type in ['t2v', 'i2v', 'flf2v', 'vace']
 | 
			
		||||
        self.model_type = model_type
 | 
			
		||||
 | 
			
		||||
        self.patch_size = patch_size
 | 
			
		||||
 | 
			
		||||
@ -71,7 +71,7 @@ class VaceWanModel(WanModel):
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
                 vace_layers=None,
 | 
			
		||||
                 vace_in_dim=None,
 | 
			
		||||
                 model_type='t2v',
 | 
			
		||||
                 model_type='vace',
 | 
			
		||||
                 patch_size=(1, 2, 2),
 | 
			
		||||
                 text_len=512,
 | 
			
		||||
                 in_dim=16,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user