From ec1159bb59ba70e5f07536bf678e8aa82b5abe05 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Mon, 3 Mar 2025 18:41:33 +0100 Subject: [PATCH] Added TeaCache support --- README.md | 5 +-- gradio_server.py | 40 ++++++++++++++------- wan/image2video.py | 5 ++- wan/modules/model.py | 83 +++++++++++++++++++++++++++++++++++++------- wan/text2video.py | 5 ++- 5 files changed, 104 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index c95eae8..21cf2d2 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,8 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid ## 🔥 Latest News!! -* Mar 03, 2025: 👋 Wan2.1GP by DeepBeepMeep v1 brings: +* Mar 03, 2025: 👋 Wan2.1GP v1.1: added Tea Cache support for faster generations: optimization of kijai's implementation (https://github.com/kijai/ComfyUI-WanVideoWrapper/) of teacache (https://github.com/ali-vilab/TeaCache) +* Mar 02, 2025: 👋 Wan2.1GP by DeepBeepMeep v1 brings: - Support for all Wan including the Image to Video model - Reduced memory consumption by 2, with possiblity to generate more than 10s of video at 720p with a RTX 4090 and 10s of video at 480p with less than 12GB of VRAM. Many thanks to REFLEx (https://github.com/thu-ml/RIFLEx) for their algorithm that allows generating nice looking video longer than 5s. - The usual perks: web interface, multiple generations, loras support, sage attebtion, auto download of models, ... @@ -162,7 +163,7 @@ You will find prebuilt Loras on https://civitai.com/ or you will be able to buil ### Profiles (for power users only) You can choose between 5 profiles, but two are really relevant here : -- LowRAM_HighVRAM (3): loads entirely the model in VRAM, slighty faster, but less VRAM +- LowRAM_HighVRAM (3): loads entirely the model in VRAM, slightly faster, but less VRAM - LowRAM_LowVRAM (4): load only the part of the models that is needed, low VRAM and low RAM requirement but slightly slower diff --git a/gradio_server.py b/gradio_server.py index 0f07406..7e88037 100644 --- a/gradio_server.py +++ b/gradio_server.py @@ -19,7 +19,7 @@ from wan.modules.attention import get_attention_modes import torch import gc import traceback - +import math def _parse_args(): parser = argparse.ArgumentParser( @@ -650,6 +650,7 @@ def generate_video( embedded_guidance_scale, repeat_generation, tea_cache, + tea_cache_start_step_perc, loras_choices, loras_mult_choices, image_to_continue, @@ -783,12 +784,15 @@ def generate_video( break if trans.enable_teacache: - trans.num_steps = num_inference_steps - trans.cnt = 0 - trans.rel_l1_thresh = tea_cache #0.15 # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup - trans.accumulated_rel_l1_distance = 0 - trans.previous_modulated_input = None - trans.previous_residual = None + trans.teacache_counter = 0 + trans.rel_l1_thresh = tea_cache + trans.teacache_start_step = max(math.ceil(tea_cache_start_step_perc*num_inference_steps/100),2) + trans.previous_residual_uncond = None + trans.previous_modulated_input_uncond = None + trans.previous_residual_cond = None + trans.previous_modulated_input_cond= None + + trans.teacache_cache_device = "cuda" if profile==3 or profile==1 else "cpu" video_no += 1 status = f"Video {video_no}/{total_video}" @@ -799,6 +803,7 @@ def generate_video( gc.collect() torch.cuda.empty_cache() + wan_model._interrupt = False try: if use_image2video: samples = wan_model.generate( @@ -858,6 +863,9 @@ def generate_video( else: raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'") + if trans.enable_teacache: + trans.previous_residual_uncond = None + trans.previous_residual_cond = None if samples != None: samples = samples.to("cpu") @@ -874,7 +882,10 @@ def generate_video( # video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c") time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss") - file_name = f"{time_flag}_seed{seed}_{prompt[:100].replace('/','').strip()}.mp4".replace(':',' ').replace('\\',' ') + if os.name == 'nt': + file_name = f"{time_flag}_seed{seed}_{prompt[:50].replace('/','').strip()}.mp4".replace(':',' ').replace('\\',' ') + else: + file_name = f"{time_flag}_seed{seed}_{prompt[:100].replace('/','').strip()}.mp4".replace(':',' ').replace('\\',' ') video_path = os.path.join(os.getcwd(), "gradio_outputs", file_name) cache_video( tensor=sample[None], @@ -1189,14 +1200,16 @@ def create_demo(): flow_shift = gr.Slider(0.0, 25.0, value= default_flow_shift, step=0.1, label="Shift Scale") tea_cache_setting = gr.Dropdown( choices=[ - ("Disabled", 0), - ("Fast (x1.6 speed up)", 0.1), - ("Faster (x2.1 speed up)", 0.15), + ("Tea Cache Disabled", 0), + ("0.03 (around x1.6 speed up)", 0.03), + ("0.05 (around x2 speed up)", 0.05), + ("0.10 (around x3 speed up)", 0.1), ], value=default_tea_cache, - visible=False, - label="Tea Cache acceleration (the faster the acceleration the higher the degradation of the quality of the video. Consumes VRAM)" + visible=True, + label="Tea Cache Threshold to Skip Steps (the higher, the more steps are skipped but the lower the quality of the video (Tea Cache Consumes VRAM)" ) + tea_cache_start_step_perc = gr.Slider(2, 100, value=20, step=1, label="Tea Cache starting moment in percentage of generation (the later, the higher the quality but also the lower the speed gain)") RIFLEx_setting = gr.Dropdown( choices=[ @@ -1241,6 +1254,7 @@ def create_demo(): embedded_guidance_scale, repeat_generation, tea_cache_setting, + tea_cache_start_step_perc, loras_choices, loras_mult_choices, image_to_continue, diff --git a/wan/image2video.py b/wan/image2video.py index 6755795..2f9aeea 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -316,7 +316,6 @@ class WanI2V: if callback != None: callback(-1, None) - self._interrupt = False for i, t in enumerate(tqdm(timesteps)): latent_model_input = [latent.to(self.device)] timestep = [t] @@ -324,13 +323,13 @@ class WanI2V: timestep = torch.stack(timestep).to(self.device) noise_pred_cond = self.model( - latent_model_input, t=timestep, **arg_c)[0] + latent_model_input, t=timestep, current_step=i, is_uncond = False, **arg_c)[0] if self._interrupt: return None if offload_model: torch.cuda.empty_cache() noise_pred_uncond = self.model( - latent_model_input, t=timestep, **arg_null)[0] + latent_model_input, t=timestep, current_step=i, is_uncond = True, **arg_null)[0] if self._interrupt: return None del latent_model_input diff --git a/wan/modules/model.py b/wan/modules/model.py index a8b5c74..e0afa99 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -146,6 +146,11 @@ def rope_apply(x, grid_sizes, freqs): output.append(x_i) return torch.stack(output) #.float() +def relative_l1_distance(last_tensor, current_tensor): + l1_distance = torch.abs(last_tensor - current_tensor).mean() + norm = torch.abs(last_tensor).mean() + relative_l1_distance = l1_distance / norm + return relative_l1_distance.to(torch.float32) class WanRMSNorm(nn.Module): @@ -662,6 +667,7 @@ class WanModel(ModelMixin, ConfigMixin): return freqs + def forward( self, x, @@ -672,6 +678,8 @@ class WanModel(ModelMixin, ConfigMixin): y=None, freqs = None, pipeline = None, + current_step = 0, + is_uncond=False ): r""" Forward pass through the diffusion model @@ -723,7 +731,6 @@ class WanModel(ModelMixin, ConfigMixin): e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t)) e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(torch.bfloat16) - # assert e.dtype == torch.float32 and e0.dtype == torch.float32 # context context_lens = None @@ -737,21 +744,71 @@ class WanModel(ModelMixin, ConfigMixin): if clip_fea is not None: context_clip = self.img_emb(clip_fea) # bs x 257 x dim context = torch.concat([context_clip, context], dim=1) + # deepbeepmeep optimization of kijai's implementation (https://github.com/kijai/ComfyUI-WanVideoWrapper/) of teacache (https://github.com/ali-vilab/TeaCache) + should_calc = True + if self.enable_teacache and current_step >= self.teacache_start_step: + if current_step == self.teacache_start_step: + self.accumulated_rel_l1_distance_cond = 0 + self.accumulated_rel_l1_distance_uncond = 0 + self.teacache_skipped_cond_steps = 0 + self.teacache_skipped_uncond_steps = 0 + else: + prev_input = self.previous_modulated_input_uncond if is_uncond else self.previous_modulated_input_cond + acc_distance_attr = 'accumulated_rel_l1_distance_uncond' if is_uncond else 'accumulated_rel_l1_distance_cond' - # arguments - kwargs = dict( - e=e0, - seq_lens=seq_lens, - grid_sizes=grid_sizes, - freqs=freqs, - context=context, - context_lens=context_lens) + temb_relative_l1 = relative_l1_distance(prev_input, e0) + setattr(self, acc_distance_attr, getattr(self, acc_distance_attr) + temb_relative_l1) - for block in self.blocks: - if pipeline._interrupt: - return [None] + if getattr(self, acc_distance_attr) < self.rel_l1_thresh: + should_calc = False + self.teacache_counter += 1 + else: + should_calc = True + setattr(self, acc_distance_attr, 0) + + if is_uncond: + self.previous_modulated_input_uncond = e0.clone() + if should_calc: + self.previous_residual_uncond = None + else: + x += self.previous_residual_uncond + self.teacache_skipped_cond_steps += 1 + # print(f"Skipped uncond:{self.teacache_skipped_cond_steps}/{current_step}" ) + else: + self.previous_modulated_input_cond = e0.clone() + if should_calc: + self.previous_residual_cond = None + else: + x += self.previous_residual_cond + self.teacache_skipped_uncond_steps += 1 + # print(f"Skipped uncond:{self.teacache_skipped_uncond_steps}/{current_step}" ) - x = block(x, **kwargs) + if should_calc: + if self.enable_teacache: + ori_hidden_states = x.clone() + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=freqs, + context=context, + context_lens=context_lens) + + for block in self.blocks: + if pipeline._interrupt: + return [None] + + x = block(x, **kwargs) + + if self.enable_teacache: + residual = ori_hidden_states # just to have a readable code + torch.sub(x, ori_hidden_states, out=residual) + if is_uncond: + self.previous_residual_uncond = residual + else: + self.previous_residual_cond = residual + del residual, ori_hidden_states # head x = self.head(x, e) diff --git a/wan/text2video.py b/wan/text2video.py index a46c2d6..2f1d9d0 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -248,7 +248,6 @@ class WanT2V: if callback != None: callback(-1, None) - self._interrupt = False for i, t in enumerate(tqdm(timesteps)): latent_model_input = latents timestep = [t] @@ -257,11 +256,11 @@ class WanT2V: # self.model.to(self.device) noise_pred_cond = self.model( - latent_model_input, t=timestep, **arg_c)[0] + latent_model_input, t=timestep,current_step=i, is_uncond = False, **arg_c)[0] if self._interrupt: return None noise_pred_uncond = self.model( - latent_model_input, t=timestep, **arg_null)[0] + latent_model_input, t=timestep,current_step=i, is_uncond = True, **arg_null)[0] if self._interrupt: return None