mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-12-15 11:43:21 +00:00
added lora unet support for flux
This commit is contained in:
parent
075aaa8f90
commit
b5676254f8
@ -84,14 +84,34 @@ class Flux(nn.Module):
|
||||
def preprocess_loras(self, model_type, sd):
|
||||
new_sd = {}
|
||||
if len(sd) == 0: return sd
|
||||
|
||||
|
||||
def swap_scale_shift(weight):
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
return new_weight
|
||||
|
||||
first_key= next(iter(sd))
|
||||
if first_key.startswith("transformer."):
|
||||
if first_key.startswith("lora_unet_"):
|
||||
new_sd = {}
|
||||
print("Converting Lora Safetensors format to Lora Diffusers format")
|
||||
repl_list = ["linear1", "linear2", "modulation_lin"]
|
||||
src_list = ["_" + k + "." for k in repl_list]
|
||||
tgt_list = ["." + k.replace("_", ".") + "." for k in repl_list]
|
||||
|
||||
for k,v in sd.items():
|
||||
k = k.replace("lora_unet_blocks_","diffusion_model.blocks.")
|
||||
k = k.replace("lora_unet__blocks_","diffusion_model.blocks.")
|
||||
k = k.replace("lora_unet_single_blocks_","diffusion_model.single_blocks.")
|
||||
|
||||
for s,t in zip(src_list, tgt_list):
|
||||
k = k.replace(s,t)
|
||||
|
||||
k = k.replace("lora_up","lora_B")
|
||||
k = k.replace("lora_down","lora_A")
|
||||
|
||||
new_sd[k] = v
|
||||
|
||||
elif first_key.startswith("transformer."):
|
||||
root_src = ["time_text_embed.timestep_embedder.linear_1", "time_text_embed.timestep_embedder.linear_2", "time_text_embed.text_embedder.linear_1", "time_text_embed.text_embedder.linear_2",
|
||||
"time_text_embed.guidance_embedder.linear_1", "time_text_embed.guidance_embedder.linear_2",
|
||||
"x_embedder", "context_embedder", "proj_out" ]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user