mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2026-02-06 18:57:49 +00:00
added lora unet support for flux
This commit is contained in:
parent
075aaa8f90
commit
b5676254f8
@ -91,7 +91,27 @@ class Flux(nn.Module):
|
||||
return new_weight
|
||||
|
||||
first_key= next(iter(sd))
|
||||
if first_key.startswith("transformer."):
|
||||
if first_key.startswith("lora_unet_"):
|
||||
new_sd = {}
|
||||
print("Converting Lora Safetensors format to Lora Diffusers format")
|
||||
repl_list = ["linear1", "linear2", "modulation_lin"]
|
||||
src_list = ["_" + k + "." for k in repl_list]
|
||||
tgt_list = ["." + k.replace("_", ".") + "." for k in repl_list]
|
||||
|
||||
for k,v in sd.items():
|
||||
k = k.replace("lora_unet_blocks_","diffusion_model.blocks.")
|
||||
k = k.replace("lora_unet__blocks_","diffusion_model.blocks.")
|
||||
k = k.replace("lora_unet_single_blocks_","diffusion_model.single_blocks.")
|
||||
|
||||
for s,t in zip(src_list, tgt_list):
|
||||
k = k.replace(s,t)
|
||||
|
||||
k = k.replace("lora_up","lora_B")
|
||||
k = k.replace("lora_down","lora_A")
|
||||
|
||||
new_sd[k] = v
|
||||
|
||||
elif first_key.startswith("transformer."):
|
||||
root_src = ["time_text_embed.timestep_embedder.linear_1", "time_text_embed.timestep_embedder.linear_2", "time_text_embed.text_embedder.linear_1", "time_text_embed.text_embedder.linear_2",
|
||||
"time_text_embed.guidance_embedder.linear_1", "time_text_embed.guidance_embedder.linear_2",
|
||||
"x_embedder", "context_embedder", "proj_out" ]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user