mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-05 06:29:14 +00:00
Added support for more Loras format
This commit is contained in:
parent
c1b9b94143
commit
b4afced9ce
@ -19,7 +19,9 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
|
|||||||
|
|
||||||
|
|
||||||
## 🔥 Latest News!!
|
## 🔥 Latest News!!
|
||||||
* Marc 17 2022: 👋 Wan2.1GP v2.0: The Lora festival continues:
|
* Mar 18 2022: 👋 Wan2.1GP v2.1: More Loras !: added support for 'Safetensors' and 'Replicate' Lora formats.\
|
||||||
|
You will need to refresh the requirements with a *pip install -r requirements.txt*
|
||||||
|
* Mar 17 2022: 👋 Wan2.1GP v2.0: The Lora festival continues:
|
||||||
- Clearer user interface
|
- Clearer user interface
|
||||||
- Download 30 Loras in one click to try them all (expand the info section)
|
- Download 30 Loras in one click to try them all (expand the info section)
|
||||||
- Very to use Loras as now Lora presets can input the subject (or other need terms) of the Lora so that you dont have to modify manually a prompt
|
- Very to use Loras as now Lora presets can input the subject (or other need terms) of the Lora so that you dont have to modify manually a prompt
|
||||||
|
|||||||
@ -23,7 +23,12 @@ import asyncio
|
|||||||
from wan.utils import prompt_parser
|
from wan.utils import prompt_parser
|
||||||
PROMPT_VARS_MAX = 10
|
PROMPT_VARS_MAX = 10
|
||||||
|
|
||||||
|
target_mmgp_version = "3.2.8"
|
||||||
|
from importlib.metadata import version
|
||||||
|
mmgp_version = version("mmgp")
|
||||||
|
if mmgp_version != target_mmgp_version:
|
||||||
|
print(f"Incorrect version of mmgp ({mmgp_version}), version {target_mmgp_version} is needed. Please upgrade with the command 'pip install -r requirements.txt'")
|
||||||
|
exit()
|
||||||
def _parse_args():
|
def _parse_args():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Generate a video from a text prompt or image using Gradio")
|
description="Generate a video from a text prompt or image using Gradio")
|
||||||
@ -77,8 +82,7 @@ def _parse_args():
|
|||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--check-loras",
|
"--check-loras",
|
||||||
type=str,
|
action="store_true",
|
||||||
default=0,
|
|
||||||
help="Filter Loras that are not valid"
|
help="Filter Loras that are not valid"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -346,6 +350,54 @@ if args.compile: #args.fastest or
|
|||||||
#attention_mode="xformers"
|
#attention_mode="xformers"
|
||||||
# compile = "transformer"
|
# compile = "transformer"
|
||||||
|
|
||||||
|
def preprocess_loras(sd):
|
||||||
|
if not use_image2video:
|
||||||
|
return sd
|
||||||
|
|
||||||
|
new_sd = {}
|
||||||
|
first = next(iter(sd), None)
|
||||||
|
if first == None:
|
||||||
|
return sd
|
||||||
|
if not first.startswith("lora_unet_"):
|
||||||
|
return sd
|
||||||
|
print("Converting Lora Safetensors format to Lora Diffusers format")
|
||||||
|
alphas = {}
|
||||||
|
repl_list = ["cross_attn", "self_attn", "ffn"]
|
||||||
|
src_list = ["_" + k + "_" for k in repl_list]
|
||||||
|
tgt_list = ["." + k + "." for k in repl_list]
|
||||||
|
|
||||||
|
for k,v in sd.items():
|
||||||
|
k = k.replace("lora_unet_blocks_","diffusion_model.blocks.")
|
||||||
|
|
||||||
|
for s,t in zip(src_list, tgt_list):
|
||||||
|
k = k.replace(s,t)
|
||||||
|
|
||||||
|
k = k.replace("lora_up","lora_B")
|
||||||
|
k = k.replace("lora_down","lora_A")
|
||||||
|
|
||||||
|
if "alpha" in k:
|
||||||
|
alphas[k] = v
|
||||||
|
else:
|
||||||
|
new_sd[k] = v
|
||||||
|
|
||||||
|
new_alphas = {}
|
||||||
|
for k,v in new_sd.items():
|
||||||
|
if "lora_B" in k:
|
||||||
|
dim = v.shape[1]
|
||||||
|
elif "lora_A" in k:
|
||||||
|
dim = v.shape[0]
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
alpha_key = k[:-len("lora_X.weight")] +"alpha"
|
||||||
|
if alpha_key in alphas:
|
||||||
|
scale = alphas[alpha_key] / dim
|
||||||
|
new_alphas[alpha_key] = scale
|
||||||
|
else:
|
||||||
|
print(f"Lora alpha'{alpha_key}' is missing")
|
||||||
|
new_sd.update(new_alphas)
|
||||||
|
return new_sd
|
||||||
|
|
||||||
|
|
||||||
def download_models(transformer_filename, text_encoder_filename):
|
def download_models(transformer_filename, text_encoder_filename):
|
||||||
def computeList(filename):
|
def computeList(filename):
|
||||||
pos = filename.rfind("/")
|
pos = filename.rfind("/")
|
||||||
@ -450,7 +502,7 @@ def setup_loras(transformer, lora_dir, lora_preselected_preset, split_linear_mo
|
|||||||
loras_presets = [ Path(Path(file_path).parts[-1]).stem for file_path in dir_presets]
|
loras_presets = [ Path(Path(file_path).parts[-1]).stem for file_path in dir_presets]
|
||||||
|
|
||||||
if check_loras:
|
if check_loras:
|
||||||
loras = offload.load_loras_into_model(transformer, loras, activate_all_loras=False, check_only= True, split_linear_modules_map = split_linear_modules_map) #lora_multiplier,
|
loras = offload.load_loras_into_model(transformer, loras, activate_all_loras=False, check_only= True, preprocess_sd=preprocess_loras, split_linear_modules_map = split_linear_modules_map) #lora_multiplier,
|
||||||
|
|
||||||
if len(loras) > 0:
|
if len(loras) > 0:
|
||||||
loras_names = [ Path(lora).stem for lora in loras ]
|
loras_names = [ Path(lora).stem for lora in loras ]
|
||||||
@ -986,7 +1038,7 @@ def generate_video(
|
|||||||
list_mult_choices_nums += [1.0] * ( len(loras_choices) - len(list_mult_choices_nums ) )
|
list_mult_choices_nums += [1.0] * ( len(loras_choices) - len(list_mult_choices_nums ) )
|
||||||
loras_selected = [ lora for i, lora in enumerate(loras) if str(i) in loras_choices]
|
loras_selected = [ lora for i, lora in enumerate(loras) if str(i) in loras_choices]
|
||||||
pinnedLora = False # profile !=5
|
pinnedLora = False # profile !=5
|
||||||
offload.load_loras_into_model(trans, loras_selected, list_mult_choices_nums, activate_all_loras=True, pinnedLora=pinnedLora, split_linear_modules_map = None)
|
offload.load_loras_into_model(trans, loras_selected, list_mult_choices_nums, activate_all_loras=True, preprocess_sd=preprocess_loras, pinnedLora=pinnedLora, split_linear_modules_map = None)
|
||||||
errors = trans._loras_errors
|
errors = trans._loras_errors
|
||||||
if len(errors) > 0:
|
if len(errors) > 0:
|
||||||
error_files = [msg for _ , msg in errors]
|
error_files = [msg for _ , msg in errors]
|
||||||
@ -1534,9 +1586,9 @@ def create_demo():
|
|||||||
state_dict = {}
|
state_dict = {}
|
||||||
|
|
||||||
if use_image2video:
|
if use_image2video:
|
||||||
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v2.0 - Image To Video <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</FONT SIZE=3></H1></div>")
|
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v2.1 - Image To Video <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</FONT SIZE=3></H1></div>")
|
||||||
else:
|
else:
|
||||||
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v2.0 - Text To Video <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</FONT SIZE=3></H1></div>")
|
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v2.1 - Text To Video <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</FONT SIZE=3></H1></div>")
|
||||||
|
|
||||||
gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !</FONT>")
|
gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !</FONT>")
|
||||||
|
|
||||||
@ -1727,7 +1779,9 @@ def create_demo():
|
|||||||
|
|
||||||
advanced_prompt = advanced
|
advanced_prompt = advanced
|
||||||
prompt_vars=[]
|
prompt_vars=[]
|
||||||
if not advanced_prompt:
|
if advanced_prompt:
|
||||||
|
default_wizard_prompt, variables, values= None, None, None
|
||||||
|
else:
|
||||||
default_wizard_prompt, variables, values, errors = extract_wizard_prompt(default_prompt)
|
default_wizard_prompt, variables, values, errors = extract_wizard_prompt(default_prompt)
|
||||||
advanced_prompt = len(errors) > 0
|
advanced_prompt = len(errors) > 0
|
||||||
|
|
||||||
|
|||||||
@ -16,5 +16,5 @@ gradio>=5.0.0
|
|||||||
numpy>=1.23.5,<2
|
numpy>=1.23.5,<2
|
||||||
einops
|
einops
|
||||||
moviepy==1.0.3
|
moviepy==1.0.3
|
||||||
mmgp==3.2.7
|
mmgp==3.2.8
|
||||||
peft==0.14.0
|
peft==0.14.0
|
||||||
Loading…
Reference in New Issue
Block a user