mag cache support for Wan 2.2

This commit is contained in:
deepbeepmeep 2025-08-09 20:12:51 +02:00
parent b8d843e7db
commit 8e300f791a
15 changed files with 362 additions and 239 deletions

View File

@ -20,6 +20,11 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
## 🔥 Latest Updates : ## 🔥 Latest Updates :
### August 8 2025: WanGP v7.74 - Qwen Rebirth part 2
Added support for Qwen Lightning lora for a 8 steps generation (https://huggingface.co/lightx2v/Qwen-Image-Lightning/blob/main/Qwen-Image-Lightning-8steps-V1.0.safetensors). Lora is not normalized and you can use a multiplier around 0.1.
Mag Cache support for all the Wan2.2 models
### August 8 2025: WanGP v7.73 - Qwen Rebirth ### August 8 2025: WanGP v7.73 - Qwen Rebirth
Ever wondered what impact not using Guidance has on a model that expects it ? Just look at Qween Image in WanGP 7.71 whose outputs were erratic. Somehow I had convinced myself that Qwen was a distilled model. In fact Qwen was dying for a negative prompt. And in WanGP 7.72 there is at last one for him. Ever wondered what impact not using Guidance has on a model that expects it ? Just look at Qween Image in WanGP 7.71 whose outputs were erratic. Somehow I had convinced myself that Qwen was a distilled model. In fact Qwen was dying for a negative prompt. And in WanGP 7.72 there is at last one for him.

View File

@ -5,8 +5,7 @@
"architecture" : "fantasy", "architecture" : "fantasy",
"modules": ["fantasy"], "modules": ["fantasy"],
"description": "The Fantasy Talking model corresponds to the original Wan image 2 video model combined with the Fantasy Speaking module to process an audio Input.", "description": "The Fantasy Talking model corresponds to the original Wan image 2 video model combined with the Fantasy Speaking module to process an audio Input.",
"URLs": "i2v_720p", "URLs": "i2v_720p"
"teacache_coefficients" : [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
}, },
"resolution": "1280x720" "resolution": "1280x720"
} }

View File

@ -949,15 +949,18 @@ class HunyuanVideoPipeline(DiffusionPipeline):
# width = width or self.transformer.config.sample_size * self.vae_scale_factor # width = width or self.transformer.config.sample_size * self.vae_scale_factor
# to deal with lora scaling and other possible forward hooks # to deal with lora scaling and other possible forward hooks
trans = self.transformer trans = self.transformer
if trans.enable_cache == "tea": skip_steps_cache = trans.cache
teacache_multiplier = trans.cache_multiplier if skip_steps_cache != None:
trans.accumulated_rel_l1_distance = 0 cache_type = skip_steps_cache.cache_type
trans.rel_l1_thresh = 0.1 if teacache_multiplier < 2 else 0.15 if cache_type == "tea":
elif trans.enable_cache == "mag": teacache_multiplier = skip_steps_cache.multiplier
trans.compute_magcache_threshold(trans.cache_start_step, num_inference_steps, trans.cache_multiplier) skip_steps_cache.accumulated_rel_l1_distance = 0
trans.accumulated_err, trans.accumulated_steps, trans.accumulated_ratio = 0, 0, 1.0 skip_steps_cache.rel_l1_thresh = 0.1 if teacache_multiplier < 2 else 0.15
else: elif cache_type== "mag":
trans.enable_cache == None trans.compute_magcache_threshold(skip_steps_cache.start_step, num_inference_steps, skip_steps_cache.multiplier)
skip_steps_cache.accumulated_err, skip_steps_cache.accumulated_steps, skip_steps_cache.accumulated_ratio = 0, 0, 1.0
else:
trans.cache = None
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs( self.check_inputs(
prompt, prompt,
@ -1212,8 +1215,8 @@ class HunyuanVideoPipeline(DiffusionPipeline):
if ip_cfg_scale>0: if ip_cfg_scale>0:
latent_items += 1 latent_items += 1
if self.transformer.enable_cache: if skip_steps_cache != None:
self.transformer.previous_residual = [None] * latent_items skip_steps_cache.previous_residual = [None] * latent_items
# if is_progress_bar: # if is_progress_bar:
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:

View File

@ -934,15 +934,20 @@ class HunyuanVideoAudioPipeline(DiffusionPipeline):
transformer = self.transformer transformer = self.transformer
if transformer.enable_cache == "tea": skip_steps_cache = transformer.cache
teacache_multiplier = transformer.cache_multiplier cache_type = None
transformer.accumulated_rel_l1_distance = 0 if skip_steps_cache != None:
transformer.rel_l1_thresh = 0.1 if teacache_multiplier < 2 else 0.15 cache_type = skip_steps_cache.cache_type
elif transformer.enable_cache == "mag": if cache_type == "tea":
transformer.compute_magcache_threshold(transformer.cache_start_step, num_inference_steps, transformer.cache_multiplier) teacache_multiplier = skip_steps_cache.multiplier
transformer.accumulated_err, transformer.accumulated_steps, transformer.accumulated_ratio = 0, 0, 1.0 skip_steps_cache.accumulated_rel_l1_distance = 0
else: skip_steps_cache.rel_l1_thresh = 0.1 if teacache_multiplier < 2 else 0.15
transformer.enable_cache == None elif cache_type == "mag":
transformer.compute_magcache_threshold(skip_steps_cache.start_step, num_inference_steps, skip_steps_cache.multiplier)
skip_steps_cache.accumulated_err, skip_steps_cache.accumulated_steps, skip_steps_cache.accumulated_ratio = 0, 0, 1.0
else:
transformer.cache = None
cache_type = None
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs( self.check_inputs(
@ -1141,16 +1146,16 @@ class HunyuanVideoAudioPipeline(DiffusionPipeline):
if self._interrupt: if self._interrupt:
return [None] return [None]
if transformer.enable_cache == "tea": if cache_type == "tea":
cache_size = round( infer_length / frames_per_batch ) cache_size = round( infer_length / frames_per_batch )
transformer.previous_residual = [None] * latent_items skip_steps_cache.previous_residual = [None] * latent_items
cache_all_previous_residual = [None] * latent_items cache_all_previous_residual = [None] * latent_items
cache_all_previous_modulated_input = None cache_all_previous_modulated_input = None
cache_should_calc = [True] * cache_size cache_should_calc = [True] * cache_size
cache_accumulated_rel_l1_distance = [0.] * cache_size cache_accumulated_rel_l1_distance = [0.] * cache_size
cache_teacache_skipped_steps = [0] * cache_size cache_teacache_skipped_steps = [0] * cache_size
elif transformer.enable_cache == "mag": elif cache_type == "mag":
transformer.previous_residual = [None] * latent_items skip_steps_cache.previous_residual = [None] * latent_items
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
@ -1187,16 +1192,16 @@ class HunyuanVideoAudioPipeline(DiffusionPipeline):
img_ref_len = (latent_model_input.shape[-1] // 2) * (latent_model_input.shape[-2] // 2) * ( 1) img_ref_len = (latent_model_input.shape[-1] // 2) * (latent_model_input.shape[-2] // 2) * ( 1)
img_all_len = (latents_all.shape[-1] // 2) * (latents_all.shape[-2] // 2) * latents_all.shape[-3] img_all_len = (latents_all.shape[-1] // 2) * (latents_all.shape[-2] // 2) * latents_all.shape[-3]
if transformer.enable_cache == "tea" and cache_size > 1: if cache_type == "tea" and cache_size > 1:
for l in range(latent_items): for l in range(latent_items):
if cache_all_previous_residual[l] != None: if cache_all_previous_residual[l] != None:
bsz = cache_all_previous_residual[l].shape[0] bsz = cache_all_previous_residual[l].shape[0]
transformer.previous_residual[l][:, img_ref_len:] = cache_all_previous_residual[l].reshape(1, -1, embedded_hw) [:, idx_list].reshape(1, -1, 3072) skip_steps_cache.previous_residual[l][:, img_ref_len:] = cache_all_previous_residual[l].reshape(1, -1, embedded_hw) [:, idx_list].reshape(1, -1, 3072)
if cache_all_previous_modulated_input != None: if cache_all_previous_modulated_input != None:
transformer.previous_modulated_input[:, img_ref_len:] = cache_all_previous_modulated_input.reshape(1, -1, embedded_hw) [:, idx_list].reshape(1, -1, 3072) skip_steps_cache.previous_modulated_input[:, img_ref_len:] = cache_all_previous_modulated_input.reshape(1, -1, embedded_hw) [:, idx_list].reshape(1, -1, 3072)
transformer.should_calc = cache_should_calc[cache_slot_no] skip_steps_cache.should_calc = cache_should_calc[cache_slot_no]
transformer.accumulated_rel_l1_distance = cache_accumulated_rel_l1_distance[cache_slot_no] skip_steps_cache.accumulated_rel_l1_distance = cache_accumulated_rel_l1_distance[cache_slot_no]
transformer.teacache_skipped_steps = cache_teacache_skipped_steps[cache_slot_no] skip_steps_cache.teacache_skipped_steps = cache_teacache_skipped_steps[cache_slot_no]
if self.do_classifier_free_guidance: if self.do_classifier_free_guidance:
@ -1304,21 +1309,21 @@ class HunyuanVideoAudioPipeline(DiffusionPipeline):
pred_latents[:, :, p] += latents[:, :, iii] pred_latents[:, :, p] += latents[:, :, iii]
counter[:, :, p] += 1 counter[:, :, p] += 1
if transformer.enable_cache == "tea" and cache_size > 1: if cache_type == "tea" and cache_size > 1:
for l in range(latent_items): for l in range(latent_items):
if transformer.previous_residual[l] != None: if skip_steps_cache.previous_residual[l] != None:
bsz = transformer.previous_residual[l].shape[0] bsz = skip_steps_cache.previous_residual[l].shape[0]
if cache_all_previous_residual[l] == None: if cache_all_previous_residual[l] == None:
cache_all_previous_residual[l] = torch.zeros((bsz, img_all_len, 3072 ), device=transformer.previous_residual[l].device, dtype=transformer.previous_residual[l].dtype) cache_all_previous_residual[l] = torch.zeros((bsz, img_all_len, 3072 ), device=skip_steps_cache.previous_residual[l].device, dtype=skip_steps_cache.previous_residual[l].dtype)
cache_all_previous_residual[l].reshape(bsz, -1, embedded_hw)[:, idx_list] = transformer.previous_residual[l][:, img_ref_len:].reshape(bsz, -1, embedded_hw) cache_all_previous_residual[l].reshape(bsz, -1, embedded_hw)[:, idx_list] = skip_steps_cache.previous_residual[l][:, img_ref_len:].reshape(bsz, -1, embedded_hw)
if transformer.previous_modulated_input != None: if skip_steps_cache.previous_modulated_input != None:
if cache_all_previous_modulated_input == None: if cache_all_previous_modulated_input == None:
cache_all_previous_modulated_input = torch.zeros((1, img_all_len, 3072 ), device=transformer.previous_modulated_input.device, dtype=transformer.previous_modulated_input.dtype) cache_all_previous_modulated_input = torch.zeros((1, img_all_len, 3072 ), device=skip_steps_cache.previous_modulated_input.device, dtype=skip_steps_cache.previous_modulated_input.dtype)
cache_all_previous_modulated_input.reshape(1, -1, embedded_hw)[:, idx_list] = transformer.previous_modulated_input[:, img_ref_len:].reshape(1, -1, embedded_hw) cache_all_previous_modulated_input.reshape(1, -1, embedded_hw)[:, idx_list] = skip_steps_cache.previous_modulated_input[:, img_ref_len:].reshape(1, -1, embedded_hw)
cache_should_calc[cache_slot_no] = transformer.should_calc cache_should_calc[cache_slot_no] = skip_steps_cache.should_calc
cache_accumulated_rel_l1_distance[cache_slot_no] = transformer.accumulated_rel_l1_distance cache_accumulated_rel_l1_distance[cache_slot_no] = skip_steps_cache.accumulated_rel_l1_distance
cache_teacache_skipped_steps[cache_slot_no] = transformer.teacache_skipped_steps cache_teacache_skipped_steps[cache_slot_no] = skip_steps_cache.teacache_skipped_steps
cache_slot_no += 1 cache_slot_no += 1

View File

@ -9,6 +9,25 @@ def get_hunyuan_text_encoder_filename(text_encoder_quantization):
return text_encoder_filename return text_encoder_filename
class family_handler(): class family_handler():
@staticmethod
def set_cache_parameters(cache_type, base_model_type, model_def, inputs, skip_steps_cache):
resolution = inputs["resolution"]
width, height = resolution.split("x")
pixels = int(width) * int(height)
if cache_type == "mag":
skip_steps_cache.update({
"magcache_thresh" : 0,
"magcache_K" : 2,
})
if pixels >= 1280* 720:
skip_steps_cache.def_mag_ratios = [1.0754, 1.27807, 1.11596, 1.09504, 1.05188, 1.00844, 1.05779, 1.00657, 1.04142, 1.03101, 1.00679, 1.02556, 1.00908, 1.06949, 1.05438, 1.02214, 1.02321, 1.03019, 1.00779, 1.03381, 1.01886, 1.01161, 1.02968, 1.00544, 1.02822, 1.00689, 1.02119, 1.0105, 1.01044, 1.01572, 1.02972, 1.0094, 1.02368, 1.0226, 0.98965, 1.01588, 1.02146, 1.0018, 1.01687, 0.99436, 1.00283, 1.01139, 0.97122, 0.98251, 0.94513, 0.97656, 0.90943, 0.85703, 0.75456]
else:
skip_steps_cache.def_mag_ratios = [1.06971, 1.29073, 1.11245, 1.09596, 1.05233, 1.01415, 1.05672, 1.00848, 1.03632, 1.02974, 1.00984, 1.03028, 1.00681, 1.06614, 1.05022, 1.02592, 1.01776, 1.02985, 1.00726, 1.03727, 1.01502, 1.00992, 1.03371, 0.9976, 1.02742, 1.0093, 1.01869, 1.00815, 1.01461, 1.01152, 1.03082, 1.0061, 1.02162, 1.01999, 0.99063, 1.01186, 1.0217, 0.99947, 1.01711, 0.9904, 1.00258, 1.00878, 0.97039, 0.97686, 0.94315, 0.97728, 0.91154, 0.86139, 0.76592]
else:
skip_steps_cache.coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
@staticmethod @staticmethod
def query_model_def(base_model_type, model_def): def query_model_def(base_model_type, model_def):
extra_model_def = {} extra_model_def = {}
@ -25,7 +44,8 @@ class family_handler():
extra_model_def["sliding_window"] = False extra_model_def["sliding_window"] = False
extra_model_def["embedded_guidance"] = base_model_type in ["hunyuan", "hunyuan_i2v"] extra_model_def["embedded_guidance"] = base_model_type in ["hunyuan", "hunyuan_i2v"]
extra_model_def["cfg_star"] = base_model_type in [ "hunyuan_avatar", "hunyuan_custom_audio", "hunyuan_custom_edit", "hunyuan_custom"] extra_model_def["cfg_star"] = base_model_type in [ "hunyuan_avatar", "hunyuan_custom_audio", "hunyuan_custom_edit", "hunyuan_custom"]
extra_model_def["skip_steps_cache"] = True extra_model_def["tea_cache"] = True
extra_model_def["mag_cache"] = True
return extra_model_def return extra_model_def
@staticmethod @staticmethod

View File

@ -794,6 +794,8 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
block.disable_deterministic() block.disable_deterministic()
def compute_magcache_threshold(self, start_step, num_inference_steps = 0, speed_factor =0): def compute_magcache_threshold(self, start_step, num_inference_steps = 0, speed_factor =0):
skips_step_cache = self.cache
def nearest_interp(src_array, target_length): def nearest_interp(src_array, target_length):
src_length = len(src_array) src_length = len(src_array)
if target_length == 1: if target_length == 1:
@ -801,11 +803,11 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
scale = (src_length - 1) / (target_length - 1) scale = (src_length - 1) / (target_length - 1)
mapped_indices = np.round(np.arange(target_length) * scale).astype(int) mapped_indices = np.round(np.arange(target_length) * scale).astype(int)
return src_array[mapped_indices] return src_array[mapped_indices]
def_mag_ratios = np.array([1.0]+ skips_step_cache.def_mag_ratios)
if len(self.def_mag_ratios) != num_inference_steps: if len(def_mag_ratios) != num_inference_steps:
self.mag_ratios = nearest_interp(self.def_mag_ratios, num_inference_steps) skips_step_cache.mag_ratios = nearest_interp(def_mag_ratios, num_inference_steps)
else: else:
self.mag_ratios = self.def_mag_ratios skips_step_cache.mag_ratios = def_mag_ratios
best_deltas = None best_deltas = None
best_threshold = 0.01 best_threshold = 0.01
@ -821,12 +823,12 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
if i<=start_step: if i<=start_step:
skip = False skip = False
else: else:
cur_mag_ratio = self.mag_ratios[i] # conditional and unconditional in one list cur_mag_ratio = skips_step_cache.mag_ratios[i] # conditional and unconditional in one list
accumulated_ratio *= cur_mag_ratio # magnitude ratio between current step and the cached step accumulated_ratio *= cur_mag_ratio # magnitude ratio between current step and the cached step
accumulated_steps += 1 # skip steps plus 1 accumulated_steps += 1 # skip steps plus 1
cur_skip_err = np.abs(1-accumulated_ratio) # skip error of current steps cur_skip_err = np.abs(1-accumulated_ratio) # skip error of current steps
accumulated_err += cur_skip_err # accumulated error of multiple steps accumulated_err += cur_skip_err # accumulated error of multiple steps
if accumulated_err<threshold and accumulated_steps<=self.magcache_K: if accumulated_err<threshold and accumulated_steps<=skips_step_cache.magcache_K:
skip = True skip = True
else: else:
skip = False skip = False
@ -842,7 +844,7 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
elif diff > best_diff: elif diff > best_diff:
break break
threshold += 0.01 threshold += 0.01
self.magcache_thresh = best_threshold skips_step_cache.magcache_thresh = best_threshold
print(f"Mag Cache, best threshold found:{best_threshold:0.2f} with gain x{num_inference_steps/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}") print(f"Mag Cache, best threshold found:{best_threshold:0.2f} with gain x{num_inference_steps/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}")
return best_threshold return best_threshold
@ -969,23 +971,24 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
attn_mask = None attn_mask = None
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
should_calc = True
skip_steps_cache = self.cache
if self.enable_cache: if skip_steps_cache is not None:
cache_type = skip_steps_cache.cache_type
if x_id == 0: if x_id == 0:
self.should_calc = True skip_steps_cache.should_calc = True
if self.enable_cache == "mag": if cache_type == "mag":
if step_no > self.cache_start_step: if step_no > skip_steps_cache.start_step:
cur_mag_ratio = self.mag_ratios[step_no] cur_mag_ratio = skip_steps_cache.mag_ratios[step_no]
self.accumulated_ratio = self.accumulated_ratio*cur_mag_ratio skip_steps_cache.accumulated_ratio = skip_steps_cache.accumulated_ratio*cur_mag_ratio
cur_skip_err = np.abs(1-self.accumulated_ratio) cur_skip_err = np.abs(1-skip_steps_cache.accumulated_ratio)
self.accumulated_err += cur_skip_err skip_steps_cache.accumulated_err += cur_skip_err
self.accumulated_steps += 1 skip_steps_cache.accumulated_steps += 1
if self.accumulated_err<=self.magcache_thresh and self.accumulated_steps<=self.magcache_K: if skip_steps_cache.accumulated_err<=skip_steps_cache.magcache_thresh and skip_steps_cache.accumulated_steps<=skip_steps_cache.magcache_K:
self.should_calc = False skip_steps_cache.should_calc = False
self.cache_skipped_steps += 1 skip_steps_cache.skipped_steps += 1
else: else:
self.accumulated_ratio, self.accumulated_steps, self.accumulated_err = 1.0, 0, 0 skip_steps_cache.accumulated_ratio, skip_steps_cache.accumulated_steps, skip_steps_cache.accumulated_err = 1.0, 0, 0
else: else:
inp = img[0:1] inp = img[0:1]
vec_ = vec[0:1] vec_ = vec[0:1]
@ -994,26 +997,24 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
normed_inp = normed_inp.to(torch.bfloat16) normed_inp = normed_inp.to(torch.bfloat16)
modulated_inp = modulate( normed_inp, shift=img_mod1_shift, scale=img_mod1_scale ) modulated_inp = modulate( normed_inp, shift=img_mod1_shift, scale=img_mod1_scale )
del normed_inp, img_mod1_shift, img_mod1_scale del normed_inp, img_mod1_shift, img_mod1_scale
if step_no <= self.cache_start_step or step_no == self.num_steps-1: if step_no <= skip_steps_cache.start_step or step_no == skip_steps_cache.num_steps-1:
self.accumulated_rel_l1_distance = 0 skip_steps_cache.accumulated_rel_l1_distance = 0
else: else:
coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02] rescale_func = np.poly1d(skip_steps_cache.coefficients)
rescale_func = np.poly1d(coefficients) skip_steps_cache.accumulated_rel_l1_distance += rescale_func(((modulated_inp-skip_steps_cache.previous_modulated_input).abs().mean() / skip_steps_cache.previous_modulated_input.abs().mean()).cpu().item())
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) if skip_steps_cache.accumulated_rel_l1_distance < skip_steps_cache.rel_l1_thresh:
if self.accumulated_rel_l1_distance < self.rel_l1_thresh: skip_steps_cache.should_calc = False
self.should_calc = False skip_steps_cache.skipped_steps += 1
self.cache_skipped_steps += 1
else: else:
self.accumulated_rel_l1_distance = 0 skip_steps_cache.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp skip_steps_cache.previous_modulated_input = modulated_inp
else: should_calc = skip_steps_cache.should_calc
self.should_calc = True
if not self.should_calc: if not should_calc:
img += self.previous_residual[x_id] img += skip_steps_cache.previous_residual[x_id]
else: else:
if self.enable_cache: if skip_steps_cache is not None:
self.previous_residual[x_id] = None skip_steps_cache.previous_residual[x_id] = None
ori_img = img[0:1].clone() ori_img = img[0:1].clone()
# --------------------- Pass through DiT blocks ------------------------ # --------------------- Pass through DiT blocks ------------------------
for layer_num, block in enumerate(self.double_blocks): for layer_num, block in enumerate(self.double_blocks):
@ -1076,10 +1077,10 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
single_block_args = None single_block_args = None
# img = x[:, :img_seq_len, ...] # img = x[:, :img_seq_len, ...]
if self.enable_cache: if skip_steps_cache is not None:
if len(img) > 1: if len(img) > 1:
self.previous_residual[0] = torch.empty_like(img) skip_steps_cache.previous_residual[0] = torch.empty_like(img)
for i, (x, residual) in enumerate(zip(img, self.previous_residual[0])): for i, (x, residual) in enumerate(zip(img, skip_steps_cache.previous_residual[0])):
if i < len(img) - 1: if i < len(img) - 1:
residual[...] = torch.sub(x, ori_img) residual[...] = torch.sub(x, ori_img)
else: else:
@ -1087,8 +1088,8 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
torch.sub(x, ori_img, out=residual) torch.sub(x, ori_img, out=residual)
x = None x = None
else: else:
self.previous_residual[x_id] = ori_img skip_steps_cache.previous_residual[x_id] = ori_img
torch.sub(img, ori_img, out=self.previous_residual[x_id]) torch.sub(img, ori_img, out=skip_steps_cache.previous_residual[x_id])
if ref_length != None: if ref_length != None:

View File

@ -102,6 +102,8 @@ class model_factory():
self.vae.tile_latent_min_height = VAE_tile_size[1] self.vae.tile_latent_min_height = VAE_tile_size[1]
self.vae.tile_latent_min_width = VAE_tile_size[1] self.vae.tile_latent_min_width = VAE_tile_size[1]
self.vae.enable_slicing()
# width, height = aspect_ratios["16:9"] # width, height = aspect_ratios["16:9"]
if n_prompt is None or len(n_prompt) == 0: if n_prompt is None or len(n_prompt) == 0:

View File

@ -467,6 +467,16 @@ class QwenImageTransformer2DModel(nn.Module):
_no_split_modules = ["QwenImageTransformerBlock"] _no_split_modules = ["QwenImageTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"] _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
def preprocess_loras(self, model_type, sd):
new_sd = {}
for k,v in sd.items():
if k.startswith("transformer_blocks"):
k = "diffusion_model." + k
new_sd[k] = v
sd = new_sd
return sd
def __init__( def __init__(
self, self,
patch_size: int = 2, patch_size: int = 2,

View File

@ -460,6 +460,7 @@ class WanAny2V:
sigmas=sampling_sigmas) sigmas=sampling_sigmas)
else: else:
raise NotImplementedError(f"Unsupported Scheduler {sample_solver}") raise NotImplementedError(f"Unsupported Scheduler {sample_solver}")
original_timesteps = timesteps
seed_g = torch.Generator(device=self.device) seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed) seed_g.manual_seed(seed)
@ -497,7 +498,7 @@ class WanAny2V:
multitalk = model_type in ["multitalk", "vace_multitalk_14B"] multitalk = model_type in ["multitalk", "vace_multitalk_14B"]
recam = model_type in ["recam_1.3B"] recam = model_type in ["recam_1.3B"]
ti2v = model_type in ["ti2v_2_2"] ti2v = model_type in ["ti2v_2_2"]
start_step_no = 0
ref_images_count = 0 ref_images_count = 0
trim_frames = 0 trim_frames = 0
extended_overlapped_latents = None extended_overlapped_latents = None
@ -638,6 +639,7 @@ class WanAny2V:
latent_keep_frames.append(all(keep_frames_parsed[i:i+4])) latent_keep_frames.append(all(keep_frames_parsed[i:i+4]))
else: else:
timesteps = timesteps[injection_denoising_step:] timesteps = timesteps[injection_denoising_step:]
start_step_no = injection_denoising_step
if hasattr(sample_scheduler, "timesteps"): sample_scheduler.timesteps = timesteps if hasattr(sample_scheduler, "timesteps"): sample_scheduler.timesteps = timesteps
if hasattr(sample_scheduler, "sigmas"): sample_scheduler.sigmas= sample_scheduler.sigmas[injection_denoising_step:] if hasattr(sample_scheduler, "sigmas"): sample_scheduler.sigmas= sample_scheduler.sigmas[injection_denoising_step:]
injection_denoising_step = 0 injection_denoising_step = 0
@ -722,16 +724,17 @@ class WanAny2V:
kwargs["freqs"] = freqs kwargs["freqs"] = freqs
# Steps Skipping # Steps Skipping
cache_type = self.model.enable_cache skip_steps_cache = self.model.cache
if cache_type != None: if skip_steps_cache != None:
cache_type = skip_steps_cache.cache_type
x_count = 3 if phantom or fantasy or multitalk else 2 x_count = 3 if phantom or fantasy or multitalk else 2
self.model.previous_residual = [None] * x_count skip_steps_cache.previous_residual = [None] * x_count
if cache_type == "tea": if cache_type == "tea":
self.model.compute_teacache_threshold(self.model.cache_start_step, timesteps, self.model.cache_multiplier) self.model.compute_teacache_threshold(max(skip_steps_cache.start_step, start_step_no), original_timesteps, skip_steps_cache.multiplier)
else: else:
self.model.compute_magcache_threshold(self.model.cache_start_step, timesteps, self.model.cache_multiplier) self.model.compute_magcache_threshold(max(skip_steps_cache.start_step, start_step_no), original_timesteps, skip_steps_cache.multiplier)
self.model.accumulated_err, self.model.accumulated_steps, self.model.accumulated_ratio = [0.0] * x_count, [0] * x_count, [1.0] * x_count skip_steps_cache.accumulated_err, skip_steps_cache.accumulated_steps, skip_steps_cache.accumulated_ratio = [0.0] * x_count, [0] * x_count, [1.0] * x_count
self.model.one_for_all = x_count > 2 skip_steps_cache.one_for_all = x_count > 2
if callback != None: if callback != None:
callback(-1, None, True) callback(-1, None, True)
@ -781,7 +784,7 @@ class WanAny2V:
timestep = torch.full((target_shape[-3],), t, dtype=torch.int64, device=latents.device) timestep = torch.full((target_shape[-3],), t, dtype=torch.int64, device=latents.device)
timestep[:source_latents.shape[2]] = 0 timestep[:source_latents.shape[2]] = 0
kwargs.update({"t": timestep, "current_step": i}) kwargs.update({"t": timestep, "current_step": start_step_no + i})
kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None
if denoising_strength < 1 and input_frames != None and i <= injection_denoising_step: if denoising_strength < 1 and input_frames != None and i <= injection_denoising_step:

View File

@ -1,6 +1,16 @@
import torch import torch
class family_handler(): class family_handler():
@staticmethod
def set_cache_parameters(cache_type, base_model_type, model_def, inputs, skip_steps_cache):
if base_model_type == "sky_df_1.3B":
coefficients= [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
else:
coefficients= [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
skip_steps_cache.coefficients = coefficients
@staticmethod @staticmethod
def query_model_def(base_model_type, model_def): def query_model_def(base_model_type, model_def):
extra_model_def = {} extra_model_def = {}
@ -13,6 +23,7 @@ class family_handler():
extra_model_def["frames_steps"] = 20 extra_model_def["frames_steps"] = 20
extra_model_def["sliding_window"] = True extra_model_def["sliding_window"] = True
extra_model_def["skip_layer_guidance"] = True extra_model_def["skip_layer_guidance"] = True
extra_model_def["tea_cache"] = True
return extra_model_def return extra_model_def
@staticmethod @staticmethod

View File

@ -313,21 +313,24 @@ class DTT2V:
if callback != None: if callback != None:
update_loras_slists(self.model, loras_slists, updated_num_steps) update_loras_slists(self.model, loras_slists, updated_num_steps)
callback(-1, None, True, override_num_inference_steps = updated_num_steps) callback(-1, None, True, override_num_inference_steps = updated_num_steps)
if self.model.enable_cache == "tea": skip_steps_cache = self.model.cache
x_count = 2 if self.do_classifier_free_guidance else 1 if skip_steps_cache != None:
self.model.previous_residual = [None] * x_count skip_steps_cache.num_steps = updated_num_steps
time_steps_comb = [] if skip_steps_cache.cache_type == "tea":
self.model.num_steps = updated_num_steps x_count = 2 if self.do_classifier_free_guidance else 1
for i, timestep_i in enumerate(step_matrix): skip_steps_cache.previous_residual = [None] * x_count
valid_interval_start, valid_interval_end = valid_interval[i] time_steps_comb = []
timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() skip_steps_cache.steps = updated_num_steps
if overlap_noise > 0 and valid_interval_start < predix_video_latent_length: for i, timestep_i in enumerate(step_matrix):
timestep[:, valid_interval_start:predix_video_latent_length] = overlap_noise valid_interval_start, valid_interval_end = valid_interval[i]
time_steps_comb.append(timestep) timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
self.model.compute_teacache_threshold(self.model.cache_start_step, time_steps_comb, self.model.cache_multiplier) if overlap_noise > 0 and valid_interval_start < predix_video_latent_length:
del time_steps_comb timestep[:, valid_interval_start:predix_video_latent_length] = overlap_noise
else: time_steps_comb.append(timestep)
self.model.enable_cache = None self.model.compute_teacache_threshold(skip_steps_cache.start_step, time_steps_comb, skip_steps_cache.multiplier)
del time_steps_comb
else:
self.model.cache = None
from mmgp import offload from mmgp import offload
freqs = get_rotary_pos_embed(latents.shape[2 :], enable_RIFLEx= False) freqs = get_rotary_pos_embed(latents.shape[2 :], enable_RIFLEx= False)
kwrags = { kwrags = {

View File

@ -1010,6 +1010,7 @@ class WanModel(ModelMixin, ConfigMixin):
self._lock_dtype = dtype self._lock_dtype = dtype
def compute_magcache_threshold(self, start_step, timesteps = None, speed_factor =0): def compute_magcache_threshold(self, start_step, timesteps = None, speed_factor =0):
skips_step_cache = self.cache
def nearest_interp(src_array, target_length): def nearest_interp(src_array, target_length):
src_length = len(src_array) src_length = len(src_array)
if target_length == 1: return np.array([src_array[-1]]) if target_length == 1: return np.array([src_array[-1]])
@ -1017,13 +1018,14 @@ class WanModel(ModelMixin, ConfigMixin):
mapped_indices = np.round(np.arange(target_length) * scale).astype(int) mapped_indices = np.round(np.arange(target_length) * scale).astype(int)
return src_array[mapped_indices] return src_array[mapped_indices]
num_inference_steps = len(timesteps) num_inference_steps = len(timesteps)
if len(self.def_mag_ratios) != num_inference_steps*2: def_mag_ratios = np.array([1.0]*2+ skips_step_cache.def_mag_ratios)
mag_ratio_con = nearest_interp(self.def_mag_ratios[0::2], num_inference_steps) if len(def_mag_ratios) != num_inference_steps*2:
mag_ratio_ucon = nearest_interp(self.def_mag_ratios[1::2], num_inference_steps) mag_ratio_con = nearest_interp(def_mag_ratios[0::2], num_inference_steps)
mag_ratio_ucon = nearest_interp(def_mag_ratios[1::2], num_inference_steps)
interpolated_mag_ratios = np.concatenate([mag_ratio_con.reshape(-1, 1), mag_ratio_ucon.reshape(-1, 1)], axis=1).reshape(-1) interpolated_mag_ratios = np.concatenate([mag_ratio_con.reshape(-1, 1), mag_ratio_ucon.reshape(-1, 1)], axis=1).reshape(-1)
self.mag_ratios = interpolated_mag_ratios skips_step_cache.mag_ratios = interpolated_mag_ratios
else: else:
self.mag_ratios = self.def_mag_ratios skips_step_cache.mag_ratios = def_mag_ratios
best_deltas = None best_deltas = None
@ -1044,12 +1046,12 @@ class WanModel(ModelMixin, ConfigMixin):
else: else:
x_should_calc = [] x_should_calc = []
for cur_x_id in range(x_id_max): for cur_x_id in range(x_id_max):
cur_mag_ratio = self.mag_ratios[i * 2 + cur_x_id] # conditional and unconditional in one list cur_mag_ratio = skips_step_cache.mag_ratios[i * 2 + cur_x_id] # conditional and unconditional in one list
accumulated_ratio[cur_x_id] *= cur_mag_ratio # magnitude ratio between current step and the cached step accumulated_ratio[cur_x_id] *= cur_mag_ratio # magnitude ratio between current step and the cached step
accumulated_steps[cur_x_id] += 1 # skip steps plus 1 accumulated_steps[cur_x_id] += 1 # skip steps plus 1
cur_skip_err = np.abs(1-accumulated_ratio[cur_x_id]) # skip error of current steps cur_skip_err = np.abs(1-accumulated_ratio[cur_x_id]) # skip error of current steps
accumulated_err[cur_x_id] += cur_skip_err # accumulated error of multiple steps accumulated_err[cur_x_id] += cur_skip_err # accumulated error of multiple steps
if accumulated_err[cur_x_id]<threshold and accumulated_steps[cur_x_id]<=self.magcache_K: if accumulated_err[cur_x_id]<threshold and accumulated_steps[cur_x_id]<=skips_step_cache.magcache_K:
skip = True skip = True
else: else:
skip = False skip = False
@ -1066,13 +1068,14 @@ class WanModel(ModelMixin, ConfigMixin):
elif diff > best_diff: elif diff > best_diff:
break break
threshold += 0.01 threshold += 0.01
self.magcache_thresh = best_threshold skips_step_cache.magcache_thresh = best_threshold
print(f"Mag Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}") print(f"Mag Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}")
return best_threshold return best_threshold
def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0): def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
skips_step_cache = self.cache
modulation_dtype = self.time_projection[1].weight.dtype modulation_dtype = self.time_projection[1].weight.dtype
rescale_func = np.poly1d(self.coefficients) rescale_func = np.poly1d(skips_step_cache.coefficients)
e_list = [] e_list = []
for t in timesteps: for t in timesteps:
t = torch.stack([t]) t = torch.stack([t])
@ -1112,7 +1115,7 @@ class WanModel(ModelMixin, ConfigMixin):
elif diff > best_diff: elif diff > best_diff:
break break
threshold += 0.01 threshold += 0.01
self.rel_l1_thresh = best_threshold skips_step_cache.rel_l1_thresh = best_threshold
print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}") print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}")
# print(f"deltas:{best_deltas}") # print(f"deltas:{best_deltas}")
return best_threshold return best_threshold
@ -1282,72 +1285,73 @@ class WanModel(ModelMixin, ConfigMixin):
del c del c
should_calc = True should_calc = True
x_should_calc = None x_should_calc = None
if self.enable_cache != None: skips_steps_cache = self.cache
if self.enable_cache == "mag": if skips_steps_cache != None:
if current_step <= self.cache_start_step: if skips_steps_cache.cache_type == "mag":
if current_step <= skips_steps_cache.start_step:
should_calc = True should_calc = True
elif self.one_for_all and x_id != 0: # not joint pass, not main pas, one for all elif skips_steps_cache.one_for_all and x_id != 0: # not joint pass, not main pas, one for all
assert len(x_list) == 1 assert len(x_list) == 1
should_calc = self.should_calc should_calc = skips_steps_cache.should_calc
else: else:
x_should_calc = [] x_should_calc = []
for i in range(1 if self.one_for_all else len(x_list)): for i in range(1 if skips_steps_cache.one_for_all else len(x_list)):
cur_x_id = i if joint_pass else x_id cur_x_id = i if joint_pass else x_id
cur_mag_ratio = self.mag_ratios[current_step * 2 + cur_x_id] # conditional and unconditional in one list cur_mag_ratio = skips_steps_cache.mag_ratios[current_step * 2 + cur_x_id] # conditional and unconditional in one list
self.accumulated_ratio[cur_x_id] *= cur_mag_ratio # magnitude ratio between current step and the cached step skips_steps_cache.accumulated_ratio[cur_x_id] *= cur_mag_ratio # magnitude ratio between current step and the cached step
self.accumulated_steps[cur_x_id] += 1 # skip steps plus 1 skips_steps_cache.accumulated_steps[cur_x_id] += 1 # skip steps plus 1
cur_skip_err = np.abs(1-self.accumulated_ratio[cur_x_id]) # skip error of current steps cur_skip_err = np.abs(1-skips_steps_cache.accumulated_ratio[cur_x_id]) # skip error of current steps
self.accumulated_err[cur_x_id] += cur_skip_err # accumulated error of multiple steps skips_steps_cache.accumulated_err[cur_x_id] += cur_skip_err # accumulated error of multiple steps
if self.accumulated_err[cur_x_id]<self.magcache_thresh and self.accumulated_steps[cur_x_id]<=self.magcache_K: if skips_steps_cache.accumulated_err[cur_x_id]<skips_steps_cache.magcache_thresh and skips_steps_cache.accumulated_steps[cur_x_id]<=skips_steps_cache.magcache_K:
skip_forward = True skip_forward = True
if i == 0 and x_id == 0: self.cache_skipped_steps += 1 if i == 0 and x_id == 0: skips_steps_cache.skipped_steps += 1
# print(f"skip: step={current_step} for x_id={cur_x_id}, accum error {self.accumulated_err[cur_x_id]}") # print(f"skip: step={current_step} for x_id={cur_x_id}, accum error {skips_step_cache.accumulated_err[cur_x_id]}")
else: else:
skip_forward = False skip_forward = False
self.accumulated_err[cur_x_id], self.accumulated_steps[cur_x_id], self.accumulated_ratio[cur_x_id] = 0, 0, 1.0 skips_steps_cache.accumulated_err[cur_x_id], skips_steps_cache.accumulated_steps[cur_x_id], skips_steps_cache.accumulated_ratio[cur_x_id] = 0, 0, 1.0
x_should_calc.append(not skip_forward) x_should_calc.append(not skip_forward)
if self.one_for_all: if skips_steps_cache.one_for_all:
should_calc = self.should_calc = x_should_calc[0] should_calc = skips_steps_cache.should_calc = x_should_calc[0]
x_should_calc = None x_should_calc = None
else: else:
if x_id != 0: if x_id != 0:
should_calc = self.should_calc should_calc = skips_steps_cache.should_calc
else: else:
if current_step <= self.cache_start_step or current_step == self.num_steps-1: if current_step <= skips_steps_cache.start_step or current_step == skips_steps_cache.num_steps-1:
should_calc = True should_calc = True
self.accumulated_rel_l1_distance = 0 skips_steps_cache.accumulated_rel_l1_distance = 0
else: else:
rescale_func = np.poly1d(self.coefficients) rescale_func = np.poly1d(skips_steps_cache.coefficients)
delta = abs(rescale_func(((e-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())) delta = abs(rescale_func(((e-skips_steps_cache.previous_modulated_input).abs().mean() / skips_steps_cache.previous_modulated_input.abs().mean()).cpu().item()))
self.accumulated_rel_l1_distance += delta skips_steps_cache.accumulated_rel_l1_distance += delta
if self.accumulated_rel_l1_distance < self.rel_l1_thresh: if skips_steps_cache.accumulated_rel_l1_distance < skips_steps_cache.rel_l1_thresh:
should_calc = False should_calc = False
self.cache_skipped_steps += 1 skips_steps_cache.skipped_steps += 1
# print(f"Teacache Skipped Step no {current_step} ({self.cache_skipped_steps}/{current_step}), delta={delta}" ) # print(f"Teacache Skipped Step no {current_step} ({skips_step_cache.cache_skipped_steps}/{current_step}), delta={delta}" )
else: else:
should_calc = True should_calc = True
self.accumulated_rel_l1_distance = 0 skips_steps_cache.accumulated_rel_l1_distance = 0
self.previous_modulated_input = e skips_steps_cache.previous_modulated_input = e
self.should_calc = should_calc skips_steps_cache.should_calc = should_calc
if x_should_calc == None: x_should_calc = [should_calc] * len(x_list) if x_should_calc == None: x_should_calc = [should_calc] * len(x_list)
if joint_pass: if joint_pass:
for i, x in enumerate(x_list): for i, x in enumerate(x_list):
if not x_should_calc[i]: x += self.previous_residual[i] if not x_should_calc[i]: x += skips_steps_cache.previous_residual[i]
elif not x_should_calc[0]: elif not x_should_calc[0]:
x = x_list[0] x = x_list[0]
x += self.previous_residual[x_id] x += skips_steps_cache.previous_residual[x_id]
x = None x = None
if self.enable_cache != None: if skips_steps_cache != None:
if self.previous_residual == None: self.previous_residual = [ None ] * len(self.previous_residual) if skips_steps_cache.previous_residual == None: skips_steps_cache.previous_residual = [ None ] * len(x_list)
if joint_pass: if joint_pass:
for i, should_calc in enumerate(x_should_calc): for i, should_calc in enumerate(x_should_calc):
if should_calc: self.previous_residual[i] = None if should_calc: skips_steps_cache.previous_residual[i] = None
elif x_should_calc[0]: elif x_should_calc[0]:
self.previous_residual[x_id] = None skips_steps_cache.previous_residual[x_id] = None
ori_hidden_states = [ None ] * len(x_list) ori_hidden_states = [ None ] * len(x_list)
if all(x_should_calc): if all(x_should_calc):
ori_hidden_states[0] = x_list[0].clone() ori_hidden_states[0] = x_list[0].clone()
@ -1379,27 +1383,27 @@ class WanModel(ModelMixin, ConfigMixin):
del x del x
context = hints = audio_embedding = None context = hints = audio_embedding = None
if self.enable_cache != None: if skips_steps_cache != None:
if joint_pass: if joint_pass:
if all(x_should_calc): if all(x_should_calc):
for i, (x, ori, is_source) in enumerate(zip(x_list, ori_hidden_states, is_source_x)) : for i, (x, ori, is_source) in enumerate(zip(x_list, ori_hidden_states, is_source_x)) :
if i == 0 or is_source and i != last_x_idx : if i == 0 or is_source and i != last_x_idx :
self.previous_residual[i] = torch.sub(x, ori) skips_steps_cache.previous_residual[i] = torch.sub(x, ori)
else: else:
self.previous_residual[i] = ori skips_steps_cache.previous_residual[i] = ori
torch.sub(x, ori, out=self.previous_residual[i]) torch.sub(x, ori, out=skips_steps_cache.previous_residual[i])
ori_hidden_states[i] = None ori_hidden_states[i] = None
else: else:
for i, (x, ori, is_source, should_calc) in enumerate(zip(x_list, ori_hidden_states, is_source_x, x_should_calc)) : for i, (x, ori, is_source, should_calc) in enumerate(zip(x_list, ori_hidden_states, is_source_x, x_should_calc)) :
if should_calc: if should_calc:
self.previous_residual[i] = ori skips_steps_cache.previous_residual[i] = ori
torch.sub(x, ori, out=self.previous_residual[i]) torch.sub(x, ori, out=skips_steps_cache.previous_residual[i])
ori_hidden_states[i] = None ori_hidden_states[i] = None
x , ori = None, None x , ori = None, None
elif x_should_calc[0]: elif x_should_calc[0]:
residual = ori_hidden_states[0] # just to have a readable code residual = ori_hidden_states[0] # just to have a readable code
torch.sub(x_list[0], ori_hidden_states[0], out=residual) torch.sub(x_list[0], ori_hidden_states[0], out=residual)
self.previous_residual[x_id] = residual skips_steps_cache.previous_residual[x_id] = residual
residual, ori_hidden_states = None, None residual, ori_hidden_states = None, None
for i, x in enumerate(x_list): for i, x in enumerate(x_list):

View File

@ -1,9 +1,59 @@
import torch import torch
import numpy as np
def test_class_i2v(base_model_type): def test_class_i2v(base_model_type):
return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", ] #"hunyuan_i2v", return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", ] #"hunyuan_i2v",
def test_class_1_3B(base_model_type):
return base_model_type in [ "vace_1.3B", "t2v_1.3B", "recam_1.3B","phantom_1.3B","fun_inp_1.3B"]
class family_handler(): class family_handler():
@staticmethod
def set_cache_parameters(cache_type, base_model_type, model_def, inputs, skip_steps_cache):
i2v = test_class_i2v(base_model_type)
resolution = inputs["resolution"]
width, height = resolution.split("x")
pixels = int(width) * int(height)
if cache_type == "mag":
skip_steps_cache.update({
"magcache_thresh" : 0,
"magcache_K" : 2,
})
if base_model_type in ["t2v"] and "URLs2" in model_def:
def_mag_ratios = [1.00124, 1.00155, 0.99822, 0.99851, 0.99696, 0.99687, 0.99703, 0.99732, 0.9966, 0.99679, 0.99602, 0.99658, 0.99578, 0.99664, 0.99484, 0.9949, 0.99633, 0.996, 0.99659, 0.99683, 0.99534, 0.99549, 0.99584, 0.99577, 0.99681, 0.99694, 0.99563, 0.99554, 0.9944, 0.99473, 0.99594, 0.9964, 0.99466, 0.99461, 0.99453, 0.99481, 0.99389, 0.99365, 0.99391, 0.99406, 0.99354, 0.99361, 0.99283, 0.99278, 0.99268, 0.99263, 0.99057, 0.99091, 0.99125, 0.99126, 0.65523, 0.65252, 0.98808, 0.98852, 0.98765, 0.98736, 0.9851, 0.98535, 0.98311, 0.98339, 0.9805, 0.9806, 0.97776, 0.97771, 0.97278, 0.97286, 0.96731, 0.96728, 0.95857, 0.95855, 0.94385, 0.94385, 0.92118, 0.921, 0.88108, 0.88076, 0.80263, 0.80181]
elif base_model_type in ["i2v_2_2"]:
def_mag_ratios = [0.99191, 0.99144, 0.99356, 0.99337, 0.99326, 0.99285, 0.99251, 0.99264, 0.99393, 0.99366, 0.9943, 0.9943, 0.99276, 0.99288, 0.99389, 0.99393, 0.99274, 0.99289, 0.99316, 0.9931, 0.99379, 0.99377, 0.99268, 0.99271, 0.99222, 0.99227, 0.99175, 0.9916, 0.91076, 0.91046, 0.98931, 0.98933, 0.99087, 0.99088, 0.98852, 0.98855, 0.98895, 0.98896, 0.98806, 0.98808, 0.9871, 0.98711, 0.98613, 0.98618, 0.98434, 0.98435, 0.983, 0.98307, 0.98185, 0.98187, 0.98131, 0.98131, 0.9783, 0.97835, 0.97619, 0.9762, 0.97264, 0.9727, 0.97088, 0.97098, 0.96568, 0.9658, 0.96045, 0.96055, 0.95322, 0.95335, 0.94579, 0.94594, 0.93297, 0.93311, 0.91699, 0.9172, 0.89174, 0.89202, 0.8541, 0.85446, 0.79823, 0.79902]
elif base_model_type in ["ti2v_2_2"]:
if inputs.get("image_start", None) is not None and inputs.get("video_source", None) is not None : # t2v
def_mag_ratios = [0.99505, 0.99389, 0.99441, 0.9957, 0.99558, 0.99551, 0.99499, 0.9945, 0.99534, 0.99548, 0.99468, 0.9946, 0.99463, 0.99458, 0.9946, 0.99453, 0.99408, 0.99404, 0.9945, 0.99441, 0.99409, 0.99398, 0.99403, 0.99397, 0.99382, 0.99377, 0.99349, 0.99343, 0.99377, 0.99378, 0.9933, 0.99328, 0.99303, 0.99301, 0.99217, 0.99216, 0.992, 0.99201, 0.99201, 0.99202, 0.99133, 0.99132, 0.99112, 0.9911, 0.99155, 0.99155, 0.98958, 0.98957, 0.98959, 0.98958, 0.98838, 0.98835, 0.98826, 0.98825, 0.9883, 0.98828, 0.98711, 0.98709, 0.98562, 0.98561, 0.98511, 0.9851, 0.98414, 0.98412, 0.98284, 0.98282, 0.98104, 0.98101, 0.97981, 0.97979, 0.97849, 0.97849, 0.97557, 0.97554, 0.97398, 0.97395, 0.97171, 0.97166, 0.96917, 0.96913, 0.96511, 0.96507, 0.96263, 0.96257, 0.95839, 0.95835, 0.95483, 0.95475, 0.94942, 0.94936, 0.9468, 0.94678, 0.94583, 0.94594, 0.94843, 0.94872, 0.96949, 0.97015]
else: # i2v
def_mag_ratios = [0.99512, 0.99559, 0.99559, 0.99561, 0.99595, 0.99577, 0.99512, 0.99512, 0.99546, 0.99534, 0.99543, 0.99531, 0.99496, 0.99491, 0.99504, 0.99499, 0.99444, 0.99449, 0.99481, 0.99481, 0.99435, 0.99435, 0.9943, 0.99431, 0.99411, 0.99406, 0.99373, 0.99376, 0.99413, 0.99405, 0.99363, 0.99359, 0.99335, 0.99331, 0.99244, 0.99243, 0.99229, 0.99229, 0.99239, 0.99236, 0.99163, 0.9916, 0.99149, 0.99151, 0.99191, 0.99192, 0.9898, 0.98981, 0.9899, 0.98987, 0.98849, 0.98849, 0.98846, 0.98846, 0.98861, 0.98861, 0.9874, 0.98738, 0.98588, 0.98589, 0.98539, 0.98534, 0.98444, 0.98439, 0.9831, 0.98309, 0.98119, 0.98118, 0.98001, 0.98, 0.97862, 0.97859, 0.97555, 0.97558, 0.97392, 0.97388, 0.97152, 0.97145, 0.96871, 0.9687, 0.96435, 0.96434, 0.96129, 0.96127, 0.95639, 0.95638, 0.95176, 0.95175, 0.94446, 0.94452, 0.93972, 0.93974, 0.93575, 0.9359, 0.93537, 0.93552, 0.96655, 0.96616]
elif test_class_1_3B(base_model_type): #text 1.3B
def_mag_ratios = [1.0124, 1.02213, 1.00166, 1.0041, 0.99791, 1.00061, 0.99682, 0.99762, 0.99634, 0.99685, 0.99567, 0.99586, 0.99416, 0.99422, 0.99578, 0.99575, 0.9957, 0.99563, 0.99511, 0.99506, 0.99535, 0.99531, 0.99552, 0.99549, 0.99541, 0.99539, 0.9954, 0.99536, 0.99489, 0.99485, 0.99518, 0.99514, 0.99484, 0.99478, 0.99481, 0.99479, 0.99415, 0.99413, 0.99419, 0.99416, 0.99396, 0.99393, 0.99388, 0.99386, 0.99349, 0.99349, 0.99309, 0.99304, 0.9927, 0.9927, 0.99228, 0.99226, 0.99171, 0.9917, 0.99137, 0.99135, 0.99068, 0.99063, 0.99005, 0.99003, 0.98944, 0.98942, 0.98849, 0.98849, 0.98758, 0.98757, 0.98644, 0.98643, 0.98504, 0.98503, 0.9836, 0.98359, 0.98202, 0.98201, 0.97977, 0.97978, 0.97717, 0.97718, 0.9741, 0.97411, 0.97003, 0.97002, 0.96538, 0.96541, 0.9593, 0.95933, 0.95086, 0.95089, 0.94013, 0.94019, 0.92402, 0.92414, 0.90241, 0.9026, 0.86821, 0.86868, 0.81838, 0.81939]#**(0.5)# In our papaer, we utilize the sqrt to smooth the ratio, which has little impact on the performance and can be deleted.
elif i2v:
if pixels >= 1280*720:
def_mag_ratios = [0.99428, 0.99498, 0.98588, 0.98621, 0.98273, 0.98281, 0.99018, 0.99023, 0.98911, 0.98917, 0.98646, 0.98652, 0.99454, 0.99456, 0.9891, 0.98909, 0.99124, 0.99127, 0.99102, 0.99103, 0.99215, 0.99212, 0.99515, 0.99515, 0.99576, 0.99572, 0.99068, 0.99072, 0.99097, 0.99097, 0.99166, 0.99169, 0.99041, 0.99042, 0.99201, 0.99198, 0.99101, 0.99101, 0.98599, 0.98603, 0.98845, 0.98844, 0.98848, 0.98851, 0.98862, 0.98857, 0.98718, 0.98719, 0.98497, 0.98497, 0.98264, 0.98263, 0.98389, 0.98393, 0.97938, 0.9794, 0.97535, 0.97536, 0.97498, 0.97499, 0.973, 0.97301, 0.96827, 0.96828, 0.96261, 0.96263, 0.95335, 0.9534, 0.94649, 0.94655, 0.93397, 0.93414, 0.91636, 0.9165, 0.89088, 0.89109, 0.8679, 0.86768]
else:
def_mag_ratios = [0.98783, 0.98993, 0.97559, 0.97593, 0.98311, 0.98319, 0.98202, 0.98225, 0.9888, 0.98878, 0.98762, 0.98759, 0.98957, 0.98971, 0.99052, 0.99043, 0.99383, 0.99384, 0.98857, 0.9886, 0.99065, 0.99068, 0.98845, 0.98847, 0.99057, 0.99057, 0.98957, 0.98961, 0.98601, 0.9861, 0.98823, 0.98823, 0.98756, 0.98759, 0.98808, 0.98814, 0.98721, 0.98724, 0.98571, 0.98572, 0.98543, 0.98544, 0.98157, 0.98165, 0.98411, 0.98413, 0.97952, 0.97953, 0.98149, 0.9815, 0.9774, 0.97742, 0.97825, 0.97826, 0.97355, 0.97361, 0.97085, 0.97087, 0.97056, 0.97055, 0.96588, 0.96587, 0.96113, 0.96124, 0.9567, 0.95681, 0.94961, 0.94969, 0.93973, 0.93988, 0.93217, 0.93224, 0.91878, 0.91896, 0.90955, 0.90954, 0.92617, 0.92616]
else: # text 14B
def_mag_ratios = [1.02504, 1.03017, 1.00025, 1.00251, 0.9985, 0.99962, 0.99779, 0.99771, 0.9966, 0.99658, 0.99482, 0.99476, 0.99467, 0.99451, 0.99664, 0.99656, 0.99434, 0.99431, 0.99533, 0.99545, 0.99468, 0.99465, 0.99438, 0.99434, 0.99516, 0.99517, 0.99384, 0.9938, 0.99404, 0.99401, 0.99517, 0.99516, 0.99409, 0.99408, 0.99428, 0.99426, 0.99347, 0.99343, 0.99418, 0.99416, 0.99271, 0.99269, 0.99313, 0.99311, 0.99215, 0.99215, 0.99218, 0.99215, 0.99216, 0.99217, 0.99163, 0.99161, 0.99138, 0.99135, 0.98982, 0.9898, 0.98996, 0.98995, 0.9887, 0.98866, 0.98772, 0.9877, 0.98767, 0.98765, 0.98573, 0.9857, 0.98501, 0.98498, 0.9838, 0.98376, 0.98177, 0.98173, 0.98037, 0.98035, 0.97678, 0.97677, 0.97546, 0.97543, 0.97184, 0.97183, 0.96711, 0.96708, 0.96349, 0.96345, 0.95629, 0.95625, 0.94926, 0.94929, 0.93964, 0.93961, 0.92511, 0.92504, 0.90693, 0.90678, 0.8796, 0.87945, 0.86111, 0.86189]
skip_steps_cache.def_mag_ratios = def_mag_ratios
else:
if i2v:
if pixels >= 1280*720:
coefficients= [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
else:
coefficients= [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
else:
if test_class_1_3B(base_model_type):
coefficients= [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
else:
coefficients= [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
skip_steps_cache.coefficients = coefficients
@staticmethod @staticmethod
def get_wan_text_encoder_filename(text_encoder_quantization): def get_wan_text_encoder_filename(text_encoder_quantization):
text_encoder_filename = "ckpts/umt5-xxl/models_t5_umt5-xxl-enc-bf16.safetensors" text_encoder_filename = "ckpts/umt5-xxl/models_t5_umt5-xxl-enc-bf16.safetensors"
@ -27,7 +77,8 @@ class family_handler():
extra_model_def = {} extra_model_def = {}
if "URLs2" in model_def: if "URLs2" in model_def:
extra_model_def["no_steps_skipping"] = True extra_model_def["no_steps_skipping"] = True
extra_model_def["i2v_class"] = test_class_i2v(base_model_type) i2v = test_class_i2v(base_model_type)
extra_model_def["i2v_class"] = i2v
vace_class = base_model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B"] vace_class = base_model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B"]
extra_model_def["vace_class"] = vace_class extra_model_def["vace_class"] = vace_class
@ -55,7 +106,9 @@ class family_handler():
"cfg_zero" : True, "cfg_zero" : True,
"cfg_star" : True, "cfg_star" : True,
"adaptive_projected_guidance" : True, "adaptive_projected_guidance" : True,
"skip_steps_cache" : not (base_model_type in ["i2v_2_2", "ti2v_2_2" ] or "URLs2" in model_def), "tea_cache" : not (base_model_type in ["i2v_2_2", "ti2v_2_2" ] or "URLs2" in model_def),
"mag_cache" : True,
}) })
return extra_model_def return extra_model_def

View File

@ -44,7 +44,7 @@ def parse_loras_multipliers(loras_multipliers, nb_loras, num_inference_steps, me
slists_dict["phase2"] = phase2 = [1.] * nb_loras slists_dict["phase2"] = phase2 = [1.] * nb_loras
if isinstance(loras_multipliers, list) or len(loras_multipliers) > 0: if isinstance(loras_multipliers, list) or len(loras_multipliers) > 0:
list_mult_choices_list = preparse_loras_multipliers(loras_multipliers) list_mult_choices_list = preparse_loras_multipliers(loras_multipliers)[:nb_loras]
for i, mult in enumerate(list_mult_choices_list): for i, mult in enumerate(list_mult_choices_list):
current_phase = phase1 current_phase = phase1
if isinstance(mult, str): if isinstance(mult, str):

146
wgp.py
View File

@ -177,7 +177,7 @@ def process_prompt_and_add_tasks(state, model_choice):
return get_queue_table(queue) return get_queue_table(queue)
model_def = get_model_def(model_type) model_def = get_model_def(model_type)
image_outputs = inputs["image_mode"] == 1 image_outputs = inputs["image_mode"] == 1
any_steps_skipping = model_def.get("skip_steps_cache", False) any_steps_skipping = model_def.get("tea_cache", False) or model_def.get("mag_cache", False)
model_type = get_base_model_type(model_type) model_type = get_base_model_type(model_type)
inputs["model_filename"] = model_filename inputs["model_filename"] = model_filename
@ -302,16 +302,10 @@ def process_prompt_and_add_tasks(state, model_choice):
return return
if not any_steps_skipping: skip_steps_cache_type = "" if not any_steps_skipping: skip_steps_cache_type = ""
if switch_threshold is not None and switch_threshold != 0 and len(skip_steps_cache_type) > 0:
gr.Info("Steps skipping is not yet supported if Switch Threshold is not null")
return
if not model_def.get("lock_inference_steps", False) and model_type in ["ltxv_13B"] and num_inference_steps < 20: if not model_def.get("lock_inference_steps", False) and model_type in ["ltxv_13B"] and num_inference_steps < 20:
gr.Info("The minimum number of steps should be 20") gr.Info("The minimum number of steps should be 20")
return return
if skip_steps_cache_type == "mag": if skip_steps_cache_type == "mag":
if model_type in ["sky_df_1.3B", "sky_df_14B"]:
gr.Info("Mag Cache is not supported with Diffusion Forcing")
return
if num_inference_steps > 50: if num_inference_steps > 50:
gr.Info("Mag Cache maximum number of steps is 50") gr.Info("Mag Cache maximum number of steps is 50")
return return
@ -3844,6 +3838,36 @@ def get_transformer_loras(model_type):
transformer_loras_multipliers = transformer_loras_multipliers[:len(transformer_loras_filenames)] transformer_loras_multipliers = transformer_loras_multipliers[:len(transformer_loras_filenames)]
return transformer_loras_filenames, transformer_loras_multipliers return transformer_loras_filenames, transformer_loras_multipliers
class DynamicClass:
def __init__(self, **kwargs):
self._data = {}
# Preassign default properties from kwargs
for key, value in kwargs.items():
self._data[key] = value
def __getattr__(self, name):
if name in self._data:
return self._data[name]
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
def __setattr__(self, name, value):
if name.startswith('_'):
super().__setattr__(name, value)
else:
if not hasattr(self, '_data'):
super().__setattr__('_data', {})
self._data[name] = value
def assign(self, **kwargs):
"""Assign multiple properties at once"""
for key, value in kwargs.items():
self._data[key] = value
return self # For method chaining
def update(self, dict):
"""Alias for assign() - more dict-like"""
return self.assign(**dict)
def generate_video( def generate_video(
task, task,
send_cmd, send_cmd,
@ -3999,7 +4023,6 @@ def generate_video(
width, height = resolution.split("x") width, height = resolution.split("x")
width, height = int(width), int(height) width, height = int(width), int(height)
resolution_reformated = str(height) + "*" + str(width)
default_image_size = (height, width) default_image_size = (height, width)
if slg_switch == 0: if slg_switch == 0:
@ -4137,50 +4160,26 @@ def generate_video(
update_task_thumbnails(task, locals()) update_task_thumbnails(task, locals())
send_cmd("output") send_cmd("output")
joint_pass = boost ==1 #and profile != 1 and profile != 3 joint_pass = boost ==1 #and profile != 1 and profile != 3
trans.enable_cache = None if len(skip_steps_cache_type) == 0 else skip_steps_cache_type
if trans2 is not None: skip_steps_cache = None if len(skip_steps_cache_type) == 0 else DynamicClass(cache_type = skip_steps_cache_type)
trans2.enable_cache = None
if trans.enable_cache != None: if skip_steps_cache != None:
trans.cache_multiplier = skip_steps_multiplier skip_steps_cache.update({
trans.cache_start_step = int(skip_steps_start_step_perc*num_inference_steps/100) "multiplier" : skip_steps_multiplier,
"start_step": int(skip_steps_start_step_perc*num_inference_steps/100)
if trans.enable_cache == "mag": })
trans.magcache_thresh = 0 model_handler.set_cache_parameters(skip_steps_cache_type, base_model_type, model_def, locals(), skip_steps_cache)
trans.magcache_K = 2 if skip_steps_cache_type == "mag":
def_mag_ratios = model_def.get("magcache_ratios", None) if model_def != None else None def_mag_ratios = model_def.get("magcache_ratios", None) if model_def != None else None
if def_mag_ratios != None: if def_mag_ratios is not None: skip_steps_cache.def_mag_ratios = def_mag_ratios
trans.def_mag_ratios = def_mag_ratios elif skip_steps_cache_type == "tea":
elif get_model_family(model_type) == "wan": def_tea_coefficients = model_def.get("teacache_coefficients", None) if model_def != None else None
if i2v: if def_tea_coefficients is not None: skip_steps_cache.coefficients = def_tea_coefficients
trans.def_mag_ratios = np.array([1.0]*2+[1.0124, 1.02213, 1.00166, 1.0041, 0.99791, 1.00061, 0.99682, 0.99762, 0.99634, 0.99685, 0.99567, 0.99586, 0.99416, 0.99422, 0.99578, 0.99575, 0.9957, 0.99563, 0.99511, 0.99506, 0.99535, 0.99531, 0.99552, 0.99549, 0.99541, 0.99539, 0.9954, 0.99536, 0.99489, 0.99485, 0.99518, 0.99514, 0.99484, 0.99478, 0.99481, 0.99479, 0.99415, 0.99413, 0.99419, 0.99416, 0.99396, 0.99393, 0.99388, 0.99386, 0.99349, 0.99349, 0.99309, 0.99304, 0.9927, 0.9927, 0.99228, 0.99226, 0.99171, 0.9917, 0.99137, 0.99135, 0.99068, 0.99063, 0.99005, 0.99003, 0.98944, 0.98942, 0.98849, 0.98849, 0.98758, 0.98757, 0.98644, 0.98643, 0.98504, 0.98503, 0.9836, 0.98359, 0.98202, 0.98201, 0.97977, 0.97978, 0.97717, 0.97718, 0.9741, 0.97411, 0.97003, 0.97002, 0.96538, 0.96541, 0.9593, 0.95933, 0.95086, 0.95089, 0.94013, 0.94019, 0.92402, 0.92414, 0.90241, 0.9026, 0.86821, 0.86868, 0.81838, 0.81939])#**(0.5)# In our papaer, we utilize the sqrt to smooth the ratio, which has little impact on the performance and can be deleted.
else:
trans.def_mag_ratios = np.array([1.0]*2+[1.02504, 1.03017, 1.00025, 1.00251, 0.9985, 0.99962, 0.99779, 0.99771, 0.9966, 0.99658, 0.99482, 0.99476, 0.99467, 0.99451, 0.99664, 0.99656, 0.99434, 0.99431, 0.99533, 0.99545, 0.99468, 0.99465, 0.99438, 0.99434, 0.99516, 0.99517, 0.99384, 0.9938, 0.99404, 0.99401, 0.99517, 0.99516, 0.99409, 0.99408, 0.99428, 0.99426, 0.99347, 0.99343, 0.99418, 0.99416, 0.99271, 0.99269, 0.99313, 0.99311, 0.99215, 0.99215, 0.99218, 0.99215, 0.99216, 0.99217, 0.99163, 0.99161, 0.99138, 0.99135, 0.98982, 0.9898, 0.98996, 0.98995, 0.9887, 0.98866, 0.98772, 0.9877, 0.98767, 0.98765, 0.98573, 0.9857, 0.98501, 0.98498, 0.9838, 0.98376, 0.98177, 0.98173, 0.98037, 0.98035, 0.97678, 0.97677, 0.97546, 0.97543, 0.97184, 0.97183, 0.96711, 0.96708, 0.96349, 0.96345, 0.95629, 0.95625, 0.94926, 0.94929, 0.93964, 0.93961, 0.92511, 0.92504, 0.90693, 0.90678, 0.8796, 0.87945, 0.86111, 0.86189])
else: else:
if width * height >= 1280* 720: raise Exception(f"unknown cache type {skip_steps_cache_type}")
trans.def_mag_ratios = np.array([1.0]+[1.0754, 1.27807, 1.11596, 1.09504, 1.05188, 1.00844, 1.05779, 1.00657, 1.04142, 1.03101, 1.00679, 1.02556, 1.00908, 1.06949, 1.05438, 1.02214, 1.02321, 1.03019, 1.00779, 1.03381, 1.01886, 1.01161, 1.02968, 1.00544, 1.02822, 1.00689, 1.02119, 1.0105, 1.01044, 1.01572, 1.02972, 1.0094, 1.02368, 1.0226, 0.98965, 1.01588, 1.02146, 1.0018, 1.01687, 0.99436, 1.00283, 1.01139, 0.97122, 0.98251, 0.94513, 0.97656, 0.90943, 0.85703, 0.75456]) trans.cache = skip_steps_cache
else: if trans2 is not None: trans2.cache = skip_steps_cache
trans.def_mag_ratios = np.array([1.0]+[1.06971, 1.29073, 1.11245, 1.09596, 1.05233, 1.01415, 1.05672, 1.00848, 1.03632, 1.02974, 1.00984, 1.03028, 1.00681, 1.06614, 1.05022, 1.02592, 1.01776, 1.02985, 1.00726, 1.03727, 1.01502, 1.00992, 1.03371, 0.9976, 1.02742, 1.0093, 1.01869, 1.00815, 1.01461, 1.01152, 1.03082, 1.0061, 1.02162, 1.01999, 0.99063, 1.01186, 1.0217, 0.99947, 1.01711, 0.9904, 1.00258, 1.00878, 0.97039, 0.97686, 0.94315, 0.97728, 0.91154, 0.86139, 0.76592])
elif trans.enable_cache == "tea":
trans.rel_l1_thresh = 0
model_def = get_model_def(model_type)
def_tea_coefficients = model_def.get("teacache_coefficients", None) if model_def != None else None
if def_tea_coefficients != None:
trans.coefficients = def_tea_coefficients
elif get_model_family(model_type) == "wan":
if i2v:
if '720p' in model_filename:
trans.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
else:
trans.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
else:
if '1.3B' in model_filename:
trans.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
elif '14B' in model_filename:
trans.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
else:
raise gr.Error("Teacache not supported for this model")
output_new_audio_data = None output_new_audio_data = None
output_new_audio_filepath = None output_new_audio_filepath = None
original_audio_guide = audio_guide original_audio_guide = audio_guide
@ -4502,12 +4501,13 @@ def generate_video(
progress_args = [0, merge_status_context(status, "Encoding Prompt")] progress_args = [0, merge_status_context(status, "Encoding Prompt")]
send_cmd("progress", progress_args) send_cmd("progress", progress_args)
if trans.enable_cache != None: if skip_steps_cache != None:
trans.num_steps = num_inference_steps skip_steps_cache.update({
trans.cache_skipped_steps = 0 "num_steps" : num_inference_steps,
trans.previous_residual = None "skipped_steps" : 0,
trans.previous_modulated_input = None "previous_residual": None,
"previous_modulated_input": None,
})
# samples = torch.empty( (1,2)) #for testing # samples = torch.empty( (1,2)) #for testing
# if False: # if False:
@ -4580,8 +4580,12 @@ def generate_video(
cleanup_temp_audio_files(control_audio_tracks + source_audio_tracks) cleanup_temp_audio_files(control_audio_tracks + source_audio_tracks)
remove_temp_filenames(temp_filenames_list) remove_temp_filenames(temp_filenames_list)
offloadobj.unload_all() offloadobj.unload_all()
trans.cache = None
offload.unload_loras_from_model(trans) offload.unload_loras_from_model(trans)
if trans is not None: offload.unload_loras_from_model(trans) if trans2 is not None:
trans2.cache = None
offload.unload_loras_from_model(trans2)
skip_steps_cache = None
# if compile: # if compile:
# cache_size = torch._dynamo.config.cache_size_limit # cache_size = torch._dynamo.config.cache_size_limit
# torch.compiler.reset() # torch.compiler.reset()
@ -4608,12 +4612,11 @@ def generate_video(
send_cmd("error", new_error) send_cmd("error", new_error)
clear_status(state) clear_status(state)
return return
finally:
trans.previous_residual = None
trans.previous_modulated_input = None
if trans.enable_cache != None : if skip_steps_cache != None :
print(f"Skipped Steps:{trans.cache_skipped_steps}/{trans.num_steps}" ) skip_steps_cache.previous_residual = None
skip_steps_cache.previous_modulated_input = None
print(f"Skipped Steps:{skip_steps_cache.skipped_steps}/{skip_steps_cache.num_steps}" )
if samples != None: if samples != None:
if isinstance(samples, dict): if isinstance(samples, dict):
@ -4797,8 +4800,10 @@ def generate_video(
seed = set_seed(-1) seed = set_seed(-1)
clear_status(state) clear_status(state)
trans.cache = None
offload.unload_loras_from_model(trans) offload.unload_loras_from_model(trans)
if not trans2 is None: if not trans2 is None:
trans2.cache = None
offload.unload_loras_from_model(trans2) offload.unload_loras_from_model(trans2)
if len(control_audio_tracks) > 0 or len(source_audio_tracks) > 0: if len(control_audio_tracks) > 0 or len(source_audio_tracks) > 0:
@ -5523,7 +5528,7 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None
if not model_def.get("embedded_guidance", False) or model_def.get("no_guidance", False): if not model_def.get("embedded_guidance", False) or model_def.get("no_guidance", False):
pop += ["embedded_guidance_scale"] pop += ["embedded_guidance_scale"]
if not model_def.get("skip_steps_cache", False) : if not (model_def.get("tea_cache", False) or model_def.get("mag_cache", False)) :
pop += ["skip_steps_cache_type", "skip_steps_multiplier", "skip_steps_start_step_perc"] pop += ["skip_steps_cache_type", "skip_steps_multiplier", "skip_steps_start_step_perc"]
if model_def.get("no_guidance", False) : if model_def.get("no_guidance", False) :
@ -6496,7 +6501,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
ltxv = "ltxv" in model_filename ltxv = "ltxv" in model_filename
lock_inference_steps = model_def.get("lock_inference_steps", False) lock_inference_steps = model_def.get("lock_inference_steps", False)
model_reference_image = model_def.get("reference_image", False) model_reference_image = model_def.get("reference_image", False)
any_steps_skipping = model_def.get("skip_steps_cache", False) any_tea_cache = model_def.get("tea_cache", False)
any_mag_cache = model_def.get("mag_cache", False)
recammaster = base_model_type in ["recam_1.3B"] recammaster = base_model_type in ["recam_1.3B"]
vace = test_vace_module(base_model_type) vace = test_vace_module(base_model_type)
phantom = base_model_type in ["phantom_1.3B", "phantom_14B"] phantom = base_model_type in ["phantom_1.3B", "phantom_14B"]
@ -6988,18 +6994,16 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
label="Activated Loras" label="Activated Loras"
) )
loras_multipliers = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by Space chars or CR, lines that start with # are ignored", value=launch_multis_str) loras_multipliers = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by Space chars or CR, lines that start with # are ignored", value=launch_multis_str)
with gr.Tab("Steps Skipping", visible = any_steps_skipping) as speed_tab: with gr.Tab("Steps Skipping", visible = any_tea_cache or any_mag_cache) as speed_tab:
with gr.Column(): with gr.Column():
gr.Markdown("<B>Tea Cache and Mag Cache accelerate the Video Generation by skipping intelligently some steps, the more steps are skipped the lower the quality of the video.</B>") gr.Markdown("<B>Tea Cache and Mag Cache accelerate the Video Generation by skipping intelligently some steps, the more steps are skipped the lower the quality of the video.</B>")
gr.Markdown("<B>Steps Skipping consumes also VRAM. It is recommended not to skip at least the first 10% steps.</B>") gr.Markdown("<B>Steps Skipping consumes also VRAM. It is recommended not to skip at least the first 10% steps.</B>")
steps_skipping_choices = [("None", "")]
if any_tea_cache: steps_skipping_choices += [("Tea Cache", "tea")]
if any_mag_cache: steps_skipping_choices += [("Mag Cache", "mag")]
skip_steps_cache_type = gr.Dropdown( skip_steps_cache_type = gr.Dropdown(
choices=[ choices= steps_skipping_choices,
("None", ""), value="" if not (any_tea_cache or any_mag_cache) else ui_defaults.get("skip_steps_cache_type",""),
("Tea Cache", "tea"),
("Mag Cache", "mag"),
],
value=ui_defaults.get("skip_steps_cache_type",""),
visible=True, visible=True,
label="Skip Steps Cache Type" label="Skip Steps Cache Type"
) )