added lora unet support for flux

This commit is contained in:
deepbeepmeep 2025-07-21 13:54:18 +02:00
parent 075aaa8f90
commit b5676254f8

View File

@ -84,14 +84,34 @@ class Flux(nn.Module):
def preprocess_loras(self, model_type, sd):
new_sd = {}
if len(sd) == 0: return sd
def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
first_key= next(iter(sd))
if first_key.startswith("transformer."):
if first_key.startswith("lora_unet_"):
new_sd = {}
print("Converting Lora Safetensors format to Lora Diffusers format")
repl_list = ["linear1", "linear2", "modulation_lin"]
src_list = ["_" + k + "." for k in repl_list]
tgt_list = ["." + k.replace("_", ".") + "." for k in repl_list]
for k,v in sd.items():
k = k.replace("lora_unet_blocks_","diffusion_model.blocks.")
k = k.replace("lora_unet__blocks_","diffusion_model.blocks.")
k = k.replace("lora_unet_single_blocks_","diffusion_model.single_blocks.")
for s,t in zip(src_list, tgt_list):
k = k.replace(s,t)
k = k.replace("lora_up","lora_B")
k = k.replace("lora_down","lora_A")
new_sd[k] = v
elif first_key.startswith("transformer."):
root_src = ["time_text_embed.timestep_embedder.linear_1", "time_text_embed.timestep_embedder.linear_2", "time_text_embed.text_embedder.linear_1", "time_text_embed.text_embedder.linear_2",
"time_text_embed.guidance_embedder.linear_1", "time_text_embed.guidance_embedder.linear_2",
"x_embedder", "context_embedder", "proj_out" ]