mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
143 lines
4.4 KiB
Python
143 lines
4.4 KiB
Python
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
|
import torch
|
|
import numbers
|
|
from peft import LoraConfig
|
|
|
|
|
|
def get_loraconfig(transformer, rank=128, alpha=128, init_lora_weights="gaussian"):
|
|
target_modules = []
|
|
for name, module in transformer.named_modules():
|
|
if "blocks" in name and "face" not in name and "modulation" not in name and isinstance(module, torch.nn.Linear):
|
|
target_modules.append(name)
|
|
|
|
transformer_lora_config = LoraConfig(
|
|
r=rank,
|
|
lora_alpha=alpha,
|
|
init_lora_weights=init_lora_weights,
|
|
target_modules=target_modules,
|
|
)
|
|
return transformer_lora_config
|
|
|
|
|
|
|
|
class TensorList(object):
|
|
|
|
def __init__(self, tensors):
|
|
"""
|
|
tensors: a list of torch.Tensor objects. No need to have uniform shape.
|
|
"""
|
|
assert isinstance(tensors, (list, tuple))
|
|
assert all(isinstance(u, torch.Tensor) for u in tensors)
|
|
assert len(set([u.ndim for u in tensors])) == 1
|
|
assert len(set([u.dtype for u in tensors])) == 1
|
|
assert len(set([u.device for u in tensors])) == 1
|
|
self.tensors = tensors
|
|
|
|
def to(self, *args, **kwargs):
|
|
return TensorList([u.to(*args, **kwargs) for u in self.tensors])
|
|
|
|
def size(self, dim):
|
|
assert dim == 0, 'only support get the 0th size'
|
|
return len(self.tensors)
|
|
|
|
def pow(self, *args, **kwargs):
|
|
return TensorList([u.pow(*args, **kwargs) for u in self.tensors])
|
|
|
|
def squeeze(self, dim):
|
|
assert dim != 0
|
|
if dim > 0:
|
|
dim -= 1
|
|
return TensorList([u.squeeze(dim) for u in self.tensors])
|
|
|
|
def type(self, *args, **kwargs):
|
|
return TensorList([u.type(*args, **kwargs) for u in self.tensors])
|
|
|
|
def type_as(self, other):
|
|
assert isinstance(other, (torch.Tensor, TensorList))
|
|
if isinstance(other, torch.Tensor):
|
|
return TensorList([u.type_as(other) for u in self.tensors])
|
|
else:
|
|
return TensorList([u.type(other.dtype) for u in self.tensors])
|
|
|
|
@property
|
|
def dtype(self):
|
|
return self.tensors[0].dtype
|
|
|
|
@property
|
|
def device(self):
|
|
return self.tensors[0].device
|
|
|
|
@property
|
|
def ndim(self):
|
|
return 1 + self.tensors[0].ndim
|
|
|
|
def __getitem__(self, index):
|
|
return self.tensors[index]
|
|
|
|
def __len__(self):
|
|
return len(self.tensors)
|
|
|
|
def __add__(self, other):
|
|
return self._apply(other, lambda u, v: u + v)
|
|
|
|
def __radd__(self, other):
|
|
return self._apply(other, lambda u, v: v + u)
|
|
|
|
def __sub__(self, other):
|
|
return self._apply(other, lambda u, v: u - v)
|
|
|
|
def __rsub__(self, other):
|
|
return self._apply(other, lambda u, v: v - u)
|
|
|
|
def __mul__(self, other):
|
|
return self._apply(other, lambda u, v: u * v)
|
|
|
|
def __rmul__(self, other):
|
|
return self._apply(other, lambda u, v: v * u)
|
|
|
|
def __floordiv__(self, other):
|
|
return self._apply(other, lambda u, v: u // v)
|
|
|
|
def __truediv__(self, other):
|
|
return self._apply(other, lambda u, v: u / v)
|
|
|
|
def __rfloordiv__(self, other):
|
|
return self._apply(other, lambda u, v: v // u)
|
|
|
|
def __rtruediv__(self, other):
|
|
return self._apply(other, lambda u, v: v / u)
|
|
|
|
def __pow__(self, other):
|
|
return self._apply(other, lambda u, v: u ** v)
|
|
|
|
def __rpow__(self, other):
|
|
return self._apply(other, lambda u, v: v ** u)
|
|
|
|
def __neg__(self):
|
|
return TensorList([-u for u in self.tensors])
|
|
|
|
def __iter__(self):
|
|
for tensor in self.tensors:
|
|
yield tensor
|
|
|
|
def __repr__(self):
|
|
return 'TensorList: \n' + repr(self.tensors)
|
|
|
|
def _apply(self, other, op):
|
|
if isinstance(other, (list, tuple, TensorList)) or (
|
|
isinstance(other, torch.Tensor) and (
|
|
other.numel() > 1 or other.ndim > 1
|
|
)
|
|
):
|
|
assert len(other) == len(self.tensors)
|
|
return TensorList([op(u, v) for u, v in zip(self.tensors, other)])
|
|
elif isinstance(other, numbers.Number) or (
|
|
isinstance(other, torch.Tensor) and (
|
|
other.numel() == 1 and other.ndim <= 1
|
|
)
|
|
):
|
|
return TensorList([op(u, other) for u in self.tensors])
|
|
else:
|
|
raise TypeError(
|
|
f'unsupported operand for *: "TensorList" and "{type(other)}"'
|
|
) |