Compare commits

...

3 Commits

Author SHA1 Message Date
shirubei
835d80e6ed
Merge 5f7e7ed289 into e5a741309d 2025-05-17 20:06:08 +03:00
Shiwei Zhang
e5a741309d
Update README.md (#406) 2025-05-17 10:57:06 +08:00
shirubei
5f7e7ed289
Update attention.py
Adding support for cards that aren't Ampere architecture
2025-02-28 23:29:31 +09:00
2 changed files with 92 additions and 44 deletions

View File

@ -643,7 +643,7 @@ If you find our work helpful, please cite us.
```
@article{wan2025,
title={Wan: Open and Advanced Large-Scale Video Generative Models},
author={Ang Wang and Baole Ai and Bin Wen and Chaojie Mao and Chen-Wei Xie and Di Chen and Feiwu Yu and Haiming Zhao and Jianxiao Yang and Jianyuan Zeng and Jiayu Wang and Jingfeng Zhang and Jingren Zhou and Jinkai Wang and Jixuan Chen and Kai Zhu and Kang Zhao and Keyu Yan and Lianghua Huang and Mengyang Feng and Ningyi Zhang and Pandeng Li and Pingyu Wu and Ruihang Chu and Ruili Feng and Shiwei Zhang and Siyang Sun and Tao Fang and Tianxing Wang and Tianyi Gui and Tingyu Weng and Tong Shen and Wei Lin and Wei Wang and Wei Wang and Wenmeng Zhou and Wente Wang and Wenting Shen and Wenyuan Yu and Xianzhong Shi and Xiaoming Huang and Xin Xu and Yan Kou and Yangyu Lv and Yifei Li and Yijing Liu and Yiming Wang and Yingya Zhang and Yitong Huang and Yong Li and You Wu and Yu Liu and Yulin Pan and Yun Zheng and Yuntao Hong and Yupeng Shi and Yutong Feng and Zeyinzi Jiang and Zhen Han and Zhi-Fan Wu and Ziyu Liu},
author={Team Wan and Ang Wang and Baole Ai and Bin Wen and Chaojie Mao and Chen-Wei Xie and Di Chen and Feiwu Yu and Haiming Zhao and Jianxiao Yang and Jianyuan Zeng and Jiayu Wang and Jingfeng Zhang and Jingren Zhou and Jinkai Wang and Jixuan Chen and Kai Zhu and Kang Zhao and Keyu Yan and Lianghua Huang and Mengyang Feng and Ningyi Zhang and Pandeng Li and Pingyu Wu and Ruihang Chu and Ruili Feng and Shiwei Zhang and Siyang Sun and Tao Fang and Tianxing Wang and Tianyi Gui and Tingyu Weng and Tong Shen and Wei Lin and Wei Wang and Wei Wang and Wenmeng Zhou and Wente Wang and Wenting Shen and Wenyuan Yu and Xianzhong Shi and Xiaoming Huang and Xin Xu and Yan Kou and Yangyu Lv and Yifei Li and Yijing Liu and Yiming Wang and Yingya Zhang and Yitong Huang and Yong Li and You Wu and Yu Liu and Yulin Pan and Yun Zheng and Yuntao Hong and Yupeng Shi and Yutong Feng and Zeyinzi Jiang and Zhen Han and Zhi-Fan Wu and Ziyu Liu},
journal = {arXiv preprint arXiv:2503.20314},
year={2025}
}

View File

@ -79,6 +79,7 @@ def flash_attention(
k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
try:
q = q.to(v.dtype)
k = k.to(v.dtype)
@ -126,6 +127,53 @@ def flash_attention(
window_size=window_size,
deterministic=deterministic).unflatten(0, (b, lq))
except RuntimeError as e:
if "FlashAttention only supports Ampere GPUs or newer" in str(e):
#for cards like 2080ti that aren't Ampere structure
from torch import nn
import torch.nn.functional as F
q = q.to(half(k).dtype)
# 转置维度,保证形状为 [B, N, L, C]
q = q.view(b, lq, q.size(1), q.size(2)).transpose(1, 2)
k = k.view(b, lk, k.size(1), k.size(2)).transpose(1, 2)
v = v.view(b, lk, v.size(1), v.size(2)).transpose(1, 2)
# 计算注意力
# 注意:确保 Q、K、V 的形状为 [B, N, L, C]
# 设置默认缩放因子
if softmax_scale is None:
softmax_scale = 1.0 / q.size(-1) ** 0.5
# 如果 q_scale 存在,则应用缩放
if q_scale is not None:
q = q * q_scale
# 创建掩码
if causal:
attn_mask = torch.triu(torch.full((q.size(2), k.size(2)), -torch.inf), diagonal=1).to(q.device)
else:
attn_mask = None
# 计算注意力
# 使用 scaled_dot_product_attention
x = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=causal,
)
# 转换回原形状 [B, L, N, C]
x = x.transpose(1, 2).contiguous()
# 对输出应用 Dropout
dropout = nn.Dropout(dropout_p)
x = dropout(x)
else:
raise
# output
return x.type(out_dtype)