From 7b3715f410996eb2ba9808bcfce81f9418f73094 Mon Sep 17 00:00:00 2001 From: deepbeepmeep Date: Mon, 25 Aug 2025 15:26:32 +0200 Subject: [PATCH] added support for qwen lora safetensors format --- models/qwen/transformer_qwenimage.py | 38 +++++++++++++++++++++------- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/models/qwen/transformer_qwenimage.py b/models/qwen/transformer_qwenimage.py index 8751648..6d90806 100644 --- a/models/qwen/transformer_qwenimage.py +++ b/models/qwen/transformer_qwenimage.py @@ -497,16 +497,36 @@ class QwenImageTransformer2DModel(nn.Module): new_sd[k] = v sd = new_sd - if first.startswith("transformer_blocks"): - new_sd = {} - for k,v in sd.items(): - if k.startswith("transformer_blocks"): - k = "diffusion_model." + k + prefix_list = ["lora_unet_transformer_blocks"] + for prefix in prefix_list: + if first.startswith(prefix): + repl_list = ["attn", "img_mlp", "txt_mlp", "img_mod", "txt_mod"] + src_list = ["_" + k + "_" for k in repl_list] + tgt_list = ["." + k + "." for k in repl_list] + src_list2 = ["_0_", "_0.", "_1.", "_2."] + tgt_list2 = [".0.", ".0.", ".1.", ".2."] + new_sd = {} + for k,v in sd.items(): + k = "diffusion_model.transformer_blocks." + k[len(prefix)+1:] + for s,t in zip(src_list, tgt_list): + k = k.replace(s,t) + for s,t in zip(src_list2, tgt_list2): + k = k.replace(s,t) new_sd[k] = v - sd = new_sd - return sd - else: - return sd + sd = new_sd + return sd + + prefix_list = ["transformer_blocks"] + for prefix in prefix_list: + if first.startswith(prefix): + new_sd = {} + for k,v in sd.items(): + if k.startswith(prefix): + k = "diffusion_model." + k + new_sd[k] = v + sd = new_sd + return sd + return sd def __init__( self,