mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-12-15 11:43:21 +00:00
added support for qwen lora safetensors format
This commit is contained in:
parent
9b6448c19c
commit
7b3715f410
@ -497,16 +497,36 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
new_sd[k] = v
|
||||
sd = new_sd
|
||||
|
||||
if first.startswith("transformer_blocks"):
|
||||
new_sd = {}
|
||||
for k,v in sd.items():
|
||||
if k.startswith("transformer_blocks"):
|
||||
k = "diffusion_model." + k
|
||||
prefix_list = ["lora_unet_transformer_blocks"]
|
||||
for prefix in prefix_list:
|
||||
if first.startswith(prefix):
|
||||
repl_list = ["attn", "img_mlp", "txt_mlp", "img_mod", "txt_mod"]
|
||||
src_list = ["_" + k + "_" for k in repl_list]
|
||||
tgt_list = ["." + k + "." for k in repl_list]
|
||||
src_list2 = ["_0_", "_0.", "_1.", "_2."]
|
||||
tgt_list2 = [".0.", ".0.", ".1.", ".2."]
|
||||
new_sd = {}
|
||||
for k,v in sd.items():
|
||||
k = "diffusion_model.transformer_blocks." + k[len(prefix)+1:]
|
||||
for s,t in zip(src_list, tgt_list):
|
||||
k = k.replace(s,t)
|
||||
for s,t in zip(src_list2, tgt_list2):
|
||||
k = k.replace(s,t)
|
||||
new_sd[k] = v
|
||||
sd = new_sd
|
||||
return sd
|
||||
else:
|
||||
return sd
|
||||
sd = new_sd
|
||||
return sd
|
||||
|
||||
prefix_list = ["transformer_blocks"]
|
||||
for prefix in prefix_list:
|
||||
if first.startswith(prefix):
|
||||
new_sd = {}
|
||||
for k,v in sd.items():
|
||||
if k.startswith(prefix):
|
||||
k = "diffusion_model." + k
|
||||
new_sd[k] = v
|
||||
sd = new_sd
|
||||
return sd
|
||||
return sd
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user