mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	merge taylor cache in generate.py and add config args for taylor cache
This commit is contained in:
		
							parent
							
								
									caac4152cf
								
							
						
					
					
						commit
						ce7766085e
					
				
							
								
								
									
										40
									
								
								generate.py
									
									
									
									
									
								
							
							
						
						
									
										40
									
								
								generate.py
									
									
									
									
									
								
							@ -30,6 +30,9 @@ from wan.modules.model import sinusoidal_embedding_1d
 | 
			
		||||
from wan.utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
 | 
			
		||||
                               get_sampling_sigmas, retrieve_timesteps)
 | 
			
		||||
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
 | 
			
		||||
from wan.taylorseer.generates import wan_t2v_generate, wan_i2v_generate
 | 
			
		||||
from wan.taylorseer.forwards import wan_forward, xfusers_wan_forward, wan_attention_forward
 | 
			
		||||
import types
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
EXAMPLE_PROMPT = {
 | 
			
		||||
@ -802,7 +805,7 @@ def _parse_args():
 | 
			
		||||
        "--enable_teacache",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        default=False,
 | 
			
		||||
        help=" use ret_steps or not")
 | 
			
		||||
        help="enable teacache or not")
 | 
			
		||||
    
 | 
			
		||||
    #teacache_thresh
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
@ -811,6 +814,27 @@ def _parse_args():
 | 
			
		||||
        default= 0.2,
 | 
			
		||||
        help="tea_cache threshold")
 | 
			
		||||
    
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--enable_taylor_cache",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        default=False,
 | 
			
		||||
        help="enable taylor cache or not"
 | 
			
		||||
    )
 | 
			
		||||
    
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--taylor_cache_max_order",
 | 
			
		||||
        type=int,
 | 
			
		||||
        default= -1,
 | 
			
		||||
        help="taylor cache max_order or not"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--taylor_cache_fresh_threshold",
 | 
			
		||||
        type=int,
 | 
			
		||||
        default= -1,
 | 
			
		||||
        help="taylor cache fresh threshold"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--enable-fa3",
 | 
			
		||||
        "--enable_fa3",
 | 
			
		||||
@ -940,7 +964,7 @@ def generate(args):
 | 
			
		||||
            t5_fsdp=args.t5_fsdp,
 | 
			
		||||
            dit_fsdp=args.dit_fsdp,
 | 
			
		||||
            use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
 | 
			
		||||
            t5_cpu=args.t5_cpu,
 | 
			
		||||
            t5_cpu=args.t5_cpu, use_taylor_cache= args.enable_taylor_cache
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if args.enable_fa3:
 | 
			
		||||
@ -980,6 +1004,18 @@ def generate(args):
 | 
			
		||||
        logging.info(
 | 
			
		||||
            f"Generating {'image' if 't2i' in args.task else 'video'} ...")
 | 
			
		||||
        start = time.time()
 | 
			
		||||
 | 
			
		||||
        if args.enable_taylor_cache:
 | 
			
		||||
            wan_t2v.generate = types.MethodType(wan_t2v_generate, wan_t2v)
 | 
			
		||||
            cache_update_dic = {}
 | 
			
		||||
            if args.taylor_cache_max_order > 0:
 | 
			
		||||
                cache_update_dic["max_order"] = args.taylor_cache_max_order
 | 
			
		||||
            if args.taylor_cache_fresh_threshold > 0:
 | 
			
		||||
                cache_update_dic["fresh_threshold"] = args.taylor_cache_fresh_threshold
 | 
			
		||||
            
 | 
			
		||||
            # print("*****************>>>>>>>>>>>>>>>>>>",cache_update_dic)
 | 
			
		||||
            wan_t2v.model.cache_update(**cache_update_dic)
 | 
			
		||||
        
 | 
			
		||||
        video = wan_t2v.generate(
 | 
			
		||||
            args.prompt,
 | 
			
		||||
            size=SIZE_CONFIGS[args.size],
 | 
			
		||||
 | 
			
		||||
@ -8,7 +8,7 @@ from diffusers.configuration_utils import ConfigMixin, register_to_config
 | 
			
		||||
from diffusers.models.modeling_utils import ModelMixin
 | 
			
		||||
 | 
			
		||||
from .attention import flash_attention
 | 
			
		||||
from wan.taylorseer.cache_functions import cache_init
 | 
			
		||||
from wan.taylorseer.cache_functions import cache_init, cache_update
 | 
			
		||||
 | 
			
		||||
__all__ = ['WanModel']
 | 
			
		||||
 | 
			
		||||
@ -512,6 +512,12 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
			
		||||
 | 
			
		||||
    def cache_init(self):
 | 
			
		||||
        self.cache_dic, self.current = cache_init(self)
 | 
			
		||||
    
 | 
			
		||||
    def cache_update(self, **kwargs):
 | 
			
		||||
        # print("############### kwargs: ",  kwargs)
 | 
			
		||||
        if "max_order" in kwargs or "fresh_threshold" in kwargs:
 | 
			
		||||
            self.cache_dic, self.current = cache_update(self, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        x,
 | 
			
		||||
@ -519,7 +525,8 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
			
		||||
        context,
 | 
			
		||||
        seq_len,
 | 
			
		||||
        clip_fea=None,
 | 
			
		||||
        y=None,
 | 
			
		||||
        y=None
 | 
			
		||||
        # **kwargs
 | 
			
		||||
    ):
 | 
			
		||||
        r"""
 | 
			
		||||
        Forward pass through the diffusion model
 | 
			
		||||
 | 
			
		||||
@ -1,3 +1,3 @@
 | 
			
		||||
from .cache_init import cache_init
 | 
			
		||||
from .cache_init import cache_init, cache_update
 | 
			
		||||
from .cal_type import cal_type
 | 
			
		||||
from .force_scheduler import force_scheduler
 | 
			
		||||
@ -65,3 +65,72 @@ def cache_init(self, num_steps= 50):
 | 
			
		||||
    current['num_steps'] = num_steps
 | 
			
		||||
 | 
			
		||||
    return cache_dic, current
 | 
			
		||||
 | 
			
		||||
def cache_update(self, num_steps= 50, **kwargs):   
 | 
			
		||||
    '''
 | 
			
		||||
    Initialization for cache.
 | 
			
		||||
    '''
 | 
			
		||||
    cache_dic = {}
 | 
			
		||||
    cache = {}
 | 
			
		||||
    cache[-1]={}
 | 
			
		||||
    cache[-1]['cond_stream']={}
 | 
			
		||||
    cache[-1]['uncond_stream']={}
 | 
			
		||||
    cache_dic['cache_counter'] = 0
 | 
			
		||||
 | 
			
		||||
    for j in range(self.num_layers):
 | 
			
		||||
        cache[-1]['cond_stream'][j] = {}
 | 
			
		||||
        cache[-1]['uncond_stream'][j] = {}
 | 
			
		||||
 | 
			
		||||
    cache_dic['taylor_cache'] = False
 | 
			
		||||
    cache_dic['Delta-DiT'] = False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    cache_dic['cache_type'] = 'random'
 | 
			
		||||
    cache_dic['fresh_ratio_schedule'] = 'ToCa' 
 | 
			
		||||
    cache_dic['fresh_ratio'] = 0.0
 | 
			
		||||
    cache_dic['fresh_threshold'] = 1
 | 
			
		||||
    cache_dic['force_fresh'] = 'global'
 | 
			
		||||
 | 
			
		||||
    mode = 'Taylor'
 | 
			
		||||
 | 
			
		||||
    if mode == 'original':
 | 
			
		||||
        cache_dic['cache'] = cache
 | 
			
		||||
        cache_dic['force_fresh'] = 'global'
 | 
			
		||||
        cache_dic['max_order'] = 0
 | 
			
		||||
        cache_dic['first_enhance'] = 3
 | 
			
		||||
        
 | 
			
		||||
    elif mode == 'ToCa':
 | 
			
		||||
        cache_dic['cache_type'] = 'attention'
 | 
			
		||||
        cache_dic['cache'] = cache
 | 
			
		||||
        cache_dic['fresh_ratio_schedule'] = 'ToCa' 
 | 
			
		||||
        cache_dic['fresh_ratio'] = 0.1
 | 
			
		||||
        cache_dic['fresh_threshold'] = 5
 | 
			
		||||
        cache_dic['force_fresh'] = 'global' 
 | 
			
		||||
        cache_dic['soft_fresh_weight'] = 0.0
 | 
			
		||||
        cache_dic['max_order'] = 0
 | 
			
		||||
        cache_dic['first_enhance'] = 3
 | 
			
		||||
    
 | 
			
		||||
    elif mode == 'Taylor':
 | 
			
		||||
        taylor_max_order = kwargs.get("max_order", 1)
 | 
			
		||||
        taylor_fresh_threshold  = kwargs.get("fresh_threshold", 5)
 | 
			
		||||
        # print("*****************use max_order : {} fresh_threshold : {}".format(taylor_max_order,  taylor_fresh_threshold))
 | 
			
		||||
        cache_dic['cache'] = cache
 | 
			
		||||
        cache_dic['fresh_threshold'] = taylor_fresh_threshold
 | 
			
		||||
        cache_dic['taylor_cache'] = True
 | 
			
		||||
        cache_dic['max_order'] = taylor_max_order
 | 
			
		||||
        cache_dic['first_enhance'] = 1
 | 
			
		||||
 | 
			
		||||
    elif mode == 'Delta':
 | 
			
		||||
        cache_dic['cache'] = cache
 | 
			
		||||
        cache_dic['fresh_ratio'] = 0.0
 | 
			
		||||
        cache_dic['fresh_threshold'] = 3
 | 
			
		||||
        cache_dic['Delta-DiT'] = True
 | 
			
		||||
        cache_dic['max_order'] = 0
 | 
			
		||||
        cache_dic['first_enhance'] = 1
 | 
			
		||||
 | 
			
		||||
    current = {}
 | 
			
		||||
    current['activated_steps'] = [0]
 | 
			
		||||
    current['step'] = 0
 | 
			
		||||
    current['num_steps'] = num_steps
 | 
			
		||||
 | 
			
		||||
    return cache_dic, current
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user