mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-07-13 19:20:09 +00:00
[FEAT] enable multi-device teacache
This commit is contained in:
parent
420963b7b5
commit
3e209f8b45
@ -1,6 +1,9 @@
|
|||||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda.amp as amp
|
import torch.cuda.amp as amp
|
||||||
|
import numpy as np
|
||||||
|
import logging
|
||||||
|
|
||||||
from xfuser.core.distributed import (
|
from xfuser.core.distributed import (
|
||||||
get_sequence_parallel_rank,
|
get_sequence_parallel_rank,
|
||||||
get_sequence_parallel_world_size,
|
get_sequence_parallel_world_size,
|
||||||
@ -166,8 +169,66 @@ def usp_dit_forward(
|
|||||||
kwargs['hints'] = hints
|
kwargs['hints'] = hints
|
||||||
kwargs['context_scale'] = vace_context_scale
|
kwargs['context_scale'] = vace_context_scale
|
||||||
|
|
||||||
for block in self.blocks:
|
if self.enable_teacache:
|
||||||
x = block(x, **kwargs)
|
# print("enable teacache")
|
||||||
|
modulated_inp = e0 if self.use_ref_steps else e
|
||||||
|
device_id = torch.cuda.current_device()
|
||||||
|
# print("check args, {} {} {} {} {}".format(self.cnt, self.ret_steps, self.cutoff_steps, self.coefficients,
|
||||||
|
# device_id))
|
||||||
|
# teacache
|
||||||
|
if self.cnt%2==0: # even -> conditon
|
||||||
|
self.is_even = True
|
||||||
|
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
|
||||||
|
should_calc_even = True
|
||||||
|
self.accumulated_rel_l1_distance_even = 0
|
||||||
|
else:
|
||||||
|
rescale_func = np.poly1d(self.coefficients)
|
||||||
|
self.accumulated_rel_l1_distance_even += rescale_func(((modulated_inp-self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()).cpu().item())
|
||||||
|
if self.accumulated_rel_l1_distance_even < self.teacache_thresh:
|
||||||
|
should_calc_even = False
|
||||||
|
else:
|
||||||
|
should_calc_even = True
|
||||||
|
self.accumulated_rel_l1_distance_even = 0
|
||||||
|
self.previous_e0_even = modulated_inp.clone()
|
||||||
|
|
||||||
|
else: # odd -> unconditon
|
||||||
|
self.is_even = False
|
||||||
|
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
|
||||||
|
should_calc_odd = True
|
||||||
|
self.accumulated_rel_l1_distance_odd = 0
|
||||||
|
else:
|
||||||
|
rescale_func = np.poly1d(self.coefficients)
|
||||||
|
self.accumulated_rel_l1_distance_odd += rescale_func(((modulated_inp-self.previous_e0_odd).abs().mean() / self.previous_e0_odd.abs().mean()).cpu().item())
|
||||||
|
if self.accumulated_rel_l1_distance_odd < self.teacache_thresh:
|
||||||
|
should_calc_odd = False
|
||||||
|
else:
|
||||||
|
should_calc_odd = True
|
||||||
|
self.accumulated_rel_l1_distance_odd = 0
|
||||||
|
self.previous_e0_odd = modulated_inp.clone()
|
||||||
|
|
||||||
|
if self.enable_teacache:
|
||||||
|
if self.is_even:
|
||||||
|
if not should_calc_even:
|
||||||
|
logging.info("use residual estimation for this difusion step")
|
||||||
|
x += self.previous_residual_even
|
||||||
|
else:
|
||||||
|
ori_x = x.clone()
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x, **kwargs)
|
||||||
|
self.previous_residual_even = x - ori_x
|
||||||
|
else:
|
||||||
|
if not should_calc_odd:
|
||||||
|
logging.info("use residual estimation for thi8s difusion step")
|
||||||
|
x += self.previous_residual_odd
|
||||||
|
else:
|
||||||
|
ori_x = x.clone()
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x, **kwargs)
|
||||||
|
self.previous_residual_odd = x - ori_x
|
||||||
|
|
||||||
|
else:
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x, **kwargs)
|
||||||
|
|
||||||
# head
|
# head
|
||||||
x = self.head(x, e)
|
x = self.head(x, e)
|
||||||
@ -175,6 +236,10 @@ def usp_dit_forward(
|
|||||||
# Context Parallel
|
# Context Parallel
|
||||||
x = get_sp_group().all_gather(x, dim=1)
|
x = get_sp_group().all_gather(x, dim=1)
|
||||||
|
|
||||||
|
self.cnt += 1
|
||||||
|
if self.cnt >= self.num_steps:
|
||||||
|
self.cnt = 0
|
||||||
|
|
||||||
# unpatchify
|
# unpatchify
|
||||||
x = self.unpatchify(x, grid_sizes)
|
x = self.unpatchify(x, grid_sizes)
|
||||||
return [u.float() for u in x]
|
return [u.float() for u in x]
|
||||||
|
Loading…
Reference in New Issue
Block a user