merge latest vram fixes

This commit is contained in:
Chris Malone 2025-03-24 23:18:02 +11:00
parent 347ab55d4a
commit 1b125fe233

View File

@ -854,6 +854,8 @@ def load_models(i2v):
kwargs = { "extraModelsToQuantize": None} kwargs = { "extraModelsToQuantize": None}
if profile == 2 or profile == 4: if profile == 2 or profile == 4:
kwargs["budgets"] = { "transformer" : 100 if preload == 0 else preload, "text_encoder" : 100, "*" : 1000 } kwargs["budgets"] = { "transformer" : 100 if preload == 0 else preload, "text_encoder" : 100, "*" : 1000 }
if profile == 4:
kwargs["partialPinning"] = True
elif profile == 3: elif profile == 3:
kwargs["budgets"] = { "*" : "70%" } kwargs["budgets"] = { "*" : "70%" }
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", **kwargs) offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", **kwargs)
@ -864,7 +866,7 @@ wan_model, offloadobj, transformer = load_models(use_image2video)
if check_loras: if check_loras:
setup_loras(use_image2video, transformer, get_lora_dir(use_image2video), "", None) setup_loras(use_image2video, transformer, get_lora_dir(use_image2video), "", None)
exit() exit()
del transformer
gen_in_progress = False gen_in_progress = False
def get_auto_attention(): def get_auto_attention():
@ -1080,10 +1082,10 @@ def generate_video(
file_model_needed = model_needed(image2video) file_model_needed = model_needed(image2video)
if(server_config.get("reload_model",1) == 2): if(server_config.get("reload_model",1) == 2):
if file_model_needed != model_filename or reload_needed: if file_model_needed != model_filename or reload_needed:
del wan_model
if offloadobj is not None: if offloadobj is not None:
offloadobj.release() offloadobj.release()
offloadobj = None del offloadobj
wan_model = None
gc.collect() gc.collect()
print(f"Loading model {get_model_name(file_model_needed)}...") print(f"Loading model {get_model_name(file_model_needed)}...")
wan_model, offloadobj, trans = load_models(image2video) wan_model, offloadobj, trans = load_models(image2video)
@ -1194,7 +1196,7 @@ def generate_video(
if len(list_mult_choices_nums ) < len(loras_choices): if len(list_mult_choices_nums ) < len(loras_choices):
list_mult_choices_nums += [1.0] * ( len(loras_choices) - len(list_mult_choices_nums ) ) list_mult_choices_nums += [1.0] * ( len(loras_choices) - len(list_mult_choices_nums ) )
loras_selected = [ lora for i, lora in enumerate(loras) if str(i) in loras_choices] loras_selected = [ lora for i, lora in enumerate(loras) if str(i) in loras_choices]
pinnedLora = False !=5 #False # # # pinnedLora = False #profile !=5 #False # # #
offload.load_loras_into_model(trans, loras_selected, list_mult_choices_nums, activate_all_loras=True, preprocess_sd=preprocess_loras, pinnedLora=pinnedLora, split_linear_modules_map = None) offload.load_loras_into_model(trans, loras_selected, list_mult_choices_nums, activate_all_loras=True, preprocess_sd=preprocess_loras, pinnedLora=pinnedLora, split_linear_modules_map = None)
errors = trans._loras_errors errors = trans._loras_errors
if len(errors) > 0: if len(errors) > 0: