mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-07-14 03:30:10 +00:00
add teacache switch
This commit is contained in:
parent
3e209f8b45
commit
8bbcee625e
@ -169,7 +169,11 @@ def usp_dit_forward(
|
|||||||
kwargs['hints'] = hints
|
kwargs['hints'] = hints
|
||||||
kwargs['context_scale'] = vace_context_scale
|
kwargs['context_scale'] = vace_context_scale
|
||||||
|
|
||||||
if self.enable_teacache:
|
teacache_switch = True
|
||||||
|
if not hasattr(self, "enable_teacache"):
|
||||||
|
teacache_switch = False
|
||||||
|
|
||||||
|
if teacache_switch and self.enable_teacache:
|
||||||
# print("enable teacache")
|
# print("enable teacache")
|
||||||
modulated_inp = e0 if self.use_ref_steps else e
|
modulated_inp = e0 if self.use_ref_steps else e
|
||||||
device_id = torch.cuda.current_device()
|
device_id = torch.cuda.current_device()
|
||||||
@ -206,7 +210,7 @@ def usp_dit_forward(
|
|||||||
self.accumulated_rel_l1_distance_odd = 0
|
self.accumulated_rel_l1_distance_odd = 0
|
||||||
self.previous_e0_odd = modulated_inp.clone()
|
self.previous_e0_odd = modulated_inp.clone()
|
||||||
|
|
||||||
if self.enable_teacache:
|
if teacache_switch and self.enable_teacache:
|
||||||
if self.is_even:
|
if self.is_even:
|
||||||
if not should_calc_even:
|
if not should_calc_even:
|
||||||
logging.info("use residual estimation for this difusion step")
|
logging.info("use residual estimation for this difusion step")
|
||||||
@ -236,6 +240,7 @@ def usp_dit_forward(
|
|||||||
# Context Parallel
|
# Context Parallel
|
||||||
x = get_sp_group().all_gather(x, dim=1)
|
x = get_sp_group().all_gather(x, dim=1)
|
||||||
|
|
||||||
|
if teacache_switch:
|
||||||
self.cnt += 1
|
self.cnt += 1
|
||||||
if self.cnt >= self.num_steps:
|
if self.cnt >= self.num_steps:
|
||||||
self.cnt = 0
|
self.cnt = 0
|
||||||
|
Loading…
Reference in New Issue
Block a user