diff --git a/fantasytalking/infer.py b/fantasytalking/infer.py index f2d4964..e7bdc6f 100644 --- a/fantasytalking/infer.py +++ b/fantasytalking/infer.py @@ -4,24 +4,33 @@ from transformers import Wav2Vec2Model, Wav2Vec2Processor from .model import FantasyTalkingAudioConditionModel from .utils import get_audio_features - +import gc, torch def parse_audio(audio_path, num_frames, fps = 23, device = "cuda"): fantasytalking = FantasyTalkingAudioConditionModel(None, 768, 2048).to(device) from mmgp import offload from accelerate import init_empty_weights from fantasytalking.model import AudioProjModel + + torch.set_grad_enabled(False) + with init_empty_weights(): proj_model = AudioProjModel( 768, 2048) offload.load_model_data(proj_model, "ckpts/fantasy_proj_model.safetensors") - proj_model.to(device).eval().requires_grad_(False) + proj_model.to("cpu").eval().requires_grad_(False) wav2vec_model_dir = "ckpts/wav2vec" wav2vec_processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_dir) - wav2vec = Wav2Vec2Model.from_pretrained(wav2vec_model_dir).to(device).eval().requires_grad_(False) + wav2vec = Wav2Vec2Model.from_pretrained(wav2vec_model_dir, device_map="cpu").eval().requires_grad_(False) + wav2vec.to(device) + proj_model.to(device) audio_wav2vec_fea = get_audio_features( wav2vec, wav2vec_processor, audio_path, fps, num_frames ) audio_proj_fea = proj_model(audio_wav2vec_fea) pos_idx_ranges = fantasytalking.split_audio_sequence( audio_proj_fea.size(1), num_frames=num_frames ) - audio_proj_split, audio_context_lens = fantasytalking.split_tensor_with_padding( audio_proj_fea, pos_idx_ranges, expand_length=4 ) # [b,21,9+8,768] + audio_proj_split, audio_context_lens = fantasytalking.split_tensor_with_padding( audio_proj_fea, pos_idx_ranges, expand_length=4 ) # [b,21,9+8,768] + wav2vec, proj_model= None, None + gc.collect() + torch.cuda.empty_cache() + return audio_proj_split, audio_context_lens \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index e31f7dc..ec8ea94 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,7 +16,7 @@ gradio==5.23.0 numpy>=1.23.5,<2 einops moviepy==1.0.3 -mmgp==3.4.2 +mmgp==3.4.3 peft==0.14.0 mutagen pydantic==2.10.6 diff --git a/wan/image2video.py b/wan/image2video.py index 7b0fcb7..db36ec8 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -103,7 +103,7 @@ class WanI2V: # dtype = torch.float16 self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) #, forcedConfigPath= "c:/temp/i2v720p/config.json") self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype) - # offload.change_dtype(self.model, dtype, True) + offload.change_dtype(self.model, dtype, True) # offload.save_model(self.model, "wan2.1_image2video_720p_14B_mbf16.safetensors", config_file_path="c:/temp/i2v720p/config.json") # offload.save_model(self.model, "wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors",do_quantize=True, config_file_path="c:/temp/i2v720p/config.json") # offload.save_model(self.model, "wan2.1_image2video_720p_14B_quanto_mfp16_int8.safetensors",do_quantize=True, config_file_path="c:/temp/i2v720p/config.json") @@ -403,9 +403,7 @@ class WanI2V: if callback is not None: callback(i, latent, False) - x0 = [latent] - - # x0 = [lat_y] + x0 = [latent] video = self.vae.decode(x0, VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0] if any_end_frame and add_frames_for_end_image: diff --git a/wan/modules/model.py b/wan/modules/model.py index 50f1d8c..8892454 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -312,8 +312,6 @@ class WanI2VCrossAttention(WanSelfAttention): del x self.norm_q(q) q= q.view(b, -1, n, d) - if audio_scale != None: - audio_x = self.processor(q, audio_proj, grid_sizes[0], audio_context_lens) k = self.k(context) self.norm_k(k) k = k.view(b, -1, n, d) @@ -323,6 +321,8 @@ class WanI2VCrossAttention(WanSelfAttention): del k,v x = pay_attention(qkv_list) + if audio_scale != None: + audio_x = self.processor(q, audio_proj, grid_sizes[0], audio_context_lens) k_img = self.k_img(context_img) self.norm_k_img(k_img) k_img = k_img.view(b, -1, n, d) diff --git a/wgp.py b/wgp.py index 935966e..5357adb 100644 --- a/wgp.py +++ b/wgp.py @@ -40,7 +40,7 @@ global_queue_ref = [] AUTOSAVE_FILENAME = "queue.zip" PROMPT_VARS_MAX = 10 -target_mmgp_version = "3.4.2" +target_mmgp_version = "3.4.3" from importlib.metadata import version mmgp_version = version("mmgp") if mmgp_version != target_mmgp_version: @@ -50,6 +50,7 @@ lock = threading.Lock() current_task_id = None task_id = 0 + def download_ffmpeg(): if os.name != 'nt': return exes = ['ffmpeg.exe', 'ffprobe.exe', 'ffplay.exe'] @@ -1421,6 +1422,7 @@ for path in ["wan2.1_Vace_1.3B_preview_bf16.safetensors", "sky_reels2_diffusion "wan2.1_image2video_720p_14B_quanto_int8.safetensors", "wan2.1_image2video_720p_14B_quanto_fp16_int8.safetensors", "wan2.1_image2video_720p_14B_bf16.safetensors" ]: if Path(os.path.join("ckpts" , path)).is_file(): + print(f"Removing old version of model '{path}'. A new version of this model will be downloaded next time you use it.") os.remove( os.path.join("ckpts" , path)) @@ -1511,14 +1513,21 @@ def get_model_filename(model_type, quantization): quantization = "bf16" if len(choices) <= 1: - return choices[0] - - sub_choices = [ name for name in choices if quantization in name] - if len(sub_choices) > 0: - return sub_choices[0] + raw_filename = choices[0] else: - return choices[0] + sub_choices = [ name for name in choices if quantization in name] + if len(sub_choices) > 0: + raw_filename = sub_choices[0] + else: + raw_filename = choices[0] + if transformer_dtype == torch.float16 : + if "quanto_int8" in raw_filename: + raw_filename = raw_filename.replace("quanto_int8", "quanto_fp16_int8") + elif "quanto_mbf16_int8": + raw_filename= raw_filename.replace("quanto_mbf16_int8", "quanto_mfp16_int8") + return raw_filename + def get_settings_file_name(model_filename): return os.path.join(args.settings, get_model_type(model_filename) + "_settings.json") @@ -1599,6 +1608,13 @@ def get_default_settings(filename): ui_defaults["num_inference_steps"] = default_number_steps return ui_defaults +major, minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None) +if major < 8: + print("Switching to f16 models as GPU architecture doesn't support bf16") + transformer_dtype = torch.float16 +else: + transformer_dtype = torch.float16 if args.fp16 else torch.bfloat16 + transformer_types = server_config.get("transformer_types", []) transformer_type = transformer_types[0] if len(transformer_types) > 0 else model_types[0] transformer_quantization =server_config.get("transformer_quantization", "int8") @@ -1892,32 +1908,17 @@ def load_models(model_filename): global transformer_filename perc_reserved_mem_max = args.perc_reserved_mem_max - - major, minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None) - if major < 8: - print("Switching to f16 model as GPU architecture doesn't support bf16") - default_dtype = torch.float16 - else: - default_dtype = torch.float16 if args.fp16 else torch.bfloat16 model_filelist = get_dependent_models(model_filename, quantization= transformer_quantization) + [model_filename] - updated_model_filename = [] for filename in model_filelist: - if default_dtype == torch.float16 : - if "quanto_int8" in filename: - filename = filename.replace("quanto_int8", "quanto_fp16_int8") - elif "quanto_mbf16_int8": - filename = filename.replace("quanto_mbf16_int8", "quanto_mfp16_int8") - updated_model_filename.append(filename) download_models(filename, text_encoder_filename) - model_filelist = updated_model_filename VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float mixed_precision_transformer = server_config.get("mixed_precision","0") == "1" transformer_filename = None new_transformer_filename = model_filelist[-1] if test_class_i2v(new_transformer_filename): - wan_model, pipe = load_i2v_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = default_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer) + wan_model, pipe = load_i2v_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer) else: - wan_model, pipe = load_t2v_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = default_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer) + wan_model, pipe = load_t2v_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer) wan_model._model_file_name = new_transformer_filename kwargs = { "extraModelsToQuantize": None} if profile == 2 or profile == 4: @@ -1926,7 +1927,7 @@ def load_models(model_filename): # kwargs["partialPinning"] = True elif profile == 3: kwargs["budgets"] = { "*" : "70%" } - offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = default_dtype, **kwargs) + offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = transformer_dtype, **kwargs) if len(args.gpu) > 0: torch.set_default_device(args.gpu) transformer_filename = new_transformer_filename @@ -2410,6 +2411,7 @@ def generate_video( ): global wan_model, offloadobj, reload_needed gen = get_gen_info(state) + torch.set_grad_enabled(False) file_list = gen["file_list"] prompt_no = gen["prompt_no"] @@ -2574,6 +2576,7 @@ def generate_video( if seed == None or seed <0: seed = random.randint(0, 999999999) + torch.set_grad_enabled(False) global save_path os.makedirs(save_path, exist_ok=True) abort = False