From 1b125fe2331e884297aa724ffc971e82e137afc2 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Mon, 24 Mar 2025 23:18:02 +1100 Subject: [PATCH] merge latest vram fixes --- gradio_server.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/gradio_server.py b/gradio_server.py index 2c71042..9671334 100644 --- a/gradio_server.py +++ b/gradio_server.py @@ -854,6 +854,8 @@ def load_models(i2v): kwargs = { "extraModelsToQuantize": None} if profile == 2 or profile == 4: kwargs["budgets"] = { "transformer" : 100 if preload == 0 else preload, "text_encoder" : 100, "*" : 1000 } + if profile == 4: + kwargs["partialPinning"] = True elif profile == 3: kwargs["budgets"] = { "*" : "70%" } offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", **kwargs) @@ -864,7 +866,7 @@ wan_model, offloadobj, transformer = load_models(use_image2video) if check_loras: setup_loras(use_image2video, transformer, get_lora_dir(use_image2video), "", None) exit() - +del transformer gen_in_progress = False def get_auto_attention(): @@ -1080,10 +1082,10 @@ def generate_video( file_model_needed = model_needed(image2video) if(server_config.get("reload_model",1) == 2): if file_model_needed != model_filename or reload_needed: + del wan_model if offloadobj is not None: offloadobj.release() - offloadobj = None - wan_model = None + del offloadobj gc.collect() print(f"Loading model {get_model_name(file_model_needed)}...") wan_model, offloadobj, trans = load_models(image2video) @@ -1194,7 +1196,7 @@ def generate_video( if len(list_mult_choices_nums ) < len(loras_choices): list_mult_choices_nums += [1.0] * ( len(loras_choices) - len(list_mult_choices_nums ) ) loras_selected = [ lora for i, lora in enumerate(loras) if str(i) in loras_choices] - pinnedLora = False !=5 #False # # # + pinnedLora = False #profile !=5 #False # # # offload.load_loras_into_model(trans, loras_selected, list_mult_choices_nums, activate_all_loras=True, preprocess_sd=preprocess_loras, pinnedLora=pinnedLora, split_linear_modules_map = None) errors = trans._loras_errors if len(errors) > 0: