mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
95 lines
2.5 KiB
Python
95 lines
2.5 KiB
Python
import torch
|
|
from torch import nn
|
|
|
|
|
|
def replace_linear_with_lora(
|
|
module: nn.Module,
|
|
max_rank: int,
|
|
scale: float = 1.0,
|
|
) -> None:
|
|
for name, child in module.named_children():
|
|
if isinstance(child, nn.Linear):
|
|
new_lora = LinearLora(
|
|
in_features=child.in_features,
|
|
out_features=child.out_features,
|
|
bias=child.bias,
|
|
rank=max_rank,
|
|
scale=scale,
|
|
dtype=child.weight.dtype,
|
|
device=child.weight.device,
|
|
)
|
|
|
|
new_lora.weight = child.weight
|
|
new_lora.bias = child.bias if child.bias is not None else None
|
|
|
|
setattr(module, name, new_lora)
|
|
else:
|
|
replace_linear_with_lora(
|
|
module=child,
|
|
max_rank=max_rank,
|
|
scale=scale,
|
|
)
|
|
|
|
|
|
class LinearLora(nn.Linear):
|
|
def __init__(
|
|
self,
|
|
in_features: int,
|
|
out_features: int,
|
|
bias: bool,
|
|
rank: int,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
lora_bias: bool = True,
|
|
scale: float = 1.0,
|
|
*args,
|
|
**kwargs,
|
|
) -> None:
|
|
super().__init__(
|
|
in_features=in_features,
|
|
out_features=out_features,
|
|
bias=bias is not None,
|
|
device=device,
|
|
dtype=dtype,
|
|
*args,
|
|
**kwargs,
|
|
)
|
|
|
|
assert isinstance(scale, float), "scale must be a float"
|
|
|
|
self.scale = scale
|
|
self.rank = rank
|
|
self.lora_bias = lora_bias
|
|
self.dtype = dtype
|
|
self.device = device
|
|
|
|
if rank > (new_rank := min(self.out_features, self.in_features)):
|
|
self.rank = new_rank
|
|
|
|
self.lora_A = nn.Linear(
|
|
in_features=in_features,
|
|
out_features=self.rank,
|
|
bias=False,
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
self.lora_B = nn.Linear(
|
|
in_features=self.rank,
|
|
out_features=out_features,
|
|
bias=self.lora_bias,
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
|
|
def set_scale(self, scale: float) -> None:
|
|
assert isinstance(scale, float), "scalar value must be a float"
|
|
self.scale = scale
|
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
base_out = super().forward(input)
|
|
|
|
_lora_out_B = self.lora_B(self.lora_A(input))
|
|
lora_update = _lora_out_B * self.scale
|
|
|
|
return base_out + lora_update
|