diff --git a/wan/distributed/xdit_context_parallel.py b/wan/distributed/xdit_context_parallel.py index 4718577..f07661d 100644 --- a/wan/distributed/xdit_context_parallel.py +++ b/wan/distributed/xdit_context_parallel.py @@ -1,6 +1,9 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import torch import torch.cuda.amp as amp +import numpy as np +import logging + from xfuser.core.distributed import ( get_sequence_parallel_rank, get_sequence_parallel_world_size, @@ -166,8 +169,66 @@ def usp_dit_forward( kwargs['hints'] = hints kwargs['context_scale'] = vace_context_scale - for block in self.blocks: - x = block(x, **kwargs) + if self.enable_teacache: + # print("enable teacache") + modulated_inp = e0 if self.use_ref_steps else e + device_id = torch.cuda.current_device() + # print("check args, {} {} {} {} {}".format(self.cnt, self.ret_steps, self.cutoff_steps, self.coefficients, + # device_id)) + # teacache + if self.cnt%2==0: # even -> conditon + self.is_even = True + if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps: + should_calc_even = True + self.accumulated_rel_l1_distance_even = 0 + else: + rescale_func = np.poly1d(self.coefficients) + self.accumulated_rel_l1_distance_even += rescale_func(((modulated_inp-self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance_even < self.teacache_thresh: + should_calc_even = False + else: + should_calc_even = True + self.accumulated_rel_l1_distance_even = 0 + self.previous_e0_even = modulated_inp.clone() + + else: # odd -> unconditon + self.is_even = False + if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps: + should_calc_odd = True + self.accumulated_rel_l1_distance_odd = 0 + else: + rescale_func = np.poly1d(self.coefficients) + self.accumulated_rel_l1_distance_odd += rescale_func(((modulated_inp-self.previous_e0_odd).abs().mean() / self.previous_e0_odd.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance_odd < self.teacache_thresh: + should_calc_odd = False + else: + should_calc_odd = True + self.accumulated_rel_l1_distance_odd = 0 + self.previous_e0_odd = modulated_inp.clone() + + if self.enable_teacache: + if self.is_even: + if not should_calc_even: + logging.info("use residual estimation for this difusion step") + x += self.previous_residual_even + else: + ori_x = x.clone() + for block in self.blocks: + x = block(x, **kwargs) + self.previous_residual_even = x - ori_x + else: + if not should_calc_odd: + logging.info("use residual estimation for thi8s difusion step") + x += self.previous_residual_odd + else: + ori_x = x.clone() + for block in self.blocks: + x = block(x, **kwargs) + self.previous_residual_odd = x - ori_x + + else: + for block in self.blocks: + x = block(x, **kwargs) # head x = self.head(x, e) @@ -175,6 +236,10 @@ def usp_dit_forward( # Context Parallel x = get_sp_group().all_gather(x, dim=1) + self.cnt += 1 + if self.cnt >= self.num_steps: + self.cnt = 0 + # unpatchify x = self.unpatchify(x, grid_sizes) return [u.float() for u in x]