add teacache switch

This commit is contained in:
kousakawang 2025-05-26 15:25:46 +08:00
parent 3e209f8b45
commit 8bbcee625e

View File

@ -169,7 +169,11 @@ def usp_dit_forward(
kwargs['hints'] = hints
kwargs['context_scale'] = vace_context_scale
if self.enable_teacache:
teacache_switch = True
if not hasattr(self, "enable_teacache"):
teacache_switch = False
if teacache_switch and self.enable_teacache:
# print("enable teacache")
modulated_inp = e0 if self.use_ref_steps else e
device_id = torch.cuda.current_device()
@ -206,7 +210,7 @@ def usp_dit_forward(
self.accumulated_rel_l1_distance_odd = 0
self.previous_e0_odd = modulated_inp.clone()
if self.enable_teacache:
if teacache_switch and self.enable_teacache:
if self.is_even:
if not should_calc_even:
logging.info("use residual estimation for this difusion step")
@ -235,10 +239,11 @@ def usp_dit_forward(
# Context Parallel
x = get_sp_group().all_gather(x, dim=1)
self.cnt += 1
if self.cnt >= self.num_steps:
self.cnt = 0
if teacache_switch:
self.cnt += 1
if self.cnt >= self.num_steps:
self.cnt = 0
# unpatchify
x = self.unpatchify(x, grid_sizes)