diff --git a/wan/distributed/xdit_context_parallel.py b/wan/distributed/xdit_context_parallel.py index f07661d..cf13c9e 100644 --- a/wan/distributed/xdit_context_parallel.py +++ b/wan/distributed/xdit_context_parallel.py @@ -169,7 +169,11 @@ def usp_dit_forward( kwargs['hints'] = hints kwargs['context_scale'] = vace_context_scale - if self.enable_teacache: + teacache_switch = True + if not hasattr(self, "enable_teacache"): + teacache_switch = False + + if teacache_switch and self.enable_teacache: # print("enable teacache") modulated_inp = e0 if self.use_ref_steps else e device_id = torch.cuda.current_device() @@ -206,7 +210,7 @@ def usp_dit_forward( self.accumulated_rel_l1_distance_odd = 0 self.previous_e0_odd = modulated_inp.clone() - if self.enable_teacache: + if teacache_switch and self.enable_teacache: if self.is_even: if not should_calc_even: logging.info("use residual estimation for this difusion step") @@ -235,10 +239,11 @@ def usp_dit_forward( # Context Parallel x = get_sp_group().all_gather(x, dim=1) - - self.cnt += 1 - if self.cnt >= self.num_steps: - self.cnt = 0 + + if teacache_switch: + self.cnt += 1 + if self.cnt >= self.num_steps: + self.cnt = 0 # unpatchify x = self.unpatchify(x, grid_sizes)