diff --git a/README.md b/README.md index ea22b03..186c96a 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,9 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid ## 🔥 Latest News!! -* Marc 17 2022: 👋 Wan2.1GP v2.0: The Lora festival continues: +* Mar 18 2022: 👋 Wan2.1GP v2.1: More Loras !: added support for 'Safetensors' and 'Replicate' Lora formats.\ +You will need to refresh the requirements with a *pip install -r requirements.txt* +* Mar 17 2022: 👋 Wan2.1GP v2.0: The Lora festival continues: - Clearer user interface - Download 30 Loras in one click to try them all (expand the info section) - Very to use Loras as now Lora presets can input the subject (or other need terms) of the Lora so that you dont have to modify manually a prompt diff --git a/gradio_server.py b/gradio_server.py index 7c45ccf..f7588f9 100644 --- a/gradio_server.py +++ b/gradio_server.py @@ -23,7 +23,12 @@ import asyncio from wan.utils import prompt_parser PROMPT_VARS_MAX = 10 - +target_mmgp_version = "3.2.8" +from importlib.metadata import version +mmgp_version = version("mmgp") +if mmgp_version != target_mmgp_version: + print(f"Incorrect version of mmgp ({mmgp_version}), version {target_mmgp_version} is needed. Please upgrade with the command 'pip install -r requirements.txt'") + exit() def _parse_args(): parser = argparse.ArgumentParser( description="Generate a video from a text prompt or image using Gradio") @@ -77,8 +82,7 @@ def _parse_args(): parser.add_argument( "--check-loras", - type=str, - default=0, + action="store_true", help="Filter Loras that are not valid" ) @@ -346,6 +350,54 @@ if args.compile: #args.fastest or #attention_mode="xformers" # compile = "transformer" +def preprocess_loras(sd): + if not use_image2video: + return sd + + new_sd = {} + first = next(iter(sd), None) + if first == None: + return sd + if not first.startswith("lora_unet_"): + return sd + print("Converting Lora Safetensors format to Lora Diffusers format") + alphas = {} + repl_list = ["cross_attn", "self_attn", "ffn"] + src_list = ["_" + k + "_" for k in repl_list] + tgt_list = ["." + k + "." for k in repl_list] + + for k,v in sd.items(): + k = k.replace("lora_unet_blocks_","diffusion_model.blocks.") + + for s,t in zip(src_list, tgt_list): + k = k.replace(s,t) + + k = k.replace("lora_up","lora_B") + k = k.replace("lora_down","lora_A") + + if "alpha" in k: + alphas[k] = v + else: + new_sd[k] = v + + new_alphas = {} + for k,v in new_sd.items(): + if "lora_B" in k: + dim = v.shape[1] + elif "lora_A" in k: + dim = v.shape[0] + else: + continue + alpha_key = k[:-len("lora_X.weight")] +"alpha" + if alpha_key in alphas: + scale = alphas[alpha_key] / dim + new_alphas[alpha_key] = scale + else: + print(f"Lora alpha'{alpha_key}' is missing") + new_sd.update(new_alphas) + return new_sd + + def download_models(transformer_filename, text_encoder_filename): def computeList(filename): pos = filename.rfind("/") @@ -450,7 +502,7 @@ def setup_loras(transformer, lora_dir, lora_preselected_preset, split_linear_mo loras_presets = [ Path(Path(file_path).parts[-1]).stem for file_path in dir_presets] if check_loras: - loras = offload.load_loras_into_model(transformer, loras, activate_all_loras=False, check_only= True, split_linear_modules_map = split_linear_modules_map) #lora_multiplier, + loras = offload.load_loras_into_model(transformer, loras, activate_all_loras=False, check_only= True, preprocess_sd=preprocess_loras, split_linear_modules_map = split_linear_modules_map) #lora_multiplier, if len(loras) > 0: loras_names = [ Path(lora).stem for lora in loras ] @@ -986,7 +1038,7 @@ def generate_video( list_mult_choices_nums += [1.0] * ( len(loras_choices) - len(list_mult_choices_nums ) ) loras_selected = [ lora for i, lora in enumerate(loras) if str(i) in loras_choices] pinnedLora = False # profile !=5 - offload.load_loras_into_model(trans, loras_selected, list_mult_choices_nums, activate_all_loras=True, pinnedLora=pinnedLora, split_linear_modules_map = None) + offload.load_loras_into_model(trans, loras_selected, list_mult_choices_nums, activate_all_loras=True, preprocess_sd=preprocess_loras, pinnedLora=pinnedLora, split_linear_modules_map = None) errors = trans._loras_errors if len(errors) > 0: error_files = [msg for _ , msg in errors] @@ -1534,9 +1586,9 @@ def create_demo(): state_dict = {} if use_image2video: - gr.Markdown("