Support training with multiple gpus

This commit is contained in:
wangang.wa 2025-05-08 18:34:26 +08:00
parent d372c7eb6b
commit 1cfa46a3cd
3 changed files with 43 additions and 6 deletions

View File

@ -63,12 +63,44 @@ def rope_apply(x, grid_sizes, freqs):
return torch.stack(output).float()
def usp_dit_forward_vace(
self,
x,
vace_context,
seq_len,
kwargs
):
# embeddings
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c]
c = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in c
])
# arguments
new_kwargs = dict(x=x)
new_kwargs.update(kwargs)
# Context Parallel
c = torch.chunk(
c, get_sequence_parallel_world_size(),
dim=1)[get_sequence_parallel_rank()]
for block in self.vace_blocks:
c = block(c, **new_kwargs)
hints = torch.unbind(c)[:-1]
return hints
def usp_dit_forward(
self,
x,
t,
context,
seq_len,
vace_context=None,
vace_context_scale=1.0,
clip_fea=None,
y=None,
):
@ -77,14 +109,14 @@ def usp_dit_forward(
t: [B].
context: A list of text embeddings each with shape [L, C].
"""
if self.model_type == 'i2v':
if self.model_type != 'vace':
assert clip_fea is not None and y is not None
# params
device = self.patch_embedding.weight.device
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
if y is not None:
if self.model_type != 'vace' and y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings
@ -114,7 +146,7 @@ def usp_dit_forward(
for u in context
]))
if clip_fea is not None:
if self.model_type != 'vace' and clip_fea is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
@ -132,6 +164,11 @@ def usp_dit_forward(
x, get_sequence_parallel_world_size(),
dim=1)[get_sequence_parallel_rank()]
if self.model_type == 'vace':
hints = self.forward_vace(x, vace_context, seq_len, kwargs)
kwargs['hints'] = hints
kwargs['context_scale'] = vace_context_scale
for block in self.blocks:
x = block(x, **kwargs)

View File

@ -400,7 +400,7 @@ class WanModel(ModelMixin, ConfigMixin):
Args:
model_type (`str`, *optional*, defaults to 't2v'):
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video)
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video) or 'vace'
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
text_len (`int`, *optional*, defaults to 512):
@ -433,7 +433,7 @@ class WanModel(ModelMixin, ConfigMixin):
super().__init__()
assert model_type in ['t2v', 'i2v', 'flf2v']
assert model_type in ['t2v', 'i2v', 'flf2v', 'vace']
self.model_type = model_type
self.patch_size = patch_size

View File

@ -71,7 +71,7 @@ class VaceWanModel(WanModel):
def __init__(self,
vace_layers=None,
vace_in_dim=None,
model_type='t2v',
model_type='vace',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,