mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-05 14:33:15 +00:00
Refactored Loras
This commit is contained in:
parent
48b777ae6c
commit
d233dd7ed9
@ -19,6 +19,7 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
|
|||||||
|
|
||||||
|
|
||||||
## 🔥 Latest News!!
|
## 🔥 Latest News!!
|
||||||
|
* Mar 14, 2025: 👋 Wan2.1GP v1.7: Lora Fest special edition: very fast loading / unload of loras so for those Loras collectors around. You will need to refresh the requirements *pip install -r requirements.txt*
|
||||||
* Mar 13, 2025: 👋 Wan2.1GP v1.6: Better Loras support, accelerated loading Loras. You will need to refresh the requirements *pip install -r requirements.txt*
|
* Mar 13, 2025: 👋 Wan2.1GP v1.6: Better Loras support, accelerated loading Loras. You will need to refresh the requirements *pip install -r requirements.txt*
|
||||||
* Mar 10, 2025: 👋 Wan2.1GP v1.5: Official Teachache support + Smart Teacache (find automatically best parameters for a requested speed multiplier), 10% speed boost with no quality loss, improved lora presets (they can now include prompts and comments to guide the user)
|
* Mar 10, 2025: 👋 Wan2.1GP v1.5: Official Teachache support + Smart Teacache (find automatically best parameters for a requested speed multiplier), 10% speed boost with no quality loss, improved lora presets (they can now include prompts and comments to guide the user)
|
||||||
* Mar 07, 2025: 👋 Wan2.1GP v1.4: Fix Pytorch compilation, now it is really 20% faster when activated
|
* Mar 07, 2025: 👋 Wan2.1GP v1.4: Fix Pytorch compilation, now it is really 20% faster when activated
|
||||||
@ -157,6 +158,11 @@ python gradio_server.py --attention sdpa
|
|||||||
|
|
||||||
Every lora stored in the subfoler 'loras' for t2v and 'loras_i2v' will be automatically loaded. You will be then able to activate / desactive any of them when running the application by selecting them in the area below "Activated Loras" .
|
Every lora stored in the subfoler 'loras' for t2v and 'loras_i2v' will be automatically loaded. You will be then able to activate / desactive any of them when running the application by selecting them in the area below "Activated Loras" .
|
||||||
|
|
||||||
|
If you want to manage in differenta areas Loras for the 1.3B model and the 14B as they are not comptatible, just create the following subfolders:
|
||||||
|
- loras/1.3B
|
||||||
|
- loras/14B
|
||||||
|
|
||||||
|
|
||||||
For each activated Lora, you may specify a *multiplier* that is one float number that corresponds to its weight (default is 1.0) .The multipliers for each Lora should be separated by a space character or a carriage return. For instance:\
|
For each activated Lora, you may specify a *multiplier* that is one float number that corresponds to its weight (default is 1.0) .The multipliers for each Lora should be separated by a space character or a carriage return. For instance:\
|
||||||
*1.2 0.8* means that the first lora will have a 1.2 multiplier and the second one will have 0.8.
|
*1.2 0.8* means that the first lora will have a 1.2 multiplier and the second one will have 0.8.
|
||||||
|
|
||||||
|
|||||||
332
gradio_server.py
332
gradio_server.py
@ -170,6 +170,18 @@ def _parse_args():
|
|||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
def get_lora_dir(root_lora_dir):
|
||||||
|
if not use_image2video:
|
||||||
|
if "1.3B" in transformer_filename_t2v:
|
||||||
|
lora_dir_1_3B = os.path.join(root_lora_dir, "1.3B")
|
||||||
|
if os.path.isdir(lora_dir_1_3B ):
|
||||||
|
return lora_dir_1_3B
|
||||||
|
else:
|
||||||
|
lora_dir_14B = os.path.join(root_lora_dir, "14B")
|
||||||
|
if os.path.isdir(lora_dir_14B ):
|
||||||
|
return lora_dir_14B
|
||||||
|
return root_lora_dir
|
||||||
|
|
||||||
attention_modes_supported = get_attention_modes()
|
attention_modes_supported = get_attention_modes()
|
||||||
|
|
||||||
args = _parse_args()
|
args = _parse_args()
|
||||||
@ -238,9 +250,10 @@ if args.i2v:
|
|||||||
|
|
||||||
lora_dir =args.lora_dir
|
lora_dir =args.lora_dir
|
||||||
if use_image2video and len(lora_dir)==0:
|
if use_image2video and len(lora_dir)==0:
|
||||||
lora_dir =args.lora_dir_i2v
|
root_lora_dir =args.lora_dir_i2v
|
||||||
if len(lora_dir) ==0:
|
if len(lora_dir) ==0:
|
||||||
lora_dir = "loras_i2v" if use_image2video else "loras"
|
root_lora_dir = "loras_i2v" if use_image2video else "loras"
|
||||||
|
lora_dir = get_lora_dir(root_lora_dir)
|
||||||
lora_preselected_preset = args.lora_preset
|
lora_preselected_preset = args.lora_preset
|
||||||
default_tea_cache = 0
|
default_tea_cache = 0
|
||||||
# if args.fast : #or args.fastest
|
# if args.fast : #or args.fastest
|
||||||
@ -295,35 +308,42 @@ def sanitize_file_name(file_name, rep =""):
|
|||||||
return file_name.replace("/",rep).replace("\\",rep).replace(":",rep).replace("|",rep).replace("?",rep).replace("<",rep).replace(">",rep).replace("\"",rep)
|
return file_name.replace("/",rep).replace("\\",rep).replace(":",rep).replace("|",rep).replace("?",rep).replace("<",rep).replace(">",rep).replace("\"",rep)
|
||||||
|
|
||||||
def extract_preset(lset_name, loras):
|
def extract_preset(lset_name, loras):
|
||||||
|
loras_choices = []
|
||||||
|
loras_choices_files = []
|
||||||
|
loras_mult_choices = ""
|
||||||
|
prompt =""
|
||||||
|
full_prompt =""
|
||||||
lset_name = sanitize_file_name(lset_name)
|
lset_name = sanitize_file_name(lset_name)
|
||||||
if not lset_name.endswith(".lset"):
|
if not lset_name.endswith(".lset"):
|
||||||
lset_name_filename = os.path.join(lora_dir, lset_name + ".lset" )
|
lset_name_filename = os.path.join(lora_dir, lset_name + ".lset" )
|
||||||
else:
|
else:
|
||||||
lset_name_filename = os.path.join(lora_dir, lset_name )
|
lset_name_filename = os.path.join(lora_dir, lset_name )
|
||||||
|
error = ""
|
||||||
if not os.path.isfile(lset_name_filename):
|
if not os.path.isfile(lset_name_filename):
|
||||||
raise gr.Error(f"Preset '{lset_name}' not found ")
|
error = f"Preset '{lset_name}' not found "
|
||||||
|
else:
|
||||||
|
missing_loras = []
|
||||||
|
|
||||||
with open(lset_name_filename, "r", encoding="utf-8") as reader:
|
with open(lset_name_filename, "r", encoding="utf-8") as reader:
|
||||||
text = reader.read()
|
text = reader.read()
|
||||||
lset = json.loads(text)
|
lset = json.loads(text)
|
||||||
|
|
||||||
loras_choices_files = lset["loras"]
|
loras_choices_files = lset["loras"]
|
||||||
loras_choices = []
|
for lora_file in loras_choices_files:
|
||||||
missing_loras = []
|
choice = os.path.join(lora_dir, lora_file)
|
||||||
for lora_file in loras_choices_files:
|
if choice not in loras:
|
||||||
loras_choice_no = loras.index(os.path.join(lora_dir, lora_file))
|
missing_loras.append(lora_file)
|
||||||
if loras_choice_no < 0:
|
else:
|
||||||
missing_loras.append(lora_file)
|
loras_choice_no = loras.index(choice)
|
||||||
else:
|
loras_choices.append(str(loras_choice_no))
|
||||||
loras_choices.append(str(loras_choice_no))
|
|
||||||
|
|
||||||
if len(missing_loras) > 0:
|
if len(missing_loras) > 0:
|
||||||
raise gr.Error(f"Unable to apply Lora preset '{lset_name} because the following Loras files are missing: {missing_loras}")
|
error = f"Unable to apply Lora preset '{lset_name} because the following Loras files are missing or invalid: {missing_loras}"
|
||||||
|
|
||||||
loras_mult_choices = lset["loras_mult"]
|
loras_mult_choices = lset["loras_mult"]
|
||||||
prompt = lset.get("prompt", "")
|
prompt = lset.get("prompt", "")
|
||||||
return loras_choices, loras_mult_choices, prompt, lset.get("full_prompt", False)
|
full_prompt = lset.get("full_prompt", False)
|
||||||
|
return loras_choices, loras_mult_choices, prompt, full_prompt, error
|
||||||
|
|
||||||
def get_default_prompt(i2v):
|
def get_default_prompt(i2v):
|
||||||
if i2v:
|
if i2v:
|
||||||
@ -332,7 +352,7 @@ def get_default_prompt(i2v):
|
|||||||
return "A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect."
|
return "A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect."
|
||||||
|
|
||||||
|
|
||||||
def setup_loras(pipe, lora_dir, lora_preselected_preset, split_linear_modules_map = None):
|
def setup_loras(transformer, lora_dir, lora_preselected_preset, split_linear_modules_map = None):
|
||||||
loras =[]
|
loras =[]
|
||||||
loras_names = []
|
loras_names = []
|
||||||
default_loras_choices = []
|
default_loras_choices = []
|
||||||
@ -341,32 +361,35 @@ def setup_loras(pipe, lora_dir, lora_preselected_preset, split_linear_modules_m
|
|||||||
default_lora_preset = ""
|
default_lora_preset = ""
|
||||||
default_prompt = ""
|
default_prompt = ""
|
||||||
|
|
||||||
if use_image2video or not "1.3B" in transformer_filename_t2v:
|
from pathlib import Path
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
if lora_dir != None :
|
if lora_dir != None :
|
||||||
if not os.path.isdir(lora_dir):
|
if not os.path.isdir(lora_dir):
|
||||||
raise Exception("--lora-dir should be a path to a directory that contains Loras")
|
raise Exception("--lora-dir should be a path to a directory that contains Loras")
|
||||||
|
|
||||||
if lora_dir != None:
|
|
||||||
import glob
|
|
||||||
dir_loras = glob.glob( os.path.join(lora_dir , "*.sft") ) + glob.glob( os.path.join(lora_dir , "*.safetensors") )
|
|
||||||
dir_loras.sort()
|
|
||||||
loras += [element for element in dir_loras if element not in loras ]
|
|
||||||
|
|
||||||
dir_presets = glob.glob( os.path.join(lora_dir , "*.lset") )
|
if lora_dir != None:
|
||||||
dir_presets.sort()
|
import glob
|
||||||
loras_presets = [ Path(Path(file_path).parts[-1]).stem for file_path in dir_presets]
|
dir_loras = glob.glob( os.path.join(lora_dir , "*.sft") ) + glob.glob( os.path.join(lora_dir , "*.safetensors") )
|
||||||
|
dir_loras.sort()
|
||||||
|
loras += [element for element in dir_loras if element not in loras ]
|
||||||
|
|
||||||
if len(loras) > 0:
|
dir_presets = glob.glob( os.path.join(lora_dir , "*.lset") )
|
||||||
loras_names = [ Path(lora).stem for lora in loras ]
|
dir_presets.sort()
|
||||||
offload.load_loras_into_model(pipe["transformer"], loras, activate_all_loras=False, split_linear_modules_map = split_linear_modules_map) #lora_multiplier,
|
loras_presets = [ Path(Path(file_path).parts[-1]).stem for file_path in dir_presets]
|
||||||
|
|
||||||
if len(lora_preselected_preset) > 0:
|
loras = offload.load_loras_into_model(transformer, loras, activate_all_loras=False, check_only= True, split_linear_modules_map = split_linear_modules_map) #lora_multiplier,
|
||||||
if not os.path.isfile(os.path.join(lora_dir, lora_preselected_preset + ".lset")):
|
|
||||||
raise Exception(f"Unknown preset '{lora_preselected_preset}'")
|
if len(loras) > 0:
|
||||||
default_lora_preset = lora_preselected_preset
|
loras_names = [ Path(lora).stem for lora in loras ]
|
||||||
default_loras_choices, default_loras_multis_str, default_prompt, _ = extract_preset(default_lora_preset, loras)
|
|
||||||
|
if len(lora_preselected_preset) > 0:
|
||||||
|
if not os.path.isfile(os.path.join(lora_dir, lora_preselected_preset + ".lset")):
|
||||||
|
raise Exception(f"Unknown preset '{lora_preselected_preset}'")
|
||||||
|
default_lora_preset = lora_preselected_preset
|
||||||
|
default_loras_choices, default_loras_multis_str, default_prompt, _ , error = extract_preset(default_lora_preset, loras)
|
||||||
|
if len(error) > 0:
|
||||||
|
print(error[:200])
|
||||||
if len(default_prompt) == 0:
|
if len(default_prompt) == 0:
|
||||||
default_prompt = get_default_prompt(use_image2video)
|
default_prompt = get_default_prompt(use_image2video)
|
||||||
return loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets
|
return loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets
|
||||||
@ -450,8 +473,8 @@ def load_models(i2v, lora_dir, lora_preselected_preset ):
|
|||||||
kwargs["budgets"] = { "*" : "70%" }
|
kwargs["budgets"] = { "*" : "70%" }
|
||||||
|
|
||||||
|
|
||||||
loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = setup_loras(pipe, lora_dir, lora_preselected_preset, None)
|
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", **kwargs)
|
||||||
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, **kwargs)
|
loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = setup_loras(pipe["transformer"], lora_dir, lora_preselected_preset, None)
|
||||||
|
|
||||||
|
|
||||||
return wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets
|
return wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets
|
||||||
@ -542,7 +565,7 @@ def apply_changes( state,
|
|||||||
state["config_new"] = server_config
|
state["config_new"] = server_config
|
||||||
state["config_old"] = old_server_config
|
state["config_old"] = old_server_config
|
||||||
|
|
||||||
global attention_mode, profile, compile, transformer_filename_t2v, transformer_filename_i2v, text_encoder_filename, vae_config, boost
|
global attention_mode, profile, compile, transformer_filename_t2v, transformer_filename_i2v, text_encoder_filename, vae_config, boost, lora_dir
|
||||||
attention_mode = server_config["attention_mode"]
|
attention_mode = server_config["attention_mode"]
|
||||||
profile = server_config["profile"]
|
profile = server_config["profile"]
|
||||||
compile = server_config["compile"]
|
compile = server_config["compile"]
|
||||||
@ -560,7 +583,7 @@ def apply_changes( state,
|
|||||||
offloadobj.release()
|
offloadobj.release()
|
||||||
offloadobj = None
|
offloadobj = None
|
||||||
yield "<DIV ALIGN=CENTER>Please wait while the new configuration is being applied</DIV>"
|
yield "<DIV ALIGN=CENTER>Please wait while the new configuration is being applied</DIV>"
|
||||||
|
lora_dir = get_lora_dir(root_lora_dir)
|
||||||
wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset )
|
wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset )
|
||||||
|
|
||||||
|
|
||||||
@ -590,7 +613,11 @@ def update_defaults(state, num_inference_steps,flow_shift):
|
|||||||
num_inference_steps, flow_shift = get_default_flow(trans_file)
|
num_inference_steps, flow_shift = get_default_flow(trans_file)
|
||||||
|
|
||||||
header = generate_header(trans_file, server_config["compile"], server_config["attention_mode"] )
|
header = generate_header(trans_file, server_config["compile"], server_config["attention_mode"] )
|
||||||
return num_inference_steps, flow_shift, header
|
new_loras_choices = [ (loras_name, str(i)) for i,loras_name in enumerate(loras_names)]
|
||||||
|
lset_choices = [ (preset, preset) for preset in loras_presets]
|
||||||
|
lset_choices.append( (new_preset_msg, ""))
|
||||||
|
|
||||||
|
return num_inference_steps, flow_shift, header, gr.Dropdown(choices=lset_choices, value= ""), gr.Dropdown(choices=new_loras_choices, value= [])
|
||||||
|
|
||||||
|
|
||||||
from moviepy.editor import ImageSequenceClip
|
from moviepy.editor import ImageSequenceClip
|
||||||
@ -603,23 +630,32 @@ def save_video(final_frames, output_path, fps=24):
|
|||||||
ImageSequenceClip(list(final_frames), fps=fps).write_videofile(output_path, verbose= False, logger = None)
|
ImageSequenceClip(list(final_frames), fps=fps).write_videofile(output_path, verbose= False, logger = None)
|
||||||
|
|
||||||
def build_callback(state, pipe, progress, status, num_inference_steps):
|
def build_callback(state, pipe, progress, status, num_inference_steps):
|
||||||
def callback(step_idx, latents):
|
def callback(step_idx, latents, read_state = False):
|
||||||
step_idx += 1
|
status = state["progress_status"]
|
||||||
if state.get("abort", False):
|
if read_state:
|
||||||
# pipe._interrupt = True
|
phase, step_idx = state["progress_phase"]
|
||||||
status_msg = status + " - Aborting"
|
|
||||||
elif step_idx == num_inference_steps:
|
|
||||||
status_msg = status + " - VAE Decoding"
|
|
||||||
else:
|
else:
|
||||||
status_msg = status + " - Denoising"
|
step_idx += 1
|
||||||
|
if state.get("abort", False):
|
||||||
progress( (step_idx , num_inference_steps) , status_msg , num_inference_steps)
|
# pipe._interrupt = True
|
||||||
|
phase = " - Aborting"
|
||||||
|
elif step_idx == num_inference_steps:
|
||||||
|
phase = " - VAE Decoding"
|
||||||
|
else:
|
||||||
|
phase = " - Denoising"
|
||||||
|
state["progress_phase"] = (phase, step_idx)
|
||||||
|
status_msg = status + phase
|
||||||
|
if step_idx >= 0:
|
||||||
|
progress( (step_idx , num_inference_steps) , status_msg , num_inference_steps)
|
||||||
|
else:
|
||||||
|
progress(0, status_msg)
|
||||||
|
|
||||||
return callback
|
return callback
|
||||||
|
|
||||||
def abort_generation(state):
|
def abort_generation(state):
|
||||||
if "in_progress" in state:
|
if "in_progress" in state:
|
||||||
state["abort"] = True
|
state["abort"] = True
|
||||||
|
state["extra_orders"] = 0
|
||||||
wan_model._interrupt= True
|
wan_model._interrupt= True
|
||||||
return gr.Button(interactive= False)
|
return gr.Button(interactive= False)
|
||||||
else:
|
else:
|
||||||
@ -634,11 +670,12 @@ def finalize_gallery(state):
|
|||||||
if "in_progress" in state:
|
if "in_progress" in state:
|
||||||
del state["in_progress"]
|
del state["in_progress"]
|
||||||
choice = state.get("selected",0)
|
choice = state.get("selected",0)
|
||||||
|
|
||||||
|
state["extra_orders"] = 0
|
||||||
time.sleep(0.2)
|
time.sleep(0.2)
|
||||||
global gen_in_progress
|
global gen_in_progress
|
||||||
gen_in_progress = False
|
gen_in_progress = False
|
||||||
return gr.Gallery(selected_index=choice), gr.Button(interactive= True)
|
return gr.Gallery(selected_index=choice), gr.Button(interactive= True), gr.Button(visible= True), gr.Checkbox(visible= False)
|
||||||
|
|
||||||
def select_video(state , event_data: gr.EventData):
|
def select_video(state , event_data: gr.EventData):
|
||||||
data= event_data._data
|
data= event_data._data
|
||||||
@ -656,6 +693,32 @@ def expand_slist(slist, num_inference_steps ):
|
|||||||
return new_slist
|
return new_slist
|
||||||
|
|
||||||
|
|
||||||
|
def one_more_video(state):
|
||||||
|
extra_orders = state.get("extra_orders", 0)
|
||||||
|
extra_orders += 1
|
||||||
|
state["extra_orders"] = extra_orders
|
||||||
|
prompts_max = state["prompts_max"]
|
||||||
|
prompt_no = state["prompt_no"]
|
||||||
|
video_no = state["video_no"]
|
||||||
|
total_video = state["total_video"]
|
||||||
|
# total_video += (prompts_max- prompt_no)
|
||||||
|
total_video += 1
|
||||||
|
total_generation = state["total_generation"] + extra_orders
|
||||||
|
state["total_video"] = total_video
|
||||||
|
|
||||||
|
state["progress_status"] = f"Video {video_no}/{total_video}"
|
||||||
|
offload.shared_state["refresh"] = 1
|
||||||
|
# if (prompts_max - prompt_no) > 1:
|
||||||
|
# gr.Info(f"An extra video generation is planned for a total of {total_generation} videos for the next {prompts_max - prompt_no} prompts")
|
||||||
|
# else:
|
||||||
|
gr.Info(f"An extra video generation is planned for a total of {total_generation} videos for this prompt")
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
def prepare_generate_video():
|
||||||
|
|
||||||
|
return gr.Button(visible= False), gr.Checkbox(visible= True)
|
||||||
|
|
||||||
def generate_video(
|
def generate_video(
|
||||||
prompt,
|
prompt,
|
||||||
negative_prompt,
|
negative_prompt,
|
||||||
@ -733,6 +796,7 @@ def generate_video(
|
|||||||
else:
|
else:
|
||||||
VAE_tile_size = 128
|
VAE_tile_size = 128
|
||||||
|
|
||||||
|
trans = wan_model.model
|
||||||
|
|
||||||
global gen_in_progress
|
global gen_in_progress
|
||||||
gen_in_progress = True
|
gen_in_progress = True
|
||||||
@ -740,7 +804,7 @@ def generate_video(
|
|||||||
if len(prompt) ==0:
|
if len(prompt) ==0:
|
||||||
return
|
return
|
||||||
prompts = prompt.replace("\r", "").split("\n")
|
prompts = prompt.replace("\r", "").split("\n")
|
||||||
prompts = [prompt for prompt in prompts if len(prompt)>0 and not prompt.startswith("#")]
|
prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")]
|
||||||
if len(prompts) ==0:
|
if len(prompts) ==0:
|
||||||
return
|
return
|
||||||
if use_image2video:
|
if use_image2video:
|
||||||
@ -808,9 +872,13 @@ def generate_video(
|
|||||||
list_mult_choices_nums.append(float(mult))
|
list_mult_choices_nums.append(float(mult))
|
||||||
if len(list_mult_choices_nums ) < len(loras_choices):
|
if len(list_mult_choices_nums ) < len(loras_choices):
|
||||||
list_mult_choices_nums += [1.0] * ( len(loras_choices) - len(list_mult_choices_nums ) )
|
list_mult_choices_nums += [1.0] * ( len(loras_choices) - len(list_mult_choices_nums ) )
|
||||||
|
loras_selected = [ lora for i, lora in enumerate(loras) if str(i) in loras_choices]
|
||||||
offload.activate_loras(wan_model.model, loras_choices, list_mult_choices_nums)
|
pinnedLora = False # profile !=5
|
||||||
|
offload.load_loras_into_model(trans, loras_selected, list_mult_choices_nums, activate_all_loras=True, pinnedLora=pinnedLora, split_linear_modules_map = None)
|
||||||
|
errors = trans._loras_errors
|
||||||
|
if len(errors) > 0:
|
||||||
|
error_files = [msg for _ , msg in errors]
|
||||||
|
raise gr.Error("Error while loading Loras: " + ", ".join(error_files))
|
||||||
seed = None if seed == -1 else seed
|
seed = None if seed == -1 else seed
|
||||||
# negative_prompt = "" # not applicable in the inference
|
# negative_prompt = "" # not applicable in the inference
|
||||||
|
|
||||||
@ -825,7 +893,6 @@ def generate_video(
|
|||||||
|
|
||||||
joint_pass = boost ==1
|
joint_pass = boost ==1
|
||||||
# TeaCache
|
# TeaCache
|
||||||
trans = wan_model.model
|
|
||||||
trans.enable_teacache = tea_cache > 0
|
trans.enable_teacache = tea_cache > 0
|
||||||
if trans.enable_teacache:
|
if trans.enable_teacache:
|
||||||
if use_image2video:
|
if use_image2video:
|
||||||
@ -857,11 +924,23 @@ def generate_video(
|
|||||||
os.makedirs(save_path, exist_ok=True)
|
os.makedirs(save_path, exist_ok=True)
|
||||||
video_no = 0
|
video_no = 0
|
||||||
total_video = repeat_generation * len(prompts)
|
total_video = repeat_generation * len(prompts)
|
||||||
|
state["total_video"] = total_video
|
||||||
|
extra_generation = 0
|
||||||
abort = False
|
abort = False
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
for prompt in prompts:
|
state["prompts_max"] = len(prompts)
|
||||||
for _ in range(repeat_generation):
|
for no, prompt in enumerate(prompts):
|
||||||
if abort:
|
repeat_no = 0
|
||||||
|
state["prompt_no"] = no
|
||||||
|
extra_generation = 0
|
||||||
|
while True:
|
||||||
|
extra_orders = state.get("extra_orders",0)
|
||||||
|
state["extra_orders"] = 0
|
||||||
|
extra_generation += extra_orders
|
||||||
|
state["total_generation"] = repeat_generation + extra_generation
|
||||||
|
# total_video += (len(prompts)- no) * extra_orders
|
||||||
|
total_video += extra_orders
|
||||||
|
if abort or repeat_no >= (repeat_generation + extra_generation):
|
||||||
break
|
break
|
||||||
|
|
||||||
if trans.enable_teacache:
|
if trans.enable_teacache:
|
||||||
@ -875,9 +954,12 @@ def generate_video(
|
|||||||
|
|
||||||
video_no += 1
|
video_no += 1
|
||||||
status = f"Video {video_no}/{total_video}"
|
status = f"Video {video_no}/{total_video}"
|
||||||
|
state["video_no"] = video_no
|
||||||
|
state["progress_status"] = status
|
||||||
|
state["progress_phase"] = (" - Encoding Prompt", -1 )
|
||||||
progress(0, desc=status + " - Encoding Prompt" )
|
progress(0, desc=status + " - Encoding Prompt" )
|
||||||
|
|
||||||
callback = build_callback(state, trans, progress, status, num_inference_steps)
|
callback = build_callback(state, trans, progress, status, num_inference_steps)
|
||||||
|
offload.shared_state["callback"] = callback
|
||||||
|
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
@ -887,7 +969,7 @@ def generate_video(
|
|||||||
if use_image2video:
|
if use_image2video:
|
||||||
samples = wan_model.generate(
|
samples = wan_model.generate(
|
||||||
prompt,
|
prompt,
|
||||||
image_to_continue[ (video_no-1) % len(image_to_continue)].convert('RGB'),
|
image_to_continue[no].convert('RGB'),
|
||||||
frame_num=(video_length // 4)* 4 + 1,
|
frame_num=(video_length // 4)* 4 + 1,
|
||||||
max_area=MAX_AREA_CONFIGS[resolution],
|
max_area=MAX_AREA_CONFIGS[resolution],
|
||||||
shift=flow_shift,
|
shift=flow_shift,
|
||||||
@ -923,6 +1005,7 @@ def generate_video(
|
|||||||
if temp_filename!= None and os.path.isfile(temp_filename):
|
if temp_filename!= None and os.path.isfile(temp_filename):
|
||||||
os.remove(temp_filename)
|
os.remove(temp_filename)
|
||||||
offload.last_offload_obj.unload_all()
|
offload.last_offload_obj.unload_all()
|
||||||
|
offload.unload_loras_from_model(trans)
|
||||||
# if compile:
|
# if compile:
|
||||||
# cache_size = torch._dynamo.config.cache_size_limit
|
# cache_size = torch._dynamo.config.cache_size_limit
|
||||||
# torch.compiler.reset()
|
# torch.compiler.reset()
|
||||||
@ -967,9 +1050,9 @@ def generate_video(
|
|||||||
|
|
||||||
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
|
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
|
||||||
if os.name == 'nt':
|
if os.name == 'nt':
|
||||||
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50])}.mp4"
|
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50]).strip()}.mp4"
|
||||||
else:
|
else:
|
||||||
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:100])}.mp4"
|
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:100]).strip()}.mp4"
|
||||||
video_path = os.path.join(os.getcwd(), "gradio_outputs", file_name)
|
video_path = os.path.join(os.getcwd(), "gradio_outputs", file_name)
|
||||||
cache_video(
|
cache_video(
|
||||||
tensor=sample[None],
|
tensor=sample[None],
|
||||||
@ -987,10 +1070,13 @@ def generate_video(
|
|||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
yield f"Total Generation Time: {end_time-start_time:.1f}s"
|
yield f"Total Generation Time: {end_time-start_time:.1f}s"
|
||||||
seed += 1
|
seed += 1
|
||||||
|
repeat_no += 1
|
||||||
|
|
||||||
if temp_filename!= None and os.path.isfile(temp_filename):
|
if temp_filename!= None and os.path.isfile(temp_filename):
|
||||||
os.remove(temp_filename)
|
os.remove(temp_filename)
|
||||||
gen_in_progress = False
|
gen_in_progress = False
|
||||||
|
offload.unload_loras_from_model(trans)
|
||||||
|
|
||||||
|
|
||||||
new_preset_msg = "Enter a Name for a Lora Preset or Choose One Above"
|
new_preset_msg = "Enter a Name for a Lora Preset or Choose One Above"
|
||||||
|
|
||||||
@ -998,19 +1084,19 @@ new_preset_msg = "Enter a Name for a Lora Preset or Choose One Above"
|
|||||||
def validate_delete_lset(lset_name):
|
def validate_delete_lset(lset_name):
|
||||||
if len(lset_name) == 0 or lset_name == new_preset_msg:
|
if len(lset_name) == 0 or lset_name == new_preset_msg:
|
||||||
gr.Info(f"Choose a Preset to delete")
|
gr.Info(f"Choose a Preset to delete")
|
||||||
return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False)
|
return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False)
|
||||||
else:
|
else:
|
||||||
return gr.Button(visible= False), gr.Checkbox(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True)
|
return gr.Button(visible= False), gr.Checkbox(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True)
|
||||||
|
|
||||||
def validate_save_lset(lset_name):
|
def validate_save_lset(lset_name):
|
||||||
if len(lset_name) == 0 or lset_name == new_preset_msg:
|
if len(lset_name) == 0 or lset_name == new_preset_msg:
|
||||||
gr.Info("Please enter a name for the preset")
|
gr.Info("Please enter a name for the preset")
|
||||||
return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False),gr.Checkbox(visible= False)
|
return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False),gr.Checkbox(visible= False)
|
||||||
else:
|
else:
|
||||||
return gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True),gr.Checkbox(visible= True)
|
return gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True),gr.Checkbox(visible= True)
|
||||||
|
|
||||||
def cancel_lset():
|
def cancel_lset():
|
||||||
return gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
|
return gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
|
||||||
|
|
||||||
def save_lset(lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_cbox):
|
def save_lset(lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_cbox):
|
||||||
global loras_presets
|
global loras_presets
|
||||||
@ -1047,7 +1133,7 @@ def save_lset(lset_name, loras_choices, loras_mult_choices, prompt, save_lset_pr
|
|||||||
lset_choices = [ ( preset, preset) for preset in loras_presets ]
|
lset_choices = [ ( preset, preset) for preset in loras_presets ]
|
||||||
lset_choices.append( (new_preset_msg, ""))
|
lset_choices.append( (new_preset_msg, ""))
|
||||||
|
|
||||||
return gr.Dropdown(choices=lset_choices, value= lset_name), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
|
return gr.Dropdown(choices=lset_choices, value= lset_name), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
|
||||||
|
|
||||||
def delete_lset(lset_name):
|
def delete_lset(lset_name):
|
||||||
global loras_presets
|
global loras_presets
|
||||||
@ -1065,23 +1151,57 @@ def delete_lset(lset_name):
|
|||||||
|
|
||||||
lset_choices = [ (preset, preset) for preset in loras_presets]
|
lset_choices = [ (preset, preset) for preset in loras_presets]
|
||||||
lset_choices.append((new_preset_msg, ""))
|
lset_choices.append((new_preset_msg, ""))
|
||||||
return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Checkbox(visible= False)
|
return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Checkbox(visible= False)
|
||||||
|
|
||||||
|
def refresh_lora_list(lset_name, loras_choices):
|
||||||
|
global loras,loras_names, loras_presets
|
||||||
|
prev_lora_names_selected = [ loras_names[int(i)] for i in loras_choices]
|
||||||
|
|
||||||
|
loras, loras_names, _, _, _, _, loras_presets = setup_loras(wan_model.model, lora_dir, lora_preselected_preset, None)
|
||||||
|
gc.collect()
|
||||||
|
new_loras_choices = [ (loras_name, str(i)) for i,loras_name in enumerate(loras_names)]
|
||||||
|
new_loras_dict = { loras_name: str(i) for i,loras_name in enumerate(loras_names) }
|
||||||
|
lora_names_selected = []
|
||||||
|
for lora in prev_lora_names_selected:
|
||||||
|
lora_id = new_loras_dict.get(lora, None)
|
||||||
|
if lora_id!= None:
|
||||||
|
lora_names_selected.append(lora_id)
|
||||||
|
|
||||||
|
lset_choices = [ (preset, preset) for preset in loras_presets]
|
||||||
|
lset_choices.append((new_preset_msg, ""))
|
||||||
|
if lset_name in loras_presets:
|
||||||
|
pos = loras_presets.index(lset_name)
|
||||||
|
else:
|
||||||
|
pos = len(loras_presets)
|
||||||
|
lset_name =""
|
||||||
|
|
||||||
|
errors = wan_model.model._loras_errors
|
||||||
|
if len(errors) > 0:
|
||||||
|
error_files = [path for path, _ in errors]
|
||||||
|
gr.Info("Error while refreshing Lora List, invalid Lora files: " + ", ".join(error_files))
|
||||||
|
else:
|
||||||
|
gr.Info("Lora List has been refreshed")
|
||||||
|
|
||||||
|
|
||||||
|
return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Dropdown(choices=new_loras_choices, value= lora_names_selected)
|
||||||
|
|
||||||
def apply_lset(lset_name, loras_choices, loras_mult_choices, prompt):
|
def apply_lset(lset_name, loras_choices, loras_mult_choices, prompt):
|
||||||
|
|
||||||
if len(lset_name) == 0 or lset_name== new_preset_msg:
|
if len(lset_name) == 0 or lset_name== new_preset_msg:
|
||||||
gr.Info("Please choose a preset in the list or create one")
|
gr.Info("Please choose a preset in the list or create one")
|
||||||
else:
|
else:
|
||||||
loras_choices, loras_mult_choices, preset_prompt, full_prompt = extract_preset(lset_name, loras)
|
loras_choices, loras_mult_choices, preset_prompt, full_prompt, error = extract_preset(lset_name, loras)
|
||||||
if full_prompt:
|
if len(error) > 0:
|
||||||
prompt = preset_prompt
|
gr.Info(error)
|
||||||
elif len(preset_prompt) > 0:
|
else:
|
||||||
prompts = prompt.replace("\r", "").split("\n")
|
if full_prompt:
|
||||||
prompts = [prompt for prompt in prompts if len(prompt)>0 and not prompt.startswith("#")]
|
prompt = preset_prompt
|
||||||
prompt = "\n".join(prompts)
|
elif len(preset_prompt) > 0:
|
||||||
prompt = preset_prompt + '\n' + prompt
|
prompts = prompt.replace("\r", "").split("\n")
|
||||||
gr.Info(f"Lora Preset '{lset_name}' has been applied")
|
prompts = [prompt for prompt in prompts if len(prompt)>0 and not prompt.startswith("#")]
|
||||||
|
prompt = "\n".join(prompts)
|
||||||
|
prompt = preset_prompt + '\n' + prompt
|
||||||
|
gr.Info(f"Lora Preset '{lset_name}' has been applied")
|
||||||
|
|
||||||
return loras_choices, loras_mult_choices, prompt
|
return loras_choices, loras_mult_choices, prompt
|
||||||
|
|
||||||
@ -1094,21 +1214,21 @@ def create_demo():
|
|||||||
state = gr.State({})
|
state = gr.State({})
|
||||||
|
|
||||||
if use_image2video:
|
if use_image2video:
|
||||||
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v1 - AI Image To Video Generator (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</H1></div>")
|
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v1.7 - AI Image To Video Generator (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</H1></div>")
|
||||||
else:
|
else:
|
||||||
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v1 - AI Text To Video Generator (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</H1></div>")
|
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v1.7 - AI Text To Video Generator (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</H1></div>")
|
||||||
|
|
||||||
gr.Markdown("<FONT SIZE=3>With this first release of Wan 2.1GP by <B>DeepBeepMeep</B>, the VRAM requirements have been divided by more than 2 with no quality loss</FONT>")
|
gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP by <B>DeepBeepMeep</B>, a super fast and low VRAM Video Generator !</FONT>")
|
||||||
|
|
||||||
if use_image2video and False:
|
if use_image2video and False:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
gr.Markdown("The VRAM requirements will depend greatly of the resolution and the duration of the video, for instance : 24 GB of VRAM (RTX 3090 / RTX 4090), the limits are as follows:")
|
gr.Markdown("The VRAM requirements will depend greatly of the resolution and the duration of the video, for instance :")
|
||||||
gr.Markdown("- 848 x 480 with a 14B model: 80 frames (5s) : 8 GB of VRAM")
|
gr.Markdown("- 848 x 480 with a 14B model: 80 frames (5s) : 8 GB of VRAM")
|
||||||
gr.Markdown("- 848 x 480 with the 1.3B model: 80 frames (5s) : 5 GB of VRAM")
|
gr.Markdown("- 848 x 480 with the 1.3B model: 80 frames (5s) : 5 GB of VRAM")
|
||||||
gr.Markdown("- 1280 x 720 with a 14B model: 80 frames (5s): 11 GB of VRAM")
|
gr.Markdown("- 1280 x 720 with a 14B model: 80 frames (5s): 11 GB of VRAM")
|
||||||
gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifacts may appear")
|
gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifacts may appear")
|
||||||
gr.Markdown("Please note that if your turn on compilation, the first generation step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
|
gr.Markdown("Please note that if your turn on compilation, the first denoising step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
|
||||||
|
|
||||||
|
|
||||||
# css = """<STYLE>
|
# css = """<STYLE>
|
||||||
@ -1302,6 +1422,7 @@ def create_demo():
|
|||||||
# with gr.Column():
|
# with gr.Column():
|
||||||
with gr.Row(height=17):
|
with gr.Row(height=17):
|
||||||
apply_lset_btn = gr.Button("Apply Lora Preset", size="sm", min_width= 1)
|
apply_lset_btn = gr.Button("Apply Lora Preset", size="sm", min_width= 1)
|
||||||
|
refresh_lora_btn = gr.Button("Refresh Lora List", size="sm", min_width= 1)
|
||||||
# save_lset_prompt_cbox = gr.Checkbox(label="Save Prompt Comments in Preset", value=False, visible= False)
|
# save_lset_prompt_cbox = gr.Checkbox(label="Save Prompt Comments in Preset", value=False, visible= False)
|
||||||
save_lset_prompt_drop= gr.Dropdown(
|
save_lset_prompt_drop= gr.Dropdown(
|
||||||
choices=[
|
choices=[
|
||||||
@ -1334,7 +1455,7 @@ def create_demo():
|
|||||||
with gr.Row(visible=False) as advanced_row:
|
with gr.Row(visible=False) as advanced_row:
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
seed = gr.Slider(-1, 999999999, value=-1, step=1, label="Seed (-1 for random)")
|
seed = gr.Slider(-1, 999999999, value=-1, step=1, label="Seed (-1 for random)")
|
||||||
repeat_generation = gr.Slider(1, 25.0, value=1.0, step=1, label="Number of Generated Video per prompt")
|
repeat_generation = gr.Slider(1, 25.0, value=1.0, step=1, label="Default Number of Generated Videos per Prompt")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
negative_prompt = gr.Textbox(label="Negative Prompt", value="")
|
negative_prompt = gr.Textbox(label="Negative Prompt", value="")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -1377,22 +1498,25 @@ def create_demo():
|
|||||||
label="Generated videos", show_label=False, elem_id="gallery"
|
label="Generated videos", show_label=False, elem_id="gallery"
|
||||||
, columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= False)
|
, columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= False)
|
||||||
generate_btn = gr.Button("Generate")
|
generate_btn = gr.Button("Generate")
|
||||||
|
onemore_btn = gr.Button("One More Please !", visible= False)
|
||||||
abort_btn = gr.Button("Abort")
|
abort_btn = gr.Button("Abort")
|
||||||
|
|
||||||
save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
|
save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
|
||||||
confirm_save_lset_btn.click(save_lset, inputs=[lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
|
confirm_save_lset_btn.click(save_lset, inputs=[lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn,refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
|
||||||
delete_lset_btn.click(validate_delete_lset, inputs=[lset_name], outputs=[apply_lset_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
|
delete_lset_btn.click(validate_delete_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
|
||||||
confirm_delete_lset_btn.click(delete_lset, inputs=[lset_name], outputs=[lset_name, apply_lset_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
|
confirm_delete_lset_btn.click(delete_lset, inputs=[lset_name], outputs=[lset_name, apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
|
||||||
cancel_lset_btn.click(cancel_lset, inputs=[], outputs=[apply_lset_btn, delete_lset_btn, save_lset_btn, confirm_delete_lset_btn,confirm_save_lset_btn, cancel_lset_btn,save_lset_prompt_drop ])
|
cancel_lset_btn.click(cancel_lset, inputs=[], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_delete_lset_btn,confirm_save_lset_btn, cancel_lset_btn,save_lset_prompt_drop ])
|
||||||
|
|
||||||
apply_lset_btn.click(apply_lset, inputs=[lset_name,loras_choices, loras_mult_choices, prompt], outputs=[loras_choices, loras_mult_choices, prompt])
|
apply_lset_btn.click(apply_lset, inputs=[lset_name,loras_choices, loras_mult_choices, prompt], outputs=[loras_choices, loras_mult_choices, prompt])
|
||||||
|
|
||||||
|
refresh_lora_btn.click(refresh_lora_list, inputs=[lset_name,loras_choices], outputs=[lset_name, loras_choices])
|
||||||
|
|
||||||
gen_status.change(refresh_gallery, inputs = [state], outputs = output )
|
gen_status.change(refresh_gallery, inputs = [state], outputs = output )
|
||||||
|
|
||||||
abort_btn.click(abort_generation,state,abort_btn )
|
abort_btn.click(abort_generation,state,abort_btn )
|
||||||
output.select(select_video, state, None )
|
output.select(select_video, state, None )
|
||||||
|
onemore_btn.click(fn=one_more_video,inputs=[state], outputs= [state])
|
||||||
generate_btn.click(
|
generate_btn.click(fn=prepare_generate_video,inputs=[], outputs= [generate_btn, onemore_btn]).then(
|
||||||
fn=generate_video,
|
fn=generate_video,
|
||||||
inputs=[
|
inputs=[
|
||||||
prompt,
|
prompt,
|
||||||
@ -1420,7 +1544,7 @@ def create_demo():
|
|||||||
).then(
|
).then(
|
||||||
finalize_gallery,
|
finalize_gallery,
|
||||||
[state],
|
[state],
|
||||||
[output , abort_btn]
|
[output , abort_btn, generate_btn, onemore_btn]
|
||||||
)
|
)
|
||||||
|
|
||||||
apply_btn.click(
|
apply_btn.click(
|
||||||
@ -1441,7 +1565,7 @@ def create_demo():
|
|||||||
).then(
|
).then(
|
||||||
update_defaults,
|
update_defaults,
|
||||||
[state, num_inference_steps, flow_shift],
|
[state, num_inference_steps, flow_shift],
|
||||||
[num_inference_steps, flow_shift, header]
|
[num_inference_steps, flow_shift, header, lset_name , loras_choices ]
|
||||||
)
|
)
|
||||||
|
|
||||||
return demo
|
return demo
|
||||||
|
|||||||
@ -16,5 +16,5 @@ gradio>=5.0.0
|
|||||||
numpy>=1.23.5,<2
|
numpy>=1.23.5,<2
|
||||||
einops
|
einops
|
||||||
moviepy==1.0.3
|
moviepy==1.0.3
|
||||||
mmgp==3.2.6
|
mmgp==3.2.7
|
||||||
peft==0.14.0
|
peft==0.14.0
|
||||||
@ -331,7 +331,7 @@ class WanI2V:
|
|||||||
callback(-1, None)
|
callback(-1, None)
|
||||||
|
|
||||||
for i, t in enumerate(tqdm(timesteps)):
|
for i, t in enumerate(tqdm(timesteps)):
|
||||||
offload.set_step_no_for_lora(i)
|
offload.set_step_no_for_lora(self.model, i)
|
||||||
latent_model_input = [latent.to(self.device)]
|
latent_model_input = [latent.to(self.device)]
|
||||||
timestep = [t]
|
timestep = [t]
|
||||||
|
|
||||||
|
|||||||
@ -60,6 +60,30 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
sageattn = None
|
sageattn = None
|
||||||
|
|
||||||
|
# # try:
|
||||||
|
# if True:
|
||||||
|
# from sageattention import sageattn_qk_int8_pv_fp8_window_cuda
|
||||||
|
# @torch.compiler.disable()
|
||||||
|
# def sageattn_window_wrapper(
|
||||||
|
# qkv_list,
|
||||||
|
# attention_length,
|
||||||
|
# window
|
||||||
|
# ):
|
||||||
|
# q,k, v = qkv_list
|
||||||
|
# padding_length = q.shape[0] -attention_length
|
||||||
|
# q = q[:attention_length, :, : ].unsqueeze(0)
|
||||||
|
# k = k[:attention_length, :, : ].unsqueeze(0)
|
||||||
|
# v = v[:attention_length, :, : ].unsqueeze(0)
|
||||||
|
# o = sageattn_qk_int8_pv_fp8_window_cuda(q, k, v, tensor_layout="NHD", window = window).squeeze(0)
|
||||||
|
# del q, k ,v
|
||||||
|
# qkv_list.clear()
|
||||||
|
|
||||||
|
# if padding_length > 0:
|
||||||
|
# o = torch.cat([o, torch.empty( (padding_length, *o.shape[-2:]), dtype= o.dtype, device=o.device ) ], 0)
|
||||||
|
|
||||||
|
# return o
|
||||||
|
# # except ImportError:
|
||||||
|
# # sageattn = sageattn_qk_int8_pv_fp8_window_cuda
|
||||||
|
|
||||||
@torch.compiler.disable()
|
@torch.compiler.disable()
|
||||||
def sdpa_wrapper(
|
def sdpa_wrapper(
|
||||||
@ -119,7 +143,8 @@ def pay_attention(
|
|||||||
deterministic=False,
|
deterministic=False,
|
||||||
dtype=torch.bfloat16,
|
dtype=torch.bfloat16,
|
||||||
version=None,
|
version=None,
|
||||||
force_attention= None
|
force_attention= None,
|
||||||
|
cross_attn= False
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
q: [B, Lq, Nq, C1].
|
q: [B, Lq, Nq, C1].
|
||||||
@ -194,9 +219,67 @@ def pay_attention(
|
|||||||
max_seqlen_kv=lk,
|
max_seqlen_kv=lk,
|
||||||
).unflatten(0, (b, lq))
|
).unflatten(0, (b, lq))
|
||||||
elif attn=="sage2":
|
elif attn=="sage2":
|
||||||
qkv_list = [q,k,v]
|
import math
|
||||||
del q,k,v
|
if cross_attn or True:
|
||||||
x = sageattn_wrapper(qkv_list, lq).unsqueeze(0)
|
qkv_list = [q,k,v]
|
||||||
|
del q,k,v
|
||||||
|
|
||||||
|
x = sageattn_wrapper(qkv_list, lq).unsqueeze(0)
|
||||||
|
# else:
|
||||||
|
# layer = offload.shared_state["layer"]
|
||||||
|
# embed_sizes = offload.shared_state["embed_sizes"]
|
||||||
|
# current_step = offload.shared_state["step_no"]
|
||||||
|
# max_steps = offload.shared_state["max_steps"]
|
||||||
|
|
||||||
|
|
||||||
|
# nb_latents = embed_sizes[0] * embed_sizes[1]* embed_sizes[2]
|
||||||
|
|
||||||
|
# window = 0
|
||||||
|
# start_window_step = int(max_steps * 0.4)
|
||||||
|
# start_layer = 10
|
||||||
|
# if (layer < start_layer ) or current_step <start_window_step:
|
||||||
|
# window = 0
|
||||||
|
# else:
|
||||||
|
# coef = min((max_steps - current_step)/(max_steps-start_window_step),1)*max(min((25 - layer)/(25-start_layer),1),0) * 0.7 + 0.3
|
||||||
|
# print(f"step: {current_step}, layer: {layer}, coef:{coef:0.1f}]")
|
||||||
|
# window = math.ceil(coef* nb_latents)
|
||||||
|
|
||||||
|
# invert_spaces = (layer + current_step) % 2 == 0 and window > 0
|
||||||
|
|
||||||
|
# def flip(q):
|
||||||
|
# q = q.reshape(*embed_sizes, *q.shape[-2:])
|
||||||
|
# q = q.transpose(0,2)
|
||||||
|
# q = q.contiguous()
|
||||||
|
# q = q.transpose(0,2)
|
||||||
|
# q = q.reshape( -1, *q.shape[-2:])
|
||||||
|
# return q
|
||||||
|
|
||||||
|
# def flop(q):
|
||||||
|
# q = q.reshape(embed_sizes[2], embed_sizes[1], embed_sizes[0] , *q.shape[-2:])
|
||||||
|
# q = q.transpose(0,2)
|
||||||
|
# q = q.contiguous()
|
||||||
|
# q = q.transpose(0,2)
|
||||||
|
# q = q.reshape( -1, *q.shape[-2:])
|
||||||
|
# return q
|
||||||
|
|
||||||
|
|
||||||
|
# if invert_spaces:
|
||||||
|
|
||||||
|
# q = flip(q)
|
||||||
|
# k = flip(k)
|
||||||
|
# v = flip(v)
|
||||||
|
# qkv_list = [q,k,v]
|
||||||
|
# del q,k,v
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# x = sageattn_window_wrapper(qkv_list, lq, window= window) #.unsqueeze(0)
|
||||||
|
|
||||||
|
# if invert_spaces:
|
||||||
|
# x = flop(x)
|
||||||
|
# x = x.unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
elif attn=="sdpa":
|
elif attn=="sdpa":
|
||||||
qkv_list = [q, k, v]
|
qkv_list = [q, k, v]
|
||||||
del q, k , v
|
del q, k , v
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from diffusers.configuration_utils import ConfigMixin, register_to_config
|
|||||||
from diffusers.models.modeling_utils import ModelMixin
|
from diffusers.models.modeling_utils import ModelMixin
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Union,Optional
|
from typing import Union,Optional
|
||||||
|
from mmgp import offload
|
||||||
from .attention import pay_attention
|
from .attention import pay_attention
|
||||||
|
|
||||||
__all__ = ['WanModel']
|
__all__ = ['WanModel']
|
||||||
@ -302,7 +302,7 @@ class WanT2VCrossAttention(WanSelfAttention):
|
|||||||
# compute attention
|
# compute attention
|
||||||
qvl_list=[q, k, v]
|
qvl_list=[q, k, v]
|
||||||
del q, k, v
|
del q, k, v
|
||||||
x = pay_attention(qvl_list, k_lens=context_lens)
|
x = pay_attention(qvl_list, k_lens=context_lens, cross_attn= True)
|
||||||
|
|
||||||
# output
|
# output
|
||||||
x = x.flatten(2)
|
x = x.flatten(2)
|
||||||
@ -716,7 +716,9 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
pipeline = None,
|
pipeline = None,
|
||||||
current_step = 0,
|
current_step = 0,
|
||||||
context2 = None,
|
context2 = None,
|
||||||
is_uncond=False
|
is_uncond=False,
|
||||||
|
max_steps = 0
|
||||||
|
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Forward pass through the diffusion model
|
Forward pass through the diffusion model
|
||||||
@ -755,6 +757,12 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
# [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
# [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
||||||
|
|
||||||
grid_sizes = [ list(u.shape[2:]) for u in x]
|
grid_sizes = [ list(u.shape[2:]) for u in x]
|
||||||
|
embed_sizes = grid_sizes[0]
|
||||||
|
|
||||||
|
offload.shared_state["embed_sizes"] = embed_sizes
|
||||||
|
offload.shared_state["step_no"] = current_step
|
||||||
|
offload.shared_state["max_steps"] = max_steps
|
||||||
|
|
||||||
|
|
||||||
x = [u.flatten(2).transpose(1, 2) for u in x]
|
x = [u.flatten(2).transpose(1, 2) for u in x]
|
||||||
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
||||||
@ -843,7 +851,11 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
# context=context,
|
# context=context,
|
||||||
context_lens=context_lens)
|
context_lens=context_lens)
|
||||||
|
|
||||||
for block in self.blocks:
|
for l, block in enumerate(self.blocks):
|
||||||
|
offload.shared_state["layer"] = l
|
||||||
|
if "refresh" in offload.shared_state:
|
||||||
|
del offload.shared_state["refresh"]
|
||||||
|
offload.shared_state["callback"](-1, -1, True)
|
||||||
if pipeline._interrupt:
|
if pipeline._interrupt:
|
||||||
if joint_pass:
|
if joint_pass:
|
||||||
return None, None
|
return None, None
|
||||||
|
|||||||
@ -243,6 +243,10 @@ class WanT2V:
|
|||||||
arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
||||||
arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
||||||
|
|
||||||
|
# arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
|
||||||
|
# arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
|
||||||
|
# arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
|
||||||
|
|
||||||
if self.model.enable_teacache:
|
if self.model.enable_teacache:
|
||||||
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
||||||
if callback != None:
|
if callback != None:
|
||||||
@ -250,7 +254,7 @@ class WanT2V:
|
|||||||
for i, t in enumerate(tqdm(timesteps)):
|
for i, t in enumerate(tqdm(timesteps)):
|
||||||
latent_model_input = latents
|
latent_model_input = latents
|
||||||
timestep = [t]
|
timestep = [t]
|
||||||
offload.set_step_no_for_lora(i)
|
offload.set_step_no_for_lora(self.model, i)
|
||||||
timestep = torch.stack(timestep)
|
timestep = torch.stack(timestep)
|
||||||
|
|
||||||
# self.model.to(self.device)
|
# self.model.to(self.device)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user