mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-07-13 19:20:09 +00:00
merge taylor cache in generate.py and add config args for taylor cache
This commit is contained in:
parent
caac4152cf
commit
ce7766085e
40
generate.py
40
generate.py
@ -30,6 +30,9 @@ from wan.modules.model import sinusoidal_embedding_1d
|
|||||||
from wan.utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
from wan.utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
||||||
get_sampling_sigmas, retrieve_timesteps)
|
get_sampling_sigmas, retrieve_timesteps)
|
||||||
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||||
|
from wan.taylorseer.generates import wan_t2v_generate, wan_i2v_generate
|
||||||
|
from wan.taylorseer.forwards import wan_forward, xfusers_wan_forward, wan_attention_forward
|
||||||
|
import types
|
||||||
|
|
||||||
|
|
||||||
EXAMPLE_PROMPT = {
|
EXAMPLE_PROMPT = {
|
||||||
@ -802,7 +805,7 @@ def _parse_args():
|
|||||||
"--enable_teacache",
|
"--enable_teacache",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
default=False,
|
default=False,
|
||||||
help=" use ret_steps or not")
|
help="enable teacache or not")
|
||||||
|
|
||||||
#teacache_thresh
|
#teacache_thresh
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -811,6 +814,27 @@ def _parse_args():
|
|||||||
default= 0.2,
|
default= 0.2,
|
||||||
help="tea_cache threshold")
|
help="tea_cache threshold")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable_taylor_cache",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="enable taylor cache or not"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--taylor_cache_max_order",
|
||||||
|
type=int,
|
||||||
|
default= -1,
|
||||||
|
help="taylor cache max_order or not"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--taylor_cache_fresh_threshold",
|
||||||
|
type=int,
|
||||||
|
default= -1,
|
||||||
|
help="taylor cache fresh threshold"
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable-fa3",
|
"--enable-fa3",
|
||||||
"--enable_fa3",
|
"--enable_fa3",
|
||||||
@ -940,7 +964,7 @@ def generate(args):
|
|||||||
t5_fsdp=args.t5_fsdp,
|
t5_fsdp=args.t5_fsdp,
|
||||||
dit_fsdp=args.dit_fsdp,
|
dit_fsdp=args.dit_fsdp,
|
||||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||||
t5_cpu=args.t5_cpu,
|
t5_cpu=args.t5_cpu, use_taylor_cache= args.enable_taylor_cache
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.enable_fa3:
|
if args.enable_fa3:
|
||||||
@ -980,6 +1004,18 @@ def generate(args):
|
|||||||
logging.info(
|
logging.info(
|
||||||
f"Generating {'image' if 't2i' in args.task else 'video'} ...")
|
f"Generating {'image' if 't2i' in args.task else 'video'} ...")
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
|
||||||
|
if args.enable_taylor_cache:
|
||||||
|
wan_t2v.generate = types.MethodType(wan_t2v_generate, wan_t2v)
|
||||||
|
cache_update_dic = {}
|
||||||
|
if args.taylor_cache_max_order > 0:
|
||||||
|
cache_update_dic["max_order"] = args.taylor_cache_max_order
|
||||||
|
if args.taylor_cache_fresh_threshold > 0:
|
||||||
|
cache_update_dic["fresh_threshold"] = args.taylor_cache_fresh_threshold
|
||||||
|
|
||||||
|
# print("*****************>>>>>>>>>>>>>>>>>>",cache_update_dic)
|
||||||
|
wan_t2v.model.cache_update(**cache_update_dic)
|
||||||
|
|
||||||
video = wan_t2v.generate(
|
video = wan_t2v.generate(
|
||||||
args.prompt,
|
args.prompt,
|
||||||
size=SIZE_CONFIGS[args.size],
|
size=SIZE_CONFIGS[args.size],
|
||||||
|
@ -8,7 +8,7 @@ from diffusers.configuration_utils import ConfigMixin, register_to_config
|
|||||||
from diffusers.models.modeling_utils import ModelMixin
|
from diffusers.models.modeling_utils import ModelMixin
|
||||||
|
|
||||||
from .attention import flash_attention
|
from .attention import flash_attention
|
||||||
from wan.taylorseer.cache_functions import cache_init
|
from wan.taylorseer.cache_functions import cache_init, cache_update
|
||||||
|
|
||||||
__all__ = ['WanModel']
|
__all__ = ['WanModel']
|
||||||
|
|
||||||
@ -512,6 +512,12 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
|
|
||||||
def cache_init(self):
|
def cache_init(self):
|
||||||
self.cache_dic, self.current = cache_init(self)
|
self.cache_dic, self.current = cache_init(self)
|
||||||
|
|
||||||
|
def cache_update(self, **kwargs):
|
||||||
|
# print("############### kwargs: ", kwargs)
|
||||||
|
if "max_order" in kwargs or "fresh_threshold" in kwargs:
|
||||||
|
self.cache_dic, self.current = cache_update(self, **kwargs)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
@ -519,7 +525,8 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
context,
|
context,
|
||||||
seq_len,
|
seq_len,
|
||||||
clip_fea=None,
|
clip_fea=None,
|
||||||
y=None,
|
y=None
|
||||||
|
# **kwargs
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Forward pass through the diffusion model
|
Forward pass through the diffusion model
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
from .cache_init import cache_init
|
from .cache_init import cache_init, cache_update
|
||||||
from .cal_type import cal_type
|
from .cal_type import cal_type
|
||||||
from .force_scheduler import force_scheduler
|
from .force_scheduler import force_scheduler
|
@ -65,3 +65,72 @@ def cache_init(self, num_steps= 50):
|
|||||||
current['num_steps'] = num_steps
|
current['num_steps'] = num_steps
|
||||||
|
|
||||||
return cache_dic, current
|
return cache_dic, current
|
||||||
|
|
||||||
|
def cache_update(self, num_steps= 50, **kwargs):
|
||||||
|
'''
|
||||||
|
Initialization for cache.
|
||||||
|
'''
|
||||||
|
cache_dic = {}
|
||||||
|
cache = {}
|
||||||
|
cache[-1]={}
|
||||||
|
cache[-1]['cond_stream']={}
|
||||||
|
cache[-1]['uncond_stream']={}
|
||||||
|
cache_dic['cache_counter'] = 0
|
||||||
|
|
||||||
|
for j in range(self.num_layers):
|
||||||
|
cache[-1]['cond_stream'][j] = {}
|
||||||
|
cache[-1]['uncond_stream'][j] = {}
|
||||||
|
|
||||||
|
cache_dic['taylor_cache'] = False
|
||||||
|
cache_dic['Delta-DiT'] = False
|
||||||
|
|
||||||
|
|
||||||
|
cache_dic['cache_type'] = 'random'
|
||||||
|
cache_dic['fresh_ratio_schedule'] = 'ToCa'
|
||||||
|
cache_dic['fresh_ratio'] = 0.0
|
||||||
|
cache_dic['fresh_threshold'] = 1
|
||||||
|
cache_dic['force_fresh'] = 'global'
|
||||||
|
|
||||||
|
mode = 'Taylor'
|
||||||
|
|
||||||
|
if mode == 'original':
|
||||||
|
cache_dic['cache'] = cache
|
||||||
|
cache_dic['force_fresh'] = 'global'
|
||||||
|
cache_dic['max_order'] = 0
|
||||||
|
cache_dic['first_enhance'] = 3
|
||||||
|
|
||||||
|
elif mode == 'ToCa':
|
||||||
|
cache_dic['cache_type'] = 'attention'
|
||||||
|
cache_dic['cache'] = cache
|
||||||
|
cache_dic['fresh_ratio_schedule'] = 'ToCa'
|
||||||
|
cache_dic['fresh_ratio'] = 0.1
|
||||||
|
cache_dic['fresh_threshold'] = 5
|
||||||
|
cache_dic['force_fresh'] = 'global'
|
||||||
|
cache_dic['soft_fresh_weight'] = 0.0
|
||||||
|
cache_dic['max_order'] = 0
|
||||||
|
cache_dic['first_enhance'] = 3
|
||||||
|
|
||||||
|
elif mode == 'Taylor':
|
||||||
|
taylor_max_order = kwargs.get("max_order", 1)
|
||||||
|
taylor_fresh_threshold = kwargs.get("fresh_threshold", 5)
|
||||||
|
# print("*****************use max_order : {} fresh_threshold : {}".format(taylor_max_order, taylor_fresh_threshold))
|
||||||
|
cache_dic['cache'] = cache
|
||||||
|
cache_dic['fresh_threshold'] = taylor_fresh_threshold
|
||||||
|
cache_dic['taylor_cache'] = True
|
||||||
|
cache_dic['max_order'] = taylor_max_order
|
||||||
|
cache_dic['first_enhance'] = 1
|
||||||
|
|
||||||
|
elif mode == 'Delta':
|
||||||
|
cache_dic['cache'] = cache
|
||||||
|
cache_dic['fresh_ratio'] = 0.0
|
||||||
|
cache_dic['fresh_threshold'] = 3
|
||||||
|
cache_dic['Delta-DiT'] = True
|
||||||
|
cache_dic['max_order'] = 0
|
||||||
|
cache_dic['first_enhance'] = 1
|
||||||
|
|
||||||
|
current = {}
|
||||||
|
current['activated_steps'] = [0]
|
||||||
|
current['step'] = 0
|
||||||
|
current['num_steps'] = num_steps
|
||||||
|
|
||||||
|
return cache_dic, current
|
||||||
|
Loading…
Reference in New Issue
Block a user