Added CFG Zero *

This commit is contained in:
DeepBeepMeep 2025-03-26 00:47:10 +01:00
parent 91e35a806c
commit 826d5ac84f
8 changed files with 186 additions and 143 deletions

View File

@ -19,6 +19,12 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
## 🔥 Latest News!!
* Mar 19 2022: 👋 Wan2.1GP v3.2:
- Added Classifier-Free Guidance Zero Star. The video should match better the text prompt (especially with text2video) at no performance cost: many thanks to the **CFG Zero * Team:**\
Dont hesitate to give them a star if you appreciate the results: https://github.com/WeichenFan/CFG-Zero-star
- Added back support for Pytorch compilation with Loras. It seems it had been broken for some time
- Added possibility to keep a number of pregenerated videos in the Video Gallery (useful to compare outputs of different settings)
You will need one more *pip install -r requirements.txt*
* Mar 19 2022: 👋 Wan2.1GP v3.1: Faster launch and RAM optimizations (should require less RAM to run)\
You will need one more *pip install -r requirements.txt*
* Mar 18 2022: 👋 Wan2.1GP v3.0:

View File

@ -23,7 +23,7 @@ import asyncio
from wan.utils import prompt_parser
PROMPT_VARS_MAX = 10
target_mmgp_version = "3.3.3"
target_mmgp_version = "3.3.4"
from importlib.metadata import version
mmgp_version = version("mmgp")
if mmgp_version != target_mmgp_version:
@ -300,6 +300,7 @@ if not Path(server_config_filename).is_file():
"metadata_type": "metadata",
"default_ui": "t2v",
"boost" : 1,
"clear_file_list" : 0,
"vae_config": 0,
"profile" : profile_type.LowRAM_LowVRAM }
@ -382,7 +383,6 @@ if len(args.vae_config) > 0:
reload_needed = False
default_ui = server_config.get("default_ui", "t2v")
metadata = server_config.get("metadata_type", "metadata")
save_path = server_config.get("save_path", os.path.join(os.getcwd(), "gradio_outputs"))
use_image2video = default_ui != "t2v"
if args.t2v:
@ -741,7 +741,8 @@ def apply_changes( state,
vae_config_choice,
metadata_choice,
default_ui_choice ="t2v",
boost_choice = 1
boost_choice = 1,
clear_file_list = 0,
):
if args.lock_config:
return
@ -760,6 +761,7 @@ def apply_changes( state,
"metadata_choice": metadata_choice,
"default_ui" : default_ui_choice,
"boost" : boost_choice,
"clear_file_list" : clear_file_list
}
if Path(server_config_filename).is_file():
@ -792,7 +794,7 @@ def apply_changes( state,
text_encoder_filename = server_config["text_encoder_filename"]
vae_config = server_config["vae_config"]
boost = server_config["boost"]
if all(change in ["attention_mode", "vae_config", "default_ui", "boost", "save_path", "metadata_choice"] for change in changes ):
if all(change in ["attention_mode", "vae_config", "default_ui", "boost", "save_path", "metadata_choice", "clear_file_list"] for change in changes ):
pass
else:
reload_needed = True
@ -849,13 +851,17 @@ def refresh_gallery(state, txt):
if len(prompt) == 0:
return file_list, gr.Text(visible= False, value="")
else:
choice = 0
if "in_progress" in state:
choice = state.get("selected",0)
prompts_max = state.get("prompts_max",0)
prompt_no = state.get("prompt_no",0)
if prompts_max >1 :
label = f"Current Prompt ({prompt_no+1}/{prompts_max})"
else:
label = f"Current Prompt"
return file_list, gr.Text(visible= True, value=prompt, label=label)
return gr.Gallery(selected_index=choice, value = file_list), gr.Text(visible= True, value=prompt, label=label)
def finalize_gallery(state):
@ -863,6 +869,8 @@ def finalize_gallery(state):
if "in_progress" in state:
del state["in_progress"]
choice = state.get("selected",0)
# file_list = state.get("file_list", [])
state["extra_orders"] = 0
time.sleep(0.2)
@ -930,6 +938,7 @@ def generate_video(
tea_cache_start_step_perc,
loras_choices,
loras_mult_choices,
image_prompt_type,
image_to_continue,
image_to_end,
video_to_continue,
@ -938,7 +947,9 @@ def generate_video(
slg_switch,
slg_layers,
slg_start,
slg_end,
slg_end,
cfg_star_switch,
cfg_zero_step,
state,
image2video,
progress=gr.Progress() #track_tqdm= True
@ -1031,6 +1042,8 @@ def generate_video(
if len(prompts) ==0:
return
if image2video:
if image_prompt_type == 0:
image_to_end = None
if image_to_continue is not None:
if isinstance(image_to_continue, list):
image_to_continue = [ tup[0] for tup in image_to_continue ]
@ -1135,7 +1148,6 @@ def generate_video(
if "abort" in state:
del state["abort"]
state["in_progress"] = True
state["selected"] = 0
enable_RIFLEx = RIFLEx_setting == 0 and video_length > (6* 16) or RIFLEx_setting == 1
# VAE Tiling
@ -1172,8 +1184,21 @@ def generate_video(
if seed == None or seed <0:
seed = random.randint(0, 999999999)
file_list = []
clear_file_list = server_config.get("clear_file_list", 0)
file_list = state.get("file_list", [])
if clear_file_list > 0:
file_list_current_size = len(file_list)
keep_file_from = max(file_list_current_size - clear_file_list, 0)
files_removed = keep_file_from
choice = state.get("selected",0)
choice = max(choice- files_removed, 0)
file_list = file_list[ keep_file_from: ]
else:
file_list = []
choice = 0
state["selected"] = choice
state["file_list"] = file_list
global save_path
os.makedirs(save_path, exist_ok=True)
video_no = 0
@ -1240,6 +1265,8 @@ def generate_video(
slg_layers = slg_layers,
slg_start = slg_start/100,
slg_end = slg_end/100,
cfg_star_switch = cfg_star_switch,
cfg_zero_step = cfg_zero_step,
)
else:
@ -1260,6 +1287,8 @@ def generate_video(
slg_layers = slg_layers,
slg_start = slg_start/100,
slg_end = slg_end/100,
cfg_star_switch = cfg_star_switch,
cfg_zero_step = cfg_zero_step,
)
except Exception as e:
gen_in_progress = False
@ -1326,7 +1355,7 @@ def generate_video(
value_range=(-1, 1))
configs = get_settings_dict(state, use_image2video, prompt, 0 if image_to_end == None else 1 , video_length, raw_resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
loras_mult_choices, tea_cache , tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start, slg_end)
loras_mult_choices, tea_cache , tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start, slg_end, cfg_star_switch, cfg_zero_step)
metadata_choice = server_config.get("metadata_choice","metadata")
if metadata_choice == "json":
@ -1642,7 +1671,7 @@ def switch_advanced(state, new_advanced, lset_name):
def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc):
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step):
loras = state["loras"]
activated_loras = [Path( loras[int(no)]).parts[-1] for no in loras_choices ]
@ -1666,7 +1695,9 @@ def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resol
"slg_switch": slg_switch,
"slg_layers": slg_layers,
"slg_start_perc": slg_start_perc,
"slg_end_perc": slg_end_perc
"slg_end_perc": slg_end_perc,
"cfg_star_switch": cfg_star_switch,
"cfg_zero_step": cfg_zero_step
}
if i2v:
@ -1678,13 +1709,13 @@ def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resol
return ui_settings
def save_settings(state, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc):
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step):
if state.get("validate_success",0) != 1:
return
ui_defaults = get_settings_dict(state, use_image2video, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc)
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step)
defaults_filename = get_settings_file_name(use_image2video)
@ -1955,7 +1986,7 @@ def generate_video_tab(image2video=False):
label="RIFLEx positional embedding to generate long video"
)
with gr.Row():
gr.Markdown("<B>Experimental: Skip Layer guidance,should improve video quality</B>")
gr.Markdown("<B>Experimental: Skip Layer Guidance, should improve video quality</B>")
with gr.Row():
slg_switch = gr.Dropdown(
choices=[
@ -1979,6 +2010,23 @@ def generate_video_tab(image2video=False):
with gr.Row():
slg_start_perc = gr.Slider(0, 100, value=ui_defaults["slg_start_perc"], step=1, label="Denoising Steps % start")
slg_end_perc = gr.Slider(0, 100, value=ui_defaults["slg_end_perc"], step=1, label="Denoising Steps % end")
with gr.Row():
gr.Markdown("<B>Experimental: Classifier-Free Guidance Zero Star, better adherence to Text Prompt")
with gr.Row():
cfg_star_switch = gr.Dropdown(
choices=[
("OFF", 0),
("ON", 1),
],
value=ui_defaults.get("cfg_star_switch",0),
visible=True,
scale = 1,
label="CFG Star"
)
with gr.Row():
cfg_zero_step = gr.Slider(-1, 39, value=ui_defaults.get("cfg_zero_step",-1), step=1, label="CFG Zero below this Layer (Extra Process)")
with gr.Row():
save_settings_btn = gr.Button("Set Settings as Default")
show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
@ -1997,7 +2045,7 @@ def generate_video_tab(image2video=False):
save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
save_settings, inputs = [state, prompt, image_prompt_type_radio, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt,
loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers,
slg_start_perc, slg_end_perc ], outputs = [])
slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step ], outputs = [])
save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
save_lset, inputs=[state, lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn,refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
@ -2035,6 +2083,7 @@ def generate_video_tab(image2video=False):
tea_cache_start_step_perc,
loras_choices,
loras_mult_choices,
image_prompt_type_radio,
image_to_continue,
image_to_end,
video_to_continue,
@ -2044,6 +2093,8 @@ def generate_video_tab(image2video=False):
slg_layers,
slg_start_perc,
slg_end_perc,
cfg_star_switch,
cfg_zero_step,
state,
gr.State(image2video)
],
@ -2175,9 +2226,24 @@ def generate_configuration_tab():
("Add metadata to video", "metadata"),
("Neither", "none")
],
value=metadata,
value=server_config.get("metadata_type", "metadata"),
label="Metadata Handling"
)
clear_file_list_choice = gr.Dropdown(
choices=[
("None", 0),
("Keep the last video", 1),
("Keep the last 5 videos", 5),
("Keep the last 10 videos", 10),
("Keep the last 20 videos", 20),
("Keep the last 30 videos", 30),
],
value=server_config.get("clear_file_list", 0),
label="Keep Previously Generated Videos when starting a Generation Batch"
)
msg = gr.Markdown()
apply_btn = gr.Button("Apply Changes")
apply_btn.click(
@ -2195,6 +2261,7 @@ def generate_configuration_tab():
metadata_choice,
default_ui_choice,
boost_choice,
clear_file_list_choice,
],
outputs= msg
)
@ -2262,7 +2329,7 @@ def create_demo():
}
"""
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md")) as demo:
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v3.1 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v3.2 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !</FONT>")
with gr.Accordion("Click here for some Info on how to use Wan2GP", open = False):

View File

@ -16,6 +16,6 @@ gradio>=5.0.0
numpy>=1.23.5,<2
einops
moviepy==1.0.3
mmgp==3.3.3
mmgp==3.3.4
peft==0.14.0
mutagen

View File

@ -28,79 +28,19 @@ from wan.modules.posemb_layers import get_rotary_pos_embed
from PIL import Image
def lanczos(samples, width, height):
images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
result = torch.stack(images)
return result.to(samples.device, samples.dtype)
def optimized_scale(positive_flat, negative_flat):
def bislerp(samples, width, height):
def slerp(b1, b2, r):
'''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
# Calculate dot production
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
c = b1.shape[-1]
# Squared norm of uncondition
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
#norms
b1_norms = torch.norm(b1, dim=-1, keepdim=True)
b2_norms = torch.norm(b2, dim=-1, keepdim=True)
#normalize
b1_normalized = b1 / b1_norms
b2_normalized = b2 / b2_norms
#zero when norms are zero
b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0
b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0
#slerp
dot = (b1_normalized*b2_normalized).sum(1)
omega = torch.acos(dot)
so = torch.sin(omega)
#technically not mathematically correct, but more pleasing?
res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized
res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c)
#edge cases for same or polar opposites
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
return res
def common_upscale(samples, width, height, upscale_method, crop):
orig_shape = tuple(samples.shape)
if len(orig_shape) > 4:
samples = samples.reshape(samples.shape[0], samples.shape[1], -1, samples.shape[-2], samples.shape[-1])
samples = samples.movedim(2, 1)
samples = samples.reshape(-1, orig_shape[1], orig_shape[-2], orig_shape[-1])
if crop == "center":
old_width = samples.shape[-1]
old_height = samples.shape[-2]
old_aspect = old_width / old_height
new_aspect = width / height
x = 0
y = 0
if old_aspect > new_aspect:
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
elif old_aspect < new_aspect:
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
s = samples.narrow(-2, y, old_height - y * 2).narrow(-1, x, old_width - x * 2)
else:
s = samples
if upscale_method == "bislerp":
out = bislerp(s, width, height)
elif upscale_method == "lanczos":
out = lanczos(s, width, height)
else:
out = torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
if len(orig_shape) == 4:
return out
out = out.reshape((orig_shape[0], -1, orig_shape[1]) + (height, width))
return out.movedim(2, 1).reshape(orig_shape[:-2] + (height, width))
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
st_star = dot_product / squared_norm
return st_star
class WanI2V:
@ -227,6 +167,8 @@ class WanI2V:
slg_layers = None,
slg_start = 0.0,
slg_end = 1.0,
cfg_star_switch = True,
cfg_zero_step = 5,
):
r"""
Generates video frames from input image and text prompt using diffusion process.
@ -375,7 +317,7 @@ class WanI2V:
# sample videos
latent = noise
batch_size = latent.shape[0]
freqs = get_rotary_pos_embed(latent.shape[1:], enable_RIFLEx= enable_RIFLEx)
arg_c = {
@ -456,8 +398,23 @@ class WanI2V:
del latent_model_input
if offload_model:
torch.cuda.empty_cache()
noise_pred = noise_pred_uncond + guide_scale * (
noise_pred_cond - noise_pred_uncond)
# CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
noise_pred_text = noise_pred_cond
if cfg_star_switch:
positive_flat = noise_pred_text.view(batch_size, -1)
negative_flat = noise_pred_uncond.view(batch_size, -1)
alpha = optimized_scale(positive_flat,negative_flat)
alpha = alpha.view(batch_size, 1, 1, 1)
if (i <= cfg_zero_step):
noise_pred = noise_pred_text*0.
else:
noise_pred = noise_pred_uncond * alpha + guide_scale * (noise_pred_text - noise_pred_uncond * alpha)
else:
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond)
del noise_pred_uncond
latent = latent.to(

View File

@ -70,30 +70,31 @@ def sageattn_wrapper(
return o
# # try:
# try:
# if True:
# from sageattention import sageattn_qk_int8_pv_fp8_window_cuda
# @torch.compiler.disable()
# def sageattn_window_wrapper(
# qkv_list,
# attention_length,
# window
# ):
# q,k, v = qkv_list
# padding_length = q.shape[0] -attention_length
# q = q[:attention_length, :, : ].unsqueeze(0)
# k = k[:attention_length, :, : ].unsqueeze(0)
# v = v[:attention_length, :, : ].unsqueeze(0)
# o = sageattn_qk_int8_pv_fp8_window_cuda(q, k, v, tensor_layout="NHD", window = window).squeeze(0)
# del q, k ,v
# qkv_list.clear()
# from .sage2_core import sageattn_qk_int8_pv_fp8_window_cuda
# @torch.compiler.disable()
# def sageattn_window_wrapper(
# qkv_list,
# attention_length,
# window
# ):
# q,k, v = qkv_list
# padding_length = q.shape[0] -attention_length
# q = q[:attention_length, :, : ].unsqueeze(0)
# k = k[:attention_length, :, : ].unsqueeze(0)
# v = v[:attention_length, :, : ].unsqueeze(0)
# qkvl_list = [q, k , v]
# del q, k ,v
# o = sageattn_qk_int8_pv_fp8_window_cuda(qkvl_list, tensor_layout="NHD", window = window).squeeze(0)
# qkv_list.clear()
# if padding_length > 0:
# o = torch.cat([o, torch.empty( (padding_length, *o.shape[-2:]), dtype= o.dtype, device=o.device ) ], 0)
# if padding_length > 0:
# o = torch.cat([o, torch.empty( (padding_length, *o.shape[-2:]), dtype= o.dtype, device=o.device ) ], 0)
# return o
# # except ImportError:
# # sageattn = sageattn_qk_int8_pv_fp8_window_cuda
# return o
# except ImportError:
# sageattn = sageattn_qk_int8_pv_fp8_window_cuda
@torch.compiler.disable()
def sdpa_wrapper(
@ -253,17 +254,19 @@ def pay_attention(
# nb_latents = embed_sizes[0] * embed_sizes[1]* embed_sizes[2]
# window = 0
# start_window_step = int(max_steps * 0.4)
# start_window_step = int(max_steps * 0.3)
# start_layer = 10
# if (layer < start_layer ) or current_step <start_window_step:
# end_layer = 30
# if (layer < start_layer or layer > end_layer ) or current_step <start_window_step:
# window = 0
# else:
# coef = min((max_steps - current_step)/(max_steps-start_window_step),1)*max(min((25 - layer)/(25-start_layer),1),0) * 0.7 + 0.3
# # coef = min((max_steps - current_step)/(max_steps-start_window_step),1)*max(min((25 - layer)/(25-start_layer),1),0) * 0.7 + 0.3
# coef = 0.3
# print(f"step: {current_step}, layer: {layer}, coef:{coef:0.1f}]")
# window = math.ceil(coef* nb_latents)
# invert_spaces = (layer + current_step) % 2 == 0 and window > 0
# invert_spaces = False
# def flip(q):
# q = q.reshape(*embed_sizes, *q.shape[-2:])
# q = q.transpose(0,2)

View File

@ -647,26 +647,6 @@ class WanModel(ModelMixin, ConfigMixin):
self.init_weights()
# self.freqs = torch.cat([
# rope_params(1024, d - 4 * (d // 6)), #44
# rope_params(1024, 2 * (d // 6)), #42
# rope_params(1024, 2 * (d // 6)) #42
# ],dim=1)
def get_rope_freqs(self, nb_latent_frames, RIFLEx_k = None, device = "cuda"):
dim = self.dim
num_heads = self.num_heads
d = dim // num_heads
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
c1, s1 = rope_params_riflex(1024, dim= d - 4 * (d // 6), L_test=nb_latent_frames, k = RIFLEx_k ) if RIFLEx_k != None else rope_params(1024, dim= d - 4 * (d // 6)) #44
c2, s2 = rope_params(1024, 2 * (d // 6)) #42
c3, s3 = rope_params(1024, 2 * (d // 6)) #42
return (torch.cat([c1,c2,c3],dim=1).to(device) , torch.cat([s1,s2,s3],dim=1).to(device))
def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
rescale_func = np.poly1d(self.coefficients)
e_list = []

View File

@ -925,11 +925,11 @@ def sageattn_qk_int8_pv_fp8_window_cuda(
if pv_accum_dtype == "fp32":
if smooth_v:
lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window = window)
lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window)
else:
lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window = window)
lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window)
elif pv_accum_dtype == "fp32+fp32":
lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window = window)
lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window)
o = o[..., :head_dim_og]

View File

@ -24,6 +24,20 @@ from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from wan.modules.posemb_layers import get_rotary_pos_embed
def optimized_scale(positive_flat, negative_flat):
# Calculate dot production
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
# Squared norm of uncondition
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
st_star = dot_product / squared_norm
return st_star
class WanT2V:
def __init__(
@ -136,6 +150,8 @@ class WanT2V:
slg_layers = None,
slg_start = 0.0,
slg_end = 1.0,
cfg_star_switch = True,
cfg_zero_step = 5,
):
r"""
Generates video frames from text prompt using diffusion process.
@ -240,7 +256,7 @@ class WanT2V:
# sample videos
latents = noise
batch_size =len(latents)
freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
@ -249,7 +265,6 @@ class WanT2V:
# arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
# arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
# arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
if self.model.enable_teacache:
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
if callback != None:
@ -280,8 +295,23 @@ class WanT2V:
return None
del latent_model_input
noise_pred = noise_pred_uncond + guide_scale * (
noise_pred_cond - noise_pred_uncond)
# CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
noise_pred_text = noise_pred_cond
if cfg_star_switch:
positive_flat = noise_pred_text.view(batch_size, -1)
negative_flat = noise_pred_uncond.view(batch_size, -1)
alpha = optimized_scale(positive_flat,negative_flat)
alpha = alpha.view(batch_size, 1, 1, 1)
if (i <= cfg_zero_step):
noise_pred = noise_pred_text*0.
else:
noise_pred = noise_pred_uncond * alpha + guide_scale * (noise_pred_text - noise_pred_uncond * alpha)
else:
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond)
del noise_pred_uncond
temp_x0 = sample_scheduler.step(