mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-06-07 15:54:52 +00:00
Support training with multiple gpus
This commit is contained in:
parent
d372c7eb6b
commit
1cfa46a3cd
@ -63,12 +63,44 @@ def rope_apply(x, grid_sizes, freqs):
|
|||||||
return torch.stack(output).float()
|
return torch.stack(output).float()
|
||||||
|
|
||||||
|
|
||||||
|
def usp_dit_forward_vace(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
vace_context,
|
||||||
|
seq_len,
|
||||||
|
kwargs
|
||||||
|
):
|
||||||
|
# embeddings
|
||||||
|
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
||||||
|
c = [u.flatten(2).transpose(1, 2) for u in c]
|
||||||
|
c = torch.cat([
|
||||||
|
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
||||||
|
dim=1) for u in c
|
||||||
|
])
|
||||||
|
|
||||||
|
# arguments
|
||||||
|
new_kwargs = dict(x=x)
|
||||||
|
new_kwargs.update(kwargs)
|
||||||
|
|
||||||
|
# Context Parallel
|
||||||
|
c = torch.chunk(
|
||||||
|
c, get_sequence_parallel_world_size(),
|
||||||
|
dim=1)[get_sequence_parallel_rank()]
|
||||||
|
|
||||||
|
for block in self.vace_blocks:
|
||||||
|
c = block(c, **new_kwargs)
|
||||||
|
hints = torch.unbind(c)[:-1]
|
||||||
|
return hints
|
||||||
|
|
||||||
|
|
||||||
def usp_dit_forward(
|
def usp_dit_forward(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
t,
|
t,
|
||||||
context,
|
context,
|
||||||
seq_len,
|
seq_len,
|
||||||
|
vace_context=None,
|
||||||
|
vace_context_scale=1.0,
|
||||||
clip_fea=None,
|
clip_fea=None,
|
||||||
y=None,
|
y=None,
|
||||||
):
|
):
|
||||||
@ -77,14 +109,14 @@ def usp_dit_forward(
|
|||||||
t: [B].
|
t: [B].
|
||||||
context: A list of text embeddings each with shape [L, C].
|
context: A list of text embeddings each with shape [L, C].
|
||||||
"""
|
"""
|
||||||
if self.model_type == 'i2v':
|
if self.model_type != 'vace':
|
||||||
assert clip_fea is not None and y is not None
|
assert clip_fea is not None and y is not None
|
||||||
# params
|
# params
|
||||||
device = self.patch_embedding.weight.device
|
device = self.patch_embedding.weight.device
|
||||||
if self.freqs.device != device:
|
if self.freqs.device != device:
|
||||||
self.freqs = self.freqs.to(device)
|
self.freqs = self.freqs.to(device)
|
||||||
|
|
||||||
if y is not None:
|
if self.model_type != 'vace' and y is not None:
|
||||||
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
||||||
|
|
||||||
# embeddings
|
# embeddings
|
||||||
@ -114,7 +146,7 @@ def usp_dit_forward(
|
|||||||
for u in context
|
for u in context
|
||||||
]))
|
]))
|
||||||
|
|
||||||
if clip_fea is not None:
|
if self.model_type != 'vace' and clip_fea is not None:
|
||||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||||
context = torch.concat([context_clip, context], dim=1)
|
context = torch.concat([context_clip, context], dim=1)
|
||||||
|
|
||||||
@ -132,6 +164,11 @@ def usp_dit_forward(
|
|||||||
x, get_sequence_parallel_world_size(),
|
x, get_sequence_parallel_world_size(),
|
||||||
dim=1)[get_sequence_parallel_rank()]
|
dim=1)[get_sequence_parallel_rank()]
|
||||||
|
|
||||||
|
if self.model_type == 'vace':
|
||||||
|
hints = self.forward_vace(x, vace_context, seq_len, kwargs)
|
||||||
|
kwargs['hints'] = hints
|
||||||
|
kwargs['context_scale'] = vace_context_scale
|
||||||
|
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x, **kwargs)
|
x = block(x, **kwargs)
|
||||||
|
|
||||||
|
@ -400,7 +400,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_type (`str`, *optional*, defaults to 't2v'):
|
model_type (`str`, *optional*, defaults to 't2v'):
|
||||||
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video)
|
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video) or 'vace'
|
||||||
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
|
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
|
||||||
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
|
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
|
||||||
text_len (`int`, *optional*, defaults to 512):
|
text_len (`int`, *optional*, defaults to 512):
|
||||||
@ -433,7 +433,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
assert model_type in ['t2v', 'i2v', 'flf2v']
|
assert model_type in ['t2v', 'i2v', 'flf2v', 'vace']
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
|
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
|
@ -71,7 +71,7 @@ class VaceWanModel(WanModel):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
vace_layers=None,
|
vace_layers=None,
|
||||||
vace_in_dim=None,
|
vace_in_dim=None,
|
||||||
model_type='t2v',
|
model_type='vace',
|
||||||
patch_size=(1, 2, 2),
|
patch_size=(1, 2, 2),
|
||||||
text_len=512,
|
text_len=512,
|
||||||
in_dim=16,
|
in_dim=16,
|
||||||
|
Loading…
Reference in New Issue
Block a user