diff --git a/README.md b/README.md
index e37c10d..e4480c2 100644
--- a/README.md
+++ b/README.md
@@ -19,6 +19,12 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
## 🔥 Latest News!!
+* Mar 19 2022: 👋 Wan2.1GP v3.2:
+ - Added Classifier-Free Guidance Zero Star. The video should match better the text prompt (especially with text2video) at no performance cost: many thanks to the **CFG Zero * Team:**\
+ Dont hesitate to give them a star if you appreciate the results: https://github.com/WeichenFan/CFG-Zero-star
+ - Added back support for Pytorch compilation with Loras. It seems it had been broken for some time
+ - Added possibility to keep a number of pregenerated videos in the Video Gallery (useful to compare outputs of different settings)
+ You will need one more *pip install -r requirements.txt*
* Mar 19 2022: 👋 Wan2.1GP v3.1: Faster launch and RAM optimizations (should require less RAM to run)\
You will need one more *pip install -r requirements.txt*
* Mar 18 2022: 👋 Wan2.1GP v3.0:
diff --git a/gradio_server.py b/gradio_server.py
index 088cf12..b429867 100644
--- a/gradio_server.py
+++ b/gradio_server.py
@@ -24,7 +24,7 @@ import asyncio
from wan.utils import prompt_parser
PROMPT_VARS_MAX = 10
-target_mmgp_version = "3.3.3"
+target_mmgp_version = "3.3.4"
from importlib.metadata import version
mmgp_version = version("mmgp")
if mmgp_version != target_mmgp_version:
@@ -98,6 +98,7 @@ def process_prompt_and_add_tasks(
tea_cache_start_step_perc,
loras_choices,
loras_mult_choices,
+ image_prompt_type,
image_to_continue,
image_to_end,
video_to_continue,
@@ -107,6 +108,8 @@ def process_prompt_and_add_tasks(
slg_layers,
slg_start,
slg_end,
+ cfg_star_switch,
+ cfg_zero_step,
state_arg,
image2video
):
@@ -138,6 +141,7 @@ def process_prompt_and_add_tasks(
tea_cache_start_step_perc,
loras_choices,
loras_mult_choices,
+ image_prompt_type,
image_to_continue,
image_to_end,
video_to_continue,
@@ -147,6 +151,8 @@ def process_prompt_and_add_tasks(
slg_layers,
slg_start,
slg_end,
+ cfg_star_switch,
+ cfg_zero_step,
state_arg,
image2video
)
@@ -380,7 +386,6 @@ def _parse_args():
default="",
help="Server name"
)
-
parser.add_argument(
"--gpu",
type=str,
@@ -482,7 +487,6 @@ def get_lora_dir(i2v):
attention_modes_installed = get_attention_modes()
attention_modes_supported = get_supported_attention_modes()
-
args = _parse_args()
args.flow_reverse = True
@@ -513,6 +517,7 @@ if not Path(server_config_filename).is_file():
"metadata_type": "metadata",
"default_ui": "t2v",
"boost" : 1,
+ "clear_file_list" : 0,
"vae_config": 0,
"profile" : profile_type.LowRAM_LowVRAM,
"reload_model": 2 }
@@ -596,7 +601,6 @@ if len(args.vae_config) > 0:
reload_needed = False
default_ui = server_config.get("default_ui", "t2v")
-metadata = server_config.get("metadata_type", "metadata")
save_path = server_config.get("save_path", os.path.join(os.getcwd(), "gradio_outputs"))
use_image2video = default_ui != "t2v"
if args.t2v:
@@ -956,6 +960,7 @@ def apply_changes( state,
metadata_choice,
default_ui_choice ="t2v",
boost_choice = 1,
+ clear_file_list = 0,
reload_choice = 1
):
if args.lock_config:
@@ -975,6 +980,7 @@ def apply_changes( state,
"metadata_choice": metadata_choice,
"default_ui" : default_ui_choice,
"boost" : boost_choice,
+ "clear_file_list" : clear_file_list
"reload_model" : reload_choice,
}
@@ -1008,7 +1014,7 @@ def apply_changes( state,
text_encoder_filename = server_config["text_encoder_filename"]
vae_config = server_config["vae_config"]
boost = server_config["boost"]
- if all(change in ["attention_mode", "vae_config", "default_ui", "boost", "save_path", "metadata_choice"] for change in changes ):
+ if all(change in ["attention_mode", "vae_config", "default_ui", "boost", "save_path", "metadata_choice", "clear_file_list"] for change in changes ):
pass
else:
reload_needed = True
@@ -1059,6 +1065,10 @@ def finalize_gallery(state):
if "in_progress" in state:
del state["in_progress"]
choice = state.get("selected",0)
+ # file_list = state.get("file_list", [])
+
+
+ state["extra_orders"] = 0
time.sleep(0.2)
global gen_in_progress
gen_in_progress = False
@@ -1096,6 +1106,7 @@ def generate_video(
tea_cache_start_step_perc,
loras_choices,
loras_mult_choices,
+ image_prompt_type,
image_to_continue,
image_to_end,
video_to_continue,
@@ -1104,7 +1115,9 @@ def generate_video(
slg_switch,
slg_layers,
slg_start,
- slg_end,
+ slg_end,
+ cfg_star_switch,
+ cfg_zero_step,
state,
image2video,
progress=gr.Progress() #track_tqdm= True
@@ -1243,7 +1256,6 @@ def generate_video(
if "abort" in state:
del state["abort"]
state["in_progress"] = True
- state["selected"] = 0
enable_RIFLEx = RIFLEx_setting == 0 and video_length > (6* 16) or RIFLEx_setting == 1
# VAE Tiling
@@ -1281,17 +1293,25 @@ def generate_video(
seed = random.randint(0, 999999999)
global file_list
- state["file_list"] = file_list
+ clear_file_list = server_config.get("clear_file_list", 0)
+ file_list = state.get("file_list", [])
+ if clear_file_list > 0:
+ file_list_current_size = len(file_list)
+ keep_file_from = max(file_list_current_size - clear_file_list, 0)
+ files_removed = keep_file_from
+ choice = state.get("selected",0)
+ choice = max(choice- files_removed, 0)
+ file_list = file_list[ keep_file_from: ]
+ else:
+ file_list = []
+ choice = 0
+ state["selected"] = choice
+ state["file_list"] = file_list
+
global save_path
os.makedirs(save_path, exist_ok=True)
- abort = False
- if trans.enable_teacache:
- trans.teacache_counter = 0
- trans.num_steps = num_inference_steps
- trans.teacache_skipped_steps = 0
- trans.previous_residual_uncond = None
- trans.previous_residual_cond = None
video_no = 0
+ abort = False
repeats = f"{video_no}/{repeat_generation}"
callback = build_callback(task_id, state, trans, num_inference_steps, repeats)
offload.shared_state["callback"] = callback
@@ -1301,14 +1321,22 @@ def generate_video(
for i in range(repeat_generation):
try:
with tracker_lock:
+ start_time = time.time()
progress_tracker[task_id] = {
'current_step': 0,
'total_steps': num_inference_steps,
- 'start_time': time.time(),
- 'last_update': time.time(),
+ 'start_time': start_time,
+ 'last_update': start_time,
'repeats': f"{video_no}/{repeat_generation}",
'status': "Encoding Prompt"
}
+ if trans.enable_teacache:
+ trans.teacache_counter = 0
+ trans.num_steps = num_inference_steps
+ trans.teacache_skipped_steps = 0
+ trans.previous_residual_uncond = None
+ trans.previous_residual_cond = None
+
video_no += 1
if image2video:
samples = wan_model.generate(
@@ -1330,6 +1358,8 @@ def generate_video(
slg_layers = slg_layers,
slg_start = slg_start/100,
slg_end = slg_end/100,
+ cfg_star_switch = cfg_star_switch,
+ cfg_zero_step = cfg_zero_step,
)
else:
samples = wan_model.generate(
@@ -1349,13 +1379,15 @@ def generate_video(
slg_layers = slg_layers,
slg_start = slg_start/100,
slg_end = slg_end/100,
+ cfg_star_switch = cfg_star_switch,
+ cfg_zero_step = cfg_zero_step,
)
except Exception as e:
gen_in_progress = False
if temp_filename!= None and os.path.isfile(temp_filename):
os.remove(temp_filename)
- if(offload.last_offload_obj): offload.last_offload_obj.unload_all()
- if(trans): offload.unload_loras_from_model(trans)
+ offload.last_offload_obj.unload_all()
+ offload.unload_loras_from_model(trans)
# if compile:
# cache_size = torch._dynamo.config.cache_size_limit
# torch.compiler.reset()
@@ -1399,6 +1431,7 @@ def generate_video(
end_time = time.time()
abort = True
state["prompt"] = ""
+ yield f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s"
else:
sample = samples.cpu()
# video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c")
@@ -1416,9 +1449,9 @@ def generate_video(
nrow=1,
normalize=True,
value_range=(-1, 1))
-
+
configs = get_settings_dict(state, use_image2video, prompt, 0 if image_to_end == None else 1 , video_length, raw_resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
- loras_mult_choices, tea_cache , tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start, slg_end)
+ loras_mult_choices, tea_cache , tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start, slg_end, cfg_star_switch, cfg_zero_step)
metadata_choice = server_config.get("metadata_choice","metadata")
if metadata_choice == "json":
@@ -1432,7 +1465,15 @@ def generate_video(
print(f"New video saved to Path: "+video_path)
file_list.append(video_path)
+ if video_no < total_video:
+ yield status
+ else:
+ end_time = time.time()
+ state["prompt"] = ""
+ yield f"Total Generation Time: {end_time-start_time:.1f}s"
seed += 1
+ repeat_no += 1
+
last_model_type = image2video
if temp_filename!= None and os.path.isfile(temp_filename):
@@ -1725,8 +1766,9 @@ def switch_advanced(state, new_advanced, lset_name):
else:
return gr.Row(visible=new_advanced), gr.Row(visible=True), gr.Button(visible=True), gr.Row(visible= False), gr.Dropdown(choices=lset_choices, value= lset_name)
+
def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
- loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc):
+ loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step):
loras = state["loras"]
activated_loras = [Path( loras[int(no)]).parts[-1] for no in loras_choices ]
@@ -1750,7 +1792,9 @@ def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resol
"slg_switch": slg_switch,
"slg_layers": slg_layers,
"slg_start_perc": slg_start_perc,
- "slg_end_perc": slg_end_perc
+ "slg_end_perc": slg_end_perc,
+ "cfg_star_switch": cfg_star_switch,
+ "cfg_zero_step": cfg_zero_step
}
if i2v:
@@ -1758,21 +1802,22 @@ def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resol
ui_settings["image_prompt_type"] = image_prompt_type
else:
ui_settings["type"] = "Wan2.1GP by DeepBeepMeep - text2video"
+
return ui_settings
def save_settings(state, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
- loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc):
+ loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step):
if state.get("validate_success",0) != 1:
return
ui_defaults = get_settings_dict(state, use_image2video, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
- loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc)
+ loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step)
defaults_filename = get_settings_file_name(use_image2video)
with open(defaults_filename, "w", encoding="utf-8") as f:
- json.dump(ui_defaults , f, indent=4)
+ json.dump(ui_defaults, f, indent=4)
gr.Info("New Default Settings saved")
@@ -1907,6 +1952,7 @@ def generate_video_tab(image2video=False):
return gr.Gallery(visible = (image_prompt_type_radio == 1) )
else:
return gr.Image(visible = (image_prompt_type_radio == 1) )
+
image_prompt_type_radio.change(fn=switch_image_prompt_type_radio, inputs=[image_prompt_type_radio], outputs=[image_to_end])
@@ -2037,7 +2083,7 @@ def generate_video_tab(image2video=False):
label="RIFLEx positional embedding to generate long video"
)
with gr.Row():
- gr.Markdown("Experimental: Skip Layer guidance,should improve video quality")
+ gr.Markdown("Experimental: Skip Layer Guidance, should improve video quality")
with gr.Row():
slg_switch = gr.Dropdown(
choices=[
@@ -2061,6 +2107,23 @@ def generate_video_tab(image2video=False):
with gr.Row():
slg_start_perc = gr.Slider(0, 100, value=ui_defaults["slg_start_perc"], step=1, label="Denoising Steps % start")
slg_end_perc = gr.Slider(0, 100, value=ui_defaults["slg_end_perc"], step=1, label="Denoising Steps % end")
+
+ with gr.Row():
+ gr.Markdown("Experimental: Classifier-Free Guidance Zero Star, better adherence to Text Prompt")
+ with gr.Row():
+ cfg_star_switch = gr.Dropdown(
+ choices=[
+ ("OFF", 0),
+ ("ON", 1),
+ ],
+ value=ui_defaults.get("cfg_star_switch",0),
+ visible=True,
+ scale = 1,
+ label="CFG Star"
+ )
+ with gr.Row():
+ cfg_zero_step = gr.Slider(-1, 39, value=ui_defaults.get("cfg_zero_step",-1), step=1, label="CFG Zero below this Layer (Extra Process)")
+
with gr.Row():
save_settings_btn = gr.Button("Set Settings as Default")
show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
@@ -2103,7 +2166,7 @@ def generate_video_tab(image2video=False):
save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
save_settings, inputs = [state, prompt, image_prompt_type_radio, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt,
loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers,
- slg_start_perc, slg_end_perc ], outputs = [])
+ slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step ], outputs = [])
save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
save_lset, inputs=[state, lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn,refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
@@ -2133,6 +2196,7 @@ def generate_video_tab(image2video=False):
tea_cache_start_step_perc,
loras_choices,
loras_mult_choices,
+ image_prompt_type_radio,
image_to_continue,
image_to_end,
video_to_continue,
@@ -2142,6 +2206,8 @@ def generate_video_tab(image2video=False):
slg_layers,
slg_start_perc,
slg_end_perc,
+ cfg_star_switch,
+ cfg_zero_step,
state,
gr.State(image2video)
]
@@ -2276,7 +2342,7 @@ def generate_configuration_tab():
("Add metadata to video", "metadata"),
("Neither", "none")
],
- value=metadata,
+ value=server_config.get("metadata_type", "metadata"),
label="Metadata Handling"
)
reload_choice = gr.Dropdown(
@@ -2287,6 +2353,21 @@ def generate_configuration_tab():
value=server_config.get("reload_model",2),
label="Reload model"
)
+
+ clear_file_list_choice = gr.Dropdown(
+ choices=[
+ ("None", 0),
+ ("Keep the last video", 1),
+ ("Keep the last 5 videos", 5),
+ ("Keep the last 10 videos", 10),
+ ("Keep the last 20 videos", 20),
+ ("Keep the last 30 videos", 30),
+ ],
+ value=server_config.get("clear_file_list", 0),
+ label="Keep Previously Generated Videos when starting a Generation Batch"
+ )
+
+
msg = gr.Markdown()
apply_btn = gr.Button("Apply Changes")
apply_btn.click(
@@ -2304,6 +2385,7 @@ def generate_configuration_tab():
metadata_choice,
default_ui_choice,
boost_choice,
+ clear_file_list_choice,
reload_choice,
],
outputs= msg
@@ -2322,6 +2404,9 @@ def generate_about_tab():
def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData):
global lora_model_filename, use_image2video
+ t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode)
+ i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode)
+
new_t2v = evt.index == 0
new_i2v = evt.index == 1
use_image2video = new_i2v
@@ -2341,9 +2426,6 @@ def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData):
wan_model, offloadobj, trans = load_models(use_image2video)
del trans
- t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode)
- i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode)
-
if new_t2v:
lora_model_filename = t2v_state["loras_model"]
if ("1.3B" in transformer_filename_t2v and not "1.3B" in lora_model_filename or "14B" in transformer_filename_t2v and not "14B" in lora_model_filename):
@@ -2470,7 +2552,7 @@ def create_demo():
}
"""
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md")) as demo:
- gr.Markdown("Wan 2.1GP v3.1 by DeepBeepMeep (Updates)
")
+ gr.Markdown("Wan 2.1GP v3.2 by DeepBeepMeep (Updates)
")
gr.Markdown("Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !")
with gr.Accordion("Click here for some Info on how to use Wan2GP", open = False):
diff --git a/requirements.txt b/requirements.txt
index bd928de..7576271 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -16,6 +16,6 @@ gradio>=5.0.0
numpy>=1.23.5,<2
einops
moviepy==1.0.3
-mmgp==3.3.3
+mmgp==3.3.4
peft==0.14.0
mutagen
\ No newline at end of file
diff --git a/wan/image2video.py b/wan/image2video.py
index a853665..e236bd7 100644
--- a/wan/image2video.py
+++ b/wan/image2video.py
@@ -28,79 +28,19 @@ from wan.modules.posemb_layers import get_rotary_pos_embed
from PIL import Image
-def lanczos(samples, width, height):
- images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
- images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
- images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
- result = torch.stack(images)
- return result.to(samples.device, samples.dtype)
+def optimized_scale(positive_flat, negative_flat):
-def bislerp(samples, width, height):
- def slerp(b1, b2, r):
- '''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
+ # Calculate dot production
+ dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
- c = b1.shape[-1]
+ # Squared norm of uncondition
+ squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
- #norms
- b1_norms = torch.norm(b1, dim=-1, keepdim=True)
- b2_norms = torch.norm(b2, dim=-1, keepdim=True)
-
- #normalize
- b1_normalized = b1 / b1_norms
- b2_normalized = b2 / b2_norms
-
- #zero when norms are zero
- b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0
- b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0
-
- #slerp
- dot = (b1_normalized*b2_normalized).sum(1)
- omega = torch.acos(dot)
- so = torch.sin(omega)
-
- #technically not mathematically correct, but more pleasing?
- res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized
- res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c)
-
- #edge cases for same or polar opposites
- res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
- res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
- return res
-
-
-def common_upscale(samples, width, height, upscale_method, crop):
- orig_shape = tuple(samples.shape)
- if len(orig_shape) > 4:
- samples = samples.reshape(samples.shape[0], samples.shape[1], -1, samples.shape[-2], samples.shape[-1])
- samples = samples.movedim(2, 1)
- samples = samples.reshape(-1, orig_shape[1], orig_shape[-2], orig_shape[-1])
- if crop == "center":
- old_width = samples.shape[-1]
- old_height = samples.shape[-2]
- old_aspect = old_width / old_height
- new_aspect = width / height
- x = 0
- y = 0
- if old_aspect > new_aspect:
- x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
- elif old_aspect < new_aspect:
- y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
- s = samples.narrow(-2, y, old_height - y * 2).narrow(-1, x, old_width - x * 2)
- else:
- s = samples
-
- if upscale_method == "bislerp":
- out = bislerp(s, width, height)
- elif upscale_method == "lanczos":
- out = lanczos(s, width, height)
- else:
- out = torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
-
- if len(orig_shape) == 4:
- return out
-
- out = out.reshape((orig_shape[0], -1, orig_shape[1]) + (height, width))
- return out.movedim(2, 1).reshape(orig_shape[:-2] + (height, width))
+ # st_star = v_cond^T * v_uncond / ||v_uncond||^2
+ st_star = dot_product / squared_norm
+
+ return st_star
+
class WanI2V:
@@ -227,6 +167,8 @@ class WanI2V:
slg_layers = None,
slg_start = 0.0,
slg_end = 1.0,
+ cfg_star_switch = True,
+ cfg_zero_step = 5,
):
r"""
Generates video frames from input image and text prompt using diffusion process.
@@ -375,7 +317,7 @@ class WanI2V:
# sample videos
latent = noise
-
+ batch_size = latent.shape[0]
freqs = get_rotary_pos_embed(latent.shape[1:], enable_RIFLEx= enable_RIFLEx)
arg_c = {
@@ -456,8 +398,23 @@ class WanI2V:
del latent_model_input
if offload_model:
torch.cuda.empty_cache()
- noise_pred = noise_pred_uncond + guide_scale * (
- noise_pred_cond - noise_pred_uncond)
+ # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
+ noise_pred_text = noise_pred_cond
+ if cfg_star_switch:
+ positive_flat = noise_pred_text.view(batch_size, -1)
+ negative_flat = noise_pred_uncond.view(batch_size, -1)
+
+ alpha = optimized_scale(positive_flat,negative_flat)
+ alpha = alpha.view(batch_size, 1, 1, 1)
+
+
+ if (i <= cfg_zero_step):
+ noise_pred = noise_pred_text*0.
+ else:
+ noise_pred = noise_pred_uncond * alpha + guide_scale * (noise_pred_text - noise_pred_uncond * alpha)
+ else:
+ noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond)
+
del noise_pred_uncond
latent = latent.to(
diff --git a/wan/modules/attention.py b/wan/modules/attention.py
index 6861283..e19e387 100644
--- a/wan/modules/attention.py
+++ b/wan/modules/attention.py
@@ -70,30 +70,31 @@ def sageattn_wrapper(
return o
-# # try:
+# try:
# if True:
-# from sageattention import sageattn_qk_int8_pv_fp8_window_cuda
-# @torch.compiler.disable()
-# def sageattn_window_wrapper(
-# qkv_list,
-# attention_length,
-# window
-# ):
-# q,k, v = qkv_list
-# padding_length = q.shape[0] -attention_length
-# q = q[:attention_length, :, : ].unsqueeze(0)
-# k = k[:attention_length, :, : ].unsqueeze(0)
-# v = v[:attention_length, :, : ].unsqueeze(0)
-# o = sageattn_qk_int8_pv_fp8_window_cuda(q, k, v, tensor_layout="NHD", window = window).squeeze(0)
-# del q, k ,v
-# qkv_list.clear()
+ # from .sage2_core import sageattn_qk_int8_pv_fp8_window_cuda
+ # @torch.compiler.disable()
+ # def sageattn_window_wrapper(
+ # qkv_list,
+ # attention_length,
+ # window
+ # ):
+ # q,k, v = qkv_list
+ # padding_length = q.shape[0] -attention_length
+ # q = q[:attention_length, :, : ].unsqueeze(0)
+ # k = k[:attention_length, :, : ].unsqueeze(0)
+ # v = v[:attention_length, :, : ].unsqueeze(0)
+ # qkvl_list = [q, k , v]
+ # del q, k ,v
+ # o = sageattn_qk_int8_pv_fp8_window_cuda(qkvl_list, tensor_layout="NHD", window = window).squeeze(0)
+ # qkv_list.clear()
-# if padding_length > 0:
-# o = torch.cat([o, torch.empty( (padding_length, *o.shape[-2:]), dtype= o.dtype, device=o.device ) ], 0)
+ # if padding_length > 0:
+ # o = torch.cat([o, torch.empty( (padding_length, *o.shape[-2:]), dtype= o.dtype, device=o.device ) ], 0)
-# return o
-# # except ImportError:
-# # sageattn = sageattn_qk_int8_pv_fp8_window_cuda
+ # return o
+# except ImportError:
+# sageattn = sageattn_qk_int8_pv_fp8_window_cuda
@torch.compiler.disable()
def sdpa_wrapper(
@@ -253,17 +254,19 @@ def pay_attention(
# nb_latents = embed_sizes[0] * embed_sizes[1]* embed_sizes[2]
# window = 0
- # start_window_step = int(max_steps * 0.4)
+ # start_window_step = int(max_steps * 0.3)
# start_layer = 10
- # if (layer < start_layer ) or current_step end_layer ) or current_step 0
-
+ # invert_spaces = False
# def flip(q):
# q = q.reshape(*embed_sizes, *q.shape[-2:])
# q = q.transpose(0,2)
diff --git a/wan/modules/model.py b/wan/modules/model.py
index 8c7dfdf..3e2ea5c 100644
--- a/wan/modules/model.py
+++ b/wan/modules/model.py
@@ -647,26 +647,6 @@ class WanModel(ModelMixin, ConfigMixin):
self.init_weights()
- # self.freqs = torch.cat([
- # rope_params(1024, d - 4 * (d // 6)), #44
- # rope_params(1024, 2 * (d // 6)), #42
- # rope_params(1024, 2 * (d // 6)) #42
- # ],dim=1)
-
-
- def get_rope_freqs(self, nb_latent_frames, RIFLEx_k = None, device = "cuda"):
- dim = self.dim
- num_heads = self.num_heads
- d = dim // num_heads
- assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
-
-
- c1, s1 = rope_params_riflex(1024, dim= d - 4 * (d // 6), L_test=nb_latent_frames, k = RIFLEx_k ) if RIFLEx_k != None else rope_params(1024, dim= d - 4 * (d // 6)) #44
- c2, s2 = rope_params(1024, 2 * (d // 6)) #42
- c3, s3 = rope_params(1024, 2 * (d // 6)) #42
-
- return (torch.cat([c1,c2,c3],dim=1).to(device) , torch.cat([s1,s2,s3],dim=1).to(device))
-
def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
rescale_func = np.poly1d(self.coefficients)
e_list = []
diff --git a/wan/modules/sage2_core.py b/wan/modules/sage2_core.py
index de94a60..f6d23ee 100644
--- a/wan/modules/sage2_core.py
+++ b/wan/modules/sage2_core.py
@@ -925,11 +925,11 @@ def sageattn_qk_int8_pv_fp8_window_cuda(
if pv_accum_dtype == "fp32":
if smooth_v:
- lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window = window)
+ lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window)
else:
- lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window = window)
+ lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window)
elif pv_accum_dtype == "fp32+fp32":
- lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window = window)
+ lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window)
o = o[..., :head_dim_og]
diff --git a/wan/text2video.py b/wan/text2video.py
index 88046db..379954f 100644
--- a/wan/text2video.py
+++ b/wan/text2video.py
@@ -24,6 +24,20 @@ from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from wan.modules.posemb_layers import get_rotary_pos_embed
+def optimized_scale(positive_flat, negative_flat):
+
+ # Calculate dot production
+ dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
+
+ # Squared norm of uncondition
+ squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
+
+ # st_star = v_cond^T * v_uncond / ||v_uncond||^2
+ st_star = dot_product / squared_norm
+
+ return st_star
+
+
class WanT2V:
def __init__(
@@ -136,6 +150,8 @@ class WanT2V:
slg_layers = None,
slg_start = 0.0,
slg_end = 1.0,
+ cfg_star_switch = True,
+ cfg_zero_step = 5,
):
r"""
Generates video frames from text prompt using diffusion process.
@@ -240,7 +256,7 @@ class WanT2V:
# sample videos
latents = noise
-
+ batch_size =len(latents)
freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
@@ -249,7 +265,6 @@ class WanT2V:
# arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
# arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
# arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
-
if self.model.enable_teacache:
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
if callback != None:
@@ -280,8 +295,23 @@ class WanT2V:
return None
del latent_model_input
- noise_pred = noise_pred_uncond + guide_scale * (
- noise_pred_cond - noise_pred_uncond)
+
+ # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
+ noise_pred_text = noise_pred_cond
+ if cfg_star_switch:
+ positive_flat = noise_pred_text.view(batch_size, -1)
+ negative_flat = noise_pred_uncond.view(batch_size, -1)
+
+ alpha = optimized_scale(positive_flat,negative_flat)
+ alpha = alpha.view(batch_size, 1, 1, 1)
+
+
+ if (i <= cfg_zero_step):
+ noise_pred = noise_pred_text*0.
+ else:
+ noise_pred = noise_pred_uncond * alpha + guide_scale * (noise_pred_text - noise_pred_uncond * alpha)
+ else:
+ noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond)
del noise_pred_uncond
temp_x0 = sample_scheduler.step(