mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 22:26:36 +00:00
Added support for more Loras format
This commit is contained in:
parent
c1b9b94143
commit
b4afced9ce
@ -19,7 +19,9 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
|
||||
|
||||
|
||||
## 🔥 Latest News!!
|
||||
* Marc 17 2022: 👋 Wan2.1GP v2.0: The Lora festival continues:
|
||||
* Mar 18 2022: 👋 Wan2.1GP v2.1: More Loras !: added support for 'Safetensors' and 'Replicate' Lora formats.\
|
||||
You will need to refresh the requirements with a *pip install -r requirements.txt*
|
||||
* Mar 17 2022: 👋 Wan2.1GP v2.0: The Lora festival continues:
|
||||
- Clearer user interface
|
||||
- Download 30 Loras in one click to try them all (expand the info section)
|
||||
- Very to use Loras as now Lora presets can input the subject (or other need terms) of the Lora so that you dont have to modify manually a prompt
|
||||
|
||||
@ -23,7 +23,12 @@ import asyncio
|
||||
from wan.utils import prompt_parser
|
||||
PROMPT_VARS_MAX = 10
|
||||
|
||||
|
||||
target_mmgp_version = "3.2.8"
|
||||
from importlib.metadata import version
|
||||
mmgp_version = version("mmgp")
|
||||
if mmgp_version != target_mmgp_version:
|
||||
print(f"Incorrect version of mmgp ({mmgp_version}), version {target_mmgp_version} is needed. Please upgrade with the command 'pip install -r requirements.txt'")
|
||||
exit()
|
||||
def _parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate a video from a text prompt or image using Gradio")
|
||||
@ -77,8 +82,7 @@ def _parse_args():
|
||||
|
||||
parser.add_argument(
|
||||
"--check-loras",
|
||||
type=str,
|
||||
default=0,
|
||||
action="store_true",
|
||||
help="Filter Loras that are not valid"
|
||||
)
|
||||
|
||||
@ -346,6 +350,54 @@ if args.compile: #args.fastest or
|
||||
#attention_mode="xformers"
|
||||
# compile = "transformer"
|
||||
|
||||
def preprocess_loras(sd):
|
||||
if not use_image2video:
|
||||
return sd
|
||||
|
||||
new_sd = {}
|
||||
first = next(iter(sd), None)
|
||||
if first == None:
|
||||
return sd
|
||||
if not first.startswith("lora_unet_"):
|
||||
return sd
|
||||
print("Converting Lora Safetensors format to Lora Diffusers format")
|
||||
alphas = {}
|
||||
repl_list = ["cross_attn", "self_attn", "ffn"]
|
||||
src_list = ["_" + k + "_" for k in repl_list]
|
||||
tgt_list = ["." + k + "." for k in repl_list]
|
||||
|
||||
for k,v in sd.items():
|
||||
k = k.replace("lora_unet_blocks_","diffusion_model.blocks.")
|
||||
|
||||
for s,t in zip(src_list, tgt_list):
|
||||
k = k.replace(s,t)
|
||||
|
||||
k = k.replace("lora_up","lora_B")
|
||||
k = k.replace("lora_down","lora_A")
|
||||
|
||||
if "alpha" in k:
|
||||
alphas[k] = v
|
||||
else:
|
||||
new_sd[k] = v
|
||||
|
||||
new_alphas = {}
|
||||
for k,v in new_sd.items():
|
||||
if "lora_B" in k:
|
||||
dim = v.shape[1]
|
||||
elif "lora_A" in k:
|
||||
dim = v.shape[0]
|
||||
else:
|
||||
continue
|
||||
alpha_key = k[:-len("lora_X.weight")] +"alpha"
|
||||
if alpha_key in alphas:
|
||||
scale = alphas[alpha_key] / dim
|
||||
new_alphas[alpha_key] = scale
|
||||
else:
|
||||
print(f"Lora alpha'{alpha_key}' is missing")
|
||||
new_sd.update(new_alphas)
|
||||
return new_sd
|
||||
|
||||
|
||||
def download_models(transformer_filename, text_encoder_filename):
|
||||
def computeList(filename):
|
||||
pos = filename.rfind("/")
|
||||
@ -450,7 +502,7 @@ def setup_loras(transformer, lora_dir, lora_preselected_preset, split_linear_mo
|
||||
loras_presets = [ Path(Path(file_path).parts[-1]).stem for file_path in dir_presets]
|
||||
|
||||
if check_loras:
|
||||
loras = offload.load_loras_into_model(transformer, loras, activate_all_loras=False, check_only= True, split_linear_modules_map = split_linear_modules_map) #lora_multiplier,
|
||||
loras = offload.load_loras_into_model(transformer, loras, activate_all_loras=False, check_only= True, preprocess_sd=preprocess_loras, split_linear_modules_map = split_linear_modules_map) #lora_multiplier,
|
||||
|
||||
if len(loras) > 0:
|
||||
loras_names = [ Path(lora).stem for lora in loras ]
|
||||
@ -986,7 +1038,7 @@ def generate_video(
|
||||
list_mult_choices_nums += [1.0] * ( len(loras_choices) - len(list_mult_choices_nums ) )
|
||||
loras_selected = [ lora for i, lora in enumerate(loras) if str(i) in loras_choices]
|
||||
pinnedLora = False # profile !=5
|
||||
offload.load_loras_into_model(trans, loras_selected, list_mult_choices_nums, activate_all_loras=True, pinnedLora=pinnedLora, split_linear_modules_map = None)
|
||||
offload.load_loras_into_model(trans, loras_selected, list_mult_choices_nums, activate_all_loras=True, preprocess_sd=preprocess_loras, pinnedLora=pinnedLora, split_linear_modules_map = None)
|
||||
errors = trans._loras_errors
|
||||
if len(errors) > 0:
|
||||
error_files = [msg for _ , msg in errors]
|
||||
@ -1534,9 +1586,9 @@ def create_demo():
|
||||
state_dict = {}
|
||||
|
||||
if use_image2video:
|
||||
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v2.0 - Image To Video <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</FONT SIZE=3></H1></div>")
|
||||
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v2.1 - Image To Video <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</FONT SIZE=3></H1></div>")
|
||||
else:
|
||||
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v2.0 - Text To Video <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</FONT SIZE=3></H1></div>")
|
||||
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v2.1 - Text To Video <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</FONT SIZE=3></H1></div>")
|
||||
|
||||
gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !</FONT>")
|
||||
|
||||
@ -1727,7 +1779,9 @@ def create_demo():
|
||||
|
||||
advanced_prompt = advanced
|
||||
prompt_vars=[]
|
||||
if not advanced_prompt:
|
||||
if advanced_prompt:
|
||||
default_wizard_prompt, variables, values= None, None, None
|
||||
else:
|
||||
default_wizard_prompt, variables, values, errors = extract_wizard_prompt(default_prompt)
|
||||
advanced_prompt = len(errors) > 0
|
||||
|
||||
|
||||
@ -16,5 +16,5 @@ gradio>=5.0.0
|
||||
numpy>=1.23.5,<2
|
||||
einops
|
||||
moviepy==1.0.3
|
||||
mmgp==3.2.7
|
||||
mmgp==3.2.8
|
||||
peft==0.14.0
|
||||
Loading…
Reference in New Issue
Block a user