From 6686e7fd18de7318e490b739c9036b3b9de0d369 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 4 Jun 2025 10:39:21 +0000 Subject: [PATCH 1/2] [feat] enable taylor cache --- taylor_generator.py | 422 ++++++++++++++++++ .../xdit_context_parallel_taylor.py | 192 ++++++++ wan/modules/model.py | 6 +- wan/taylorseer/cache_functions/__init__.py | 3 + wan/taylorseer/cache_functions/cache_init.py | 67 +++ wan/taylorseer/cache_functions/cal_type.py | 42 ++ .../cache_functions/force_scheduler.py | 16 + wan/taylorseer/forwards/__init__.py | 5 + .../forwards/wan_attention_cache_forward.py | 23 + .../forwards/wan_attention_forward.py | 82 ++++ .../wan_attention_forward_cache_step.py | 37 ++ wan/taylorseer/forwards/wan_cache_forward.py | 18 + wan/taylorseer/forwards/wan_forward.py | 114 +++++ .../forwards/xfusers_wan_forward.py | 104 +++++ wan/taylorseer/generates/__init__.py | 2 + wan/taylorseer/generates/wan_i2v_generate.py | 264 +++++++++++ wan/taylorseer/generates/wan_t2v_generate.py | 207 +++++++++ wan/taylorseer/taylorseer_utils/__init__.py | 47 ++ wan/text2video.py | 43 +- 19 files changed, 1686 insertions(+), 8 deletions(-) create mode 100644 taylor_generator.py create mode 100644 wan/distributed/xdit_context_parallel_taylor.py create mode 100644 wan/taylorseer/cache_functions/__init__.py create mode 100644 wan/taylorseer/cache_functions/cache_init.py create mode 100644 wan/taylorseer/cache_functions/cal_type.py create mode 100644 wan/taylorseer/cache_functions/force_scheduler.py create mode 100644 wan/taylorseer/forwards/__init__.py create mode 100644 wan/taylorseer/forwards/wan_attention_cache_forward.py create mode 100644 wan/taylorseer/forwards/wan_attention_forward.py create mode 100644 wan/taylorseer/forwards/wan_attention_forward_cache_step.py create mode 100644 wan/taylorseer/forwards/wan_cache_forward.py create mode 100644 wan/taylorseer/forwards/wan_forward.py create mode 100644 wan/taylorseer/forwards/xfusers_wan_forward.py create mode 100644 wan/taylorseer/generates/__init__.py create mode 100644 wan/taylorseer/generates/wan_i2v_generate.py create mode 100644 wan/taylorseer/generates/wan_t2v_generate.py create mode 100644 wan/taylorseer/taylorseer_utils/__init__.py diff --git a/taylor_generator.py b/taylor_generator.py new file mode 100644 index 0000000..a733ba2 --- /dev/null +++ b/taylor_generator.py @@ -0,0 +1,422 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import argparse +from datetime import datetime +import logging +import os +import sys +import warnings + +warnings.filterwarnings('ignore') + +import torch, random +import torch.distributed as dist +from PIL import Image + +import wan +from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES +from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander +from wan.utils.utils import cache_video, cache_image, str2bool +from wan.taylorseer.generates import wan_t2v_generate, wan_i2v_generate +from wan.taylorseer.forwards import wan_forward, xfusers_wan_forward, wan_attention_forward +import types + +EXAMPLE_PROMPT = { + "t2v-1.3B": { + "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + }, + "t2v-14B": { + "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + }, + "t2i-14B": { + "prompt": "一个朴素端庄的美人", + }, + "i2v-14B": { + "prompt": + "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.", + "image": + "examples/i2v_input.JPG", + }, +} + + +def _validate_args(args): + # Basic check + assert args.ckpt_dir is not None, "Please specify the checkpoint directory." + assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}" + assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}" + + # The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks. + if args.sample_steps is None: + args.sample_steps = 40 if "i2v" in args.task else 50 + + if args.sample_shift is None: + args.sample_shift = 5.0 + if "i2v" in args.task and args.size in ["832*480", "480*832"]: + args.sample_shift = 3.0 + + # The default number of frames are 1 for text-to-image tasks and 81 for other tasks. + if args.frame_num is None: + args.frame_num = 1 if "t2i" in args.task else 81 + + # T2I frame_num check + if "t2i" in args.task: + assert args.frame_num == 1, f"Unsupport frame_num {args.frame_num} for task {args.task}" + + args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint( + 0, sys.maxsize) + # Size check + assert args.size in SUPPORTED_SIZES[ + args. + task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}" + + +def _parse_args(): + parser = argparse.ArgumentParser( + description="Generate a image or video from a text prompt or image using Wan" + ) + parser.add_argument( + "--task", + type=str, + default="t2v-14B", + choices=list(WAN_CONFIGS.keys()), + help="The task to run.") + parser.add_argument( + "--size", + type=str, + default="1280*720", + choices=list(SIZE_CONFIGS.keys()), + help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image." + ) + parser.add_argument( + "--frame_num", + type=int, + default=None, + help="How many frames to sample from a image or video. The number should be 4n+1" + ) + parser.add_argument( + "--ckpt_dir", + type=str, + default=None, + help="The path to the checkpoint directory.") + parser.add_argument( + "--offload_model", + type=str2bool, + default=None, + help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage." + ) + parser.add_argument( + "--ulysses_size", + type=int, + default=1, + help="The size of the ulysses parallelism in DiT.") + parser.add_argument( + "--ring_size", + type=int, + default=1, + help="The size of the ring attention parallelism in DiT.") + parser.add_argument( + "--t5_fsdp", + action="store_true", + default=False, + help="Whether to use FSDP for T5.") + parser.add_argument( + "--t5_cpu", + action="store_true", + default=False, + help="Whether to place T5 model on CPU.") + parser.add_argument( + "--dit_fsdp", + action="store_true", + default=False, + help="Whether to use FSDP for DiT.") + parser.add_argument( + "--save_file", + type=str, + default=None, + help="The file to save the generated image or video to.") + parser.add_argument( + "--prompt", + type=str, + default=None, + help="The prompt to generate the image or video from.") + parser.add_argument( + "--use_prompt_extend", + action="store_true", + default=False, + help="Whether to use prompt extend.") + parser.add_argument( + "--prompt_extend_method", + type=str, + default="local_qwen", + choices=["dashscope", "local_qwen"], + help="The prompt extend method to use.") + parser.add_argument( + "--prompt_extend_model", + type=str, + default=None, + help="The prompt extend model to use.") + parser.add_argument( + "--prompt_extend_target_lang", + type=str, + default="zh", + choices=["zh", "en"], + help="The target language of prompt extend.") + parser.add_argument( + "--base_seed", + type=int, + default=-1, + help="The seed to use for generating the image or video.") + parser.add_argument( + "--image", + type=str, + default=None, + help="The image to generate the video from.") + parser.add_argument( + "--sample_solver", + type=str, + default='unipc', + choices=['unipc', 'dpm++'], + help="The solver used to sample.") + parser.add_argument( + "--sample_steps", type=int, default=None, help="The sampling steps.") + parser.add_argument( + "--sample_shift", + type=float, + default=None, + help="Sampling shift factor for flow matching schedulers.") + parser.add_argument( + "--sample_guide_scale", + type=float, + default=5.0, + help="Classifier free guidance scale.") + + args = parser.parse_args() + + _validate_args(args) + + return args + + +def _init_logging(rank): + # logging + if rank == 0: + # set format + logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s] %(levelname)s: %(message)s", + handlers=[logging.StreamHandler(stream=sys.stdout)]) + else: + logging.basicConfig(level=logging.ERROR) + + +def generate(args): + rank = int(os.getenv("RANK", 0)) + world_size = int(os.getenv("WORLD_SIZE", 1)) + local_rank = int(os.getenv("LOCAL_RANK", 0)) + device = local_rank + _init_logging(rank) + + if args.offload_model is None: + args.offload_model = False if world_size > 1 else True + logging.info( + f"offload_model is not specified, set to {args.offload_model}.") + if world_size > 1: + torch.cuda.set_device(local_rank) + dist.init_process_group( + backend="nccl", + init_method="env://", + rank=rank, + world_size=world_size) + else: + assert not ( + args.t5_fsdp or args.dit_fsdp + ), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments." + assert not ( + args.ulysses_size > 1 or args.ring_size > 1 + ), f"context parallel are not supported in non-distributed environments." + + if args.ulysses_size > 1 or args.ring_size > 1: + assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size." + from xfuser.core.distributed import (initialize_model_parallel, + init_distributed_environment) + init_distributed_environment( + rank=dist.get_rank(), world_size=dist.get_world_size()) + + initialize_model_parallel( + sequence_parallel_degree=dist.get_world_size(), + ring_degree=args.ring_size, + ulysses_degree=args.ulysses_size, + ) + + if args.use_prompt_extend: + if args.prompt_extend_method == "dashscope": + prompt_expander = DashScopePromptExpander( + model_name=args.prompt_extend_model, is_vl="i2v" in args.task) + elif args.prompt_extend_method == "local_qwen": + prompt_expander = QwenPromptExpander( + model_name=args.prompt_extend_model, + is_vl="i2v" in args.task, + device=rank) + else: + raise NotImplementedError( + f"Unsupport prompt_extend_method: {args.prompt_extend_method}") + + cfg = WAN_CONFIGS[args.task] + if args.ulysses_size > 1: + assert cfg.num_heads % args.ulysses_size == 0, f"`num_heads` must be divisible by `ulysses_size`." + + logging.info(f"Generation job args: {args}") + logging.info(f"Generation model config: {cfg}") + + if dist.is_initialized(): + base_seed = [args.base_seed] if rank == 0 else [None] + dist.broadcast_object_list(base_seed, src=0) + args.base_seed = base_seed[0] + + if "t2v" in args.task or "t2i" in args.task: + if args.prompt is None: + args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] + logging.info(f"Input prompt: {args.prompt}") + if args.use_prompt_extend: + logging.info("Extending prompt ...") + if rank == 0: + prompt_output = prompt_expander( + args.prompt, + tar_lang=args.prompt_extend_target_lang, + seed=args.base_seed) + if prompt_output.status == False: + logging.info( + f"Extending prompt failed: {prompt_output.message}") + logging.info("Falling back to original prompt.") + input_prompt = args.prompt + else: + input_prompt = prompt_output.prompt + input_prompt = [input_prompt] + else: + input_prompt = [None] + if dist.is_initialized(): + dist.broadcast_object_list(input_prompt, src=0) + args.prompt = input_prompt[0] + logging.info(f"Extended prompt: {args.prompt}") + + logging.info("Creating WanT2V pipeline.") + wan_t2v = wan.WanT2V( + config=cfg, + checkpoint_dir=args.ckpt_dir, + device_id=device, + rank=rank, + t5_fsdp=args.t5_fsdp, + dit_fsdp=args.dit_fsdp, + use_usp=(args.ulysses_size > 1 or args.ring_size > 1), + t5_cpu=args.t5_cpu, use_taylor_cache= True + ) + + logging.info( + f"Generating {'image' if 't2i' in args.task else 'video'} ...") + + # TaylorSeer + wan_t2v.generate = types.MethodType(wan_t2v_generate, wan_t2v) + #wan_t2v = torch.compile(wan_t2v, mode="max-autotune") + video = wan_t2v.generate( + args.prompt, + size=SIZE_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=args.sample_steps, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model) + + else: + if args.prompt is None: + args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] + if args.image is None: + args.image = EXAMPLE_PROMPT[args.task]["image"] + logging.info(f"Input prompt: {args.prompt}") + logging.info(f"Input image: {args.image}") + + img = Image.open(args.image).convert("RGB") + if args.use_prompt_extend: + logging.info("Extending prompt ...") + if rank == 0: + prompt_output = prompt_expander( + args.prompt, + tar_lang=args.prompt_extend_target_lang, + image=img, + seed=args.base_seed) + if prompt_output.status == False: + logging.info( + f"Extending prompt failed: {prompt_output.message}") + logging.info("Falling back to original prompt.") + input_prompt = args.prompt + else: + input_prompt = prompt_output.prompt + input_prompt = [input_prompt] + else: + input_prompt = [None] + if dist.is_initialized(): + dist.broadcast_object_list(input_prompt, src=0) + args.prompt = input_prompt[0] + logging.info(f"Extended prompt: {args.prompt}") + + logging.info("Creating WanI2V pipeline.") + wan_i2v = wan.WanI2V( + config=cfg, + checkpoint_dir=args.ckpt_dir, + device_id=device, + rank=rank, + t5_fsdp=args.t5_fsdp, + dit_fsdp=args.dit_fsdp, + use_usp=(args.ulysses_size > 1 or args.ring_size > 1), + t5_cpu=args.t5_cpu, + ) + + logging.info("Generating video ...") + + # TaylorSeer + wan_i2v.generate = types.MethodType(wan_i2v_generate, wan_i2v) + + video = wan_i2v.generate( + args.prompt, + img, + max_area=MAX_AREA_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=args.sample_steps, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model) + + if rank == 0: + if args.save_file is None: + formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S") + formatted_prompt = args.prompt.replace(" ", "_").replace("/", + "_")[:50] + suffix = '.png' if "t2i" in args.task else '.mp4' + args.save_file = f"{args.task}_{args.size.replace('*','x') if sys.platform=='win32' else args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix + + if "t2i" in args.task: + logging.info(f"Saving generated image to {args.save_file}") + cache_image( + tensor=video.squeeze(1)[None], + save_file=args.save_file, + nrow=1, + normalize=True, + value_range=(-1, 1)) + else: + logging.info(f"Saving generated video to {args.save_file}") + cache_video( + tensor=video[None], + save_file=args.save_file, + fps=cfg.sample_fps, + nrow=1, + normalize=True, + value_range=(-1, 1)) + logging.info("Finished.") + + +if __name__ == "__main__": + args = _parse_args() + generate(args) diff --git a/wan/distributed/xdit_context_parallel_taylor.py b/wan/distributed/xdit_context_parallel_taylor.py new file mode 100644 index 0000000..01936ce --- /dev/null +++ b/wan/distributed/xdit_context_parallel_taylor.py @@ -0,0 +1,192 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import torch +import torch.cuda.amp as amp +from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) +from xfuser.core.long_ctx_attention import xFuserLongContextAttention + +from ..modules.model import sinusoidal_embedding_1d + + +def pad_freqs(original_tensor, target_len): + seq_len, s1, s2 = original_tensor.shape + pad_size = target_len - seq_len + padding_tensor = torch.ones( + pad_size, + s1, + s2, + dtype=original_tensor.dtype, + device=original_tensor.device) + padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) + return padded_tensor + + +@amp.autocast(enabled=False) +def rope_apply(x, grid_sizes, freqs): + """ + x: [B, L, N, C]. + grid_sizes: [B, 3]. + freqs: [M, C // 2]. + """ + s, n, c = x.size(1), x.size(2), x.size(3) // 2 + # split freqs + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + # loop over samples + output = [] + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + seq_len = f * h * w + + # precompute multipliers + x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape( + s, n, -1, 2)) + freqs_i = torch.cat([ + freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], + dim=-1).reshape(seq_len, 1, -1) + + # apply rotary embedding + sp_size = get_sequence_parallel_world_size() + sp_rank = get_sequence_parallel_rank() + freqs_i = pad_freqs(freqs_i, s * sp_size) + s_per_rank = s + freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) * + s_per_rank), :, :] + x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2) + x_i = torch.cat([x_i, x[i, s:]]) + + # append to collection + output.append(x_i) + return torch.stack(output).float() + + +def usp_dit_forward( + self, + x, + t, + context, + seq_len, + clip_fea=None, + y=None, +): + """ + x: A list of videos each with shape [C, T, H, W]. + t: [B]. + context: A list of text embeddings each with shape [L, C]. + """ + if self.model_type == 'i2v': + assert clip_fea is not None and y is not None + # params + device = self.patch_embedding.weight.device + if self.freqs.device != device: + self.freqs = self.freqs.to(device) + + if y is not None: + x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + grid_sizes = torch.stack( + [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + assert seq_lens.max() <= seq_len + x = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) + for u in x + ]) + + # time embeddings + with amp.autocast(dtype=torch.float32): + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).float()) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + assert e.dtype == torch.float32 and e0.dtype == torch.float32 + + # context + context_lens = None + context = self.text_embedding( + torch.stack([ + torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) + for u in context + ])) + + if clip_fea is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.freqs, + context=context, + context_lens=context_lens) + + # Context Parallel + x = torch.chunk( + x, get_sequence_parallel_world_size(), + dim=1)[get_sequence_parallel_rank()] + + for block in self.blocks: + x = block(x, **kwargs) + + # head + x = self.head(x, e) + + # Context Parallel + x = get_sp_group().all_gather(x, dim=1) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return [u.float() for u in x] + + +def usp_attn_forward(self, + x, + seq_lens, + grid_sizes, + freqs, + dtype=torch.bfloat16): + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + half_dtypes = (torch.float16, torch.bfloat16) + + def half(x): + return x if x.dtype in half_dtypes else x.to(dtype) + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n, d) + return q, k, v + + q, k, v = qkv_fn(x) + q = rope_apply(q, grid_sizes, freqs) + k = rope_apply(k, grid_sizes, freqs) + + # TODO: We should use unpaded q,k,v for attention. + # k_lens = seq_lens // get_sequence_parallel_world_size() + # if k_lens is not None: + # q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0) + # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0) + # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0) + + x = xFuserLongContextAttention()( + None, + query=half(q), + key=half(k), + value=half(v), + window_size=self.window_size) + + # TODO: padding after attention. + # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1) + + # output + x = x.flatten(2) + x = self.o(x) + return x diff --git a/wan/modules/model.py b/wan/modules/model.py index d5127fd..7bc24bc 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -8,6 +8,7 @@ from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin from .attention import flash_attention +from wan.taylorseer.cache_functions import cache_init __all__ = ['WanModel'] @@ -506,8 +507,11 @@ class WanModel(ModelMixin, ConfigMixin): self.img_emb = MLPProj(1280, dim, flf_pos_emb=model_type == 'flf2v') # initialize weights - self.init_weights() + self.init_weights() + self.cache_init() + def cache_init(self): + self.cache_dic, self.current = cache_init(self) def forward( self, x, diff --git a/wan/taylorseer/cache_functions/__init__.py b/wan/taylorseer/cache_functions/__init__.py new file mode 100644 index 0000000..b790804 --- /dev/null +++ b/wan/taylorseer/cache_functions/__init__.py @@ -0,0 +1,3 @@ +from .cache_init import cache_init +from .cal_type import cal_type +from .force_scheduler import force_scheduler \ No newline at end of file diff --git a/wan/taylorseer/cache_functions/cache_init.py b/wan/taylorseer/cache_functions/cache_init.py new file mode 100644 index 0000000..ce0ce74 --- /dev/null +++ b/wan/taylorseer/cache_functions/cache_init.py @@ -0,0 +1,67 @@ +#from wan.modules import WanModel + +def cache_init(self, num_steps= 50): + ''' + Initialization for cache. + ''' + cache_dic = {} + cache = {} + cache[-1]={} + cache[-1]['cond_stream']={} + cache[-1]['uncond_stream']={} + cache_dic['cache_counter'] = 0 + + for j in range(self.num_layers): + cache[-1]['cond_stream'][j] = {} + cache[-1]['uncond_stream'][j] = {} + + cache_dic['taylor_cache'] = False + cache_dic['Delta-DiT'] = False + + + cache_dic['cache_type'] = 'random' + cache_dic['fresh_ratio_schedule'] = 'ToCa' + cache_dic['fresh_ratio'] = 0.0 + cache_dic['fresh_threshold'] = 1 + cache_dic['force_fresh'] = 'global' + + mode = 'Taylor' + + if mode == 'original': + cache_dic['cache'] = cache + cache_dic['force_fresh'] = 'global' + cache_dic['max_order'] = 0 + cache_dic['first_enhance'] = 3 + + elif mode == 'ToCa': + cache_dic['cache_type'] = 'attention' + cache_dic['cache'] = cache + cache_dic['fresh_ratio_schedule'] = 'ToCa' + cache_dic['fresh_ratio'] = 0.1 + cache_dic['fresh_threshold'] = 5 + cache_dic['force_fresh'] = 'global' + cache_dic['soft_fresh_weight'] = 0.0 + cache_dic['max_order'] = 0 + cache_dic['first_enhance'] = 3 + + elif mode == 'Taylor': + cache_dic['cache'] = cache + cache_dic['fresh_threshold'] = 5 + cache_dic['taylor_cache'] = True + cache_dic['max_order'] = 1 + cache_dic['first_enhance'] = 1 + + elif mode == 'Delta': + cache_dic['cache'] = cache + cache_dic['fresh_ratio'] = 0.0 + cache_dic['fresh_threshold'] = 3 + cache_dic['Delta-DiT'] = True + cache_dic['max_order'] = 0 + cache_dic['first_enhance'] = 1 + + current = {} + current['activated_steps'] = [0] + current['step'] = 0 + current['num_steps'] = num_steps + + return cache_dic, current diff --git a/wan/taylorseer/cache_functions/cal_type.py b/wan/taylorseer/cache_functions/cal_type.py new file mode 100644 index 0000000..93391f6 --- /dev/null +++ b/wan/taylorseer/cache_functions/cal_type.py @@ -0,0 +1,42 @@ +from .force_scheduler import force_scheduler + +def cal_type(cache_dic, current): + ''' + Determine calculation type for this step + ''' + if (cache_dic['fresh_ratio'] == 0.0) and (not cache_dic['taylor_cache']): + # FORA:Uniform + first_step = (current['step'] == 0) + else: + # ToCa: First enhanced + first_step = (current['step'] < cache_dic['first_enhance']) + #first_step = (current['step'] <= 3) + + force_fresh = cache_dic['force_fresh'] + if not first_step: + fresh_interval = cache_dic['cal_threshold'] + else: + fresh_interval = cache_dic['fresh_threshold'] + + if (first_step) or (cache_dic['cache_counter'] == fresh_interval - 1 ): + current['type'] = 'full' + cache_dic['cache_counter'] = 0 + current['activated_steps'].append(current['step']) + #current['activated_times'].append(current['t']) + force_scheduler(cache_dic, current) + + elif (cache_dic['taylor_cache']): + cache_dic['cache_counter'] += 1 + current['type'] = 'Taylor' + + + elif (cache_dic['cache_counter'] % 2 == 1): # 0: ToCa-Aggresive-ToCa, 1: Aggresive-ToCa-Aggresive + cache_dic['cache_counter'] += 1 + current['type'] = 'ToCa' + # 'cache_noise' 'ToCa' 'FORA' + elif cache_dic['Delta-DiT']: + cache_dic['cache_counter'] += 1 + current['type'] = 'Delta-Cache' + else: + cache_dic['cache_counter'] += 1 + current['type'] = 'ToCa' diff --git a/wan/taylorseer/cache_functions/force_scheduler.py b/wan/taylorseer/cache_functions/force_scheduler.py new file mode 100644 index 0000000..4c54f05 --- /dev/null +++ b/wan/taylorseer/cache_functions/force_scheduler.py @@ -0,0 +1,16 @@ +import torch +def force_scheduler(cache_dic, current): + if cache_dic['fresh_ratio'] == 0: + # FORA + linear_step_weight = 0.0 + else: + # TokenCache + linear_step_weight = 0.0 + step_factor = torch.tensor(1 - linear_step_weight + 2 * linear_step_weight * current['step'] / current['num_steps']) + threshold = torch.round(cache_dic['fresh_threshold'] / step_factor) + + # no force constrain for sensitive steps, cause the performance is good enough. + # you may have a try. + + cache_dic['cal_threshold'] = threshold + #return threshold \ No newline at end of file diff --git a/wan/taylorseer/forwards/__init__.py b/wan/taylorseer/forwards/__init__.py new file mode 100644 index 0000000..9e918e0 --- /dev/null +++ b/wan/taylorseer/forwards/__init__.py @@ -0,0 +1,5 @@ +from .wan_forward import wan_forward +from .xfusers_wan_forward import xfusers_wan_forward +from .wan_attention_forward import wan_attention_forward +from .wan_attention_forward_cache_step import wan_attention_forward_cache_step +from .wan_cache_forward import wan_cache_forward \ No newline at end of file diff --git a/wan/taylorseer/forwards/wan_attention_cache_forward.py b/wan/taylorseer/forwards/wan_attention_cache_forward.py new file mode 100644 index 0000000..71efb87 --- /dev/null +++ b/wan/taylorseer/forwards/wan_attention_cache_forward.py @@ -0,0 +1,23 @@ +import torch +import torch.cuda.amp as amp +from typing import Dict +from wan.taylorseer.taylorseer_utils import taylor_cache_init, derivative_approximation, taylor_formula + +@torch.compile +def wan_attention_cache_forward(sa_dict:Dict, ca_dict:Dict, ffn_dict:Dict, e:tuple, x:torch.Tensor, distance:int): + + seer_sa = taylor_formula(derivative_dict=sa_dict, distance=distance) + seer_ca = taylor_formula(derivative_dict=ca_dict, distance=distance) + seer_ffn = taylor_formula(derivative_dict=ffn_dict, distance=distance) + + x = cache_add(x, seer_sa, seer_ca, seer_ffn, e) + + return x + +def cache_add(x, sa, ca, ffn, e): + with amp.autocast(dtype=torch.float32): + x = x + sa * e[2] + x = x + ca + with amp.autocast(dtype=torch.float32): + x = x + ffn * e[5] + return x \ No newline at end of file diff --git a/wan/taylorseer/forwards/wan_attention_forward.py b/wan/taylorseer/forwards/wan_attention_forward.py new file mode 100644 index 0000000..344f129 --- /dev/null +++ b/wan/taylorseer/forwards/wan_attention_forward.py @@ -0,0 +1,82 @@ +import math + +import torch +import torch.cuda.amp as amp +import torch.nn as nn +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from wan.modules import WanModel +from wan.modules.model import sinusoidal_embedding_1d, WanAttentionBlock +from wan.modules.attention import flash_attention + +from wan.taylorseer.taylorseer_utils import taylor_cache_init, derivative_approximation, taylor_formula +from .wan_attention_cache_forward import wan_attention_cache_forward + +def wan_attention_forward( + self:WanAttentionBlock, + x, + e, + seq_lens, + grid_sizes, + freqs, + context, + context_lens, + cache_dic, + current +): + r""" + Args: + x(Tensor): Shape [B, L, C] + e(Tensor): Shape [B, 6, C] + seq_lens(Tensor): Shape [B], length of each sequence in batch + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + assert e.dtype == torch.float32 + with amp.autocast(dtype=torch.float32): + e = (self.modulation + e).chunk(6, dim=1) + assert e[0].dtype == torch.float32 + + if current['type'] == 'full': + + # self-attention + current['module'] = 'self-attention' + taylor_cache_init(cache_dic=cache_dic, current=current) + y = self.self_attn( + self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes, + freqs) + derivative_approximation(cache_dic=cache_dic, current=current, feature=y) + with amp.autocast(dtype=torch.float32): + x = x + y * e[2] + + # cross-attention + current['module'] = 'cross-attention' + taylor_cache_init(cache_dic=cache_dic, current=current) + y = self.cross_attn(self.norm3(x), context, context_lens) + derivative_approximation(cache_dic=cache_dic, current=current, feature=y) + x = x + y + + # ffn + current['module'] = 'ffn' + taylor_cache_init(cache_dic=cache_dic, current=current) + y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3]) + derivative_approximation(cache_dic=cache_dic, current=current, feature=y) + with amp.autocast(dtype=torch.float32): + x = x + y * e[5] + + elif current['type'] == 'Taylor': + + #x = wan_attention_cache_forward(cache_dic, current, e, x) + x = wan_attention_cache_forward( + sa_dict=cache_dic['cache'][-1][current['stream']][current['layer']]['self-attention'], + ca_dict=cache_dic['cache'][-1][current['stream']][current['layer']]['cross-attention'], + ffn_dict=cache_dic['cache'][-1][current['stream']][current['layer']]['ffn'], + e=e, + x=x, + distance= current['step'] - current['activated_steps'][-1] + ) + + else: + raise ValueError(f"Not supported type: {current['type']}") + + return x diff --git a/wan/taylorseer/forwards/wan_attention_forward_cache_step.py b/wan/taylorseer/forwards/wan_attention_forward_cache_step.py new file mode 100644 index 0000000..d25936b --- /dev/null +++ b/wan/taylorseer/forwards/wan_attention_forward_cache_step.py @@ -0,0 +1,37 @@ +import torch +import torch.cuda.amp as amp +from diffusers.configuration_utils import ConfigMixin, register_to_config +from wan.modules.model import WanAttentionBlock + +from .wan_attention_cache_forward import wan_attention_cache_forward + +def wan_attention_forward_cache_step( + self:WanAttentionBlock, + x, + e, + layer_cache_dict, + distance, +): + r""" + Args: + x(Tensor): Shape [B, L, C] + e(Tensor): Shape [B, 6, C] + seq_lens(Tensor): Shape [B], length of each sequence in batch + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + assert e.dtype == torch.float32 + with amp.autocast(dtype=torch.float32): + e = (self.modulation + e).chunk(6, dim=1) + assert e[0].dtype == torch.float32 + + x = wan_attention_cache_forward( + sa_dict= layer_cache_dict['self-attention'], + ca_dict= layer_cache_dict['cross-attention'], + ffn_dict= layer_cache_dict['ffn'], + e=e, + x=x, + distance= distance + ) + + return x diff --git a/wan/taylorseer/forwards/wan_cache_forward.py b/wan/taylorseer/forwards/wan_cache_forward.py new file mode 100644 index 0000000..1ff6ef6 --- /dev/null +++ b/wan/taylorseer/forwards/wan_cache_forward.py @@ -0,0 +1,18 @@ +import torch +import torch.cuda.amp as amp +from wan.modules import WanModel + +@torch.compile +def wan_cache_forward(self:WanModel, + e:torch.Tensor, + cond_cache_dict:dict, + distance:int, + x:torch.Tensor) -> torch.Tensor: + + for i, block in enumerate(self.blocks): + x = block.cache_step_forward(x, + e=e, + layer_cache_dict=cond_cache_dict[i], + distance=distance) + + return x \ No newline at end of file diff --git a/wan/taylorseer/forwards/wan_forward.py b/wan/taylorseer/forwards/wan_forward.py new file mode 100644 index 0000000..d104735 --- /dev/null +++ b/wan/taylorseer/forwards/wan_forward.py @@ -0,0 +1,114 @@ +import math + +import torch +import torch.cuda.amp as amp +import torch.nn as nn +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from wan.modules import WanModel +from wan.modules.model import sinusoidal_embedding_1d +from wan.taylorseer.cache_functions import cal_type +from wan.taylorseer.taylorseer_utils import taylor_formula +from .wan_cache_forward import wan_cache_forward + +def wan_forward( + self:WanModel, + x, + t, + context, + seq_len, + current_step, + current_stream, + clip_fea=None, + y=None, + ): + r""" + Forward pass through the diffusion model + + Args: + x (List[Tensor]): + List of input video tensors, each with shape [C_in, F, H, W] + t (Tensor): + Diffusion timesteps tensor of shape [B] + context (List[Tensor]): + List of text embeddings each with shape [L, C] + seq_len (`int`): + Maximum sequence length for positional encoding + clip_fea (Tensor, *optional*): + CLIP image features for image-to-video mode + y (List[Tensor], *optional*): + Conditional video inputs for image-to-video mode, same shape as x + + Returns: + List[Tensor]: + List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] + """ + + self.current['step'] = current_step + self.current['stream'] = current_stream + if current_stream == 'cond_stream': + cal_type(self.cache_dic, self.current) + + if self.model_type == 'i2v': + assert clip_fea is not None and y is not None + # params + device = self.patch_embedding.weight.device + if self.freqs.device != device: + self.freqs = self.freqs.to(device) + + if y is not None: + x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + grid_sizes = torch.stack( + [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + assert seq_lens.max() <= seq_len + x = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], + dim=1) for u in x + ]) + + # time embeddings + with amp.autocast(dtype=torch.float32): + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).float()) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + assert e.dtype == torch.float32 and e0.dtype == torch.float32 + + # context + context_lens = None + context = self.text_embedding( + torch.stack([ + torch.cat( + [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) + for u in context + ])) + + if clip_fea is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.freqs, + context=context, + context_lens=context_lens, + cache_dic=self.cache_dic, + current=self.current) + + for i, block in enumerate(self.blocks): + self.current['layer'] = i + x = block(x, **kwargs) + + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return [u.float() for u in x] \ No newline at end of file diff --git a/wan/taylorseer/forwards/xfusers_wan_forward.py b/wan/taylorseer/forwards/xfusers_wan_forward.py new file mode 100644 index 0000000..0229cde --- /dev/null +++ b/wan/taylorseer/forwards/xfusers_wan_forward.py @@ -0,0 +1,104 @@ +import torch +import torch.cuda.amp as amp +from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) +from wan.modules import WanModel +from wan.modules.model import sinusoidal_embedding_1d +from wan.taylorseer.cache_functions import cal_type +from wan.taylorseer.taylorseer_utils import taylor_formula +from .wan_cache_forward import wan_cache_forward + +def xfusers_wan_forward( + self:WanModel, + x, + t, + context, + seq_len, + current_step, + current_stream, + clip_fea=None, + y=None, +): + """ + x: A list of videos each with shape [C, T, H, W]. + t: [B]. + context: A list of text embeddings each with shape [L, C]. + """ + + self.current['step'] = current_step + self.current['stream'] = current_stream + + if current_stream == 'cond_stream': + cal_type(self.cache_dic, self.current) + + if self.model_type == 'i2v': + assert clip_fea is not None and y is not None + # params + device = self.patch_embedding.weight.device + if self.freqs.device != device: + self.freqs = self.freqs.to(device) + + if y is not None: + x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + grid_sizes = torch.stack( + [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + assert seq_lens.max() <= seq_len + x = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) + for u in x + ]) + + # time embeddings + with amp.autocast(dtype=torch.float32): + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).float()) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + assert e.dtype == torch.float32 and e0.dtype == torch.float32 + + # context + context_lens = None + context = self.text_embedding( + torch.stack([ + torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) + for u in context + ])) + + if clip_fea is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.freqs, + context=context, + context_lens=context_lens, + cache_dic=self.cache_dic, + current=self.current) + + # Context Parallel + x = torch.chunk( + x, get_sequence_parallel_world_size(), + dim=1)[get_sequence_parallel_rank()] + + for i, block in enumerate(self.blocks): + self.current['layer'] = i + x = block(x, **kwargs) + + # head + x = self.head(x, e) + + # Context Parallel + x = get_sp_group().all_gather(x, dim=1) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return [u.float() for u in x] diff --git a/wan/taylorseer/generates/__init__.py b/wan/taylorseer/generates/__init__.py new file mode 100644 index 0000000..46c6fd0 --- /dev/null +++ b/wan/taylorseer/generates/__init__.py @@ -0,0 +1,2 @@ +from .wan_t2v_generate import wan_t2v_generate +from .wan_i2v_generate import wan_i2v_generate \ No newline at end of file diff --git a/wan/taylorseer/generates/wan_i2v_generate.py b/wan/taylorseer/generates/wan_i2v_generate.py new file mode 100644 index 0000000..0507d33 --- /dev/null +++ b/wan/taylorseer/generates/wan_i2v_generate.py @@ -0,0 +1,264 @@ +import gc +import logging +import math +import os +import random +import sys +import types +from contextlib import contextmanager +from functools import partial + +import numpy as np +import torch +import torch.cuda.amp as amp +import torch.distributed as dist +import torchvision.transforms.functional as TF +from tqdm import tqdm + +from wan.distributed.fsdp import shard_model +from wan.modules.clip import CLIPModel +from wan.modules.model import WanModel +from wan.modules.t5 import T5EncoderModel +from wan.modules.vae import WanVAE +from wan.utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas, retrieve_timesteps) +from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from wan import WanI2V + +def wan_i2v_generate(self:WanI2V, + input_prompt, + img, + max_area=720 * 1280, + frame_num=81, + shift=5.0, + sample_solver='unipc', + sampling_steps=40, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True): + r""" + Generates video frames from input image and text prompt using diffusion process. + + Args: + input_prompt (`str`): + Text prompt for content generation. + img (PIL.Image.Image): + Input image tensor. Shape: [3, H, W] + max_area (`int`, *optional*, defaults to 720*1280): + Maximum pixel area for latent space calculation. Controls video resolution scaling + frame_num (`int`, *optional*, defaults to 81): + How many frames to sample from a video. The number should be 4n+1 + shift (`float`, *optional*, defaults to 5.0): + Noise schedule shift parameter. Affects temporal dynamics + [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0. + sample_solver (`str`, *optional*, defaults to 'unipc'): + Solver used to sample the video. + sampling_steps (`int`, *optional*, defaults to 40): + Number of diffusion sampling steps. Higher values improve quality but slow generation + guide_scale (`float`, *optional*, defaults 5.0): + Classifier-free guidance scale. Controls prompt adherence vs. creativity + n_prompt (`str`, *optional*, defaults to ""): + Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` + seed (`int`, *optional*, defaults to -1): + Random seed for noise generation. If -1, use random seed + offload_model (`bool`, *optional*, defaults to True): + If True, offloads models to CPU during generation to save VRAM + + Returns: + torch.Tensor: + Generated video frames tensor. Dimensions: (C, N H, W) where: + - C: Color channels (3 for RGB) + - N: Number of frames (81) + - H: Frame height (from max_area) + - W: Frame width from max_area) + """ + img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) + + F = frame_num + h, w = img.shape[1:] + aspect_ratio = h / w + lat_h = round( + np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] // + self.patch_size[1] * self.patch_size[1]) + lat_w = round( + np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] // + self.patch_size[2] * self.patch_size[2]) + h = lat_h * self.vae_stride[1] + w = lat_w * self.vae_stride[2] + + max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // ( + self.patch_size[1] * self.patch_size[2]) + max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size + + seed = seed if seed >= 0 else random.randint(0, sys.maxsize) + seed_g = torch.Generator(device=self.device) + seed_g.manual_seed(seed) + noise = torch.randn( + 16, + 21, + lat_h, + lat_w, + dtype=torch.float32, + generator=seed_g, + device=self.device) + + msk = torch.ones(1, 81, lat_h, lat_w, device=self.device) + msk[:, 1:] = 0 + msk = torch.concat([ + torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] + ], + dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1, 2)[0] + + if n_prompt == "": + n_prompt = self.sample_neg_prompt + + # preprocess + if not self.t5_cpu: + self.text_encoder.model.to(self.device) + context = self.text_encoder([input_prompt], self.device) + context_null = self.text_encoder([n_prompt], self.device) + if offload_model: + self.text_encoder.model.cpu() + else: + context = self.text_encoder([input_prompt], torch.device('cpu')) + context_null = self.text_encoder([n_prompt], torch.device('cpu')) + context = [t.to(self.device) for t in context] + context_null = [t.to(self.device) for t in context_null] + + self.clip.model.to(self.device) + clip_context = self.clip.visual([img[:, None, :, :]]) + if offload_model: + self.clip.model.cpu() + + y = self.vae.encode([ + torch.concat([ + torch.nn.functional.interpolate( + img[None].cpu(), size=(h, w), mode='bicubic').transpose( + 0, 1), + torch.zeros(3, 80, h, w) + ], + dim=1).to(self.device) + ])[0] + y = torch.concat([msk, y]) + + @contextmanager + def noop_no_sync(): + yield + + no_sync = getattr(self.model, 'no_sync', noop_no_sync) + + # evaluation mode + with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): + + if sample_solver == 'unipc': + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sample_scheduler.set_timesteps( + sampling_steps, device=self.device, shift=shift) + timesteps = sample_scheduler.timesteps + elif sample_solver == 'dpm++': + sample_scheduler = FlowDPMSolverMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) + timesteps, _ = retrieve_timesteps( + sample_scheduler, + device=self.device, + sigmas=sampling_sigmas) + else: + raise NotImplementedError("Unsupported solver.") + + # sample videos + latent = noise + + arg_c = { + 'context': [context[0]], + 'clip_fea': clip_context, + 'seq_len': max_seq_len, + 'y': [y], + } + + arg_null = { + 'context': context_null, + 'clip_fea': clip_context, + 'seq_len': max_seq_len, + 'y': [y], + } + + if offload_model: + torch.cuda.empty_cache() + + self.model.to(self.device) + + for i, t in enumerate(tqdm(timesteps)): + + current_step = i + + latent_model_input = [latent.to(self.device)] + timestep = [t] + + timestep = torch.stack(timestep).to(self.device) + + current_stream = 'cond_stream' + + noise_pred_cond = self.model( + latent_model_input, t=timestep, + current_step = current_step, + current_stream = current_stream, + **arg_c)[0].to( + torch.device('cpu') if offload_model else self.device) + + if offload_model: + torch.cuda.empty_cache() + + current_stream = 'uncond_stream' + + noise_pred_uncond = self.model( + latent_model_input, t=timestep, + current_step = current_step, + current_stream = current_stream, + **arg_null)[0].to( + torch.device('cpu') if offload_model else self.device) + + if offload_model: + torch.cuda.empty_cache() + + noise_pred = noise_pred_uncond + guide_scale * ( + noise_pred_cond - noise_pred_uncond) + + latent = latent.to( + torch.device('cpu') if offload_model else self.device) + + temp_x0 = sample_scheduler.step( + noise_pred.unsqueeze(0), + t, + latent.unsqueeze(0), + return_dict=False, + generator=seed_g)[0] + latent = temp_x0.squeeze(0) + + x0 = [latent.to(self.device)] + del latent_model_input, timestep + + if offload_model: + self.model.cpu() + torch.cuda.empty_cache() + + if self.rank == 0: + videos = self.vae.decode(x0) + + del noise, latent + del sample_scheduler + if offload_model: + gc.collect() + torch.cuda.synchronize() + if dist.is_initialized(): + dist.barrier() + + return videos[0] if self.rank == 0 else None diff --git a/wan/taylorseer/generates/wan_t2v_generate.py b/wan/taylorseer/generates/wan_t2v_generate.py new file mode 100644 index 0000000..784d276 --- /dev/null +++ b/wan/taylorseer/generates/wan_t2v_generate.py @@ -0,0 +1,207 @@ +import gc +import logging +import math +import os +import random +import sys +import types +from contextlib import contextmanager +from functools import partial + +import torch +import torch.cuda.amp as amp +import torch.distributed as dist +from tqdm import tqdm + +from wan.distributed.fsdp import shard_model +from wan.modules.model import WanModel +from wan.modules.t5 import T5EncoderModel +from wan.modules.vae import WanVAE +from wan.utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas, retrieve_timesteps) +from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +from wan import WanT2V +from wan.taylorseer.cache_functions import cache_init, cal_type + +def wan_t2v_generate(self:WanT2V, + input_prompt, + size=(1280, 720), + frame_num=81, + shift=5.0, + sample_solver='unipc', + sampling_steps=50, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True): + r""" + Generates video frames from text prompt using diffusion process. + Args: + input_prompt (`str`): + Text prompt for content generation + size (tupele[`int`], *optional*, defaults to (1280,720)): + Controls video resolution, (width,height). + frame_num (`int`, *optional*, defaults to 81): + How many frames to sample from a video. The number should be 4n+1 + shift (`float`, *optional*, defaults to 5.0): + Noise schedule shift parameter. Affects temporal dynamics + sample_solver (`str`, *optional*, defaults to 'unipc'): + Solver used to sample the video. + sampling_steps (`int`, *optional*, defaults to 40): + Number of diffusion sampling steps. Higher values improve quality but slow generation + guide_scale (`float`, *optional*, defaults 5.0): + Classifier-free guidance scale. Controls prompt adherence vs. creativity + n_prompt (`str`, *optional*, defaults to ""): + Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` + seed (`int`, *optional*, defaults to -1): + Random seed for noise generation. If -1, use random seed. + offload_model (`bool`, *optional*, defaults to True): + If True, offloads models to CPU during generation to save VRAM + Returns: + torch.Tensor: + Generated video frames tensor. Dimensions: (C, N H, W) where: + - C: Color channels (3 for RGB) + - N: Number of frames (81) + - H: Frame height (from size) + - W: Frame width from size) + """ + # preprocess + F = frame_num + + target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, + size[1] // self.vae_stride[1], + size[0] // self.vae_stride[2]) + + seq_len = math.ceil((target_shape[2] * target_shape[3]) / + (self.patch_size[1] * self.patch_size[2]) * + target_shape[1] / self.sp_size) * self.sp_size + if n_prompt == "": + n_prompt = self.sample_neg_prompt + seed = seed if seed >= 0 else random.randint(0, sys.maxsize) + seed_g = torch.Generator(device=self.device) + seed_g.manual_seed(seed) + + if not self.t5_cpu: + self.text_encoder.model.to(self.device) + context = self.text_encoder([input_prompt], self.device) + context_null = self.text_encoder([n_prompt], self.device) + if offload_model: + self.text_encoder.model.cpu() + else: + context = self.text_encoder([input_prompt], torch.device('cpu')) + context_null = self.text_encoder([n_prompt], torch.device('cpu')) + context = [t.to(self.device) for t in context] + context_null = [t.to(self.device) for t in context_null] + + noise = [ + torch.randn( + target_shape[0], + target_shape[1], + target_shape[2], + target_shape[3], + dtype=torch.float32, + device=self.device, + generator=seed_g) + ] + + @contextmanager + + def noop_no_sync(): + yield + no_sync = getattr(self.model, 'no_sync', noop_no_sync) + + # evaluation mode + with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): + + if sample_solver == 'unipc': + + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + + sample_scheduler.set_timesteps( + sampling_steps, device=self.device, shift=shift) + + timesteps = sample_scheduler.timesteps + + elif sample_solver == 'dpm++': + + sample_scheduler = FlowDPMSolverMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + + sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) + + timesteps, _ = retrieve_timesteps( + sample_scheduler, + device=self.device, + sigmas=sampling_sigmas) + else: + raise NotImplementedError("Unsupported solver.") + + # sample videos + latents = noise + + arg_c = {'context': context, 'seq_len': seq_len} + arg_null = {'context': context_null, 'seq_len': seq_len} + + self.model.to(self.device) + + for i, t in enumerate(tqdm(timesteps)): + torch.compiler.cudagraph_mark_step_begin() + current_step = i + + latent_model_input = latents + timestep = [t] + timestep = torch.stack(timestep) + + current_stream = 'cond_stream' + + noise_pred_cond = self.model( + latent_model_input, t=timestep, + current_step = current_step, + current_stream = current_stream, + **arg_c)[0] + + current_stream = 'uncond_stream' + + noise_pred_uncond = self.model( + latent_model_input, t=timestep, + current_step = current_step, + current_stream = current_stream, + **arg_null)[0] + + noise_pred = noise_pred_uncond + guide_scale * ( + noise_pred_cond - noise_pred_uncond) + + temp_x0 = sample_scheduler.step( + noise_pred.unsqueeze(0), + t, + latents[0].unsqueeze(0), + return_dict=False, + generator=seed_g)[0] + latents = [temp_x0.squeeze(0)] + + x0 = latents + + if offload_model: + self.model.cpu() + torch.cuda.empty_cache() + + if self.rank == 0: + videos = self.vae.decode(x0) + + del noise, latents + del sample_scheduler + + if offload_model: + gc.collect() + torch.cuda.synchronize() + + if dist.is_initialized(): + dist.barrier() + + return videos[0] if self.rank == 0 else None \ No newline at end of file diff --git a/wan/taylorseer/taylorseer_utils/__init__.py b/wan/taylorseer/taylorseer_utils/__init__.py new file mode 100644 index 0000000..0d1b212 --- /dev/null +++ b/wan/taylorseer/taylorseer_utils/__init__.py @@ -0,0 +1,47 @@ +from typing import Dict +import torch +import math + +def derivative_approximation(cache_dic: Dict, current: Dict, feature: torch.Tensor): + """ + Compute derivative approximation. + + :param cache_dic: Cache dictionary + :param current: Information of the current step + """ + difference_distance = current['activated_steps'][-1] - current['activated_steps'][-2] + #difference_distance = current['activated_times'][-1] - current['activated_times'][-2] + + updated_taylor_factors = {} + updated_taylor_factors[0] = feature + + for i in range(cache_dic['max_order']): + if (cache_dic['cache'][-1][current['stream']][current['layer']][current['module']].get(i, None) is not None) and (current['step'] > cache_dic['first_enhance'] - 2): + updated_taylor_factors[i + 1] = (updated_taylor_factors[i] - cache_dic['cache'][-1][current['stream']][current['layer']][current['module']][i]) / difference_distance + else: + break + + cache_dic['cache'][-1][current['stream']][current['layer']][current['module']] = updated_taylor_factors + +def taylor_formula(derivative_dict: Dict, distance: int) -> torch.Tensor: + """ + Compute Taylor expansion error. + + :param derivative_dict: Derivative dictionary + :param x: Current step + """ + output=0 + for i in range(len(derivative_dict)): + output += (1 / math.factorial(i)) * derivative_dict[i] * (distance ** i) + + return output + +def taylor_cache_init(cache_dic: Dict, current: Dict): + """ + Initialize Taylor cache and allocate storage for different-order derivatives in the Taylor cache. + + :param cache_dic: Cache dictionary + :param current: Information of the current step + """ + if (current['step'] == 0) and (cache_dic['taylor_cache']): + cache_dic['cache'][-1][current['stream']][current['layer']][current['module']] = {} diff --git a/wan/text2video.py b/wan/text2video.py index c518b61..c0e0dcd 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -38,6 +38,7 @@ class WanT2V: dit_fsdp=False, use_usp=False, t5_cpu=False, + use_taylor_cache = False ): r""" Initializes the Wan text-to-video generation model components. @@ -86,21 +87,49 @@ class WanT2V: logging.info(f"Creating WanModel from {checkpoint_dir}") self.model = WanModel.from_pretrained(checkpoint_dir) self.model.eval().requires_grad_(False) - + + from wan.taylorseer.forwards import wan_attention_forward, xfusers_wan_forward, wan_forward#, wan_attention_forward_cache_step if use_usp: - from xfuser.core.distributed import get_sequence_parallel_world_size + from xfuser.core.distributed import \ + get_sequence_parallel_world_size - from .distributed.xdit_context_parallel import ( - usp_attn_forward, - usp_dit_forward, - ) + from .distributed.xdit_context_parallel import (usp_attn_forward, + usp_dit_forward) + for block in self.model.blocks: block.self_attn.forward = types.MethodType( usp_attn_forward, block.self_attn) - self.model.forward = types.MethodType(usp_dit_forward, self.model) + + if use_taylor_cache: + block.forward = types.MethodType(wan_attention_forward, block) + #block.cache_step_forward = types.MethodType(wan_attention_forward_cache_step, block) + else : + self.model.forward = types.MethodType(usp_dit_forward, self.model) + if use_taylor_cache: + self.model.forward = types.MethodType(xfusers_wan_forward, self.model) self.sp_size = get_sequence_parallel_world_size() else: + if use_taylor_cache: + for block in self.model.blocks: + block.forward = types.MethodType(wan_attention_forward, block) + #block.cache_step_forward = types.MethodType(wan_attention_forward_cache_step, block) + + self.model.forward = types.MethodType(wan_forward, self.model) self.sp_size = 1 + + #if use_usp: + # from xfuser.core.distributed import \ + # get_sequence_parallel_world_size +# + # from .distributed.xdit_context_parallel import (usp_attn_forward, + # usp_dit_forward) + # for block in self.model.blocks: + # block.self_attn.forward = types.MethodType( + # usp_attn_forward, block.self_attn) + # self.model.forward = types.MethodType(usp_dit_forward, self.model) + # self.sp_size = get_sequence_parallel_world_size() + #else: + # self.sp_size = 1 if dist.is_initialized(): dist.barrier() From e198d833eb31757bfcb2c4784eccfe87d102f521 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 10 Jun 2025 06:15:54 +0000 Subject: [PATCH 2/2] enbale teacache and fsdp used both --- generate.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/generate.py b/generate.py index 739de46..c567440 100644 --- a/generate.py +++ b/generate.py @@ -949,7 +949,8 @@ def generate(args): if args.enable_teacache: wan_t2v.__class__.generate = t2v_generate wan_t2v.model.__class__.enable_teacache = True - wan_t2v.model.__class__.forward = teacache_forward + if args.ulysses_size * args.ring_size == 1: # not conflict with fsdp + wan_t2v.model.__class__.forward = teacache_forward wan_t2v.model.__class__.cnt = 0 wan_t2v.model.__class__.num_steps = args.sample_steps*2 wan_t2v.model.__class__.teacache_thresh = args.teacache_thresh