merge taylor cache in generate.py and add config args for taylor cache

This commit is contained in:
Your Name 2025-06-12 06:52:14 +00:00
parent caac4152cf
commit ce7766085e
4 changed files with 117 additions and 5 deletions

View File

@ -30,6 +30,9 @@ from wan.modules.model import sinusoidal_embedding_1d
from wan.utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps)
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from wan.taylorseer.generates import wan_t2v_generate, wan_i2v_generate
from wan.taylorseer.forwards import wan_forward, xfusers_wan_forward, wan_attention_forward
import types
EXAMPLE_PROMPT = {
@ -802,7 +805,7 @@ def _parse_args():
"--enable_teacache",
action="store_true",
default=False,
help=" use ret_steps or not")
help="enable teacache or not")
#teacache_thresh
parser.add_argument(
@ -811,6 +814,27 @@ def _parse_args():
default= 0.2,
help="tea_cache threshold")
parser.add_argument(
"--enable_taylor_cache",
action="store_true",
default=False,
help="enable taylor cache or not"
)
parser.add_argument(
"--taylor_cache_max_order",
type=int,
default= -1,
help="taylor cache max_order or not"
)
parser.add_argument(
"--taylor_cache_fresh_threshold",
type=int,
default= -1,
help="taylor cache fresh threshold"
)
parser.add_argument(
"--enable-fa3",
"--enable_fa3",
@ -940,7 +964,7 @@ def generate(args):
t5_fsdp=args.t5_fsdp,
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
t5_cpu=args.t5_cpu, use_taylor_cache= args.enable_taylor_cache
)
if args.enable_fa3:
@ -980,6 +1004,18 @@ def generate(args):
logging.info(
f"Generating {'image' if 't2i' in args.task else 'video'} ...")
start = time.time()
if args.enable_taylor_cache:
wan_t2v.generate = types.MethodType(wan_t2v_generate, wan_t2v)
cache_update_dic = {}
if args.taylor_cache_max_order > 0:
cache_update_dic["max_order"] = args.taylor_cache_max_order
if args.taylor_cache_fresh_threshold > 0:
cache_update_dic["fresh_threshold"] = args.taylor_cache_fresh_threshold
# print("*****************>>>>>>>>>>>>>>>>>>",cache_update_dic)
wan_t2v.model.cache_update(**cache_update_dic)
video = wan_t2v.generate(
args.prompt,
size=SIZE_CONFIGS[args.size],

View File

@ -8,7 +8,7 @@ from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from .attention import flash_attention
from wan.taylorseer.cache_functions import cache_init
from wan.taylorseer.cache_functions import cache_init, cache_update
__all__ = ['WanModel']
@ -512,6 +512,12 @@ class WanModel(ModelMixin, ConfigMixin):
def cache_init(self):
self.cache_dic, self.current = cache_init(self)
def cache_update(self, **kwargs):
# print("############### kwargs: ", kwargs)
if "max_order" in kwargs or "fresh_threshold" in kwargs:
self.cache_dic, self.current = cache_update(self, **kwargs)
def forward(
self,
x,
@ -519,7 +525,8 @@ class WanModel(ModelMixin, ConfigMixin):
context,
seq_len,
clip_fea=None,
y=None,
y=None
# **kwargs
):
r"""
Forward pass through the diffusion model

View File

@ -1,3 +1,3 @@
from .cache_init import cache_init
from .cache_init import cache_init, cache_update
from .cal_type import cal_type
from .force_scheduler import force_scheduler

View File

@ -65,3 +65,72 @@ def cache_init(self, num_steps= 50):
current['num_steps'] = num_steps
return cache_dic, current
def cache_update(self, num_steps= 50, **kwargs):
'''
Initialization for cache.
'''
cache_dic = {}
cache = {}
cache[-1]={}
cache[-1]['cond_stream']={}
cache[-1]['uncond_stream']={}
cache_dic['cache_counter'] = 0
for j in range(self.num_layers):
cache[-1]['cond_stream'][j] = {}
cache[-1]['uncond_stream'][j] = {}
cache_dic['taylor_cache'] = False
cache_dic['Delta-DiT'] = False
cache_dic['cache_type'] = 'random'
cache_dic['fresh_ratio_schedule'] = 'ToCa'
cache_dic['fresh_ratio'] = 0.0
cache_dic['fresh_threshold'] = 1
cache_dic['force_fresh'] = 'global'
mode = 'Taylor'
if mode == 'original':
cache_dic['cache'] = cache
cache_dic['force_fresh'] = 'global'
cache_dic['max_order'] = 0
cache_dic['first_enhance'] = 3
elif mode == 'ToCa':
cache_dic['cache_type'] = 'attention'
cache_dic['cache'] = cache
cache_dic['fresh_ratio_schedule'] = 'ToCa'
cache_dic['fresh_ratio'] = 0.1
cache_dic['fresh_threshold'] = 5
cache_dic['force_fresh'] = 'global'
cache_dic['soft_fresh_weight'] = 0.0
cache_dic['max_order'] = 0
cache_dic['first_enhance'] = 3
elif mode == 'Taylor':
taylor_max_order = kwargs.get("max_order", 1)
taylor_fresh_threshold = kwargs.get("fresh_threshold", 5)
# print("*****************use max_order : {} fresh_threshold : {}".format(taylor_max_order, taylor_fresh_threshold))
cache_dic['cache'] = cache
cache_dic['fresh_threshold'] = taylor_fresh_threshold
cache_dic['taylor_cache'] = True
cache_dic['max_order'] = taylor_max_order
cache_dic['first_enhance'] = 1
elif mode == 'Delta':
cache_dic['cache'] = cache
cache_dic['fresh_ratio'] = 0.0
cache_dic['fresh_threshold'] = 3
cache_dic['Delta-DiT'] = True
cache_dic['max_order'] = 0
cache_dic['first_enhance'] = 1
current = {}
current['activated_steps'] = [0]
current['step'] = 0
current['num_steps'] = num_steps
return cache_dic, current