Added support for more Loras format

This commit is contained in:
DeepBeepMeep 2025-03-18 21:29:43 +01:00
parent c1b9b94143
commit b4afced9ce
3 changed files with 66 additions and 10 deletions

View File

@ -19,7 +19,9 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
## 🔥 Latest News!!
* Marc 17 2022: 👋 Wan2.1GP v2.0: The Lora festival continues:
* Mar 18 2022: 👋 Wan2.1GP v2.1: More Loras !: added support for 'Safetensors' and 'Replicate' Lora formats.\
You will need to refresh the requirements with a *pip install -r requirements.txt*
* Mar 17 2022: 👋 Wan2.1GP v2.0: The Lora festival continues:
- Clearer user interface
- Download 30 Loras in one click to try them all (expand the info section)
- Very to use Loras as now Lora presets can input the subject (or other need terms) of the Lora so that you dont have to modify manually a prompt

View File

@ -23,7 +23,12 @@ import asyncio
from wan.utils import prompt_parser
PROMPT_VARS_MAX = 10
target_mmgp_version = "3.2.8"
from importlib.metadata import version
mmgp_version = version("mmgp")
if mmgp_version != target_mmgp_version:
print(f"Incorrect version of mmgp ({mmgp_version}), version {target_mmgp_version} is needed. Please upgrade with the command 'pip install -r requirements.txt'")
exit()
def _parse_args():
parser = argparse.ArgumentParser(
description="Generate a video from a text prompt or image using Gradio")
@ -77,8 +82,7 @@ def _parse_args():
parser.add_argument(
"--check-loras",
type=str,
default=0,
action="store_true",
help="Filter Loras that are not valid"
)
@ -346,6 +350,54 @@ if args.compile: #args.fastest or
#attention_mode="xformers"
# compile = "transformer"
def preprocess_loras(sd):
if not use_image2video:
return sd
new_sd = {}
first = next(iter(sd), None)
if first == None:
return sd
if not first.startswith("lora_unet_"):
return sd
print("Converting Lora Safetensors format to Lora Diffusers format")
alphas = {}
repl_list = ["cross_attn", "self_attn", "ffn"]
src_list = ["_" + k + "_" for k in repl_list]
tgt_list = ["." + k + "." for k in repl_list]
for k,v in sd.items():
k = k.replace("lora_unet_blocks_","diffusion_model.blocks.")
for s,t in zip(src_list, tgt_list):
k = k.replace(s,t)
k = k.replace("lora_up","lora_B")
k = k.replace("lora_down","lora_A")
if "alpha" in k:
alphas[k] = v
else:
new_sd[k] = v
new_alphas = {}
for k,v in new_sd.items():
if "lora_B" in k:
dim = v.shape[1]
elif "lora_A" in k:
dim = v.shape[0]
else:
continue
alpha_key = k[:-len("lora_X.weight")] +"alpha"
if alpha_key in alphas:
scale = alphas[alpha_key] / dim
new_alphas[alpha_key] = scale
else:
print(f"Lora alpha'{alpha_key}' is missing")
new_sd.update(new_alphas)
return new_sd
def download_models(transformer_filename, text_encoder_filename):
def computeList(filename):
pos = filename.rfind("/")
@ -450,7 +502,7 @@ def setup_loras(transformer, lora_dir, lora_preselected_preset, split_linear_mo
loras_presets = [ Path(Path(file_path).parts[-1]).stem for file_path in dir_presets]
if check_loras:
loras = offload.load_loras_into_model(transformer, loras, activate_all_loras=False, check_only= True, split_linear_modules_map = split_linear_modules_map) #lora_multiplier,
loras = offload.load_loras_into_model(transformer, loras, activate_all_loras=False, check_only= True, preprocess_sd=preprocess_loras, split_linear_modules_map = split_linear_modules_map) #lora_multiplier,
if len(loras) > 0:
loras_names = [ Path(lora).stem for lora in loras ]
@ -986,7 +1038,7 @@ def generate_video(
list_mult_choices_nums += [1.0] * ( len(loras_choices) - len(list_mult_choices_nums ) )
loras_selected = [ lora for i, lora in enumerate(loras) if str(i) in loras_choices]
pinnedLora = False # profile !=5
offload.load_loras_into_model(trans, loras_selected, list_mult_choices_nums, activate_all_loras=True, pinnedLora=pinnedLora, split_linear_modules_map = None)
offload.load_loras_into_model(trans, loras_selected, list_mult_choices_nums, activate_all_loras=True, preprocess_sd=preprocess_loras, pinnedLora=pinnedLora, split_linear_modules_map = None)
errors = trans._loras_errors
if len(errors) > 0:
error_files = [msg for _ , msg in errors]
@ -1534,9 +1586,9 @@ def create_demo():
state_dict = {}
if use_image2video:
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v2.0 - Image To Video <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</FONT SIZE=3></H1></div>")
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v2.1 - Image To Video <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</FONT SIZE=3></H1></div>")
else:
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v2.0 - Text To Video <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</FONT SIZE=3></H1></div>")
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v2.1 - Text To Video <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</FONT SIZE=3></H1></div>")
gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !</FONT>")
@ -1727,7 +1779,9 @@ def create_demo():
advanced_prompt = advanced
prompt_vars=[]
if not advanced_prompt:
if advanced_prompt:
default_wizard_prompt, variables, values= None, None, None
else:
default_wizard_prompt, variables, values, errors = extract_wizard_prompt(default_prompt)
advanced_prompt = len(errors) > 0

View File

@ -16,5 +16,5 @@ gradio>=5.0.0
numpy>=1.23.5,<2
einops
moviepy==1.0.3
mmgp==3.2.7
mmgp==3.2.8
peft==0.14.0