From 1cfa46a3cdd368038e56fb2fe36221eaa7f2f9d5 Mon Sep 17 00:00:00 2001 From: "wangang.wa" Date: Thu, 8 May 2025 18:34:26 +0800 Subject: [PATCH] Support training with multiple gpus --- wan/distributed/xdit_context_parallel.py | 43 ++++++++++++++++++++++-- wan/modules/model.py | 4 +-- wan/modules/vace_model.py | 2 +- 3 files changed, 43 insertions(+), 6 deletions(-) diff --git a/wan/distributed/xdit_context_parallel.py b/wan/distributed/xdit_context_parallel.py index 01936ce..e4be2e0 100644 --- a/wan/distributed/xdit_context_parallel.py +++ b/wan/distributed/xdit_context_parallel.py @@ -63,12 +63,44 @@ def rope_apply(x, grid_sizes, freqs): return torch.stack(output).float() +def usp_dit_forward_vace( + self, + x, + vace_context, + seq_len, + kwargs +): + # embeddings + c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] + c = [u.flatten(2).transpose(1, 2) for u in c] + c = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], + dim=1) for u in c + ]) + + # arguments + new_kwargs = dict(x=x) + new_kwargs.update(kwargs) + + # Context Parallel + c = torch.chunk( + c, get_sequence_parallel_world_size(), + dim=1)[get_sequence_parallel_rank()] + + for block in self.vace_blocks: + c = block(c, **new_kwargs) + hints = torch.unbind(c)[:-1] + return hints + + def usp_dit_forward( self, x, t, context, seq_len, + vace_context=None, + vace_context_scale=1.0, clip_fea=None, y=None, ): @@ -77,14 +109,14 @@ def usp_dit_forward( t: [B]. context: A list of text embeddings each with shape [L, C]. """ - if self.model_type == 'i2v': + if self.model_type != 'vace': assert clip_fea is not None and y is not None # params device = self.patch_embedding.weight.device if self.freqs.device != device: self.freqs = self.freqs.to(device) - if y is not None: + if self.model_type != 'vace' and y is not None: x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] # embeddings @@ -114,7 +146,7 @@ def usp_dit_forward( for u in context ])) - if clip_fea is not None: + if self.model_type != 'vace' and clip_fea is not None: context_clip = self.img_emb(clip_fea) # bs x 257 x dim context = torch.concat([context_clip, context], dim=1) @@ -132,6 +164,11 @@ def usp_dit_forward( x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + if self.model_type == 'vace': + hints = self.forward_vace(x, vace_context, seq_len, kwargs) + kwargs['hints'] = hints + kwargs['context_scale'] = vace_context_scale + for block in self.blocks: x = block(x, **kwargs) diff --git a/wan/modules/model.py b/wan/modules/model.py index 7c6bddb..b94474a 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -400,7 +400,7 @@ class WanModel(ModelMixin, ConfigMixin): Args: model_type (`str`, *optional*, defaults to 't2v'): - Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video) + Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video) or 'vace' patch_size (`tuple`, *optional*, defaults to (1, 2, 2)): 3D patch dimensions for video embedding (t_patch, h_patch, w_patch) text_len (`int`, *optional*, defaults to 512): @@ -433,7 +433,7 @@ class WanModel(ModelMixin, ConfigMixin): super().__init__() - assert model_type in ['t2v', 'i2v', 'flf2v'] + assert model_type in ['t2v', 'i2v', 'flf2v', 'vace'] self.model_type = model_type self.patch_size = patch_size diff --git a/wan/modules/vace_model.py b/wan/modules/vace_model.py index 85425d0..344f547 100644 --- a/wan/modules/vace_model.py +++ b/wan/modules/vace_model.py @@ -71,7 +71,7 @@ class VaceWanModel(WanModel): def __init__(self, vace_layers=None, vace_in_dim=None, - model_type='t2v', + model_type='vace', patch_size=(1, 2, 2), text_len=512, in_dim=16,