mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
Added TeaCache support
This commit is contained in:
parent
a15ac428c3
commit
ec1159bb59
@ -20,7 +20,8 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
|
||||
|
||||
## 🔥 Latest News!!
|
||||
|
||||
* Mar 03, 2025: 👋 Wan2.1GP by DeepBeepMeep v1 brings:
|
||||
* Mar 03, 2025: 👋 Wan2.1GP v1.1: added Tea Cache support for faster generations: optimization of kijai's implementation (https://github.com/kijai/ComfyUI-WanVideoWrapper/) of teacache (https://github.com/ali-vilab/TeaCache)
|
||||
* Mar 02, 2025: 👋 Wan2.1GP by DeepBeepMeep v1 brings:
|
||||
- Support for all Wan including the Image to Video model
|
||||
- Reduced memory consumption by 2, with possiblity to generate more than 10s of video at 720p with a RTX 4090 and 10s of video at 480p with less than 12GB of VRAM. Many thanks to REFLEx (https://github.com/thu-ml/RIFLEx) for their algorithm that allows generating nice looking video longer than 5s.
|
||||
- The usual perks: web interface, multiple generations, loras support, sage attebtion, auto download of models, ...
|
||||
@ -162,7 +163,7 @@ You will find prebuilt Loras on https://civitai.com/ or you will be able to buil
|
||||
|
||||
### Profiles (for power users only)
|
||||
You can choose between 5 profiles, but two are really relevant here :
|
||||
- LowRAM_HighVRAM (3): loads entirely the model in VRAM, slighty faster, but less VRAM
|
||||
- LowRAM_HighVRAM (3): loads entirely the model in VRAM, slightly faster, but less VRAM
|
||||
- LowRAM_LowVRAM (4): load only the part of the models that is needed, low VRAM and low RAM requirement but slightly slower
|
||||
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@ from wan.modules.attention import get_attention_modes
|
||||
import torch
|
||||
import gc
|
||||
import traceback
|
||||
|
||||
import math
|
||||
|
||||
def _parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -650,6 +650,7 @@ def generate_video(
|
||||
embedded_guidance_scale,
|
||||
repeat_generation,
|
||||
tea_cache,
|
||||
tea_cache_start_step_perc,
|
||||
loras_choices,
|
||||
loras_mult_choices,
|
||||
image_to_continue,
|
||||
@ -783,12 +784,15 @@ def generate_video(
|
||||
break
|
||||
|
||||
if trans.enable_teacache:
|
||||
trans.num_steps = num_inference_steps
|
||||
trans.cnt = 0
|
||||
trans.rel_l1_thresh = tea_cache #0.15 # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup
|
||||
trans.accumulated_rel_l1_distance = 0
|
||||
trans.previous_modulated_input = None
|
||||
trans.previous_residual = None
|
||||
trans.teacache_counter = 0
|
||||
trans.rel_l1_thresh = tea_cache
|
||||
trans.teacache_start_step = max(math.ceil(tea_cache_start_step_perc*num_inference_steps/100),2)
|
||||
trans.previous_residual_uncond = None
|
||||
trans.previous_modulated_input_uncond = None
|
||||
trans.previous_residual_cond = None
|
||||
trans.previous_modulated_input_cond= None
|
||||
|
||||
trans.teacache_cache_device = "cuda" if profile==3 or profile==1 else "cpu"
|
||||
|
||||
video_no += 1
|
||||
status = f"Video {video_no}/{total_video}"
|
||||
@ -799,6 +803,7 @@ def generate_video(
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
wan_model._interrupt = False
|
||||
try:
|
||||
if use_image2video:
|
||||
samples = wan_model.generate(
|
||||
@ -858,6 +863,9 @@ def generate_video(
|
||||
else:
|
||||
raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
|
||||
|
||||
if trans.enable_teacache:
|
||||
trans.previous_residual_uncond = None
|
||||
trans.previous_residual_cond = None
|
||||
|
||||
if samples != None:
|
||||
samples = samples.to("cpu")
|
||||
@ -874,7 +882,10 @@ def generate_video(
|
||||
# video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c")
|
||||
|
||||
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
|
||||
file_name = f"{time_flag}_seed{seed}_{prompt[:100].replace('/','').strip()}.mp4".replace(':',' ').replace('\\',' ')
|
||||
if os.name == 'nt':
|
||||
file_name = f"{time_flag}_seed{seed}_{prompt[:50].replace('/','').strip()}.mp4".replace(':',' ').replace('\\',' ')
|
||||
else:
|
||||
file_name = f"{time_flag}_seed{seed}_{prompt[:100].replace('/','').strip()}.mp4".replace(':',' ').replace('\\',' ')
|
||||
video_path = os.path.join(os.getcwd(), "gradio_outputs", file_name)
|
||||
cache_video(
|
||||
tensor=sample[None],
|
||||
@ -1189,14 +1200,16 @@ def create_demo():
|
||||
flow_shift = gr.Slider(0.0, 25.0, value= default_flow_shift, step=0.1, label="Shift Scale")
|
||||
tea_cache_setting = gr.Dropdown(
|
||||
choices=[
|
||||
("Disabled", 0),
|
||||
("Fast (x1.6 speed up)", 0.1),
|
||||
("Faster (x2.1 speed up)", 0.15),
|
||||
("Tea Cache Disabled", 0),
|
||||
("0.03 (around x1.6 speed up)", 0.03),
|
||||
("0.05 (around x2 speed up)", 0.05),
|
||||
("0.10 (around x3 speed up)", 0.1),
|
||||
],
|
||||
value=default_tea_cache,
|
||||
visible=False,
|
||||
label="Tea Cache acceleration (the faster the acceleration the higher the degradation of the quality of the video. Consumes VRAM)"
|
||||
visible=True,
|
||||
label="Tea Cache Threshold to Skip Steps (the higher, the more steps are skipped but the lower the quality of the video (Tea Cache Consumes VRAM)"
|
||||
)
|
||||
tea_cache_start_step_perc = gr.Slider(2, 100, value=20, step=1, label="Tea Cache starting moment in percentage of generation (the later, the higher the quality but also the lower the speed gain)")
|
||||
|
||||
RIFLEx_setting = gr.Dropdown(
|
||||
choices=[
|
||||
@ -1241,6 +1254,7 @@ def create_demo():
|
||||
embedded_guidance_scale,
|
||||
repeat_generation,
|
||||
tea_cache_setting,
|
||||
tea_cache_start_step_perc,
|
||||
loras_choices,
|
||||
loras_mult_choices,
|
||||
image_to_continue,
|
||||
|
||||
@ -316,7 +316,6 @@ class WanI2V:
|
||||
if callback != None:
|
||||
callback(-1, None)
|
||||
|
||||
self._interrupt = False
|
||||
for i, t in enumerate(tqdm(timesteps)):
|
||||
latent_model_input = [latent.to(self.device)]
|
||||
timestep = [t]
|
||||
@ -324,13 +323,13 @@ class WanI2V:
|
||||
timestep = torch.stack(timestep).to(self.device)
|
||||
|
||||
noise_pred_cond = self.model(
|
||||
latent_model_input, t=timestep, **arg_c)[0]
|
||||
latent_model_input, t=timestep, current_step=i, is_uncond = False, **arg_c)[0]
|
||||
if self._interrupt:
|
||||
return None
|
||||
if offload_model:
|
||||
torch.cuda.empty_cache()
|
||||
noise_pred_uncond = self.model(
|
||||
latent_model_input, t=timestep, **arg_null)[0]
|
||||
latent_model_input, t=timestep, current_step=i, is_uncond = True, **arg_null)[0]
|
||||
if self._interrupt:
|
||||
return None
|
||||
del latent_model_input
|
||||
|
||||
@ -146,6 +146,11 @@ def rope_apply(x, grid_sizes, freqs):
|
||||
output.append(x_i)
|
||||
return torch.stack(output) #.float()
|
||||
|
||||
def relative_l1_distance(last_tensor, current_tensor):
|
||||
l1_distance = torch.abs(last_tensor - current_tensor).mean()
|
||||
norm = torch.abs(last_tensor).mean()
|
||||
relative_l1_distance = l1_distance / norm
|
||||
return relative_l1_distance.to(torch.float32)
|
||||
|
||||
class WanRMSNorm(nn.Module):
|
||||
|
||||
@ -662,6 +667,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
|
||||
return freqs
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
@ -672,6 +678,8 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
y=None,
|
||||
freqs = None,
|
||||
pipeline = None,
|
||||
current_step = 0,
|
||||
is_uncond=False
|
||||
):
|
||||
r"""
|
||||
Forward pass through the diffusion model
|
||||
@ -723,7 +731,6 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, t))
|
||||
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(torch.bfloat16)
|
||||
# assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
||||
|
||||
# context
|
||||
context_lens = None
|
||||
@ -737,21 +744,71 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
if clip_fea is not None:
|
||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||
context = torch.concat([context_clip, context], dim=1)
|
||||
# deepbeepmeep optimization of kijai's implementation (https://github.com/kijai/ComfyUI-WanVideoWrapper/) of teacache (https://github.com/ali-vilab/TeaCache)
|
||||
should_calc = True
|
||||
if self.enable_teacache and current_step >= self.teacache_start_step:
|
||||
if current_step == self.teacache_start_step:
|
||||
self.accumulated_rel_l1_distance_cond = 0
|
||||
self.accumulated_rel_l1_distance_uncond = 0
|
||||
self.teacache_skipped_cond_steps = 0
|
||||
self.teacache_skipped_uncond_steps = 0
|
||||
else:
|
||||
prev_input = self.previous_modulated_input_uncond if is_uncond else self.previous_modulated_input_cond
|
||||
acc_distance_attr = 'accumulated_rel_l1_distance_uncond' if is_uncond else 'accumulated_rel_l1_distance_cond'
|
||||
|
||||
# arguments
|
||||
kwargs = dict(
|
||||
e=e0,
|
||||
seq_lens=seq_lens,
|
||||
grid_sizes=grid_sizes,
|
||||
freqs=freqs,
|
||||
context=context,
|
||||
context_lens=context_lens)
|
||||
temb_relative_l1 = relative_l1_distance(prev_input, e0)
|
||||
setattr(self, acc_distance_attr, getattr(self, acc_distance_attr) + temb_relative_l1)
|
||||
|
||||
for block in self.blocks:
|
||||
if pipeline._interrupt:
|
||||
return [None]
|
||||
if getattr(self, acc_distance_attr) < self.rel_l1_thresh:
|
||||
should_calc = False
|
||||
self.teacache_counter += 1
|
||||
else:
|
||||
should_calc = True
|
||||
setattr(self, acc_distance_attr, 0)
|
||||
|
||||
x = block(x, **kwargs)
|
||||
if is_uncond:
|
||||
self.previous_modulated_input_uncond = e0.clone()
|
||||
if should_calc:
|
||||
self.previous_residual_uncond = None
|
||||
else:
|
||||
x += self.previous_residual_uncond
|
||||
self.teacache_skipped_cond_steps += 1
|
||||
# print(f"Skipped uncond:{self.teacache_skipped_cond_steps}/{current_step}" )
|
||||
else:
|
||||
self.previous_modulated_input_cond = e0.clone()
|
||||
if should_calc:
|
||||
self.previous_residual_cond = None
|
||||
else:
|
||||
x += self.previous_residual_cond
|
||||
self.teacache_skipped_uncond_steps += 1
|
||||
# print(f"Skipped uncond:{self.teacache_skipped_uncond_steps}/{current_step}" )
|
||||
|
||||
if should_calc:
|
||||
if self.enable_teacache:
|
||||
ori_hidden_states = x.clone()
|
||||
# arguments
|
||||
kwargs = dict(
|
||||
e=e0,
|
||||
seq_lens=seq_lens,
|
||||
grid_sizes=grid_sizes,
|
||||
freqs=freqs,
|
||||
context=context,
|
||||
context_lens=context_lens)
|
||||
|
||||
for block in self.blocks:
|
||||
if pipeline._interrupt:
|
||||
return [None]
|
||||
|
||||
x = block(x, **kwargs)
|
||||
|
||||
if self.enable_teacache:
|
||||
residual = ori_hidden_states # just to have a readable code
|
||||
torch.sub(x, ori_hidden_states, out=residual)
|
||||
if is_uncond:
|
||||
self.previous_residual_uncond = residual
|
||||
else:
|
||||
self.previous_residual_cond = residual
|
||||
del residual, ori_hidden_states
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
@ -248,7 +248,6 @@ class WanT2V:
|
||||
|
||||
if callback != None:
|
||||
callback(-1, None)
|
||||
self._interrupt = False
|
||||
for i, t in enumerate(tqdm(timesteps)):
|
||||
latent_model_input = latents
|
||||
timestep = [t]
|
||||
@ -257,11 +256,11 @@ class WanT2V:
|
||||
|
||||
# self.model.to(self.device)
|
||||
noise_pred_cond = self.model(
|
||||
latent_model_input, t=timestep, **arg_c)[0]
|
||||
latent_model_input, t=timestep,current_step=i, is_uncond = False, **arg_c)[0]
|
||||
if self._interrupt:
|
||||
return None
|
||||
noise_pred_uncond = self.model(
|
||||
latent_model_input, t=timestep, **arg_null)[0]
|
||||
latent_model_input, t=timestep,current_step=i, is_uncond = True, **arg_null)[0]
|
||||
if self._interrupt:
|
||||
return None
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user