added support for qwen lora safetensors format

This commit is contained in:
deepbeepmeep 2025-08-25 15:26:32 +02:00
parent 9b6448c19c
commit 7b3715f410

View File

@ -497,16 +497,36 @@ class QwenImageTransformer2DModel(nn.Module):
new_sd[k] = v
sd = new_sd
if first.startswith("transformer_blocks"):
new_sd = {}
for k,v in sd.items():
if k.startswith("transformer_blocks"):
k = "diffusion_model." + k
prefix_list = ["lora_unet_transformer_blocks"]
for prefix in prefix_list:
if first.startswith(prefix):
repl_list = ["attn", "img_mlp", "txt_mlp", "img_mod", "txt_mod"]
src_list = ["_" + k + "_" for k in repl_list]
tgt_list = ["." + k + "." for k in repl_list]
src_list2 = ["_0_", "_0.", "_1.", "_2."]
tgt_list2 = [".0.", ".0.", ".1.", ".2."]
new_sd = {}
for k,v in sd.items():
k = "diffusion_model.transformer_blocks." + k[len(prefix)+1:]
for s,t in zip(src_list, tgt_list):
k = k.replace(s,t)
for s,t in zip(src_list2, tgt_list2):
k = k.replace(s,t)
new_sd[k] = v
sd = new_sd
return sd
else:
return sd
sd = new_sd
return sd
prefix_list = ["transformer_blocks"]
for prefix in prefix_list:
if first.startswith(prefix):
new_sd = {}
for k,v in sd.items():
if k.startswith(prefix):
k = "diffusion_model." + k
new_sd[k] = v
sd = new_sd
return sd
return sd
def __init__(
self,