Merge pull request #2 from bytedance-iaas/mul_device_teacache

[FEAT] support Mul device teacache
This commit is contained in:
zc8gerard 2025-05-26 21:59:27 +08:00 committed by GitHub
commit 3c7c6f8b29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,6 +2,8 @@
import torch
import torch.cuda.amp as amp
import numpy as np
import logging
import yunchang
from yunchang.kernels import AttnType
@ -176,14 +178,81 @@ def usp_dit_forward(
kwargs['hints'] = hints
kwargs['context_scale'] = vace_context_scale
for block in self.blocks:
x = block(x, **kwargs)
teacache_switch = True
if not hasattr(self, "enable_teacache"):
teacache_switch = False
if teacache_switch and self.enable_teacache:
# print("enable teacache")
modulated_inp = e0 if self.use_ref_steps else e
device_id = torch.cuda.current_device()
# print("check args, {} {} {} {} {}".format(self.cnt, self.ret_steps, self.cutoff_steps, self.coefficients,
# device_id))
# teacache
if self.cnt%2==0: # even -> conditon
self.is_even = True
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
should_calc_even = True
self.accumulated_rel_l1_distance_even = 0
else:
rescale_func = np.poly1d(self.coefficients)
self.accumulated_rel_l1_distance_even += rescale_func(((modulated_inp-self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance_even < self.teacache_thresh:
should_calc_even = False
else:
should_calc_even = True
self.accumulated_rel_l1_distance_even = 0
self.previous_e0_even = modulated_inp.clone()
else: # odd -> unconditon
self.is_even = False
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
should_calc_odd = True
self.accumulated_rel_l1_distance_odd = 0
else:
rescale_func = np.poly1d(self.coefficients)
self.accumulated_rel_l1_distance_odd += rescale_func(((modulated_inp-self.previous_e0_odd).abs().mean() / self.previous_e0_odd.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance_odd < self.teacache_thresh:
should_calc_odd = False
else:
should_calc_odd = True
self.accumulated_rel_l1_distance_odd = 0
self.previous_e0_odd = modulated_inp.clone()
if teacache_switch and self.enable_teacache:
if self.is_even:
if not should_calc_even:
logging.info("use residual estimation for this difusion step")
x += self.previous_residual_even
else:
ori_x = x.clone()
for block in self.blocks:
x = block(x, **kwargs)
self.previous_residual_even = x - ori_x
else:
if not should_calc_odd:
logging.info("use residual estimation for thi8s difusion step")
x += self.previous_residual_odd
else:
ori_x = x.clone()
for block in self.blocks:
x = block(x, **kwargs)
self.previous_residual_odd = x - ori_x
else:
for block in self.blocks:
x = block(x, **kwargs)
# head
x = self.head(x, e)
# Context Parallel
x = get_sp_group().all_gather(x, dim=1)
if teacache_switch:
self.cnt += 1
if self.cnt >= self.num_steps:
self.cnt = 0
# unpatchify
x = self.unpatchify(x, grid_sizes)