Merge pull request #7 from bytedance-iaas/taylor_cache

[FEAT]: enable taylor cache
This commit is contained in:
zc8gerard 2025-06-10 14:23:44 +08:00 committed by GitHub
commit 6d85885134
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 1688 additions and 9 deletions

View File

@ -949,7 +949,8 @@ def generate(args):
if args.enable_teacache:
wan_t2v.__class__.generate = t2v_generate
wan_t2v.model.__class__.enable_teacache = True
wan_t2v.model.__class__.forward = teacache_forward
if args.ulysses_size * args.ring_size == 1: # not conflict with fsdp
wan_t2v.model.__class__.forward = teacache_forward
wan_t2v.model.__class__.cnt = 0
wan_t2v.model.__class__.num_steps = args.sample_steps*2
wan_t2v.model.__class__.teacache_thresh = args.teacache_thresh

422
taylor_generator.py Normal file
View File

@ -0,0 +1,422 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
from datetime import datetime
import logging
import os
import sys
import warnings
warnings.filterwarnings('ignore')
import torch, random
import torch.distributed as dist
from PIL import Image
import wan
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_video, cache_image, str2bool
from wan.taylorseer.generates import wan_t2v_generate, wan_i2v_generate
from wan.taylorseer.forwards import wan_forward, xfusers_wan_forward, wan_attention_forward
import types
EXAMPLE_PROMPT = {
"t2v-1.3B": {
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
},
"t2v-14B": {
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
},
"t2i-14B": {
"prompt": "一个朴素端庄的美人",
},
"i2v-14B": {
"prompt":
"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
"image":
"examples/i2v_input.JPG",
},
}
def _validate_args(args):
# Basic check
assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"
assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}"
# The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks.
if args.sample_steps is None:
args.sample_steps = 40 if "i2v" in args.task else 50
if args.sample_shift is None:
args.sample_shift = 5.0
if "i2v" in args.task and args.size in ["832*480", "480*832"]:
args.sample_shift = 3.0
# The default number of frames are 1 for text-to-image tasks and 81 for other tasks.
if args.frame_num is None:
args.frame_num = 1 if "t2i" in args.task else 81
# T2I frame_num check
if "t2i" in args.task:
assert args.frame_num == 1, f"Unsupport frame_num {args.frame_num} for task {args.task}"
args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
0, sys.maxsize)
# Size check
assert args.size in SUPPORTED_SIZES[
args.
task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
def _parse_args():
parser = argparse.ArgumentParser(
description="Generate a image or video from a text prompt or image using Wan"
)
parser.add_argument(
"--task",
type=str,
default="t2v-14B",
choices=list(WAN_CONFIGS.keys()),
help="The task to run.")
parser.add_argument(
"--size",
type=str,
default="1280*720",
choices=list(SIZE_CONFIGS.keys()),
help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image."
)
parser.add_argument(
"--frame_num",
type=int,
default=None,
help="How many frames to sample from a image or video. The number should be 4n+1"
)
parser.add_argument(
"--ckpt_dir",
type=str,
default=None,
help="The path to the checkpoint directory.")
parser.add_argument(
"--offload_model",
type=str2bool,
default=None,
help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage."
)
parser.add_argument(
"--ulysses_size",
type=int,
default=1,
help="The size of the ulysses parallelism in DiT.")
parser.add_argument(
"--ring_size",
type=int,
default=1,
help="The size of the ring attention parallelism in DiT.")
parser.add_argument(
"--t5_fsdp",
action="store_true",
default=False,
help="Whether to use FSDP for T5.")
parser.add_argument(
"--t5_cpu",
action="store_true",
default=False,
help="Whether to place T5 model on CPU.")
parser.add_argument(
"--dit_fsdp",
action="store_true",
default=False,
help="Whether to use FSDP for DiT.")
parser.add_argument(
"--save_file",
type=str,
default=None,
help="The file to save the generated image or video to.")
parser.add_argument(
"--prompt",
type=str,
default=None,
help="The prompt to generate the image or video from.")
parser.add_argument(
"--use_prompt_extend",
action="store_true",
default=False,
help="Whether to use prompt extend.")
parser.add_argument(
"--prompt_extend_method",
type=str,
default="local_qwen",
choices=["dashscope", "local_qwen"],
help="The prompt extend method to use.")
parser.add_argument(
"--prompt_extend_model",
type=str,
default=None,
help="The prompt extend model to use.")
parser.add_argument(
"--prompt_extend_target_lang",
type=str,
default="zh",
choices=["zh", "en"],
help="The target language of prompt extend.")
parser.add_argument(
"--base_seed",
type=int,
default=-1,
help="The seed to use for generating the image or video.")
parser.add_argument(
"--image",
type=str,
default=None,
help="The image to generate the video from.")
parser.add_argument(
"--sample_solver",
type=str,
default='unipc',
choices=['unipc', 'dpm++'],
help="The solver used to sample.")
parser.add_argument(
"--sample_steps", type=int, default=None, help="The sampling steps.")
parser.add_argument(
"--sample_shift",
type=float,
default=None,
help="Sampling shift factor for flow matching schedulers.")
parser.add_argument(
"--sample_guide_scale",
type=float,
default=5.0,
help="Classifier free guidance scale.")
args = parser.parse_args()
_validate_args(args)
return args
def _init_logging(rank):
# logging
if rank == 0:
# set format
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] %(levelname)s: %(message)s",
handlers=[logging.StreamHandler(stream=sys.stdout)])
else:
logging.basicConfig(level=logging.ERROR)
def generate(args):
rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
local_rank = int(os.getenv("LOCAL_RANK", 0))
device = local_rank
_init_logging(rank)
if args.offload_model is None:
args.offload_model = False if world_size > 1 else True
logging.info(
f"offload_model is not specified, set to {args.offload_model}.")
if world_size > 1:
torch.cuda.set_device(local_rank)
dist.init_process_group(
backend="nccl",
init_method="env://",
rank=rank,
world_size=world_size)
else:
assert not (
args.t5_fsdp or args.dit_fsdp
), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments."
assert not (
args.ulysses_size > 1 or args.ring_size > 1
), f"context parallel are not supported in non-distributed environments."
if args.ulysses_size > 1 or args.ring_size > 1:
assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size."
from xfuser.core.distributed import (initialize_model_parallel,
init_distributed_environment)
init_distributed_environment(
rank=dist.get_rank(), world_size=dist.get_world_size())
initialize_model_parallel(
sequence_parallel_degree=dist.get_world_size(),
ring_degree=args.ring_size,
ulysses_degree=args.ulysses_size,
)
if args.use_prompt_extend:
if args.prompt_extend_method == "dashscope":
prompt_expander = DashScopePromptExpander(
model_name=args.prompt_extend_model, is_vl="i2v" in args.task)
elif args.prompt_extend_method == "local_qwen":
prompt_expander = QwenPromptExpander(
model_name=args.prompt_extend_model,
is_vl="i2v" in args.task,
device=rank)
else:
raise NotImplementedError(
f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
cfg = WAN_CONFIGS[args.task]
if args.ulysses_size > 1:
assert cfg.num_heads % args.ulysses_size == 0, f"`num_heads` must be divisible by `ulysses_size`."
logging.info(f"Generation job args: {args}")
logging.info(f"Generation model config: {cfg}")
if dist.is_initialized():
base_seed = [args.base_seed] if rank == 0 else [None]
dist.broadcast_object_list(base_seed, src=0)
args.base_seed = base_seed[0]
if "t2v" in args.task or "t2i" in args.task:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
logging.info(f"Input prompt: {args.prompt}")
if args.use_prompt_extend:
logging.info("Extending prompt ...")
if rank == 0:
prompt_output = prompt_expander(
args.prompt,
tar_lang=args.prompt_extend_target_lang,
seed=args.base_seed)
if prompt_output.status == False:
logging.info(
f"Extending prompt failed: {prompt_output.message}")
logging.info("Falling back to original prompt.")
input_prompt = args.prompt
else:
input_prompt = prompt_output.prompt
input_prompt = [input_prompt]
else:
input_prompt = [None]
if dist.is_initialized():
dist.broadcast_object_list(input_prompt, src=0)
args.prompt = input_prompt[0]
logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating WanT2V pipeline.")
wan_t2v = wan.WanT2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=device,
rank=rank,
t5_fsdp=args.t5_fsdp,
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu, use_taylor_cache= True
)
logging.info(
f"Generating {'image' if 't2i' in args.task else 'video'} ...")
# TaylorSeer
wan_t2v.generate = types.MethodType(wan_t2v_generate, wan_t2v)
#wan_t2v = torch.compile(wan_t2v, mode="max-autotune")
video = wan_t2v.generate(
args.prompt,
size=SIZE_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
else:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
if args.image is None:
args.image = EXAMPLE_PROMPT[args.task]["image"]
logging.info(f"Input prompt: {args.prompt}")
logging.info(f"Input image: {args.image}")
img = Image.open(args.image).convert("RGB")
if args.use_prompt_extend:
logging.info("Extending prompt ...")
if rank == 0:
prompt_output = prompt_expander(
args.prompt,
tar_lang=args.prompt_extend_target_lang,
image=img,
seed=args.base_seed)
if prompt_output.status == False:
logging.info(
f"Extending prompt failed: {prompt_output.message}")
logging.info("Falling back to original prompt.")
input_prompt = args.prompt
else:
input_prompt = prompt_output.prompt
input_prompt = [input_prompt]
else:
input_prompt = [None]
if dist.is_initialized():
dist.broadcast_object_list(input_prompt, src=0)
args.prompt = input_prompt[0]
logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating WanI2V pipeline.")
wan_i2v = wan.WanI2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=device,
rank=rank,
t5_fsdp=args.t5_fsdp,
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
)
logging.info("Generating video ...")
# TaylorSeer
wan_i2v.generate = types.MethodType(wan_i2v_generate, wan_i2v)
video = wan_i2v.generate(
args.prompt,
img,
max_area=MAX_AREA_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
if rank == 0:
if args.save_file is None:
formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
formatted_prompt = args.prompt.replace(" ", "_").replace("/",
"_")[:50]
suffix = '.png' if "t2i" in args.task else '.mp4'
args.save_file = f"{args.task}_{args.size.replace('*','x') if sys.platform=='win32' else args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix
if "t2i" in args.task:
logging.info(f"Saving generated image to {args.save_file}")
cache_image(
tensor=video.squeeze(1)[None],
save_file=args.save_file,
nrow=1,
normalize=True,
value_range=(-1, 1))
else:
logging.info(f"Saving generated video to {args.save_file}")
cache_video(
tensor=video[None],
save_file=args.save_file,
fps=cfg.sample_fps,
nrow=1,
normalize=True,
value_range=(-1, 1))
logging.info("Finished.")
if __name__ == "__main__":
args = _parse_args()
generate(args)

View File

@ -0,0 +1,192 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.cuda.amp as amp
from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from ..modules.model import sinusoidal_embedding_1d
def pad_freqs(original_tensor, target_len):
seq_len, s1, s2 = original_tensor.shape
pad_size = target_len - seq_len
padding_tensor = torch.ones(
pad_size,
s1,
s2,
dtype=original_tensor.dtype,
device=original_tensor.device)
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
return padded_tensor
@amp.autocast(enabled=False)
def rope_apply(x, grid_sizes, freqs):
"""
x: [B, L, N, C].
grid_sizes: [B, 3].
freqs: [M, C // 2].
"""
s, n, c = x.size(1), x.size(2), x.size(3) // 2
# split freqs
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
# loop over samples
output = []
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
seq_len = f * h * w
# precompute multipliers
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
s, n, -1, 2))
freqs_i = torch.cat([
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
],
dim=-1).reshape(seq_len, 1, -1)
# apply rotary embedding
sp_size = get_sequence_parallel_world_size()
sp_rank = get_sequence_parallel_rank()
freqs_i = pad_freqs(freqs_i, s * sp_size)
s_per_rank = s
freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
s_per_rank), :, :]
x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
x_i = torch.cat([x_i, x[i, s:]])
# append to collection
output.append(x_i)
return torch.stack(output).float()
def usp_dit_forward(
self,
x,
t,
context,
seq_len,
clip_fea=None,
y=None,
):
"""
x: A list of videos each with shape [C, T, H, W].
t: [B].
context: A list of text embeddings each with shape [L, C].
"""
if self.model_type == 'i2v':
assert clip_fea is not None and y is not None
# params
device = self.patch_embedding.weight.device
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
if y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
grid_sizes = torch.stack(
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
assert seq_lens.max() <= seq_len
x = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
for u in x
])
# time embeddings
with amp.autocast(dtype=torch.float32):
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).float())
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32
# context
context_lens = None
context = self.text_embedding(
torch.stack([
torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
for u in context
]))
if clip_fea is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
# arguments
kwargs = dict(
e=e0,
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=self.freqs,
context=context,
context_lens=context_lens)
# Context Parallel
x = torch.chunk(
x, get_sequence_parallel_world_size(),
dim=1)[get_sequence_parallel_rank()]
for block in self.blocks:
x = block(x, **kwargs)
# head
x = self.head(x, e)
# Context Parallel
x = get_sp_group().all_gather(x, dim=1)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x]
def usp_attn_forward(self,
x,
seq_lens,
grid_sizes,
freqs,
dtype=torch.bfloat16):
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
half_dtypes = (torch.float16, torch.bfloat16)
def half(x):
return x if x.dtype in half_dtypes else x.to(dtype)
# query, key, value function
def qkv_fn(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x).view(b, s, n, d)
return q, k, v
q, k, v = qkv_fn(x)
q = rope_apply(q, grid_sizes, freqs)
k = rope_apply(k, grid_sizes, freqs)
# TODO: We should use unpaded q,k,v for attention.
# k_lens = seq_lens // get_sequence_parallel_world_size()
# if k_lens is not None:
# q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
x = xFuserLongContextAttention()(
None,
query=half(q),
key=half(k),
value=half(v),
window_size=self.window_size)
# TODO: padding after attention.
# x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
# output
x = x.flatten(2)
x = self.o(x)
return x

View File

@ -8,6 +8,7 @@ from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from .attention import flash_attention
from wan.taylorseer.cache_functions import cache_init
__all__ = ['WanModel']
@ -506,8 +507,11 @@ class WanModel(ModelMixin, ConfigMixin):
self.img_emb = MLPProj(1280, dim, flf_pos_emb=model_type == 'flf2v')
# initialize weights
self.init_weights()
self.init_weights()
self.cache_init()
def cache_init(self):
self.cache_dic, self.current = cache_init(self)
def forward(
self,
x,

View File

@ -0,0 +1,3 @@
from .cache_init import cache_init
from .cal_type import cal_type
from .force_scheduler import force_scheduler

View File

@ -0,0 +1,67 @@
#from wan.modules import WanModel
def cache_init(self, num_steps= 50):
'''
Initialization for cache.
'''
cache_dic = {}
cache = {}
cache[-1]={}
cache[-1]['cond_stream']={}
cache[-1]['uncond_stream']={}
cache_dic['cache_counter'] = 0
for j in range(self.num_layers):
cache[-1]['cond_stream'][j] = {}
cache[-1]['uncond_stream'][j] = {}
cache_dic['taylor_cache'] = False
cache_dic['Delta-DiT'] = False
cache_dic['cache_type'] = 'random'
cache_dic['fresh_ratio_schedule'] = 'ToCa'
cache_dic['fresh_ratio'] = 0.0
cache_dic['fresh_threshold'] = 1
cache_dic['force_fresh'] = 'global'
mode = 'Taylor'
if mode == 'original':
cache_dic['cache'] = cache
cache_dic['force_fresh'] = 'global'
cache_dic['max_order'] = 0
cache_dic['first_enhance'] = 3
elif mode == 'ToCa':
cache_dic['cache_type'] = 'attention'
cache_dic['cache'] = cache
cache_dic['fresh_ratio_schedule'] = 'ToCa'
cache_dic['fresh_ratio'] = 0.1
cache_dic['fresh_threshold'] = 5
cache_dic['force_fresh'] = 'global'
cache_dic['soft_fresh_weight'] = 0.0
cache_dic['max_order'] = 0
cache_dic['first_enhance'] = 3
elif mode == 'Taylor':
cache_dic['cache'] = cache
cache_dic['fresh_threshold'] = 5
cache_dic['taylor_cache'] = True
cache_dic['max_order'] = 1
cache_dic['first_enhance'] = 1
elif mode == 'Delta':
cache_dic['cache'] = cache
cache_dic['fresh_ratio'] = 0.0
cache_dic['fresh_threshold'] = 3
cache_dic['Delta-DiT'] = True
cache_dic['max_order'] = 0
cache_dic['first_enhance'] = 1
current = {}
current['activated_steps'] = [0]
current['step'] = 0
current['num_steps'] = num_steps
return cache_dic, current

View File

@ -0,0 +1,42 @@
from .force_scheduler import force_scheduler
def cal_type(cache_dic, current):
'''
Determine calculation type for this step
'''
if (cache_dic['fresh_ratio'] == 0.0) and (not cache_dic['taylor_cache']):
# FORA:Uniform
first_step = (current['step'] == 0)
else:
# ToCa: First enhanced
first_step = (current['step'] < cache_dic['first_enhance'])
#first_step = (current['step'] <= 3)
force_fresh = cache_dic['force_fresh']
if not first_step:
fresh_interval = cache_dic['cal_threshold']
else:
fresh_interval = cache_dic['fresh_threshold']
if (first_step) or (cache_dic['cache_counter'] == fresh_interval - 1 ):
current['type'] = 'full'
cache_dic['cache_counter'] = 0
current['activated_steps'].append(current['step'])
#current['activated_times'].append(current['t'])
force_scheduler(cache_dic, current)
elif (cache_dic['taylor_cache']):
cache_dic['cache_counter'] += 1
current['type'] = 'Taylor'
elif (cache_dic['cache_counter'] % 2 == 1): # 0: ToCa-Aggresive-ToCa, 1: Aggresive-ToCa-Aggresive
cache_dic['cache_counter'] += 1
current['type'] = 'ToCa'
# 'cache_noise' 'ToCa' 'FORA'
elif cache_dic['Delta-DiT']:
cache_dic['cache_counter'] += 1
current['type'] = 'Delta-Cache'
else:
cache_dic['cache_counter'] += 1
current['type'] = 'ToCa'

View File

@ -0,0 +1,16 @@
import torch
def force_scheduler(cache_dic, current):
if cache_dic['fresh_ratio'] == 0:
# FORA
linear_step_weight = 0.0
else:
# TokenCache
linear_step_weight = 0.0
step_factor = torch.tensor(1 - linear_step_weight + 2 * linear_step_weight * current['step'] / current['num_steps'])
threshold = torch.round(cache_dic['fresh_threshold'] / step_factor)
# no force constrain for sensitive steps, cause the performance is good enough.
# you may have a try.
cache_dic['cal_threshold'] = threshold
#return threshold

View File

@ -0,0 +1,5 @@
from .wan_forward import wan_forward
from .xfusers_wan_forward import xfusers_wan_forward
from .wan_attention_forward import wan_attention_forward
from .wan_attention_forward_cache_step import wan_attention_forward_cache_step
from .wan_cache_forward import wan_cache_forward

View File

@ -0,0 +1,23 @@
import torch
import torch.cuda.amp as amp
from typing import Dict
from wan.taylorseer.taylorseer_utils import taylor_cache_init, derivative_approximation, taylor_formula
@torch.compile
def wan_attention_cache_forward(sa_dict:Dict, ca_dict:Dict, ffn_dict:Dict, e:tuple, x:torch.Tensor, distance:int):
seer_sa = taylor_formula(derivative_dict=sa_dict, distance=distance)
seer_ca = taylor_formula(derivative_dict=ca_dict, distance=distance)
seer_ffn = taylor_formula(derivative_dict=ffn_dict, distance=distance)
x = cache_add(x, seer_sa, seer_ca, seer_ffn, e)
return x
def cache_add(x, sa, ca, ffn, e):
with amp.autocast(dtype=torch.float32):
x = x + sa * e[2]
x = x + ca
with amp.autocast(dtype=torch.float32):
x = x + ffn * e[5]
return x

View File

@ -0,0 +1,82 @@
import math
import torch
import torch.cuda.amp as amp
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from wan.modules import WanModel
from wan.modules.model import sinusoidal_embedding_1d, WanAttentionBlock
from wan.modules.attention import flash_attention
from wan.taylorseer.taylorseer_utils import taylor_cache_init, derivative_approximation, taylor_formula
from .wan_attention_cache_forward import wan_attention_cache_forward
def wan_attention_forward(
self:WanAttentionBlock,
x,
e,
seq_lens,
grid_sizes,
freqs,
context,
context_lens,
cache_dic,
current
):
r"""
Args:
x(Tensor): Shape [B, L, C]
e(Tensor): Shape [B, 6, C]
seq_lens(Tensor): Shape [B], length of each sequence in batch
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
assert e.dtype == torch.float32
with amp.autocast(dtype=torch.float32):
e = (self.modulation + e).chunk(6, dim=1)
assert e[0].dtype == torch.float32
if current['type'] == 'full':
# self-attention
current['module'] = 'self-attention'
taylor_cache_init(cache_dic=cache_dic, current=current)
y = self.self_attn(
self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes,
freqs)
derivative_approximation(cache_dic=cache_dic, current=current, feature=y)
with amp.autocast(dtype=torch.float32):
x = x + y * e[2]
# cross-attention
current['module'] = 'cross-attention'
taylor_cache_init(cache_dic=cache_dic, current=current)
y = self.cross_attn(self.norm3(x), context, context_lens)
derivative_approximation(cache_dic=cache_dic, current=current, feature=y)
x = x + y
# ffn
current['module'] = 'ffn'
taylor_cache_init(cache_dic=cache_dic, current=current)
y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
derivative_approximation(cache_dic=cache_dic, current=current, feature=y)
with amp.autocast(dtype=torch.float32):
x = x + y * e[5]
elif current['type'] == 'Taylor':
#x = wan_attention_cache_forward(cache_dic, current, e, x)
x = wan_attention_cache_forward(
sa_dict=cache_dic['cache'][-1][current['stream']][current['layer']]['self-attention'],
ca_dict=cache_dic['cache'][-1][current['stream']][current['layer']]['cross-attention'],
ffn_dict=cache_dic['cache'][-1][current['stream']][current['layer']]['ffn'],
e=e,
x=x,
distance= current['step'] - current['activated_steps'][-1]
)
else:
raise ValueError(f"Not supported type: {current['type']}")
return x

View File

@ -0,0 +1,37 @@
import torch
import torch.cuda.amp as amp
from diffusers.configuration_utils import ConfigMixin, register_to_config
from wan.modules.model import WanAttentionBlock
from .wan_attention_cache_forward import wan_attention_cache_forward
def wan_attention_forward_cache_step(
self:WanAttentionBlock,
x,
e,
layer_cache_dict,
distance,
):
r"""
Args:
x(Tensor): Shape [B, L, C]
e(Tensor): Shape [B, 6, C]
seq_lens(Tensor): Shape [B], length of each sequence in batch
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
assert e.dtype == torch.float32
with amp.autocast(dtype=torch.float32):
e = (self.modulation + e).chunk(6, dim=1)
assert e[0].dtype == torch.float32
x = wan_attention_cache_forward(
sa_dict= layer_cache_dict['self-attention'],
ca_dict= layer_cache_dict['cross-attention'],
ffn_dict= layer_cache_dict['ffn'],
e=e,
x=x,
distance= distance
)
return x

View File

@ -0,0 +1,18 @@
import torch
import torch.cuda.amp as amp
from wan.modules import WanModel
@torch.compile
def wan_cache_forward(self:WanModel,
e:torch.Tensor,
cond_cache_dict:dict,
distance:int,
x:torch.Tensor) -> torch.Tensor:
for i, block in enumerate(self.blocks):
x = block.cache_step_forward(x,
e=e,
layer_cache_dict=cond_cache_dict[i],
distance=distance)
return x

View File

@ -0,0 +1,114 @@
import math
import torch
import torch.cuda.amp as amp
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from wan.modules import WanModel
from wan.modules.model import sinusoidal_embedding_1d
from wan.taylorseer.cache_functions import cal_type
from wan.taylorseer.taylorseer_utils import taylor_formula
from .wan_cache_forward import wan_cache_forward
def wan_forward(
self:WanModel,
x,
t,
context,
seq_len,
current_step,
current_stream,
clip_fea=None,
y=None,
):
r"""
Forward pass through the diffusion model
Args:
x (List[Tensor]):
List of input video tensors, each with shape [C_in, F, H, W]
t (Tensor):
Diffusion timesteps tensor of shape [B]
context (List[Tensor]):
List of text embeddings each with shape [L, C]
seq_len (`int`):
Maximum sequence length for positional encoding
clip_fea (Tensor, *optional*):
CLIP image features for image-to-video mode
y (List[Tensor], *optional*):
Conditional video inputs for image-to-video mode, same shape as x
Returns:
List[Tensor]:
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
"""
self.current['step'] = current_step
self.current['stream'] = current_stream
if current_stream == 'cond_stream':
cal_type(self.cache_dic, self.current)
if self.model_type == 'i2v':
assert clip_fea is not None and y is not None
# params
device = self.patch_embedding.weight.device
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
if y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
grid_sizes = torch.stack(
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
assert seq_lens.max() <= seq_len
x = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in x
])
# time embeddings
with amp.autocast(dtype=torch.float32):
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).float())
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32
# context
context_lens = None
context = self.text_embedding(
torch.stack([
torch.cat(
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
for u in context
]))
if clip_fea is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
# arguments
kwargs = dict(
e=e0,
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=self.freqs,
context=context,
context_lens=context_lens,
cache_dic=self.cache_dic,
current=self.current)
for i, block in enumerate(self.blocks):
self.current['layer'] = i
x = block(x, **kwargs)
# head
x = self.head(x, e)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x]

View File

@ -0,0 +1,104 @@
import torch
import torch.cuda.amp as amp
from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
from wan.modules import WanModel
from wan.modules.model import sinusoidal_embedding_1d
from wan.taylorseer.cache_functions import cal_type
from wan.taylorseer.taylorseer_utils import taylor_formula
from .wan_cache_forward import wan_cache_forward
def xfusers_wan_forward(
self:WanModel,
x,
t,
context,
seq_len,
current_step,
current_stream,
clip_fea=None,
y=None,
):
"""
x: A list of videos each with shape [C, T, H, W].
t: [B].
context: A list of text embeddings each with shape [L, C].
"""
self.current['step'] = current_step
self.current['stream'] = current_stream
if current_stream == 'cond_stream':
cal_type(self.cache_dic, self.current)
if self.model_type == 'i2v':
assert clip_fea is not None and y is not None
# params
device = self.patch_embedding.weight.device
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
if y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
grid_sizes = torch.stack(
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
assert seq_lens.max() <= seq_len
x = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
for u in x
])
# time embeddings
with amp.autocast(dtype=torch.float32):
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).float())
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32
# context
context_lens = None
context = self.text_embedding(
torch.stack([
torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
for u in context
]))
if clip_fea is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
# arguments
kwargs = dict(
e=e0,
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=self.freqs,
context=context,
context_lens=context_lens,
cache_dic=self.cache_dic,
current=self.current)
# Context Parallel
x = torch.chunk(
x, get_sequence_parallel_world_size(),
dim=1)[get_sequence_parallel_rank()]
for i, block in enumerate(self.blocks):
self.current['layer'] = i
x = block(x, **kwargs)
# head
x = self.head(x, e)
# Context Parallel
x = get_sp_group().all_gather(x, dim=1)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x]

View File

@ -0,0 +1,2 @@
from .wan_t2v_generate import wan_t2v_generate
from .wan_i2v_generate import wan_i2v_generate

View File

@ -0,0 +1,264 @@
import gc
import logging
import math
import os
import random
import sys
import types
from contextlib import contextmanager
from functools import partial
import numpy as np
import torch
import torch.cuda.amp as amp
import torch.distributed as dist
import torchvision.transforms.functional as TF
from tqdm import tqdm
from wan.distributed.fsdp import shard_model
from wan.modules.clip import CLIPModel
from wan.modules.model import WanModel
from wan.modules.t5 import T5EncoderModel
from wan.modules.vae import WanVAE
from wan.utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps)
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from wan import WanI2V
def wan_i2v_generate(self:WanI2V,
input_prompt,
img,
max_area=720 * 1280,
frame_num=81,
shift=5.0,
sample_solver='unipc',
sampling_steps=40,
guide_scale=5.0,
n_prompt="",
seed=-1,
offload_model=True):
r"""
Generates video frames from input image and text prompt using diffusion process.
Args:
input_prompt (`str`):
Text prompt for content generation.
img (PIL.Image.Image):
Input image tensor. Shape: [3, H, W]
max_area (`int`, *optional*, defaults to 720*1280):
Maximum pixel area for latent space calculation. Controls video resolution scaling
frame_num (`int`, *optional*, defaults to 81):
How many frames to sample from a video. The number should be 4n+1
shift (`float`, *optional*, defaults to 5.0):
Noise schedule shift parameter. Affects temporal dynamics
[NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
sample_solver (`str`, *optional*, defaults to 'unipc'):
Solver used to sample the video.
sampling_steps (`int`, *optional*, defaults to 40):
Number of diffusion sampling steps. Higher values improve quality but slow generation
guide_scale (`float`, *optional*, defaults 5.0):
Classifier-free guidance scale. Controls prompt adherence vs. creativity
n_prompt (`str`, *optional*, defaults to ""):
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
seed (`int`, *optional*, defaults to -1):
Random seed for noise generation. If -1, use random seed
offload_model (`bool`, *optional*, defaults to True):
If True, offloads models to CPU during generation to save VRAM
Returns:
torch.Tensor:
Generated video frames tensor. Dimensions: (C, N H, W) where:
- C: Color channels (3 for RGB)
- N: Number of frames (81)
- H: Frame height (from max_area)
- W: Frame width from max_area)
"""
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
F = frame_num
h, w = img.shape[1:]
aspect_ratio = h / w
lat_h = round(
np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
self.patch_size[1] * self.patch_size[1])
lat_w = round(
np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
self.patch_size[2] * self.patch_size[2])
h = lat_h * self.vae_stride[1]
w = lat_w * self.vae_stride[2]
max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
self.patch_size[1] * self.patch_size[2])
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed)
noise = torch.randn(
16,
21,
lat_h,
lat_w,
dtype=torch.float32,
generator=seed_g,
device=self.device)
msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
msk[:, 1:] = 0
msk = torch.concat([
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
],
dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2)[0]
if n_prompt == "":
n_prompt = self.sample_neg_prompt
# preprocess
if not self.t5_cpu:
self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
if offload_model:
self.text_encoder.model.cpu()
else:
context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null]
self.clip.model.to(self.device)
clip_context = self.clip.visual([img[:, None, :, :]])
if offload_model:
self.clip.model.cpu()
y = self.vae.encode([
torch.concat([
torch.nn.functional.interpolate(
img[None].cpu(), size=(h, w), mode='bicubic').transpose(
0, 1),
torch.zeros(3, 80, h, w)
],
dim=1).to(self.device)
])[0]
y = torch.concat([msk, y])
@contextmanager
def noop_no_sync():
yield
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
# evaluation mode
with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
if sample_solver == 'unipc':
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sample_scheduler.set_timesteps(
sampling_steps, device=self.device, shift=shift)
timesteps = sample_scheduler.timesteps
elif sample_solver == 'dpm++':
sample_scheduler = FlowDPMSolverMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
timesteps, _ = retrieve_timesteps(
sample_scheduler,
device=self.device,
sigmas=sampling_sigmas)
else:
raise NotImplementedError("Unsupported solver.")
# sample videos
latent = noise
arg_c = {
'context': [context[0]],
'clip_fea': clip_context,
'seq_len': max_seq_len,
'y': [y],
}
arg_null = {
'context': context_null,
'clip_fea': clip_context,
'seq_len': max_seq_len,
'y': [y],
}
if offload_model:
torch.cuda.empty_cache()
self.model.to(self.device)
for i, t in enumerate(tqdm(timesteps)):
current_step = i
latent_model_input = [latent.to(self.device)]
timestep = [t]
timestep = torch.stack(timestep).to(self.device)
current_stream = 'cond_stream'
noise_pred_cond = self.model(
latent_model_input, t=timestep,
current_step = current_step,
current_stream = current_stream,
**arg_c)[0].to(
torch.device('cpu') if offload_model else self.device)
if offload_model:
torch.cuda.empty_cache()
current_stream = 'uncond_stream'
noise_pred_uncond = self.model(
latent_model_input, t=timestep,
current_step = current_step,
current_stream = current_stream,
**arg_null)[0].to(
torch.device('cpu') if offload_model else self.device)
if offload_model:
torch.cuda.empty_cache()
noise_pred = noise_pred_uncond + guide_scale * (
noise_pred_cond - noise_pred_uncond)
latent = latent.to(
torch.device('cpu') if offload_model else self.device)
temp_x0 = sample_scheduler.step(
noise_pred.unsqueeze(0),
t,
latent.unsqueeze(0),
return_dict=False,
generator=seed_g)[0]
latent = temp_x0.squeeze(0)
x0 = [latent.to(self.device)]
del latent_model_input, timestep
if offload_model:
self.model.cpu()
torch.cuda.empty_cache()
if self.rank == 0:
videos = self.vae.decode(x0)
del noise, latent
del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
if dist.is_initialized():
dist.barrier()
return videos[0] if self.rank == 0 else None

View File

@ -0,0 +1,207 @@
import gc
import logging
import math
import os
import random
import sys
import types
from contextlib import contextmanager
from functools import partial
import torch
import torch.cuda.amp as amp
import torch.distributed as dist
from tqdm import tqdm
from wan.distributed.fsdp import shard_model
from wan.modules.model import WanModel
from wan.modules.t5 import T5EncoderModel
from wan.modules.vae import WanVAE
from wan.utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps)
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from wan import WanT2V
from wan.taylorseer.cache_functions import cache_init, cal_type
def wan_t2v_generate(self:WanT2V,
input_prompt,
size=(1280, 720),
frame_num=81,
shift=5.0,
sample_solver='unipc',
sampling_steps=50,
guide_scale=5.0,
n_prompt="",
seed=-1,
offload_model=True):
r"""
Generates video frames from text prompt using diffusion process.
Args:
input_prompt (`str`):
Text prompt for content generation
size (tupele[`int`], *optional*, defaults to (1280,720)):
Controls video resolution, (width,height).
frame_num (`int`, *optional*, defaults to 81):
How many frames to sample from a video. The number should be 4n+1
shift (`float`, *optional*, defaults to 5.0):
Noise schedule shift parameter. Affects temporal dynamics
sample_solver (`str`, *optional*, defaults to 'unipc'):
Solver used to sample the video.
sampling_steps (`int`, *optional*, defaults to 40):
Number of diffusion sampling steps. Higher values improve quality but slow generation
guide_scale (`float`, *optional*, defaults 5.0):
Classifier-free guidance scale. Controls prompt adherence vs. creativity
n_prompt (`str`, *optional*, defaults to ""):
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
seed (`int`, *optional*, defaults to -1):
Random seed for noise generation. If -1, use random seed.
offload_model (`bool`, *optional*, defaults to True):
If True, offloads models to CPU during generation to save VRAM
Returns:
torch.Tensor:
Generated video frames tensor. Dimensions: (C, N H, W) where:
- C: Color channels (3 for RGB)
- N: Number of frames (81)
- H: Frame height (from size)
- W: Frame width from size)
"""
# preprocess
F = frame_num
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
size[1] // self.vae_stride[1],
size[0] // self.vae_stride[2])
seq_len = math.ceil((target_shape[2] * target_shape[3]) /
(self.patch_size[1] * self.patch_size[2]) *
target_shape[1] / self.sp_size) * self.sp_size
if n_prompt == "":
n_prompt = self.sample_neg_prompt
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed)
if not self.t5_cpu:
self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
if offload_model:
self.text_encoder.model.cpu()
else:
context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null]
noise = [
torch.randn(
target_shape[0],
target_shape[1],
target_shape[2],
target_shape[3],
dtype=torch.float32,
device=self.device,
generator=seed_g)
]
@contextmanager
def noop_no_sync():
yield
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
# evaluation mode
with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
if sample_solver == 'unipc':
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sample_scheduler.set_timesteps(
sampling_steps, device=self.device, shift=shift)
timesteps = sample_scheduler.timesteps
elif sample_solver == 'dpm++':
sample_scheduler = FlowDPMSolverMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
timesteps, _ = retrieve_timesteps(
sample_scheduler,
device=self.device,
sigmas=sampling_sigmas)
else:
raise NotImplementedError("Unsupported solver.")
# sample videos
latents = noise
arg_c = {'context': context, 'seq_len': seq_len}
arg_null = {'context': context_null, 'seq_len': seq_len}
self.model.to(self.device)
for i, t in enumerate(tqdm(timesteps)):
torch.compiler.cudagraph_mark_step_begin()
current_step = i
latent_model_input = latents
timestep = [t]
timestep = torch.stack(timestep)
current_stream = 'cond_stream'
noise_pred_cond = self.model(
latent_model_input, t=timestep,
current_step = current_step,
current_stream = current_stream,
**arg_c)[0]
current_stream = 'uncond_stream'
noise_pred_uncond = self.model(
latent_model_input, t=timestep,
current_step = current_step,
current_stream = current_stream,
**arg_null)[0]
noise_pred = noise_pred_uncond + guide_scale * (
noise_pred_cond - noise_pred_uncond)
temp_x0 = sample_scheduler.step(
noise_pred.unsqueeze(0),
t,
latents[0].unsqueeze(0),
return_dict=False,
generator=seed_g)[0]
latents = [temp_x0.squeeze(0)]
x0 = latents
if offload_model:
self.model.cpu()
torch.cuda.empty_cache()
if self.rank == 0:
videos = self.vae.decode(x0)
del noise, latents
del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
if dist.is_initialized():
dist.barrier()
return videos[0] if self.rank == 0 else None

View File

@ -0,0 +1,47 @@
from typing import Dict
import torch
import math
def derivative_approximation(cache_dic: Dict, current: Dict, feature: torch.Tensor):
"""
Compute derivative approximation.
:param cache_dic: Cache dictionary
:param current: Information of the current step
"""
difference_distance = current['activated_steps'][-1] - current['activated_steps'][-2]
#difference_distance = current['activated_times'][-1] - current['activated_times'][-2]
updated_taylor_factors = {}
updated_taylor_factors[0] = feature
for i in range(cache_dic['max_order']):
if (cache_dic['cache'][-1][current['stream']][current['layer']][current['module']].get(i, None) is not None) and (current['step'] > cache_dic['first_enhance'] - 2):
updated_taylor_factors[i + 1] = (updated_taylor_factors[i] - cache_dic['cache'][-1][current['stream']][current['layer']][current['module']][i]) / difference_distance
else:
break
cache_dic['cache'][-1][current['stream']][current['layer']][current['module']] = updated_taylor_factors
def taylor_formula(derivative_dict: Dict, distance: int) -> torch.Tensor:
"""
Compute Taylor expansion error.
:param derivative_dict: Derivative dictionary
:param x: Current step
"""
output=0
for i in range(len(derivative_dict)):
output += (1 / math.factorial(i)) * derivative_dict[i] * (distance ** i)
return output
def taylor_cache_init(cache_dic: Dict, current: Dict):
"""
Initialize Taylor cache and allocate storage for different-order derivatives in the Taylor cache.
:param cache_dic: Cache dictionary
:param current: Information of the current step
"""
if (current['step'] == 0) and (cache_dic['taylor_cache']):
cache_dic['cache'][-1][current['stream']][current['layer']][current['module']] = {}

View File

@ -38,6 +38,7 @@ class WanT2V:
dit_fsdp=False,
use_usp=False,
t5_cpu=False,
use_taylor_cache = False
):
r"""
Initializes the Wan text-to-video generation model components.
@ -86,21 +87,49 @@ class WanT2V:
logging.info(f"Creating WanModel from {checkpoint_dir}")
self.model = WanModel.from_pretrained(checkpoint_dir)
self.model.eval().requires_grad_(False)
from wan.taylorseer.forwards import wan_attention_forward, xfusers_wan_forward, wan_forward#, wan_attention_forward_cache_step
if use_usp:
from xfuser.core.distributed import get_sequence_parallel_world_size
from xfuser.core.distributed import \
get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (
usp_attn_forward,
usp_dit_forward,
)
from .distributed.xdit_context_parallel import (usp_attn_forward,
usp_dit_forward)
for block in self.model.blocks:
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn)
self.model.forward = types.MethodType(usp_dit_forward, self.model)
if use_taylor_cache:
block.forward = types.MethodType(wan_attention_forward, block)
#block.cache_step_forward = types.MethodType(wan_attention_forward_cache_step, block)
else :
self.model.forward = types.MethodType(usp_dit_forward, self.model)
if use_taylor_cache:
self.model.forward = types.MethodType(xfusers_wan_forward, self.model)
self.sp_size = get_sequence_parallel_world_size()
else:
if use_taylor_cache:
for block in self.model.blocks:
block.forward = types.MethodType(wan_attention_forward, block)
#block.cache_step_forward = types.MethodType(wan_attention_forward_cache_step, block)
self.model.forward = types.MethodType(wan_forward, self.model)
self.sp_size = 1
#if use_usp:
# from xfuser.core.distributed import \
# get_sequence_parallel_world_size
#
# from .distributed.xdit_context_parallel import (usp_attn_forward,
# usp_dit_forward)
# for block in self.model.blocks:
# block.self_attn.forward = types.MethodType(
# usp_attn_forward, block.self_attn)
# self.model.forward = types.MethodType(usp_dit_forward, self.model)
# self.sp_size = get_sequence_parallel_world_size()
#else:
# self.sp_size = 1
if dist.is_initialized():
dist.barrier()