mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 22:26:36 +00:00
Added TeaCache support
This commit is contained in:
parent
a15ac428c3
commit
ec1159bb59
@ -20,7 +20,8 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
|
|||||||
|
|
||||||
## 🔥 Latest News!!
|
## 🔥 Latest News!!
|
||||||
|
|
||||||
* Mar 03, 2025: 👋 Wan2.1GP by DeepBeepMeep v1 brings:
|
* Mar 03, 2025: 👋 Wan2.1GP v1.1: added Tea Cache support for faster generations: optimization of kijai's implementation (https://github.com/kijai/ComfyUI-WanVideoWrapper/) of teacache (https://github.com/ali-vilab/TeaCache)
|
||||||
|
* Mar 02, 2025: 👋 Wan2.1GP by DeepBeepMeep v1 brings:
|
||||||
- Support for all Wan including the Image to Video model
|
- Support for all Wan including the Image to Video model
|
||||||
- Reduced memory consumption by 2, with possiblity to generate more than 10s of video at 720p with a RTX 4090 and 10s of video at 480p with less than 12GB of VRAM. Many thanks to REFLEx (https://github.com/thu-ml/RIFLEx) for their algorithm that allows generating nice looking video longer than 5s.
|
- Reduced memory consumption by 2, with possiblity to generate more than 10s of video at 720p with a RTX 4090 and 10s of video at 480p with less than 12GB of VRAM. Many thanks to REFLEx (https://github.com/thu-ml/RIFLEx) for their algorithm that allows generating nice looking video longer than 5s.
|
||||||
- The usual perks: web interface, multiple generations, loras support, sage attebtion, auto download of models, ...
|
- The usual perks: web interface, multiple generations, loras support, sage attebtion, auto download of models, ...
|
||||||
@ -162,7 +163,7 @@ You will find prebuilt Loras on https://civitai.com/ or you will be able to buil
|
|||||||
|
|
||||||
### Profiles (for power users only)
|
### Profiles (for power users only)
|
||||||
You can choose between 5 profiles, but two are really relevant here :
|
You can choose between 5 profiles, but two are really relevant here :
|
||||||
- LowRAM_HighVRAM (3): loads entirely the model in VRAM, slighty faster, but less VRAM
|
- LowRAM_HighVRAM (3): loads entirely the model in VRAM, slightly faster, but less VRAM
|
||||||
- LowRAM_LowVRAM (4): load only the part of the models that is needed, low VRAM and low RAM requirement but slightly slower
|
- LowRAM_LowVRAM (4): load only the part of the models that is needed, low VRAM and low RAM requirement but slightly slower
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from wan.modules.attention import get_attention_modes
|
|||||||
import torch
|
import torch
|
||||||
import gc
|
import gc
|
||||||
import traceback
|
import traceback
|
||||||
|
import math
|
||||||
|
|
||||||
def _parse_args():
|
def _parse_args():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -650,6 +650,7 @@ def generate_video(
|
|||||||
embedded_guidance_scale,
|
embedded_guidance_scale,
|
||||||
repeat_generation,
|
repeat_generation,
|
||||||
tea_cache,
|
tea_cache,
|
||||||
|
tea_cache_start_step_perc,
|
||||||
loras_choices,
|
loras_choices,
|
||||||
loras_mult_choices,
|
loras_mult_choices,
|
||||||
image_to_continue,
|
image_to_continue,
|
||||||
@ -783,12 +784,15 @@ def generate_video(
|
|||||||
break
|
break
|
||||||
|
|
||||||
if trans.enable_teacache:
|
if trans.enable_teacache:
|
||||||
trans.num_steps = num_inference_steps
|
trans.teacache_counter = 0
|
||||||
trans.cnt = 0
|
trans.rel_l1_thresh = tea_cache
|
||||||
trans.rel_l1_thresh = tea_cache #0.15 # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup
|
trans.teacache_start_step = max(math.ceil(tea_cache_start_step_perc*num_inference_steps/100),2)
|
||||||
trans.accumulated_rel_l1_distance = 0
|
trans.previous_residual_uncond = None
|
||||||
trans.previous_modulated_input = None
|
trans.previous_modulated_input_uncond = None
|
||||||
trans.previous_residual = None
|
trans.previous_residual_cond = None
|
||||||
|
trans.previous_modulated_input_cond= None
|
||||||
|
|
||||||
|
trans.teacache_cache_device = "cuda" if profile==3 or profile==1 else "cpu"
|
||||||
|
|
||||||
video_no += 1
|
video_no += 1
|
||||||
status = f"Video {video_no}/{total_video}"
|
status = f"Video {video_no}/{total_video}"
|
||||||
@ -799,6 +803,7 @@ def generate_video(
|
|||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
wan_model._interrupt = False
|
||||||
try:
|
try:
|
||||||
if use_image2video:
|
if use_image2video:
|
||||||
samples = wan_model.generate(
|
samples = wan_model.generate(
|
||||||
@ -858,6 +863,9 @@ def generate_video(
|
|||||||
else:
|
else:
|
||||||
raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
|
raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
|
||||||
|
|
||||||
|
if trans.enable_teacache:
|
||||||
|
trans.previous_residual_uncond = None
|
||||||
|
trans.previous_residual_cond = None
|
||||||
|
|
||||||
if samples != None:
|
if samples != None:
|
||||||
samples = samples.to("cpu")
|
samples = samples.to("cpu")
|
||||||
@ -874,7 +882,10 @@ def generate_video(
|
|||||||
# video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c")
|
# video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c")
|
||||||
|
|
||||||
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
|
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
|
||||||
file_name = f"{time_flag}_seed{seed}_{prompt[:100].replace('/','').strip()}.mp4".replace(':',' ').replace('\\',' ')
|
if os.name == 'nt':
|
||||||
|
file_name = f"{time_flag}_seed{seed}_{prompt[:50].replace('/','').strip()}.mp4".replace(':',' ').replace('\\',' ')
|
||||||
|
else:
|
||||||
|
file_name = f"{time_flag}_seed{seed}_{prompt[:100].replace('/','').strip()}.mp4".replace(':',' ').replace('\\',' ')
|
||||||
video_path = os.path.join(os.getcwd(), "gradio_outputs", file_name)
|
video_path = os.path.join(os.getcwd(), "gradio_outputs", file_name)
|
||||||
cache_video(
|
cache_video(
|
||||||
tensor=sample[None],
|
tensor=sample[None],
|
||||||
@ -1189,14 +1200,16 @@ def create_demo():
|
|||||||
flow_shift = gr.Slider(0.0, 25.0, value= default_flow_shift, step=0.1, label="Shift Scale")
|
flow_shift = gr.Slider(0.0, 25.0, value= default_flow_shift, step=0.1, label="Shift Scale")
|
||||||
tea_cache_setting = gr.Dropdown(
|
tea_cache_setting = gr.Dropdown(
|
||||||
choices=[
|
choices=[
|
||||||
("Disabled", 0),
|
("Tea Cache Disabled", 0),
|
||||||
("Fast (x1.6 speed up)", 0.1),
|
("0.03 (around x1.6 speed up)", 0.03),
|
||||||
("Faster (x2.1 speed up)", 0.15),
|
("0.05 (around x2 speed up)", 0.05),
|
||||||
|
("0.10 (around x3 speed up)", 0.1),
|
||||||
],
|
],
|
||||||
value=default_tea_cache,
|
value=default_tea_cache,
|
||||||
visible=False,
|
visible=True,
|
||||||
label="Tea Cache acceleration (the faster the acceleration the higher the degradation of the quality of the video. Consumes VRAM)"
|
label="Tea Cache Threshold to Skip Steps (the higher, the more steps are skipped but the lower the quality of the video (Tea Cache Consumes VRAM)"
|
||||||
)
|
)
|
||||||
|
tea_cache_start_step_perc = gr.Slider(2, 100, value=20, step=1, label="Tea Cache starting moment in percentage of generation (the later, the higher the quality but also the lower the speed gain)")
|
||||||
|
|
||||||
RIFLEx_setting = gr.Dropdown(
|
RIFLEx_setting = gr.Dropdown(
|
||||||
choices=[
|
choices=[
|
||||||
@ -1241,6 +1254,7 @@ def create_demo():
|
|||||||
embedded_guidance_scale,
|
embedded_guidance_scale,
|
||||||
repeat_generation,
|
repeat_generation,
|
||||||
tea_cache_setting,
|
tea_cache_setting,
|
||||||
|
tea_cache_start_step_perc,
|
||||||
loras_choices,
|
loras_choices,
|
||||||
loras_mult_choices,
|
loras_mult_choices,
|
||||||
image_to_continue,
|
image_to_continue,
|
||||||
|
|||||||
@ -316,7 +316,6 @@ class WanI2V:
|
|||||||
if callback != None:
|
if callback != None:
|
||||||
callback(-1, None)
|
callback(-1, None)
|
||||||
|
|
||||||
self._interrupt = False
|
|
||||||
for i, t in enumerate(tqdm(timesteps)):
|
for i, t in enumerate(tqdm(timesteps)):
|
||||||
latent_model_input = [latent.to(self.device)]
|
latent_model_input = [latent.to(self.device)]
|
||||||
timestep = [t]
|
timestep = [t]
|
||||||
@ -324,13 +323,13 @@ class WanI2V:
|
|||||||
timestep = torch.stack(timestep).to(self.device)
|
timestep = torch.stack(timestep).to(self.device)
|
||||||
|
|
||||||
noise_pred_cond = self.model(
|
noise_pred_cond = self.model(
|
||||||
latent_model_input, t=timestep, **arg_c)[0]
|
latent_model_input, t=timestep, current_step=i, is_uncond = False, **arg_c)[0]
|
||||||
if self._interrupt:
|
if self._interrupt:
|
||||||
return None
|
return None
|
||||||
if offload_model:
|
if offload_model:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
noise_pred_uncond = self.model(
|
noise_pred_uncond = self.model(
|
||||||
latent_model_input, t=timestep, **arg_null)[0]
|
latent_model_input, t=timestep, current_step=i, is_uncond = True, **arg_null)[0]
|
||||||
if self._interrupt:
|
if self._interrupt:
|
||||||
return None
|
return None
|
||||||
del latent_model_input
|
del latent_model_input
|
||||||
|
|||||||
@ -146,6 +146,11 @@ def rope_apply(x, grid_sizes, freqs):
|
|||||||
output.append(x_i)
|
output.append(x_i)
|
||||||
return torch.stack(output) #.float()
|
return torch.stack(output) #.float()
|
||||||
|
|
||||||
|
def relative_l1_distance(last_tensor, current_tensor):
|
||||||
|
l1_distance = torch.abs(last_tensor - current_tensor).mean()
|
||||||
|
norm = torch.abs(last_tensor).mean()
|
||||||
|
relative_l1_distance = l1_distance / norm
|
||||||
|
return relative_l1_distance.to(torch.float32)
|
||||||
|
|
||||||
class WanRMSNorm(nn.Module):
|
class WanRMSNorm(nn.Module):
|
||||||
|
|
||||||
@ -662,6 +667,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
|
|
||||||
return freqs
|
return freqs
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
@ -672,6 +678,8 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
y=None,
|
y=None,
|
||||||
freqs = None,
|
freqs = None,
|
||||||
pipeline = None,
|
pipeline = None,
|
||||||
|
current_step = 0,
|
||||||
|
is_uncond=False
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Forward pass through the diffusion model
|
Forward pass through the diffusion model
|
||||||
@ -723,7 +731,6 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
e = self.time_embedding(
|
e = self.time_embedding(
|
||||||
sinusoidal_embedding_1d(self.freq_dim, t))
|
sinusoidal_embedding_1d(self.freq_dim, t))
|
||||||
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(torch.bfloat16)
|
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(torch.bfloat16)
|
||||||
# assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
|
||||||
|
|
||||||
# context
|
# context
|
||||||
context_lens = None
|
context_lens = None
|
||||||
@ -737,21 +744,71 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
if clip_fea is not None:
|
if clip_fea is not None:
|
||||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||||
context = torch.concat([context_clip, context], dim=1)
|
context = torch.concat([context_clip, context], dim=1)
|
||||||
|
# deepbeepmeep optimization of kijai's implementation (https://github.com/kijai/ComfyUI-WanVideoWrapper/) of teacache (https://github.com/ali-vilab/TeaCache)
|
||||||
|
should_calc = True
|
||||||
|
if self.enable_teacache and current_step >= self.teacache_start_step:
|
||||||
|
if current_step == self.teacache_start_step:
|
||||||
|
self.accumulated_rel_l1_distance_cond = 0
|
||||||
|
self.accumulated_rel_l1_distance_uncond = 0
|
||||||
|
self.teacache_skipped_cond_steps = 0
|
||||||
|
self.teacache_skipped_uncond_steps = 0
|
||||||
|
else:
|
||||||
|
prev_input = self.previous_modulated_input_uncond if is_uncond else self.previous_modulated_input_cond
|
||||||
|
acc_distance_attr = 'accumulated_rel_l1_distance_uncond' if is_uncond else 'accumulated_rel_l1_distance_cond'
|
||||||
|
|
||||||
# arguments
|
temb_relative_l1 = relative_l1_distance(prev_input, e0)
|
||||||
kwargs = dict(
|
setattr(self, acc_distance_attr, getattr(self, acc_distance_attr) + temb_relative_l1)
|
||||||
e=e0,
|
|
||||||
seq_lens=seq_lens,
|
|
||||||
grid_sizes=grid_sizes,
|
|
||||||
freqs=freqs,
|
|
||||||
context=context,
|
|
||||||
context_lens=context_lens)
|
|
||||||
|
|
||||||
for block in self.blocks:
|
if getattr(self, acc_distance_attr) < self.rel_l1_thresh:
|
||||||
if pipeline._interrupt:
|
should_calc = False
|
||||||
return [None]
|
self.teacache_counter += 1
|
||||||
|
else:
|
||||||
|
should_calc = True
|
||||||
|
setattr(self, acc_distance_attr, 0)
|
||||||
|
|
||||||
x = block(x, **kwargs)
|
if is_uncond:
|
||||||
|
self.previous_modulated_input_uncond = e0.clone()
|
||||||
|
if should_calc:
|
||||||
|
self.previous_residual_uncond = None
|
||||||
|
else:
|
||||||
|
x += self.previous_residual_uncond
|
||||||
|
self.teacache_skipped_cond_steps += 1
|
||||||
|
# print(f"Skipped uncond:{self.teacache_skipped_cond_steps}/{current_step}" )
|
||||||
|
else:
|
||||||
|
self.previous_modulated_input_cond = e0.clone()
|
||||||
|
if should_calc:
|
||||||
|
self.previous_residual_cond = None
|
||||||
|
else:
|
||||||
|
x += self.previous_residual_cond
|
||||||
|
self.teacache_skipped_uncond_steps += 1
|
||||||
|
# print(f"Skipped uncond:{self.teacache_skipped_uncond_steps}/{current_step}" )
|
||||||
|
|
||||||
|
if should_calc:
|
||||||
|
if self.enable_teacache:
|
||||||
|
ori_hidden_states = x.clone()
|
||||||
|
# arguments
|
||||||
|
kwargs = dict(
|
||||||
|
e=e0,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
grid_sizes=grid_sizes,
|
||||||
|
freqs=freqs,
|
||||||
|
context=context,
|
||||||
|
context_lens=context_lens)
|
||||||
|
|
||||||
|
for block in self.blocks:
|
||||||
|
if pipeline._interrupt:
|
||||||
|
return [None]
|
||||||
|
|
||||||
|
x = block(x, **kwargs)
|
||||||
|
|
||||||
|
if self.enable_teacache:
|
||||||
|
residual = ori_hidden_states # just to have a readable code
|
||||||
|
torch.sub(x, ori_hidden_states, out=residual)
|
||||||
|
if is_uncond:
|
||||||
|
self.previous_residual_uncond = residual
|
||||||
|
else:
|
||||||
|
self.previous_residual_cond = residual
|
||||||
|
del residual, ori_hidden_states
|
||||||
|
|
||||||
# head
|
# head
|
||||||
x = self.head(x, e)
|
x = self.head(x, e)
|
||||||
|
|||||||
@ -248,7 +248,6 @@ class WanT2V:
|
|||||||
|
|
||||||
if callback != None:
|
if callback != None:
|
||||||
callback(-1, None)
|
callback(-1, None)
|
||||||
self._interrupt = False
|
|
||||||
for i, t in enumerate(tqdm(timesteps)):
|
for i, t in enumerate(tqdm(timesteps)):
|
||||||
latent_model_input = latents
|
latent_model_input = latents
|
||||||
timestep = [t]
|
timestep = [t]
|
||||||
@ -257,11 +256,11 @@ class WanT2V:
|
|||||||
|
|
||||||
# self.model.to(self.device)
|
# self.model.to(self.device)
|
||||||
noise_pred_cond = self.model(
|
noise_pred_cond = self.model(
|
||||||
latent_model_input, t=timestep, **arg_c)[0]
|
latent_model_input, t=timestep,current_step=i, is_uncond = False, **arg_c)[0]
|
||||||
if self._interrupt:
|
if self._interrupt:
|
||||||
return None
|
return None
|
||||||
noise_pred_uncond = self.model(
|
noise_pred_uncond = self.model(
|
||||||
latent_model_input, t=timestep, **arg_null)[0]
|
latent_model_input, t=timestep,current_step=i, is_uncond = True, **arg_null)[0]
|
||||||
if self._interrupt:
|
if self._interrupt:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user