diff --git a/wan/distributed/xdit_context_parallel.py b/wan/distributed/xdit_context_parallel.py index e4be2e0..73ea8b6 100644 --- a/wan/distributed/xdit_context_parallel.py +++ b/wan/distributed/xdit_context_parallel.py @@ -87,9 +87,10 @@ def usp_dit_forward_vace( c, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + hints = [] for block in self.vace_blocks: - c = block(c, **new_kwargs) - hints = torch.unbind(c)[:-1] + c, c_skip = block(c, **new_kwargs) + hints.append(c_skip) return hints