mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-07-13 11:10:11 +00:00
merge taylor cache in generate.py and add config args for taylor cache
This commit is contained in:
parent
caac4152cf
commit
ce7766085e
40
generate.py
40
generate.py
@ -30,6 +30,9 @@ from wan.modules.model import sinusoidal_embedding_1d
|
||||
from wan.utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
||||
get_sampling_sigmas, retrieve_timesteps)
|
||||
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||
from wan.taylorseer.generates import wan_t2v_generate, wan_i2v_generate
|
||||
from wan.taylorseer.forwards import wan_forward, xfusers_wan_forward, wan_attention_forward
|
||||
import types
|
||||
|
||||
|
||||
EXAMPLE_PROMPT = {
|
||||
@ -802,7 +805,7 @@ def _parse_args():
|
||||
"--enable_teacache",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=" use ret_steps or not")
|
||||
help="enable teacache or not")
|
||||
|
||||
#teacache_thresh
|
||||
parser.add_argument(
|
||||
@ -811,6 +814,27 @@ def _parse_args():
|
||||
default= 0.2,
|
||||
help="tea_cache threshold")
|
||||
|
||||
parser.add_argument(
|
||||
"--enable_taylor_cache",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="enable taylor cache or not"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--taylor_cache_max_order",
|
||||
type=int,
|
||||
default= -1,
|
||||
help="taylor cache max_order or not"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--taylor_cache_fresh_threshold",
|
||||
type=int,
|
||||
default= -1,
|
||||
help="taylor cache fresh threshold"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--enable-fa3",
|
||||
"--enable_fa3",
|
||||
@ -940,7 +964,7 @@ def generate(args):
|
||||
t5_fsdp=args.t5_fsdp,
|
||||
dit_fsdp=args.dit_fsdp,
|
||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||
t5_cpu=args.t5_cpu,
|
||||
t5_cpu=args.t5_cpu, use_taylor_cache= args.enable_taylor_cache
|
||||
)
|
||||
|
||||
if args.enable_fa3:
|
||||
@ -980,6 +1004,18 @@ def generate(args):
|
||||
logging.info(
|
||||
f"Generating {'image' if 't2i' in args.task else 'video'} ...")
|
||||
start = time.time()
|
||||
|
||||
if args.enable_taylor_cache:
|
||||
wan_t2v.generate = types.MethodType(wan_t2v_generate, wan_t2v)
|
||||
cache_update_dic = {}
|
||||
if args.taylor_cache_max_order > 0:
|
||||
cache_update_dic["max_order"] = args.taylor_cache_max_order
|
||||
if args.taylor_cache_fresh_threshold > 0:
|
||||
cache_update_dic["fresh_threshold"] = args.taylor_cache_fresh_threshold
|
||||
|
||||
# print("*****************>>>>>>>>>>>>>>>>>>",cache_update_dic)
|
||||
wan_t2v.model.cache_update(**cache_update_dic)
|
||||
|
||||
video = wan_t2v.generate(
|
||||
args.prompt,
|
||||
size=SIZE_CONFIGS[args.size],
|
||||
|
@ -8,7 +8,7 @@ from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
|
||||
from .attention import flash_attention
|
||||
from wan.taylorseer.cache_functions import cache_init
|
||||
from wan.taylorseer.cache_functions import cache_init, cache_update
|
||||
|
||||
__all__ = ['WanModel']
|
||||
|
||||
@ -512,6 +512,12 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
|
||||
def cache_init(self):
|
||||
self.cache_dic, self.current = cache_init(self)
|
||||
|
||||
def cache_update(self, **kwargs):
|
||||
# print("############### kwargs: ", kwargs)
|
||||
if "max_order" in kwargs or "fresh_threshold" in kwargs:
|
||||
self.cache_dic, self.current = cache_update(self, **kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
@ -519,7 +525,8 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
context,
|
||||
seq_len,
|
||||
clip_fea=None,
|
||||
y=None,
|
||||
y=None
|
||||
# **kwargs
|
||||
):
|
||||
r"""
|
||||
Forward pass through the diffusion model
|
||||
|
@ -1,3 +1,3 @@
|
||||
from .cache_init import cache_init
|
||||
from .cache_init import cache_init, cache_update
|
||||
from .cal_type import cal_type
|
||||
from .force_scheduler import force_scheduler
|
@ -65,3 +65,72 @@ def cache_init(self, num_steps= 50):
|
||||
current['num_steps'] = num_steps
|
||||
|
||||
return cache_dic, current
|
||||
|
||||
def cache_update(self, num_steps= 50, **kwargs):
|
||||
'''
|
||||
Initialization for cache.
|
||||
'''
|
||||
cache_dic = {}
|
||||
cache = {}
|
||||
cache[-1]={}
|
||||
cache[-1]['cond_stream']={}
|
||||
cache[-1]['uncond_stream']={}
|
||||
cache_dic['cache_counter'] = 0
|
||||
|
||||
for j in range(self.num_layers):
|
||||
cache[-1]['cond_stream'][j] = {}
|
||||
cache[-1]['uncond_stream'][j] = {}
|
||||
|
||||
cache_dic['taylor_cache'] = False
|
||||
cache_dic['Delta-DiT'] = False
|
||||
|
||||
|
||||
cache_dic['cache_type'] = 'random'
|
||||
cache_dic['fresh_ratio_schedule'] = 'ToCa'
|
||||
cache_dic['fresh_ratio'] = 0.0
|
||||
cache_dic['fresh_threshold'] = 1
|
||||
cache_dic['force_fresh'] = 'global'
|
||||
|
||||
mode = 'Taylor'
|
||||
|
||||
if mode == 'original':
|
||||
cache_dic['cache'] = cache
|
||||
cache_dic['force_fresh'] = 'global'
|
||||
cache_dic['max_order'] = 0
|
||||
cache_dic['first_enhance'] = 3
|
||||
|
||||
elif mode == 'ToCa':
|
||||
cache_dic['cache_type'] = 'attention'
|
||||
cache_dic['cache'] = cache
|
||||
cache_dic['fresh_ratio_schedule'] = 'ToCa'
|
||||
cache_dic['fresh_ratio'] = 0.1
|
||||
cache_dic['fresh_threshold'] = 5
|
||||
cache_dic['force_fresh'] = 'global'
|
||||
cache_dic['soft_fresh_weight'] = 0.0
|
||||
cache_dic['max_order'] = 0
|
||||
cache_dic['first_enhance'] = 3
|
||||
|
||||
elif mode == 'Taylor':
|
||||
taylor_max_order = kwargs.get("max_order", 1)
|
||||
taylor_fresh_threshold = kwargs.get("fresh_threshold", 5)
|
||||
# print("*****************use max_order : {} fresh_threshold : {}".format(taylor_max_order, taylor_fresh_threshold))
|
||||
cache_dic['cache'] = cache
|
||||
cache_dic['fresh_threshold'] = taylor_fresh_threshold
|
||||
cache_dic['taylor_cache'] = True
|
||||
cache_dic['max_order'] = taylor_max_order
|
||||
cache_dic['first_enhance'] = 1
|
||||
|
||||
elif mode == 'Delta':
|
||||
cache_dic['cache'] = cache
|
||||
cache_dic['fresh_ratio'] = 0.0
|
||||
cache_dic['fresh_threshold'] = 3
|
||||
cache_dic['Delta-DiT'] = True
|
||||
cache_dic['max_order'] = 0
|
||||
cache_dic['first_enhance'] = 1
|
||||
|
||||
current = {}
|
||||
current['activated_steps'] = [0]
|
||||
current['step'] = 0
|
||||
current['num_steps'] = num_steps
|
||||
|
||||
return cache_dic, current
|
||||
|
Loading…
Reference in New Issue
Block a user