diff --git a/README.md b/README.md index e37c10d..e4480c2 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,12 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid ## 🔥 Latest News!! +* Mar 19 2022: 👋 Wan2.1GP v3.2: + - Added Classifier-Free Guidance Zero Star. The video should match better the text prompt (especially with text2video) at no performance cost: many thanks to the **CFG Zero * Team:**\ + Dont hesitate to give them a star if you appreciate the results: https://github.com/WeichenFan/CFG-Zero-star + - Added back support for Pytorch compilation with Loras. It seems it had been broken for some time + - Added possibility to keep a number of pregenerated videos in the Video Gallery (useful to compare outputs of different settings) + You will need one more *pip install -r requirements.txt* * Mar 19 2022: 👋 Wan2.1GP v3.1: Faster launch and RAM optimizations (should require less RAM to run)\ You will need one more *pip install -r requirements.txt* * Mar 18 2022: 👋 Wan2.1GP v3.0: diff --git a/gradio_server.py b/gradio_server.py index d999fd0..78076a2 100644 --- a/gradio_server.py +++ b/gradio_server.py @@ -23,7 +23,7 @@ import asyncio from wan.utils import prompt_parser PROMPT_VARS_MAX = 10 -target_mmgp_version = "3.3.3" +target_mmgp_version = "3.3.4" from importlib.metadata import version mmgp_version = version("mmgp") if mmgp_version != target_mmgp_version: @@ -300,6 +300,7 @@ if not Path(server_config_filename).is_file(): "metadata_type": "metadata", "default_ui": "t2v", "boost" : 1, + "clear_file_list" : 0, "vae_config": 0, "profile" : profile_type.LowRAM_LowVRAM } @@ -382,7 +383,6 @@ if len(args.vae_config) > 0: reload_needed = False default_ui = server_config.get("default_ui", "t2v") -metadata = server_config.get("metadata_type", "metadata") save_path = server_config.get("save_path", os.path.join(os.getcwd(), "gradio_outputs")) use_image2video = default_ui != "t2v" if args.t2v: @@ -741,7 +741,8 @@ def apply_changes( state, vae_config_choice, metadata_choice, default_ui_choice ="t2v", - boost_choice = 1 + boost_choice = 1, + clear_file_list = 0, ): if args.lock_config: return @@ -760,6 +761,7 @@ def apply_changes( state, "metadata_choice": metadata_choice, "default_ui" : default_ui_choice, "boost" : boost_choice, + "clear_file_list" : clear_file_list } if Path(server_config_filename).is_file(): @@ -792,7 +794,7 @@ def apply_changes( state, text_encoder_filename = server_config["text_encoder_filename"] vae_config = server_config["vae_config"] boost = server_config["boost"] - if all(change in ["attention_mode", "vae_config", "default_ui", "boost", "save_path", "metadata_choice"] for change in changes ): + if all(change in ["attention_mode", "vae_config", "default_ui", "boost", "save_path", "metadata_choice", "clear_file_list"] for change in changes ): pass else: reload_needed = True @@ -849,13 +851,17 @@ def refresh_gallery(state, txt): if len(prompt) == 0: return file_list, gr.Text(visible= False, value="") else: + choice = 0 + if "in_progress" in state: + choice = state.get("selected",0) + prompts_max = state.get("prompts_max",0) prompt_no = state.get("prompt_no",0) if prompts_max >1 : label = f"Current Prompt ({prompt_no+1}/{prompts_max})" else: label = f"Current Prompt" - return file_list, gr.Text(visible= True, value=prompt, label=label) + return gr.Gallery(selected_index=choice, value = file_list), gr.Text(visible= True, value=prompt, label=label) def finalize_gallery(state): @@ -863,6 +869,8 @@ def finalize_gallery(state): if "in_progress" in state: del state["in_progress"] choice = state.get("selected",0) + # file_list = state.get("file_list", []) + state["extra_orders"] = 0 time.sleep(0.2) @@ -930,6 +938,7 @@ def generate_video( tea_cache_start_step_perc, loras_choices, loras_mult_choices, + image_prompt_type, image_to_continue, image_to_end, video_to_continue, @@ -938,7 +947,9 @@ def generate_video( slg_switch, slg_layers, slg_start, - slg_end, + slg_end, + cfg_star_switch, + cfg_zero_step, state, image2video, progress=gr.Progress() #track_tqdm= True @@ -1031,6 +1042,8 @@ def generate_video( if len(prompts) ==0: return if image2video: + if image_prompt_type == 0: + image_to_end = None if image_to_continue is not None: if isinstance(image_to_continue, list): image_to_continue = [ tup[0] for tup in image_to_continue ] @@ -1135,7 +1148,6 @@ def generate_video( if "abort" in state: del state["abort"] state["in_progress"] = True - state["selected"] = 0 enable_RIFLEx = RIFLEx_setting == 0 and video_length > (6* 16) or RIFLEx_setting == 1 # VAE Tiling @@ -1172,8 +1184,21 @@ def generate_video( if seed == None or seed <0: seed = random.randint(0, 999999999) - file_list = [] + clear_file_list = server_config.get("clear_file_list", 0) + file_list = state.get("file_list", []) + if clear_file_list > 0: + file_list_current_size = len(file_list) + keep_file_from = max(file_list_current_size - clear_file_list, 0) + files_removed = keep_file_from + choice = state.get("selected",0) + choice = max(choice- files_removed, 0) + file_list = file_list[ keep_file_from: ] + else: + file_list = [] + choice = 0 + state["selected"] = choice state["file_list"] = file_list + global save_path os.makedirs(save_path, exist_ok=True) video_no = 0 @@ -1240,6 +1265,8 @@ def generate_video( slg_layers = slg_layers, slg_start = slg_start/100, slg_end = slg_end/100, + cfg_star_switch = cfg_star_switch, + cfg_zero_step = cfg_zero_step, ) else: @@ -1260,6 +1287,8 @@ def generate_video( slg_layers = slg_layers, slg_start = slg_start/100, slg_end = slg_end/100, + cfg_star_switch = cfg_star_switch, + cfg_zero_step = cfg_zero_step, ) except Exception as e: gen_in_progress = False @@ -1326,7 +1355,7 @@ def generate_video( value_range=(-1, 1)) configs = get_settings_dict(state, use_image2video, prompt, 0 if image_to_end == None else 1 , video_length, raw_resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices, - loras_mult_choices, tea_cache , tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start, slg_end) + loras_mult_choices, tea_cache , tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start, slg_end, cfg_star_switch, cfg_zero_step) metadata_choice = server_config.get("metadata_choice","metadata") if metadata_choice == "json": @@ -1642,7 +1671,7 @@ def switch_advanced(state, new_advanced, lset_name): def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices, - loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc): + loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step): loras = state["loras"] activated_loras = [Path( loras[int(no)]).parts[-1] for no in loras_choices ] @@ -1666,7 +1695,9 @@ def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resol "slg_switch": slg_switch, "slg_layers": slg_layers, "slg_start_perc": slg_start_perc, - "slg_end_perc": slg_end_perc + "slg_end_perc": slg_end_perc, + "cfg_star_switch": cfg_star_switch, + "cfg_zero_step": cfg_zero_step } if i2v: @@ -1678,13 +1709,13 @@ def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resol return ui_settings def save_settings(state, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices, - loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc): + loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step): if state.get("validate_success",0) != 1: return ui_defaults = get_settings_dict(state, use_image2video, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices, - loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc) + loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step) defaults_filename = get_settings_file_name(use_image2video) @@ -1955,7 +1986,7 @@ def generate_video_tab(image2video=False): label="RIFLEx positional embedding to generate long video" ) with gr.Row(): - gr.Markdown("Experimental: Skip Layer guidance,should improve video quality") + gr.Markdown("Experimental: Skip Layer Guidance, should improve video quality") with gr.Row(): slg_switch = gr.Dropdown( choices=[ @@ -1979,6 +2010,23 @@ def generate_video_tab(image2video=False): with gr.Row(): slg_start_perc = gr.Slider(0, 100, value=ui_defaults["slg_start_perc"], step=1, label="Denoising Steps % start") slg_end_perc = gr.Slider(0, 100, value=ui_defaults["slg_end_perc"], step=1, label="Denoising Steps % end") + + with gr.Row(): + gr.Markdown("Experimental: Classifier-Free Guidance Zero Star, better adherence to Text Prompt") + with gr.Row(): + cfg_star_switch = gr.Dropdown( + choices=[ + ("OFF", 0), + ("ON", 1), + ], + value=ui_defaults.get("cfg_star_switch",0), + visible=True, + scale = 1, + label="CFG Star" + ) + with gr.Row(): + cfg_zero_step = gr.Slider(-1, 39, value=ui_defaults.get("cfg_zero_step",-1), step=1, label="CFG Zero below this Layer (Extra Process)") + with gr.Row(): save_settings_btn = gr.Button("Set Settings as Default") show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then( @@ -1997,7 +2045,7 @@ def generate_video_tab(image2video=False): save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then( save_settings, inputs = [state, prompt, image_prompt_type_radio, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, - slg_start_perc, slg_end_perc ], outputs = []) + slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step ], outputs = []) save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop]) confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then( save_lset, inputs=[state, lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn,refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop]) @@ -2035,6 +2083,7 @@ def generate_video_tab(image2video=False): tea_cache_start_step_perc, loras_choices, loras_mult_choices, + image_prompt_type_radio, image_to_continue, image_to_end, video_to_continue, @@ -2044,6 +2093,8 @@ def generate_video_tab(image2video=False): slg_layers, slg_start_perc, slg_end_perc, + cfg_star_switch, + cfg_zero_step, state, gr.State(image2video) ], @@ -2175,9 +2226,24 @@ def generate_configuration_tab(): ("Add metadata to video", "metadata"), ("Neither", "none") ], - value=metadata, + value=server_config.get("metadata_type", "metadata"), label="Metadata Handling" ) + + clear_file_list_choice = gr.Dropdown( + choices=[ + ("None", 0), + ("Keep the last video", 1), + ("Keep the last 5 videos", 5), + ("Keep the last 10 videos", 10), + ("Keep the last 20 videos", 20), + ("Keep the last 30 videos", 30), + ], + value=server_config.get("clear_file_list", 0), + label="Keep Previously Generated Videos when starting a Generation Batch" + ) + + msg = gr.Markdown() apply_btn = gr.Button("Apply Changes") apply_btn.click( @@ -2195,6 +2261,7 @@ def generate_configuration_tab(): metadata_choice, default_ui_choice, boost_choice, + clear_file_list_choice, ], outputs= msg ) @@ -2262,7 +2329,7 @@ def create_demo(): } """ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md")) as demo: - gr.Markdown("

Wan 2.1GP v3.1 by DeepBeepMeep (Updates)

") + gr.Markdown("

Wan 2.1GP v3.2 by DeepBeepMeep (Updates)

") gr.Markdown("Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !") with gr.Accordion("Click here for some Info on how to use Wan2GP", open = False): diff --git a/requirements.txt b/requirements.txt index bd928de..7576271 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,6 +16,6 @@ gradio>=5.0.0 numpy>=1.23.5,<2 einops moviepy==1.0.3 -mmgp==3.3.3 +mmgp==3.3.4 peft==0.14.0 mutagen \ No newline at end of file diff --git a/wan/image2video.py b/wan/image2video.py index a853665..e236bd7 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -28,79 +28,19 @@ from wan.modules.posemb_layers import get_rotary_pos_embed from PIL import Image -def lanczos(samples, width, height): - images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples] - images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images] - images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images] - result = torch.stack(images) - return result.to(samples.device, samples.dtype) +def optimized_scale(positive_flat, negative_flat): -def bislerp(samples, width, height): - def slerp(b1, b2, r): - '''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC''' + # Calculate dot production + dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) - c = b1.shape[-1] + # Squared norm of uncondition + squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 - #norms - b1_norms = torch.norm(b1, dim=-1, keepdim=True) - b2_norms = torch.norm(b2, dim=-1, keepdim=True) - - #normalize - b1_normalized = b1 / b1_norms - b2_normalized = b2 / b2_norms - - #zero when norms are zero - b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0 - b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0 - - #slerp - dot = (b1_normalized*b2_normalized).sum(1) - omega = torch.acos(dot) - so = torch.sin(omega) - - #technically not mathematically correct, but more pleasing? - res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized - res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c) - - #edge cases for same or polar opposites - res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5] - res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1] - return res - - -def common_upscale(samples, width, height, upscale_method, crop): - orig_shape = tuple(samples.shape) - if len(orig_shape) > 4: - samples = samples.reshape(samples.shape[0], samples.shape[1], -1, samples.shape[-2], samples.shape[-1]) - samples = samples.movedim(2, 1) - samples = samples.reshape(-1, orig_shape[1], orig_shape[-2], orig_shape[-1]) - if crop == "center": - old_width = samples.shape[-1] - old_height = samples.shape[-2] - old_aspect = old_width / old_height - new_aspect = width / height - x = 0 - y = 0 - if old_aspect > new_aspect: - x = round((old_width - old_width * (new_aspect / old_aspect)) / 2) - elif old_aspect < new_aspect: - y = round((old_height - old_height * (old_aspect / new_aspect)) / 2) - s = samples.narrow(-2, y, old_height - y * 2).narrow(-1, x, old_width - x * 2) - else: - s = samples - - if upscale_method == "bislerp": - out = bislerp(s, width, height) - elif upscale_method == "lanczos": - out = lanczos(s, width, height) - else: - out = torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) - - if len(orig_shape) == 4: - return out - - out = out.reshape((orig_shape[0], -1, orig_shape[1]) + (height, width)) - return out.movedim(2, 1).reshape(orig_shape[:-2] + (height, width)) + # st_star = v_cond^T * v_uncond / ||v_uncond||^2 + st_star = dot_product / squared_norm + + return st_star + class WanI2V: @@ -227,6 +167,8 @@ class WanI2V: slg_layers = None, slg_start = 0.0, slg_end = 1.0, + cfg_star_switch = True, + cfg_zero_step = 5, ): r""" Generates video frames from input image and text prompt using diffusion process. @@ -375,7 +317,7 @@ class WanI2V: # sample videos latent = noise - + batch_size = latent.shape[0] freqs = get_rotary_pos_embed(latent.shape[1:], enable_RIFLEx= enable_RIFLEx) arg_c = { @@ -456,8 +398,23 @@ class WanI2V: del latent_model_input if offload_model: torch.cuda.empty_cache() - noise_pred = noise_pred_uncond + guide_scale * ( - noise_pred_cond - noise_pred_uncond) + # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/ + noise_pred_text = noise_pred_cond + if cfg_star_switch: + positive_flat = noise_pred_text.view(batch_size, -1) + negative_flat = noise_pred_uncond.view(batch_size, -1) + + alpha = optimized_scale(positive_flat,negative_flat) + alpha = alpha.view(batch_size, 1, 1, 1) + + + if (i <= cfg_zero_step): + noise_pred = noise_pred_text*0. + else: + noise_pred = noise_pred_uncond * alpha + guide_scale * (noise_pred_text - noise_pred_uncond * alpha) + else: + noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond) + del noise_pred_uncond latent = latent.to( diff --git a/wan/modules/attention.py b/wan/modules/attention.py index 6861283..e19e387 100644 --- a/wan/modules/attention.py +++ b/wan/modules/attention.py @@ -70,30 +70,31 @@ def sageattn_wrapper( return o -# # try: +# try: # if True: -# from sageattention import sageattn_qk_int8_pv_fp8_window_cuda -# @torch.compiler.disable() -# def sageattn_window_wrapper( -# qkv_list, -# attention_length, -# window -# ): -# q,k, v = qkv_list -# padding_length = q.shape[0] -attention_length -# q = q[:attention_length, :, : ].unsqueeze(0) -# k = k[:attention_length, :, : ].unsqueeze(0) -# v = v[:attention_length, :, : ].unsqueeze(0) -# o = sageattn_qk_int8_pv_fp8_window_cuda(q, k, v, tensor_layout="NHD", window = window).squeeze(0) -# del q, k ,v -# qkv_list.clear() + # from .sage2_core import sageattn_qk_int8_pv_fp8_window_cuda + # @torch.compiler.disable() + # def sageattn_window_wrapper( + # qkv_list, + # attention_length, + # window + # ): + # q,k, v = qkv_list + # padding_length = q.shape[0] -attention_length + # q = q[:attention_length, :, : ].unsqueeze(0) + # k = k[:attention_length, :, : ].unsqueeze(0) + # v = v[:attention_length, :, : ].unsqueeze(0) + # qkvl_list = [q, k , v] + # del q, k ,v + # o = sageattn_qk_int8_pv_fp8_window_cuda(qkvl_list, tensor_layout="NHD", window = window).squeeze(0) + # qkv_list.clear() -# if padding_length > 0: -# o = torch.cat([o, torch.empty( (padding_length, *o.shape[-2:]), dtype= o.dtype, device=o.device ) ], 0) + # if padding_length > 0: + # o = torch.cat([o, torch.empty( (padding_length, *o.shape[-2:]), dtype= o.dtype, device=o.device ) ], 0) -# return o -# # except ImportError: -# # sageattn = sageattn_qk_int8_pv_fp8_window_cuda + # return o +# except ImportError: +# sageattn = sageattn_qk_int8_pv_fp8_window_cuda @torch.compiler.disable() def sdpa_wrapper( @@ -253,17 +254,19 @@ def pay_attention( # nb_latents = embed_sizes[0] * embed_sizes[1]* embed_sizes[2] # window = 0 - # start_window_step = int(max_steps * 0.4) + # start_window_step = int(max_steps * 0.3) # start_layer = 10 - # if (layer < start_layer ) or current_step end_layer ) or current_step 0 - + # invert_spaces = False # def flip(q): # q = q.reshape(*embed_sizes, *q.shape[-2:]) # q = q.transpose(0,2) diff --git a/wan/modules/model.py b/wan/modules/model.py index 8c7dfdf..3e2ea5c 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -647,26 +647,6 @@ class WanModel(ModelMixin, ConfigMixin): self.init_weights() - # self.freqs = torch.cat([ - # rope_params(1024, d - 4 * (d // 6)), #44 - # rope_params(1024, 2 * (d // 6)), #42 - # rope_params(1024, 2 * (d // 6)) #42 - # ],dim=1) - - - def get_rope_freqs(self, nb_latent_frames, RIFLEx_k = None, device = "cuda"): - dim = self.dim - num_heads = self.num_heads - d = dim // num_heads - assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 - - - c1, s1 = rope_params_riflex(1024, dim= d - 4 * (d // 6), L_test=nb_latent_frames, k = RIFLEx_k ) if RIFLEx_k != None else rope_params(1024, dim= d - 4 * (d // 6)) #44 - c2, s2 = rope_params(1024, 2 * (d // 6)) #42 - c3, s3 = rope_params(1024, 2 * (d // 6)) #42 - - return (torch.cat([c1,c2,c3],dim=1).to(device) , torch.cat([s1,s2,s3],dim=1).to(device)) - def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0): rescale_func = np.poly1d(self.coefficients) e_list = [] diff --git a/wan/modules/sage2_core.py b/wan/modules/sage2_core.py index de94a60..f6d23ee 100644 --- a/wan/modules/sage2_core.py +++ b/wan/modules/sage2_core.py @@ -925,11 +925,11 @@ def sageattn_qk_int8_pv_fp8_window_cuda( if pv_accum_dtype == "fp32": if smooth_v: - lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window = window) + lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window) else: - lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window = window) + lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window) elif pv_accum_dtype == "fp32+fp32": - lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window = window) + lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window) o = o[..., :head_dim_og] diff --git a/wan/text2video.py b/wan/text2video.py index 88046db..379954f 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -24,6 +24,20 @@ from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from wan.modules.posemb_layers import get_rotary_pos_embed +def optimized_scale(positive_flat, negative_flat): + + # Calculate dot production + dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) + + # Squared norm of uncondition + squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 + + # st_star = v_cond^T * v_uncond / ||v_uncond||^2 + st_star = dot_product / squared_norm + + return st_star + + class WanT2V: def __init__( @@ -136,6 +150,8 @@ class WanT2V: slg_layers = None, slg_start = 0.0, slg_end = 1.0, + cfg_star_switch = True, + cfg_zero_step = 5, ): r""" Generates video frames from text prompt using diffusion process. @@ -240,7 +256,7 @@ class WanT2V: # sample videos latents = noise - + batch_size =len(latents) freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx) arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self} arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self} @@ -249,7 +265,6 @@ class WanT2V: # arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps} # arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps} # arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps} - if self.model.enable_teacache: self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier) if callback != None: @@ -280,8 +295,23 @@ class WanT2V: return None del latent_model_input - noise_pred = noise_pred_uncond + guide_scale * ( - noise_pred_cond - noise_pred_uncond) + + # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/ + noise_pred_text = noise_pred_cond + if cfg_star_switch: + positive_flat = noise_pred_text.view(batch_size, -1) + negative_flat = noise_pred_uncond.view(batch_size, -1) + + alpha = optimized_scale(positive_flat,negative_flat) + alpha = alpha.view(batch_size, 1, 1, 1) + + + if (i <= cfg_zero_step): + noise_pred = noise_pred_text*0. + else: + noise_pred = noise_pred_uncond * alpha + guide_scale * (noise_pred_text - noise_pred_uncond * alpha) + else: + noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond) del noise_pred_uncond temp_x0 = sample_scheduler.step(