Added TeaCache support

This commit is contained in:
DeepBeepMeep 2025-03-03 18:41:33 +01:00
parent a15ac428c3
commit ec1159bb59
5 changed files with 104 additions and 34 deletions

View File

@ -20,7 +20,8 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
## 🔥 Latest News!!
* Mar 03, 2025: 👋 Wan2.1GP by DeepBeepMeep v1 brings:
* Mar 03, 2025: 👋 Wan2.1GP v1.1: added Tea Cache support for faster generations: optimization of kijai's implementation (https://github.com/kijai/ComfyUI-WanVideoWrapper/) of teacache (https://github.com/ali-vilab/TeaCache)
* Mar 02, 2025: 👋 Wan2.1GP by DeepBeepMeep v1 brings:
- Support for all Wan including the Image to Video model
- Reduced memory consumption by 2, with possiblity to generate more than 10s of video at 720p with a RTX 4090 and 10s of video at 480p with less than 12GB of VRAM. Many thanks to REFLEx (https://github.com/thu-ml/RIFLEx) for their algorithm that allows generating nice looking video longer than 5s.
- The usual perks: web interface, multiple generations, loras support, sage attebtion, auto download of models, ...
@ -162,7 +163,7 @@ You will find prebuilt Loras on https://civitai.com/ or you will be able to buil
### Profiles (for power users only)
You can choose between 5 profiles, but two are really relevant here :
- LowRAM_HighVRAM (3): loads entirely the model in VRAM, slighty faster, but less VRAM
- LowRAM_HighVRAM (3): loads entirely the model in VRAM, slightly faster, but less VRAM
- LowRAM_LowVRAM (4): load only the part of the models that is needed, low VRAM and low RAM requirement but slightly slower

View File

@ -19,7 +19,7 @@ from wan.modules.attention import get_attention_modes
import torch
import gc
import traceback
import math
def _parse_args():
parser = argparse.ArgumentParser(
@ -650,6 +650,7 @@ def generate_video(
embedded_guidance_scale,
repeat_generation,
tea_cache,
tea_cache_start_step_perc,
loras_choices,
loras_mult_choices,
image_to_continue,
@ -783,12 +784,15 @@ def generate_video(
break
if trans.enable_teacache:
trans.num_steps = num_inference_steps
trans.cnt = 0
trans.rel_l1_thresh = tea_cache #0.15 # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup
trans.accumulated_rel_l1_distance = 0
trans.previous_modulated_input = None
trans.previous_residual = None
trans.teacache_counter = 0
trans.rel_l1_thresh = tea_cache
trans.teacache_start_step = max(math.ceil(tea_cache_start_step_perc*num_inference_steps/100),2)
trans.previous_residual_uncond = None
trans.previous_modulated_input_uncond = None
trans.previous_residual_cond = None
trans.previous_modulated_input_cond= None
trans.teacache_cache_device = "cuda" if profile==3 or profile==1 else "cpu"
video_no += 1
status = f"Video {video_no}/{total_video}"
@ -799,6 +803,7 @@ def generate_video(
gc.collect()
torch.cuda.empty_cache()
wan_model._interrupt = False
try:
if use_image2video:
samples = wan_model.generate(
@ -858,6 +863,9 @@ def generate_video(
else:
raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
if trans.enable_teacache:
trans.previous_residual_uncond = None
trans.previous_residual_cond = None
if samples != None:
samples = samples.to("cpu")
@ -874,7 +882,10 @@ def generate_video(
# video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c")
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
file_name = f"{time_flag}_seed{seed}_{prompt[:100].replace('/','').strip()}.mp4".replace(':',' ').replace('\\',' ')
if os.name == 'nt':
file_name = f"{time_flag}_seed{seed}_{prompt[:50].replace('/','').strip()}.mp4".replace(':',' ').replace('\\',' ')
else:
file_name = f"{time_flag}_seed{seed}_{prompt[:100].replace('/','').strip()}.mp4".replace(':',' ').replace('\\',' ')
video_path = os.path.join(os.getcwd(), "gradio_outputs", file_name)
cache_video(
tensor=sample[None],
@ -1189,14 +1200,16 @@ def create_demo():
flow_shift = gr.Slider(0.0, 25.0, value= default_flow_shift, step=0.1, label="Shift Scale")
tea_cache_setting = gr.Dropdown(
choices=[
("Disabled", 0),
("Fast (x1.6 speed up)", 0.1),
("Faster (x2.1 speed up)", 0.15),
("Tea Cache Disabled", 0),
("0.03 (around x1.6 speed up)", 0.03),
("0.05 (around x2 speed up)", 0.05),
("0.10 (around x3 speed up)", 0.1),
],
value=default_tea_cache,
visible=False,
label="Tea Cache acceleration (the faster the acceleration the higher the degradation of the quality of the video. Consumes VRAM)"
visible=True,
label="Tea Cache Threshold to Skip Steps (the higher, the more steps are skipped but the lower the quality of the video (Tea Cache Consumes VRAM)"
)
tea_cache_start_step_perc = gr.Slider(2, 100, value=20, step=1, label="Tea Cache starting moment in percentage of generation (the later, the higher the quality but also the lower the speed gain)")
RIFLEx_setting = gr.Dropdown(
choices=[
@ -1241,6 +1254,7 @@ def create_demo():
embedded_guidance_scale,
repeat_generation,
tea_cache_setting,
tea_cache_start_step_perc,
loras_choices,
loras_mult_choices,
image_to_continue,

View File

@ -316,7 +316,6 @@ class WanI2V:
if callback != None:
callback(-1, None)
self._interrupt = False
for i, t in enumerate(tqdm(timesteps)):
latent_model_input = [latent.to(self.device)]
timestep = [t]
@ -324,13 +323,13 @@ class WanI2V:
timestep = torch.stack(timestep).to(self.device)
noise_pred_cond = self.model(
latent_model_input, t=timestep, **arg_c)[0]
latent_model_input, t=timestep, current_step=i, is_uncond = False, **arg_c)[0]
if self._interrupt:
return None
if offload_model:
torch.cuda.empty_cache()
noise_pred_uncond = self.model(
latent_model_input, t=timestep, **arg_null)[0]
latent_model_input, t=timestep, current_step=i, is_uncond = True, **arg_null)[0]
if self._interrupt:
return None
del latent_model_input

View File

@ -146,6 +146,11 @@ def rope_apply(x, grid_sizes, freqs):
output.append(x_i)
return torch.stack(output) #.float()
def relative_l1_distance(last_tensor, current_tensor):
l1_distance = torch.abs(last_tensor - current_tensor).mean()
norm = torch.abs(last_tensor).mean()
relative_l1_distance = l1_distance / norm
return relative_l1_distance.to(torch.float32)
class WanRMSNorm(nn.Module):
@ -662,6 +667,7 @@ class WanModel(ModelMixin, ConfigMixin):
return freqs
def forward(
self,
x,
@ -672,6 +678,8 @@ class WanModel(ModelMixin, ConfigMixin):
y=None,
freqs = None,
pipeline = None,
current_step = 0,
is_uncond=False
):
r"""
Forward pass through the diffusion model
@ -723,7 +731,6 @@ class WanModel(ModelMixin, ConfigMixin):
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t))
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(torch.bfloat16)
# assert e.dtype == torch.float32 and e0.dtype == torch.float32
# context
context_lens = None
@ -737,21 +744,71 @@ class WanModel(ModelMixin, ConfigMixin):
if clip_fea is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
# deepbeepmeep optimization of kijai's implementation (https://github.com/kijai/ComfyUI-WanVideoWrapper/) of teacache (https://github.com/ali-vilab/TeaCache)
should_calc = True
if self.enable_teacache and current_step >= self.teacache_start_step:
if current_step == self.teacache_start_step:
self.accumulated_rel_l1_distance_cond = 0
self.accumulated_rel_l1_distance_uncond = 0
self.teacache_skipped_cond_steps = 0
self.teacache_skipped_uncond_steps = 0
else:
prev_input = self.previous_modulated_input_uncond if is_uncond else self.previous_modulated_input_cond
acc_distance_attr = 'accumulated_rel_l1_distance_uncond' if is_uncond else 'accumulated_rel_l1_distance_cond'
# arguments
kwargs = dict(
e=e0,
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=freqs,
context=context,
context_lens=context_lens)
temb_relative_l1 = relative_l1_distance(prev_input, e0)
setattr(self, acc_distance_attr, getattr(self, acc_distance_attr) + temb_relative_l1)
for block in self.blocks:
if pipeline._interrupt:
return [None]
if getattr(self, acc_distance_attr) < self.rel_l1_thresh:
should_calc = False
self.teacache_counter += 1
else:
should_calc = True
setattr(self, acc_distance_attr, 0)
if is_uncond:
self.previous_modulated_input_uncond = e0.clone()
if should_calc:
self.previous_residual_uncond = None
else:
x += self.previous_residual_uncond
self.teacache_skipped_cond_steps += 1
# print(f"Skipped uncond:{self.teacache_skipped_cond_steps}/{current_step}" )
else:
self.previous_modulated_input_cond = e0.clone()
if should_calc:
self.previous_residual_cond = None
else:
x += self.previous_residual_cond
self.teacache_skipped_uncond_steps += 1
# print(f"Skipped uncond:{self.teacache_skipped_uncond_steps}/{current_step}" )
x = block(x, **kwargs)
if should_calc:
if self.enable_teacache:
ori_hidden_states = x.clone()
# arguments
kwargs = dict(
e=e0,
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=freqs,
context=context,
context_lens=context_lens)
for block in self.blocks:
if pipeline._interrupt:
return [None]
x = block(x, **kwargs)
if self.enable_teacache:
residual = ori_hidden_states # just to have a readable code
torch.sub(x, ori_hidden_states, out=residual)
if is_uncond:
self.previous_residual_uncond = residual
else:
self.previous_residual_cond = residual
del residual, ori_hidden_states
# head
x = self.head(x, e)

View File

@ -248,7 +248,6 @@ class WanT2V:
if callback != None:
callback(-1, None)
self._interrupt = False
for i, t in enumerate(tqdm(timesteps)):
latent_model_input = latents
timestep = [t]
@ -257,11 +256,11 @@ class WanT2V:
# self.model.to(self.device)
noise_pred_cond = self.model(
latent_model_input, t=timestep, **arg_c)[0]
latent_model_input, t=timestep,current_step=i, is_uncond = False, **arg_c)[0]
if self._interrupt:
return None
noise_pred_uncond = self.model(
latent_model_input, t=timestep, **arg_null)[0]
latent_model_input, t=timestep,current_step=i, is_uncond = True, **arg_null)[0]
if self._interrupt:
return None