diff --git a/README.md b/README.md index c69848b..e37c10d 100644 --- a/README.md +++ b/README.md @@ -19,11 +19,13 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid ## 🔥 Latest News!! +* Mar 19 2022: 👋 Wan2.1GP v3.1: Faster launch and RAM optimizations (should require less RAM to run)\ + You will need one more *pip install -r requirements.txt* * Mar 18 2022: 👋 Wan2.1GP v3.0: - New Tab based interface, yon can switch from i2v to t2v conversely without restarting the app - Experimental Dual Frames mode for i2v, you can also specify an End frame. It doesn't always work, so you will need a few attempts. - You can save default settings in the files *i2v_settings.json* and *t2v_settings.json* that will be used when launching the app (you can also specify the path to different settings files) - - Slight acceleration with loras + - Slight acceleration with loras\ You will need one more *pip install -r requirements.txt* Many thanks to *Tophness* who created the framework (and did a big part of the work) of the multitabs and saved settings features * Mar 18 2022: 👋 Wan2.1GP v2.11: Added more command line parameters to prefill the generation settings + customizable output directory and choice of type of metadata for generated videos. Many thanks to *Tophness* for his contributions. You will need one more *pip install -r requirements.txt* to reflect new dependencies\ @@ -255,8 +257,9 @@ You can define multiple lines of macros. If there is only one macro line, the ap --check-loras : filter loras that are incompatible (will take a few seconds while refreshing the lora list or while starting the app)\ --advanced : turn on the advanced mode while launching the app\ --i2v-settings : path to launch settings for i2v\ ---t2v-settings : path to launch settings for t2v ---listen : make server accessible on network +--t2v-settings : path to launch settings for t2v\ +--listen : make server accessible on network\ +--gpu device : run Wan on device for instance "cuda:1" ### Profiles (for power users only) You can choose between 5 profiles, but two are really relevant here : diff --git a/gradio_server.py b/gradio_server.py index bcc8dda..27d88c5 100644 --- a/gradio_server.py +++ b/gradio_server.py @@ -15,7 +15,7 @@ import json import wan from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES from wan.utils.utils import cache_video -from wan.modules.attention import get_attention_modes +from wan.modules.attention import get_attention_modes, get_supported_attention_modes import torch import gc import traceback @@ -24,7 +24,7 @@ import asyncio from wan.utils import prompt_parser PROMPT_VARS_MAX = 10 -target_mmgp_version = "3.3.1" +target_mmgp_version = "3.3.3" from importlib.metadata import version mmgp_version = version("mmgp") if mmgp_version != target_mmgp_version: @@ -55,16 +55,16 @@ def runner(): while True: with lock: for item in queue: - task_id = item['id'] + task_id_runner = item['id'] with tracker_lock: - progress = progress_tracker.get(task_id, {}) + progress = progress_tracker.get(task_id_runner, {}) if item['state'] == "Processing": current_step = progress.get('current_step', 0) total_steps = progress.get('total_steps', 0) elapsed = time.time() - progress.get('start_time', time.time()) status = progress.get('status', "") - repeats = progress.get("repeats") + repeats = progress.get("repeats", "0/0") item.update({ 'progress': f"{((current_step/total_steps)*100 if total_steps > 0 else 0):.1f}%", 'steps': f"{current_step}/{total_steps}", @@ -381,6 +381,13 @@ def _parse_args(): help="Server name" ) + parser.add_argument( + "--gpu", + type=str, + default="", + help="Default GPU Device" + ) + parser.add_argument( "--open-browser", action="store_true", @@ -473,7 +480,8 @@ def get_lora_dir(i2v): return lora_dir_14B return root_lora_dir -attention_modes_supported = get_attention_modes() +attention_modes_installed = get_attention_modes() +attention_modes_supported = get_supported_attention_modes() args = _parse_args() args.flow_reverse = True @@ -587,6 +595,7 @@ vae_config = server_config.get("vae_config", 0) if len(args.vae_config) > 0: vae_config = int(args.vae_config) +reload_needed = False default_ui = server_config.get("default_ui", "t2v") metadata = server_config.get("metadata_type", "metadata") save_path = server_config.get("save_path", os.path.join(os.getcwd(), "gradio_outputs")) @@ -686,7 +695,7 @@ def download_models(transformer_filename, text_encoder_filename): from huggingface_hub import hf_hub_download, snapshot_download repoId = "DeepBeepMeep/Wan2.1" sourceFolderList = ["xlm-roberta-large", "", ] - fileList = [ [], ["Wan2.1_VAE.pth", "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ] + fileList = [ [], ["Wan2.1_VAE_bf16.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ] targetRoot = "ckpts/" for sourceFolder, files in zip(sourceFolderList,fileList ): if len(files)==0: @@ -703,6 +712,14 @@ def download_models(transformer_filename, text_encoder_filename): offload.default_verboseLevel = verbose_level +to_remove = ["models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", "Wan2.1_VAE.pth"] +for file_name in to_remove: + file_name = os.path.join("ckpts",file_name) + if os.path.isfile(file_name): + try: + os.remove(file_name) + except: + pass download_models(transformer_filename_i2v if use_image2video else transformer_filename_t2v, text_encoder_filename) @@ -875,6 +892,8 @@ def load_models(i2v): elif profile == 3: kwargs["budgets"] = { "*" : "70%" } offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", **kwargs) + if len(args.gpu) > 0: + torch.set_default_device(args.gpu) return wan_model, offloadobj, pipe["transformer"] @@ -914,8 +933,10 @@ def generate_header(model_filename, compile, attention_mode): header += model_name header += " (attention mode: " + (attention_mode if attention_mode!="auto" else "auto/" + get_auto_attention() ) - if attention_mode not in attention_modes_supported: + if attention_mode not in attention_modes_installed: header += " -NOT INSTALLED-" + elif attention_mode not in attention_modes_supported: + header += " -NOT SUPPORTED-" if compile: header += ", pytorch compilation ON" @@ -979,11 +1000,7 @@ def apply_changes( state, if v != v_old: changes.append(k) - state["config_changes"] = changes - state["config_new"] = server_config - state["config_old"] = old_server_config - - global attention_mode, profile, compile, transformer_filename_t2v, transformer_filename_i2v, text_encoder_filename, vae_config, boost, lora_dir + global attention_mode, profile, compile, transformer_filename_t2v, transformer_filename_i2v, text_encoder_filename, vae_config, boost, lora_dir, reload_needed attention_mode = server_config["attention_mode"] profile = server_config["profile"] compile = server_config["compile"] @@ -995,7 +1012,7 @@ def apply_changes( state, if all(change in ["attention_mode", "vae_config", "default_ui", "boost", "save_path", "metadata_choice"] for change in changes ): pass else: - state["_reload_needed"] = True + reload_needed = True yield "
The new configuration has been succesfully applied
" @@ -1013,7 +1030,7 @@ def save_video(final_frames, output_path, fps=24): def build_callback(taskid, state, pipe, num_inference_steps, repeats): start_time = time.time() - def update_progress(step_idx, latents, read_state = False): + def update_progress(step_idx, _): with tracker_lock: step_idx += 1 if state.get("abort", False): @@ -1094,8 +1111,7 @@ def generate_video( progress=gr.Progress() #track_tqdm= True ): - global wan_model, offloadobj, last_model_type - reload_needed = state.get("_reload_needed", False) + global wan_model, offloadobj, reload_needed, last_model_type file_model_needed = model_needed(image2video) with lock: queue_not_empty = len(queue) > 0 @@ -1108,7 +1124,7 @@ def generate_video( print(f"Loading model {get_model_name(file_model_needed)}...") wan_model, offloadobj, trans = load_models(image2video) print(f"Model loaded") - state["_reload_needed"] = False + reload_needed= False from PIL import Image import numpy as np @@ -1121,11 +1137,12 @@ def generate_video( elif attention_mode in attention_modes_supported: attn = attention_mode else: - gr.Info(f"You have selected attention mode '{attention_mode}'. However it is not installed on your system. You should either install it or switch to the default 'sdpa' attention.") + gr.Info(f"You have selected attention mode '{attention_mode}'. However it is not installed or supported on your system. You should either install it or switch to the default 'sdpa' attention.") return #if state.get("validate_success",0) != 1: # return + raw_resolution = resolution width, height = resolution.split("x") width, height = int(width), int(height) @@ -1289,7 +1306,7 @@ def generate_video( 'total_steps': num_inference_steps, 'start_time': time.time(), 'last_update': time.time(), - 'repeats': f"0/{repeat_generation}", + 'repeats': f"{video_no}/{repeat_generation}", 'status': "Encoding Prompt" } video_no += 1 @@ -1401,14 +1418,8 @@ def generate_video( normalize=True, value_range=(-1, 1)) - configs = { - 'prompt': prompt, - 'negative_prompt': negative_prompt, - 'resolution': resolution, - 'video_length': video_length, - 'seed': seed, - 'num_inference_steps': num_inference_steps, - } + configs = get_settings_dict(state, use_image2video, prompt, 0 if image_to_end == None else 1 , video_length, raw_resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices, + loras_mult_choices, tea_cache , tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start, slg_end) metadata_choice = server_config.get("metadata_choice","metadata") if metadata_choice == "json": @@ -1715,19 +1726,14 @@ def switch_advanced(state, new_advanced, lset_name): else: return gr.Row(visible=new_advanced), gr.Row(visible=True), gr.Button(visible=True), gr.Row(visible= False), gr.Dropdown(choices=lset_choices, value= lset_name) -def save_settings(state, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices, +def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc): - if state.get("validate_success",0) != 1: - return - - loras_choices loras = state["loras"] activated_loras = [Path( loras[int(no)]).parts[-1] for no in loras_choices ] - ui_defaults = { + ui_settings = { "prompts": prompt, - "image_prompt_type": image_prompt_type, "resolution": resolution, "video_length": video_length, "num_inference_steps": num_inference_steps, @@ -1747,10 +1753,25 @@ def save_settings(state, prompt, image_prompt_type, video_length, resolution, nu "slg_start_perc": slg_start_perc, "slg_end_perc": slg_end_perc } + + if i2v: + ui_settings["type"] = "Wan2.1GP by DeepBeepMeep - image2video" + ui_settings["image_prompt_type"] = image_prompt_type + else: + ui_settings["type"] = "Wan2.1GP by DeepBeepMeep - text2video" + return ui_settings + +def save_settings(state, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices, + loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc): + if state.get("validate_success",0) != 1: + return + ui_defaults = get_settings_dict(state, use_image2video, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices, + loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc) + defaults_filename = get_settings_file_name(use_image2video) with open(defaults_filename, "w", encoding="utf-8") as f: - json.dump(ui_defaults, f, indent=4) + json.dump(ui_settings , f, indent=4) gr.Info("New Default Settings saved") @@ -1864,7 +1885,8 @@ def generate_video_tab(image2video=False): cancel_lset_btn = gr.Button("Don't do it !", size="sm", min_width= 1 , visible=False) video_to_continue = gr.Video(label= "Video to continue", visible= image2video and False) ####### - image_prompt_type = gr.Radio( [("Use only a Start Image", 0),("Use both a Start and an End Image", 1)], value =ui_defaults["image_prompt_type"], label="Location", show_label= False, scale= 3, visible=image2video) + image_prompt_type= ui_defaults.get("image_prompt_type",0) + image_prompt_type_radio = gr.Radio( [("Use only a Start Image", 0),("Use both a Start and an End Image", 1)], value =image_prompt_type, label="Location", show_label= False, scale= 3, visible=image2video) if args.multiple_images: image_to_continue = gr.Gallery( @@ -1876,9 +1898,9 @@ def generate_video_tab(image2video=False): if args.multiple_images: image_to_end = gr.Gallery( label="Images as ending points for new videos", type ="pil", #file_types= "image", - columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=False) + columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=image_prompt_type==1) else: - image_to_end = gr.Image(label= "Last Image for a new video", type ="pil", visible= False) + image_to_end = gr.Image(label= "Last Image for a new video", type ="pil", visible=image_prompt_type==1) def switch_image_prompt_type_radio(image_prompt_type_radio): if args.multiple_images: @@ -1886,7 +1908,7 @@ def generate_video_tab(image2video=False): else: return gr.Image(visible = (image_prompt_type_radio == 1) ) - image_prompt_type.change(fn=switch_image_prompt_type_radio, inputs=[image_prompt_type], outputs=[image_to_end]) + image_prompt_type_radio.change(fn=switch_image_prompt_type_radio, inputs=[image_prompt_type_radio], outputs=[image_to_end]) advanced_prompt = advanced @@ -2080,7 +2102,7 @@ def generate_video_tab(image2video=False): outputs=[output] ) save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then( - save_settings, inputs = [state, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, + save_settings, inputs = [state, prompt, image_prompt_type_radio, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc ], outputs = []) save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop]) @@ -2182,8 +2204,10 @@ def generate_configuration_tab(): value=server_config.get("save_path", save_path) ) def check(mode): - if not mode in attention_modes_supported: + if not mode in attention_modes_installed: return " (NOT INSTALLED)" + elif not mode in attention_modes_supported: + return " (NOT SUPPORTED)" else: return "" attention_choice = gr.Dropdown( @@ -2435,7 +2459,7 @@ def create_demo(): } """ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md")) as demo: - gr.Markdown("

Wan 2.1GP v3.0 by DeepBeepMeep (Updates)

") + gr.Markdown("

Wan 2.1GP v3.1 by DeepBeepMeep (Updates)

") gr.Markdown("Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !") with gr.Accordion("Click here for some Info on how to use Wan2GP", open = False): diff --git a/requirements.txt b/requirements.txt index b97e545..bd928de 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,6 +16,6 @@ gradio>=5.0.0 numpy>=1.23.5,<2 einops moviepy==1.0.3 -mmgp==3.3.0 +mmgp==3.3.3 peft==0.14.0 mutagen \ No newline at end of file diff --git a/wan/image2video.py b/wan/image2video.py index 506fe90..a853665 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -177,7 +177,7 @@ class WanI2V: logging.info(f"Creating WanModel from {model_filename}") from mmgp import offload - self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel) + self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel, writable_tensors= False) self.model.eval().requires_grad_(False) if t5_fsdp or dit_fsdp or use_usp: diff --git a/wan/modules/attention.py b/wan/modules/attention.py index e40ac63..6861283 100644 --- a/wan/modules/attention.py +++ b/wan/modules/attention.py @@ -30,6 +30,7 @@ try: max_seqlen_kv, ): return sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) + except ImportError: sageattn_varlen_wrapper = None @@ -38,11 +39,12 @@ import warnings try: from sageattention import sageattn - from .sage2_core import sageattn as alt_sageattn + from .sage2_core import sageattn as alt_sageattn, is_sage_supported + sage_supported = is_sage_supported() except ImportError: sageattn = None alt_sageattn = None - + sage_supported = False # @torch.compiler.disable() def sageattn_wrapper( qkv_list, @@ -129,6 +131,14 @@ def get_attention_modes(): return ret +def get_supported_attention_modes(): + ret = get_attention_modes() + if not sage_supported: + if "sage" in ret: + ret.remove("sage") + if "sage2" in ret: + ret.remove("sage2") + return ret __all__ = [ 'pay_attention', diff --git a/wan/modules/clip.py b/wan/modules/clip.py index 53b76c1..fc41d85 100644 --- a/wan/modules/clip.py +++ b/wan/modules/clip.py @@ -519,8 +519,11 @@ class CLIPModel: device=device) self.model = self.model.eval().requires_grad_(False) logging.info(f'loading {checkpoint_path}') - self.model.load_state_dict( - torch.load(checkpoint_path, map_location='cpu'), assign= True) + from mmgp import offload + # self.model.load_state_dict( + # torch.load(checkpoint_path, map_location='cpu'), assign= True) + + offload.load_model_data(self.model, checkpoint_path.replace(".pth", "-bf16.safetensors"), writable_tensors= False) # init tokenizer self.tokenizer = HuggingfaceTokenizer( diff --git a/wan/modules/sage2_core.py b/wan/modules/sage2_core.py index d83d255..de94a60 100644 --- a/wan/modules/sage2_core.py +++ b/wan/modules/sage2_core.py @@ -51,6 +51,15 @@ from sageattention.quant import per_channel_fp8 from typing import Any, List, Literal, Optional, Tuple, Union import warnings +import os + +def is_sage_supported(): + device_count = torch.cuda.device_count() + for i in range(device_count): + major, minor = torch.cuda.get_device_capability(i) + if major < 8: + return False + return True def get_cuda_arch_versions(): cuda_archs = [] diff --git a/wan/modules/t5.py b/wan/modules/t5.py index 5cccbbf..110e358 100644 --- a/wan/modules/t5.py +++ b/wan/modules/t5.py @@ -496,7 +496,7 @@ class T5EncoderModel: device=device).eval().requires_grad_(False) logging.info(f'loading {checkpoint_path}') from mmgp import offload - offload.load_model_data(model,checkpoint_path ) + offload.load_model_data(model,checkpoint_path, writable_tensors= False ) self.model = model if shard_fn is not None: diff --git a/wan/modules/vae.py b/wan/modules/vae.py index e47e74c..67dcd9a 100644 --- a/wan/modules/vae.py +++ b/wan/modules/vae.py @@ -744,11 +744,12 @@ def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs): with torch.device('meta'): model = WanVAE_(**cfg) + from mmgp import offload # load checkpoint logging.info(f'loading {pretrained_path}') - model.load_state_dict( - torch.load(pretrained_path, map_location=device), assign=True) - + # model.load_state_dict( + # torch.load(pretrained_path, map_location=device), assign=True) + offload.load_model_data(model, pretrained_path.replace(".pth", "_bf16.safetensors"), writable_tensors= False) return model @@ -778,7 +779,7 @@ class WanVAE: self.model = _video_vae( pretrained_path=vae_pth, z_dim=z_dim, - ).eval().requires_grad_(False).to(device) + ).eval() #.requires_grad_(False).to(device) def encode(self, videos, tile_size = 256, any_end_frame = False): """ diff --git a/wan/text2video.py b/wan/text2video.py index 4682a4a..88046db 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -89,7 +89,7 @@ class WanT2V: from mmgp import offload - self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel) + self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel, writable_tensors= False)