mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-06-05 14:54:54 +00:00
Support training with multiple gpus
This commit is contained in:
parent
d372c7eb6b
commit
1cfa46a3cd
@ -63,12 +63,44 @@ def rope_apply(x, grid_sizes, freqs):
|
||||
return torch.stack(output).float()
|
||||
|
||||
|
||||
def usp_dit_forward_vace(
|
||||
self,
|
||||
x,
|
||||
vace_context,
|
||||
seq_len,
|
||||
kwargs
|
||||
):
|
||||
# embeddings
|
||||
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
||||
c = [u.flatten(2).transpose(1, 2) for u in c]
|
||||
c = torch.cat([
|
||||
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
||||
dim=1) for u in c
|
||||
])
|
||||
|
||||
# arguments
|
||||
new_kwargs = dict(x=x)
|
||||
new_kwargs.update(kwargs)
|
||||
|
||||
# Context Parallel
|
||||
c = torch.chunk(
|
||||
c, get_sequence_parallel_world_size(),
|
||||
dim=1)[get_sequence_parallel_rank()]
|
||||
|
||||
for block in self.vace_blocks:
|
||||
c = block(c, **new_kwargs)
|
||||
hints = torch.unbind(c)[:-1]
|
||||
return hints
|
||||
|
||||
|
||||
def usp_dit_forward(
|
||||
self,
|
||||
x,
|
||||
t,
|
||||
context,
|
||||
seq_len,
|
||||
vace_context=None,
|
||||
vace_context_scale=1.0,
|
||||
clip_fea=None,
|
||||
y=None,
|
||||
):
|
||||
@ -77,14 +109,14 @@ def usp_dit_forward(
|
||||
t: [B].
|
||||
context: A list of text embeddings each with shape [L, C].
|
||||
"""
|
||||
if self.model_type == 'i2v':
|
||||
if self.model_type != 'vace':
|
||||
assert clip_fea is not None and y is not None
|
||||
# params
|
||||
device = self.patch_embedding.weight.device
|
||||
if self.freqs.device != device:
|
||||
self.freqs = self.freqs.to(device)
|
||||
|
||||
if y is not None:
|
||||
if self.model_type != 'vace' and y is not None:
|
||||
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
||||
|
||||
# embeddings
|
||||
@ -114,7 +146,7 @@ def usp_dit_forward(
|
||||
for u in context
|
||||
]))
|
||||
|
||||
if clip_fea is not None:
|
||||
if self.model_type != 'vace' and clip_fea is not None:
|
||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||
context = torch.concat([context_clip, context], dim=1)
|
||||
|
||||
@ -132,6 +164,11 @@ def usp_dit_forward(
|
||||
x, get_sequence_parallel_world_size(),
|
||||
dim=1)[get_sequence_parallel_rank()]
|
||||
|
||||
if self.model_type == 'vace':
|
||||
hints = self.forward_vace(x, vace_context, seq_len, kwargs)
|
||||
kwargs['hints'] = hints
|
||||
kwargs['context_scale'] = vace_context_scale
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, **kwargs)
|
||||
|
||||
|
@ -400,7 +400,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
|
||||
Args:
|
||||
model_type (`str`, *optional*, defaults to 't2v'):
|
||||
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video)
|
||||
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video) or 'vace'
|
||||
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
|
||||
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
|
||||
text_len (`int`, *optional*, defaults to 512):
|
||||
@ -433,7 +433,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
|
||||
super().__init__()
|
||||
|
||||
assert model_type in ['t2v', 'i2v', 'flf2v']
|
||||
assert model_type in ['t2v', 'i2v', 'flf2v', 'vace']
|
||||
self.model_type = model_type
|
||||
|
||||
self.patch_size = patch_size
|
||||
|
@ -71,7 +71,7 @@ class VaceWanModel(WanModel):
|
||||
def __init__(self,
|
||||
vace_layers=None,
|
||||
vace_in_dim=None,
|
||||
model_type='t2v',
|
||||
model_type='vace',
|
||||
patch_size=(1, 2, 2),
|
||||
text_len=512,
|
||||
in_dim=16,
|
||||
|
Loading…
Reference in New Issue
Block a user