mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
Added CFG Zero *
This commit is contained in:
parent
91e35a806c
commit
826d5ac84f
@ -19,6 +19,12 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
|
||||
|
||||
|
||||
## 🔥 Latest News!!
|
||||
* Mar 19 2022: 👋 Wan2.1GP v3.2:
|
||||
- Added Classifier-Free Guidance Zero Star. The video should match better the text prompt (especially with text2video) at no performance cost: many thanks to the **CFG Zero * Team:**\
|
||||
Dont hesitate to give them a star if you appreciate the results: https://github.com/WeichenFan/CFG-Zero-star
|
||||
- Added back support for Pytorch compilation with Loras. It seems it had been broken for some time
|
||||
- Added possibility to keep a number of pregenerated videos in the Video Gallery (useful to compare outputs of different settings)
|
||||
You will need one more *pip install -r requirements.txt*
|
||||
* Mar 19 2022: 👋 Wan2.1GP v3.1: Faster launch and RAM optimizations (should require less RAM to run)\
|
||||
You will need one more *pip install -r requirements.txt*
|
||||
* Mar 18 2022: 👋 Wan2.1GP v3.0:
|
||||
|
||||
101
gradio_server.py
101
gradio_server.py
@ -23,7 +23,7 @@ import asyncio
|
||||
from wan.utils import prompt_parser
|
||||
PROMPT_VARS_MAX = 10
|
||||
|
||||
target_mmgp_version = "3.3.3"
|
||||
target_mmgp_version = "3.3.4"
|
||||
from importlib.metadata import version
|
||||
mmgp_version = version("mmgp")
|
||||
if mmgp_version != target_mmgp_version:
|
||||
@ -300,6 +300,7 @@ if not Path(server_config_filename).is_file():
|
||||
"metadata_type": "metadata",
|
||||
"default_ui": "t2v",
|
||||
"boost" : 1,
|
||||
"clear_file_list" : 0,
|
||||
"vae_config": 0,
|
||||
"profile" : profile_type.LowRAM_LowVRAM }
|
||||
|
||||
@ -382,7 +383,6 @@ if len(args.vae_config) > 0:
|
||||
|
||||
reload_needed = False
|
||||
default_ui = server_config.get("default_ui", "t2v")
|
||||
metadata = server_config.get("metadata_type", "metadata")
|
||||
save_path = server_config.get("save_path", os.path.join(os.getcwd(), "gradio_outputs"))
|
||||
use_image2video = default_ui != "t2v"
|
||||
if args.t2v:
|
||||
@ -741,7 +741,8 @@ def apply_changes( state,
|
||||
vae_config_choice,
|
||||
metadata_choice,
|
||||
default_ui_choice ="t2v",
|
||||
boost_choice = 1
|
||||
boost_choice = 1,
|
||||
clear_file_list = 0,
|
||||
):
|
||||
if args.lock_config:
|
||||
return
|
||||
@ -760,6 +761,7 @@ def apply_changes( state,
|
||||
"metadata_choice": metadata_choice,
|
||||
"default_ui" : default_ui_choice,
|
||||
"boost" : boost_choice,
|
||||
"clear_file_list" : clear_file_list
|
||||
}
|
||||
|
||||
if Path(server_config_filename).is_file():
|
||||
@ -792,7 +794,7 @@ def apply_changes( state,
|
||||
text_encoder_filename = server_config["text_encoder_filename"]
|
||||
vae_config = server_config["vae_config"]
|
||||
boost = server_config["boost"]
|
||||
if all(change in ["attention_mode", "vae_config", "default_ui", "boost", "save_path", "metadata_choice"] for change in changes ):
|
||||
if all(change in ["attention_mode", "vae_config", "default_ui", "boost", "save_path", "metadata_choice", "clear_file_list"] for change in changes ):
|
||||
pass
|
||||
else:
|
||||
reload_needed = True
|
||||
@ -849,13 +851,17 @@ def refresh_gallery(state, txt):
|
||||
if len(prompt) == 0:
|
||||
return file_list, gr.Text(visible= False, value="")
|
||||
else:
|
||||
choice = 0
|
||||
if "in_progress" in state:
|
||||
choice = state.get("selected",0)
|
||||
|
||||
prompts_max = state.get("prompts_max",0)
|
||||
prompt_no = state.get("prompt_no",0)
|
||||
if prompts_max >1 :
|
||||
label = f"Current Prompt ({prompt_no+1}/{prompts_max})"
|
||||
else:
|
||||
label = f"Current Prompt"
|
||||
return file_list, gr.Text(visible= True, value=prompt, label=label)
|
||||
return gr.Gallery(selected_index=choice, value = file_list), gr.Text(visible= True, value=prompt, label=label)
|
||||
|
||||
|
||||
def finalize_gallery(state):
|
||||
@ -863,6 +869,8 @@ def finalize_gallery(state):
|
||||
if "in_progress" in state:
|
||||
del state["in_progress"]
|
||||
choice = state.get("selected",0)
|
||||
# file_list = state.get("file_list", [])
|
||||
|
||||
|
||||
state["extra_orders"] = 0
|
||||
time.sleep(0.2)
|
||||
@ -930,6 +938,7 @@ def generate_video(
|
||||
tea_cache_start_step_perc,
|
||||
loras_choices,
|
||||
loras_mult_choices,
|
||||
image_prompt_type,
|
||||
image_to_continue,
|
||||
image_to_end,
|
||||
video_to_continue,
|
||||
@ -938,7 +947,9 @@ def generate_video(
|
||||
slg_switch,
|
||||
slg_layers,
|
||||
slg_start,
|
||||
slg_end,
|
||||
slg_end,
|
||||
cfg_star_switch,
|
||||
cfg_zero_step,
|
||||
state,
|
||||
image2video,
|
||||
progress=gr.Progress() #track_tqdm= True
|
||||
@ -1031,6 +1042,8 @@ def generate_video(
|
||||
if len(prompts) ==0:
|
||||
return
|
||||
if image2video:
|
||||
if image_prompt_type == 0:
|
||||
image_to_end = None
|
||||
if image_to_continue is not None:
|
||||
if isinstance(image_to_continue, list):
|
||||
image_to_continue = [ tup[0] for tup in image_to_continue ]
|
||||
@ -1135,7 +1148,6 @@ def generate_video(
|
||||
if "abort" in state:
|
||||
del state["abort"]
|
||||
state["in_progress"] = True
|
||||
state["selected"] = 0
|
||||
|
||||
enable_RIFLEx = RIFLEx_setting == 0 and video_length > (6* 16) or RIFLEx_setting == 1
|
||||
# VAE Tiling
|
||||
@ -1172,8 +1184,21 @@ def generate_video(
|
||||
if seed == None or seed <0:
|
||||
seed = random.randint(0, 999999999)
|
||||
|
||||
file_list = []
|
||||
clear_file_list = server_config.get("clear_file_list", 0)
|
||||
file_list = state.get("file_list", [])
|
||||
if clear_file_list > 0:
|
||||
file_list_current_size = len(file_list)
|
||||
keep_file_from = max(file_list_current_size - clear_file_list, 0)
|
||||
files_removed = keep_file_from
|
||||
choice = state.get("selected",0)
|
||||
choice = max(choice- files_removed, 0)
|
||||
file_list = file_list[ keep_file_from: ]
|
||||
else:
|
||||
file_list = []
|
||||
choice = 0
|
||||
state["selected"] = choice
|
||||
state["file_list"] = file_list
|
||||
|
||||
global save_path
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
video_no = 0
|
||||
@ -1240,6 +1265,8 @@ def generate_video(
|
||||
slg_layers = slg_layers,
|
||||
slg_start = slg_start/100,
|
||||
slg_end = slg_end/100,
|
||||
cfg_star_switch = cfg_star_switch,
|
||||
cfg_zero_step = cfg_zero_step,
|
||||
)
|
||||
|
||||
else:
|
||||
@ -1260,6 +1287,8 @@ def generate_video(
|
||||
slg_layers = slg_layers,
|
||||
slg_start = slg_start/100,
|
||||
slg_end = slg_end/100,
|
||||
cfg_star_switch = cfg_star_switch,
|
||||
cfg_zero_step = cfg_zero_step,
|
||||
)
|
||||
except Exception as e:
|
||||
gen_in_progress = False
|
||||
@ -1326,7 +1355,7 @@ def generate_video(
|
||||
value_range=(-1, 1))
|
||||
|
||||
configs = get_settings_dict(state, use_image2video, prompt, 0 if image_to_end == None else 1 , video_length, raw_resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
|
||||
loras_mult_choices, tea_cache , tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start, slg_end)
|
||||
loras_mult_choices, tea_cache , tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start, slg_end, cfg_star_switch, cfg_zero_step)
|
||||
|
||||
metadata_choice = server_config.get("metadata_choice","metadata")
|
||||
if metadata_choice == "json":
|
||||
@ -1642,7 +1671,7 @@ def switch_advanced(state, new_advanced, lset_name):
|
||||
|
||||
|
||||
def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
|
||||
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc):
|
||||
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step):
|
||||
|
||||
loras = state["loras"]
|
||||
activated_loras = [Path( loras[int(no)]).parts[-1] for no in loras_choices ]
|
||||
@ -1666,7 +1695,9 @@ def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resol
|
||||
"slg_switch": slg_switch,
|
||||
"slg_layers": slg_layers,
|
||||
"slg_start_perc": slg_start_perc,
|
||||
"slg_end_perc": slg_end_perc
|
||||
"slg_end_perc": slg_end_perc,
|
||||
"cfg_star_switch": cfg_star_switch,
|
||||
"cfg_zero_step": cfg_zero_step
|
||||
}
|
||||
|
||||
if i2v:
|
||||
@ -1678,13 +1709,13 @@ def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resol
|
||||
return ui_settings
|
||||
|
||||
def save_settings(state, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
|
||||
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc):
|
||||
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step):
|
||||
|
||||
if state.get("validate_success",0) != 1:
|
||||
return
|
||||
|
||||
ui_defaults = get_settings_dict(state, use_image2video, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
|
||||
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc)
|
||||
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step)
|
||||
|
||||
defaults_filename = get_settings_file_name(use_image2video)
|
||||
|
||||
@ -1955,7 +1986,7 @@ def generate_video_tab(image2video=False):
|
||||
label="RIFLEx positional embedding to generate long video"
|
||||
)
|
||||
with gr.Row():
|
||||
gr.Markdown("<B>Experimental: Skip Layer guidance,should improve video quality</B>")
|
||||
gr.Markdown("<B>Experimental: Skip Layer Guidance, should improve video quality</B>")
|
||||
with gr.Row():
|
||||
slg_switch = gr.Dropdown(
|
||||
choices=[
|
||||
@ -1979,6 +2010,23 @@ def generate_video_tab(image2video=False):
|
||||
with gr.Row():
|
||||
slg_start_perc = gr.Slider(0, 100, value=ui_defaults["slg_start_perc"], step=1, label="Denoising Steps % start")
|
||||
slg_end_perc = gr.Slider(0, 100, value=ui_defaults["slg_end_perc"], step=1, label="Denoising Steps % end")
|
||||
|
||||
with gr.Row():
|
||||
gr.Markdown("<B>Experimental: Classifier-Free Guidance Zero Star, better adherence to Text Prompt")
|
||||
with gr.Row():
|
||||
cfg_star_switch = gr.Dropdown(
|
||||
choices=[
|
||||
("OFF", 0),
|
||||
("ON", 1),
|
||||
],
|
||||
value=ui_defaults.get("cfg_star_switch",0),
|
||||
visible=True,
|
||||
scale = 1,
|
||||
label="CFG Star"
|
||||
)
|
||||
with gr.Row():
|
||||
cfg_zero_step = gr.Slider(-1, 39, value=ui_defaults.get("cfg_zero_step",-1), step=1, label="CFG Zero below this Layer (Extra Process)")
|
||||
|
||||
with gr.Row():
|
||||
save_settings_btn = gr.Button("Set Settings as Default")
|
||||
show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
|
||||
@ -1997,7 +2045,7 @@ def generate_video_tab(image2video=False):
|
||||
save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
|
||||
save_settings, inputs = [state, prompt, image_prompt_type_radio, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt,
|
||||
loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers,
|
||||
slg_start_perc, slg_end_perc ], outputs = [])
|
||||
slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step ], outputs = [])
|
||||
save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
|
||||
confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
|
||||
save_lset, inputs=[state, lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn,refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
|
||||
@ -2035,6 +2083,7 @@ def generate_video_tab(image2video=False):
|
||||
tea_cache_start_step_perc,
|
||||
loras_choices,
|
||||
loras_mult_choices,
|
||||
image_prompt_type_radio,
|
||||
image_to_continue,
|
||||
image_to_end,
|
||||
video_to_continue,
|
||||
@ -2044,6 +2093,8 @@ def generate_video_tab(image2video=False):
|
||||
slg_layers,
|
||||
slg_start_perc,
|
||||
slg_end_perc,
|
||||
cfg_star_switch,
|
||||
cfg_zero_step,
|
||||
state,
|
||||
gr.State(image2video)
|
||||
],
|
||||
@ -2175,9 +2226,24 @@ def generate_configuration_tab():
|
||||
("Add metadata to video", "metadata"),
|
||||
("Neither", "none")
|
||||
],
|
||||
value=metadata,
|
||||
value=server_config.get("metadata_type", "metadata"),
|
||||
label="Metadata Handling"
|
||||
)
|
||||
|
||||
clear_file_list_choice = gr.Dropdown(
|
||||
choices=[
|
||||
("None", 0),
|
||||
("Keep the last video", 1),
|
||||
("Keep the last 5 videos", 5),
|
||||
("Keep the last 10 videos", 10),
|
||||
("Keep the last 20 videos", 20),
|
||||
("Keep the last 30 videos", 30),
|
||||
],
|
||||
value=server_config.get("clear_file_list", 0),
|
||||
label="Keep Previously Generated Videos when starting a Generation Batch"
|
||||
)
|
||||
|
||||
|
||||
msg = gr.Markdown()
|
||||
apply_btn = gr.Button("Apply Changes")
|
||||
apply_btn.click(
|
||||
@ -2195,6 +2261,7 @@ def generate_configuration_tab():
|
||||
metadata_choice,
|
||||
default_ui_choice,
|
||||
boost_choice,
|
||||
clear_file_list_choice,
|
||||
],
|
||||
outputs= msg
|
||||
)
|
||||
@ -2262,7 +2329,7 @@ def create_demo():
|
||||
}
|
||||
"""
|
||||
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md")) as demo:
|
||||
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v3.1 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
|
||||
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v3.2 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
|
||||
gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !</FONT>")
|
||||
|
||||
with gr.Accordion("Click here for some Info on how to use Wan2GP", open = False):
|
||||
|
||||
@ -16,6 +16,6 @@ gradio>=5.0.0
|
||||
numpy>=1.23.5,<2
|
||||
einops
|
||||
moviepy==1.0.3
|
||||
mmgp==3.3.3
|
||||
mmgp==3.3.4
|
||||
peft==0.14.0
|
||||
mutagen
|
||||
@ -28,79 +28,19 @@ from wan.modules.posemb_layers import get_rotary_pos_embed
|
||||
|
||||
from PIL import Image
|
||||
|
||||
def lanczos(samples, width, height):
|
||||
images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
|
||||
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
|
||||
images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
|
||||
result = torch.stack(images)
|
||||
return result.to(samples.device, samples.dtype)
|
||||
def optimized_scale(positive_flat, negative_flat):
|
||||
|
||||
def bislerp(samples, width, height):
|
||||
def slerp(b1, b2, r):
|
||||
'''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
|
||||
# Calculate dot production
|
||||
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
||||
|
||||
c = b1.shape[-1]
|
||||
# Squared norm of uncondition
|
||||
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
|
||||
|
||||
#norms
|
||||
b1_norms = torch.norm(b1, dim=-1, keepdim=True)
|
||||
b2_norms = torch.norm(b2, dim=-1, keepdim=True)
|
||||
|
||||
#normalize
|
||||
b1_normalized = b1 / b1_norms
|
||||
b2_normalized = b2 / b2_norms
|
||||
|
||||
#zero when norms are zero
|
||||
b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0
|
||||
b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0
|
||||
|
||||
#slerp
|
||||
dot = (b1_normalized*b2_normalized).sum(1)
|
||||
omega = torch.acos(dot)
|
||||
so = torch.sin(omega)
|
||||
|
||||
#technically not mathematically correct, but more pleasing?
|
||||
res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized
|
||||
res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c)
|
||||
|
||||
#edge cases for same or polar opposites
|
||||
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
|
||||
res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
|
||||
return res
|
||||
|
||||
|
||||
def common_upscale(samples, width, height, upscale_method, crop):
|
||||
orig_shape = tuple(samples.shape)
|
||||
if len(orig_shape) > 4:
|
||||
samples = samples.reshape(samples.shape[0], samples.shape[1], -1, samples.shape[-2], samples.shape[-1])
|
||||
samples = samples.movedim(2, 1)
|
||||
samples = samples.reshape(-1, orig_shape[1], orig_shape[-2], orig_shape[-1])
|
||||
if crop == "center":
|
||||
old_width = samples.shape[-1]
|
||||
old_height = samples.shape[-2]
|
||||
old_aspect = old_width / old_height
|
||||
new_aspect = width / height
|
||||
x = 0
|
||||
y = 0
|
||||
if old_aspect > new_aspect:
|
||||
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
|
||||
elif old_aspect < new_aspect:
|
||||
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
|
||||
s = samples.narrow(-2, y, old_height - y * 2).narrow(-1, x, old_width - x * 2)
|
||||
else:
|
||||
s = samples
|
||||
|
||||
if upscale_method == "bislerp":
|
||||
out = bislerp(s, width, height)
|
||||
elif upscale_method == "lanczos":
|
||||
out = lanczos(s, width, height)
|
||||
else:
|
||||
out = torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
||||
|
||||
if len(orig_shape) == 4:
|
||||
return out
|
||||
|
||||
out = out.reshape((orig_shape[0], -1, orig_shape[1]) + (height, width))
|
||||
return out.movedim(2, 1).reshape(orig_shape[:-2] + (height, width))
|
||||
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
|
||||
st_star = dot_product / squared_norm
|
||||
|
||||
return st_star
|
||||
|
||||
|
||||
class WanI2V:
|
||||
|
||||
@ -227,6 +167,8 @@ class WanI2V:
|
||||
slg_layers = None,
|
||||
slg_start = 0.0,
|
||||
slg_end = 1.0,
|
||||
cfg_star_switch = True,
|
||||
cfg_zero_step = 5,
|
||||
):
|
||||
r"""
|
||||
Generates video frames from input image and text prompt using diffusion process.
|
||||
@ -375,7 +317,7 @@ class WanI2V:
|
||||
|
||||
# sample videos
|
||||
latent = noise
|
||||
|
||||
batch_size = latent.shape[0]
|
||||
freqs = get_rotary_pos_embed(latent.shape[1:], enable_RIFLEx= enable_RIFLEx)
|
||||
|
||||
arg_c = {
|
||||
@ -456,8 +398,23 @@ class WanI2V:
|
||||
del latent_model_input
|
||||
if offload_model:
|
||||
torch.cuda.empty_cache()
|
||||
noise_pred = noise_pred_uncond + guide_scale * (
|
||||
noise_pred_cond - noise_pred_uncond)
|
||||
# CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
|
||||
noise_pred_text = noise_pred_cond
|
||||
if cfg_star_switch:
|
||||
positive_flat = noise_pred_text.view(batch_size, -1)
|
||||
negative_flat = noise_pred_uncond.view(batch_size, -1)
|
||||
|
||||
alpha = optimized_scale(positive_flat,negative_flat)
|
||||
alpha = alpha.view(batch_size, 1, 1, 1)
|
||||
|
||||
|
||||
if (i <= cfg_zero_step):
|
||||
noise_pred = noise_pred_text*0.
|
||||
else:
|
||||
noise_pred = noise_pred_uncond * alpha + guide_scale * (noise_pred_text - noise_pred_uncond * alpha)
|
||||
else:
|
||||
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
del noise_pred_uncond
|
||||
|
||||
latent = latent.to(
|
||||
|
||||
@ -70,30 +70,31 @@ def sageattn_wrapper(
|
||||
|
||||
return o
|
||||
|
||||
# # try:
|
||||
# try:
|
||||
# if True:
|
||||
# from sageattention import sageattn_qk_int8_pv_fp8_window_cuda
|
||||
# @torch.compiler.disable()
|
||||
# def sageattn_window_wrapper(
|
||||
# qkv_list,
|
||||
# attention_length,
|
||||
# window
|
||||
# ):
|
||||
# q,k, v = qkv_list
|
||||
# padding_length = q.shape[0] -attention_length
|
||||
# q = q[:attention_length, :, : ].unsqueeze(0)
|
||||
# k = k[:attention_length, :, : ].unsqueeze(0)
|
||||
# v = v[:attention_length, :, : ].unsqueeze(0)
|
||||
# o = sageattn_qk_int8_pv_fp8_window_cuda(q, k, v, tensor_layout="NHD", window = window).squeeze(0)
|
||||
# del q, k ,v
|
||||
# qkv_list.clear()
|
||||
# from .sage2_core import sageattn_qk_int8_pv_fp8_window_cuda
|
||||
# @torch.compiler.disable()
|
||||
# def sageattn_window_wrapper(
|
||||
# qkv_list,
|
||||
# attention_length,
|
||||
# window
|
||||
# ):
|
||||
# q,k, v = qkv_list
|
||||
# padding_length = q.shape[0] -attention_length
|
||||
# q = q[:attention_length, :, : ].unsqueeze(0)
|
||||
# k = k[:attention_length, :, : ].unsqueeze(0)
|
||||
# v = v[:attention_length, :, : ].unsqueeze(0)
|
||||
# qkvl_list = [q, k , v]
|
||||
# del q, k ,v
|
||||
# o = sageattn_qk_int8_pv_fp8_window_cuda(qkvl_list, tensor_layout="NHD", window = window).squeeze(0)
|
||||
# qkv_list.clear()
|
||||
|
||||
# if padding_length > 0:
|
||||
# o = torch.cat([o, torch.empty( (padding_length, *o.shape[-2:]), dtype= o.dtype, device=o.device ) ], 0)
|
||||
# if padding_length > 0:
|
||||
# o = torch.cat([o, torch.empty( (padding_length, *o.shape[-2:]), dtype= o.dtype, device=o.device ) ], 0)
|
||||
|
||||
# return o
|
||||
# # except ImportError:
|
||||
# # sageattn = sageattn_qk_int8_pv_fp8_window_cuda
|
||||
# return o
|
||||
# except ImportError:
|
||||
# sageattn = sageattn_qk_int8_pv_fp8_window_cuda
|
||||
|
||||
@torch.compiler.disable()
|
||||
def sdpa_wrapper(
|
||||
@ -253,17 +254,19 @@ def pay_attention(
|
||||
# nb_latents = embed_sizes[0] * embed_sizes[1]* embed_sizes[2]
|
||||
|
||||
# window = 0
|
||||
# start_window_step = int(max_steps * 0.4)
|
||||
# start_window_step = int(max_steps * 0.3)
|
||||
# start_layer = 10
|
||||
# if (layer < start_layer ) or current_step <start_window_step:
|
||||
# end_layer = 30
|
||||
# if (layer < start_layer or layer > end_layer ) or current_step <start_window_step:
|
||||
# window = 0
|
||||
# else:
|
||||
# coef = min((max_steps - current_step)/(max_steps-start_window_step),1)*max(min((25 - layer)/(25-start_layer),1),0) * 0.7 + 0.3
|
||||
# # coef = min((max_steps - current_step)/(max_steps-start_window_step),1)*max(min((25 - layer)/(25-start_layer),1),0) * 0.7 + 0.3
|
||||
# coef = 0.3
|
||||
# print(f"step: {current_step}, layer: {layer}, coef:{coef:0.1f}]")
|
||||
# window = math.ceil(coef* nb_latents)
|
||||
|
||||
# invert_spaces = (layer + current_step) % 2 == 0 and window > 0
|
||||
|
||||
# invert_spaces = False
|
||||
# def flip(q):
|
||||
# q = q.reshape(*embed_sizes, *q.shape[-2:])
|
||||
# q = q.transpose(0,2)
|
||||
|
||||
@ -647,26 +647,6 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
self.init_weights()
|
||||
|
||||
|
||||
# self.freqs = torch.cat([
|
||||
# rope_params(1024, d - 4 * (d // 6)), #44
|
||||
# rope_params(1024, 2 * (d // 6)), #42
|
||||
# rope_params(1024, 2 * (d // 6)) #42
|
||||
# ],dim=1)
|
||||
|
||||
|
||||
def get_rope_freqs(self, nb_latent_frames, RIFLEx_k = None, device = "cuda"):
|
||||
dim = self.dim
|
||||
num_heads = self.num_heads
|
||||
d = dim // num_heads
|
||||
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
||||
|
||||
|
||||
c1, s1 = rope_params_riflex(1024, dim= d - 4 * (d // 6), L_test=nb_latent_frames, k = RIFLEx_k ) if RIFLEx_k != None else rope_params(1024, dim= d - 4 * (d // 6)) #44
|
||||
c2, s2 = rope_params(1024, 2 * (d // 6)) #42
|
||||
c3, s3 = rope_params(1024, 2 * (d // 6)) #42
|
||||
|
||||
return (torch.cat([c1,c2,c3],dim=1).to(device) , torch.cat([s1,s2,s3],dim=1).to(device))
|
||||
|
||||
def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
|
||||
rescale_func = np.poly1d(self.coefficients)
|
||||
e_list = []
|
||||
|
||||
@ -925,11 +925,11 @@ def sageattn_qk_int8_pv_fp8_window_cuda(
|
||||
|
||||
if pv_accum_dtype == "fp32":
|
||||
if smooth_v:
|
||||
lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window = window)
|
||||
lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window)
|
||||
else:
|
||||
lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window = window)
|
||||
lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window)
|
||||
elif pv_accum_dtype == "fp32+fp32":
|
||||
lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window = window)
|
||||
lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window)
|
||||
|
||||
o = o[..., :head_dim_og]
|
||||
|
||||
|
||||
@ -24,6 +24,20 @@ from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||
from wan.modules.posemb_layers import get_rotary_pos_embed
|
||||
|
||||
|
||||
def optimized_scale(positive_flat, negative_flat):
|
||||
|
||||
# Calculate dot production
|
||||
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
||||
|
||||
# Squared norm of uncondition
|
||||
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
|
||||
|
||||
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
|
||||
st_star = dot_product / squared_norm
|
||||
|
||||
return st_star
|
||||
|
||||
|
||||
class WanT2V:
|
||||
|
||||
def __init__(
|
||||
@ -136,6 +150,8 @@ class WanT2V:
|
||||
slg_layers = None,
|
||||
slg_start = 0.0,
|
||||
slg_end = 1.0,
|
||||
cfg_star_switch = True,
|
||||
cfg_zero_step = 5,
|
||||
):
|
||||
r"""
|
||||
Generates video frames from text prompt using diffusion process.
|
||||
@ -240,7 +256,7 @@ class WanT2V:
|
||||
|
||||
# sample videos
|
||||
latents = noise
|
||||
|
||||
batch_size =len(latents)
|
||||
freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
|
||||
arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
||||
arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
||||
@ -249,7 +265,6 @@ class WanT2V:
|
||||
# arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
|
||||
# arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
|
||||
# arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
|
||||
|
||||
if self.model.enable_teacache:
|
||||
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
||||
if callback != None:
|
||||
@ -280,8 +295,23 @@ class WanT2V:
|
||||
return None
|
||||
|
||||
del latent_model_input
|
||||
noise_pred = noise_pred_uncond + guide_scale * (
|
||||
noise_pred_cond - noise_pred_uncond)
|
||||
|
||||
# CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
|
||||
noise_pred_text = noise_pred_cond
|
||||
if cfg_star_switch:
|
||||
positive_flat = noise_pred_text.view(batch_size, -1)
|
||||
negative_flat = noise_pred_uncond.view(batch_size, -1)
|
||||
|
||||
alpha = optimized_scale(positive_flat,negative_flat)
|
||||
alpha = alpha.view(batch_size, 1, 1, 1)
|
||||
|
||||
|
||||
if (i <= cfg_zero_step):
|
||||
noise_pred = noise_pred_text*0.
|
||||
else:
|
||||
noise_pred = noise_pred_uncond * alpha + guide_scale * (noise_pred_text - noise_pred_uncond * alpha)
|
||||
else:
|
||||
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond)
|
||||
del noise_pred_uncond
|
||||
|
||||
temp_x0 = sample_scheduler.step(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user