From ce7766085ebdb33dd5eb2569dfb007e10ff50d18 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 12 Jun 2025 06:52:14 +0000 Subject: [PATCH] merge taylor cache in generate.py and add config args for taylor cache --- generate.py | 40 +++++++++++- wan/modules/model.py | 11 +++- wan/taylorseer/cache_functions/__init__.py | 2 +- wan/taylorseer/cache_functions/cache_init.py | 69 ++++++++++++++++++++ 4 files changed, 117 insertions(+), 5 deletions(-) diff --git a/generate.py b/generate.py index 2ba852d..6f1f894 100644 --- a/generate.py +++ b/generate.py @@ -30,6 +30,9 @@ from wan.modules.model import sinusoidal_embedding_1d from wan.utils.fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps) from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from wan.taylorseer.generates import wan_t2v_generate, wan_i2v_generate +from wan.taylorseer.forwards import wan_forward, xfusers_wan_forward, wan_attention_forward +import types EXAMPLE_PROMPT = { @@ -802,7 +805,7 @@ def _parse_args(): "--enable_teacache", action="store_true", default=False, - help=" use ret_steps or not") + help="enable teacache or not") #teacache_thresh parser.add_argument( @@ -811,6 +814,27 @@ def _parse_args(): default= 0.2, help="tea_cache threshold") + parser.add_argument( + "--enable_taylor_cache", + action="store_true", + default=False, + help="enable taylor cache or not" + ) + + parser.add_argument( + "--taylor_cache_max_order", + type=int, + default= -1, + help="taylor cache max_order or not" + ) + + parser.add_argument( + "--taylor_cache_fresh_threshold", + type=int, + default= -1, + help="taylor cache fresh threshold" + ) + parser.add_argument( "--enable-fa3", "--enable_fa3", @@ -940,7 +964,7 @@ def generate(args): t5_fsdp=args.t5_fsdp, dit_fsdp=args.dit_fsdp, use_usp=(args.ulysses_size > 1 or args.ring_size > 1), - t5_cpu=args.t5_cpu, + t5_cpu=args.t5_cpu, use_taylor_cache= args.enable_taylor_cache ) if args.enable_fa3: @@ -980,6 +1004,18 @@ def generate(args): logging.info( f"Generating {'image' if 't2i' in args.task else 'video'} ...") start = time.time() + + if args.enable_taylor_cache: + wan_t2v.generate = types.MethodType(wan_t2v_generate, wan_t2v) + cache_update_dic = {} + if args.taylor_cache_max_order > 0: + cache_update_dic["max_order"] = args.taylor_cache_max_order + if args.taylor_cache_fresh_threshold > 0: + cache_update_dic["fresh_threshold"] = args.taylor_cache_fresh_threshold + + # print("*****************>>>>>>>>>>>>>>>>>>",cache_update_dic) + wan_t2v.model.cache_update(**cache_update_dic) + video = wan_t2v.generate( args.prompt, size=SIZE_CONFIGS[args.size], diff --git a/wan/modules/model.py b/wan/modules/model.py index 7bc24bc..96653e2 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -8,7 +8,7 @@ from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin from .attention import flash_attention -from wan.taylorseer.cache_functions import cache_init +from wan.taylorseer.cache_functions import cache_init, cache_update __all__ = ['WanModel'] @@ -512,6 +512,12 @@ class WanModel(ModelMixin, ConfigMixin): def cache_init(self): self.cache_dic, self.current = cache_init(self) + + def cache_update(self, **kwargs): + # print("############### kwargs: ", kwargs) + if "max_order" in kwargs or "fresh_threshold" in kwargs: + self.cache_dic, self.current = cache_update(self, **kwargs) + def forward( self, x, @@ -519,7 +525,8 @@ class WanModel(ModelMixin, ConfigMixin): context, seq_len, clip_fea=None, - y=None, + y=None + # **kwargs ): r""" Forward pass through the diffusion model diff --git a/wan/taylorseer/cache_functions/__init__.py b/wan/taylorseer/cache_functions/__init__.py index b790804..c1ff4b2 100644 --- a/wan/taylorseer/cache_functions/__init__.py +++ b/wan/taylorseer/cache_functions/__init__.py @@ -1,3 +1,3 @@ -from .cache_init import cache_init +from .cache_init import cache_init, cache_update from .cal_type import cal_type from .force_scheduler import force_scheduler \ No newline at end of file diff --git a/wan/taylorseer/cache_functions/cache_init.py b/wan/taylorseer/cache_functions/cache_init.py index ce0ce74..5487746 100644 --- a/wan/taylorseer/cache_functions/cache_init.py +++ b/wan/taylorseer/cache_functions/cache_init.py @@ -65,3 +65,72 @@ def cache_init(self, num_steps= 50): current['num_steps'] = num_steps return cache_dic, current + +def cache_update(self, num_steps= 50, **kwargs): + ''' + Initialization for cache. + ''' + cache_dic = {} + cache = {} + cache[-1]={} + cache[-1]['cond_stream']={} + cache[-1]['uncond_stream']={} + cache_dic['cache_counter'] = 0 + + for j in range(self.num_layers): + cache[-1]['cond_stream'][j] = {} + cache[-1]['uncond_stream'][j] = {} + + cache_dic['taylor_cache'] = False + cache_dic['Delta-DiT'] = False + + + cache_dic['cache_type'] = 'random' + cache_dic['fresh_ratio_schedule'] = 'ToCa' + cache_dic['fresh_ratio'] = 0.0 + cache_dic['fresh_threshold'] = 1 + cache_dic['force_fresh'] = 'global' + + mode = 'Taylor' + + if mode == 'original': + cache_dic['cache'] = cache + cache_dic['force_fresh'] = 'global' + cache_dic['max_order'] = 0 + cache_dic['first_enhance'] = 3 + + elif mode == 'ToCa': + cache_dic['cache_type'] = 'attention' + cache_dic['cache'] = cache + cache_dic['fresh_ratio_schedule'] = 'ToCa' + cache_dic['fresh_ratio'] = 0.1 + cache_dic['fresh_threshold'] = 5 + cache_dic['force_fresh'] = 'global' + cache_dic['soft_fresh_weight'] = 0.0 + cache_dic['max_order'] = 0 + cache_dic['first_enhance'] = 3 + + elif mode == 'Taylor': + taylor_max_order = kwargs.get("max_order", 1) + taylor_fresh_threshold = kwargs.get("fresh_threshold", 5) + # print("*****************use max_order : {} fresh_threshold : {}".format(taylor_max_order, taylor_fresh_threshold)) + cache_dic['cache'] = cache + cache_dic['fresh_threshold'] = taylor_fresh_threshold + cache_dic['taylor_cache'] = True + cache_dic['max_order'] = taylor_max_order + cache_dic['first_enhance'] = 1 + + elif mode == 'Delta': + cache_dic['cache'] = cache + cache_dic['fresh_ratio'] = 0.0 + cache_dic['fresh_threshold'] = 3 + cache_dic['Delta-DiT'] = True + cache_dic['max_order'] = 0 + cache_dic['first_enhance'] = 1 + + current = {} + current['activated_steps'] = [0] + current['step'] = 0 + current['num_steps'] = num_steps + + return cache_dic, current