mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-06-17 12:57:40 +00:00
Compare commits
3 Commits
680dcc74be
...
835d80e6ed
Author | SHA1 | Date | |
---|---|---|---|
|
835d80e6ed | ||
|
e5a741309d | ||
|
5f7e7ed289 |
@ -643,7 +643,7 @@ If you find our work helpful, please cite us.
|
||||
```
|
||||
@article{wan2025,
|
||||
title={Wan: Open and Advanced Large-Scale Video Generative Models},
|
||||
author={Ang Wang and Baole Ai and Bin Wen and Chaojie Mao and Chen-Wei Xie and Di Chen and Feiwu Yu and Haiming Zhao and Jianxiao Yang and Jianyuan Zeng and Jiayu Wang and Jingfeng Zhang and Jingren Zhou and Jinkai Wang and Jixuan Chen and Kai Zhu and Kang Zhao and Keyu Yan and Lianghua Huang and Mengyang Feng and Ningyi Zhang and Pandeng Li and Pingyu Wu and Ruihang Chu and Ruili Feng and Shiwei Zhang and Siyang Sun and Tao Fang and Tianxing Wang and Tianyi Gui and Tingyu Weng and Tong Shen and Wei Lin and Wei Wang and Wei Wang and Wenmeng Zhou and Wente Wang and Wenting Shen and Wenyuan Yu and Xianzhong Shi and Xiaoming Huang and Xin Xu and Yan Kou and Yangyu Lv and Yifei Li and Yijing Liu and Yiming Wang and Yingya Zhang and Yitong Huang and Yong Li and You Wu and Yu Liu and Yulin Pan and Yun Zheng and Yuntao Hong and Yupeng Shi and Yutong Feng and Zeyinzi Jiang and Zhen Han and Zhi-Fan Wu and Ziyu Liu},
|
||||
author={Team Wan and Ang Wang and Baole Ai and Bin Wen and Chaojie Mao and Chen-Wei Xie and Di Chen and Feiwu Yu and Haiming Zhao and Jianxiao Yang and Jianyuan Zeng and Jiayu Wang and Jingfeng Zhang and Jingren Zhou and Jinkai Wang and Jixuan Chen and Kai Zhu and Kang Zhao and Keyu Yan and Lianghua Huang and Mengyang Feng and Ningyi Zhang and Pandeng Li and Pingyu Wu and Ruihang Chu and Ruili Feng and Shiwei Zhang and Siyang Sun and Tao Fang and Tianxing Wang and Tianyi Gui and Tingyu Weng and Tong Shen and Wei Lin and Wei Wang and Wei Wang and Wenmeng Zhou and Wente Wang and Wenting Shen and Wenyuan Yu and Xianzhong Shi and Xiaoming Huang and Xin Xu and Yan Kou and Yangyu Lv and Yifei Li and Yijing Liu and Yiming Wang and Yingya Zhang and Yitong Huang and Yong Li and You Wu and Yu Liu and Yulin Pan and Yun Zheng and Yuntao Hong and Yupeng Shi and Yutong Feng and Zeyinzi Jiang and Zhen Han and Zhi-Fan Wu and Ziyu Liu},
|
||||
journal = {arXiv preprint arXiv:2503.20314},
|
||||
year={2025}
|
||||
}
|
||||
|
@ -79,6 +79,7 @@ def flash_attention(
|
||||
k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
|
||||
v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
|
||||
|
||||
try:
|
||||
q = q.to(v.dtype)
|
||||
k = k.to(v.dtype)
|
||||
|
||||
@ -126,6 +127,53 @@ def flash_attention(
|
||||
window_size=window_size,
|
||||
deterministic=deterministic).unflatten(0, (b, lq))
|
||||
|
||||
except RuntimeError as e:
|
||||
if "FlashAttention only supports Ampere GPUs or newer" in str(e):
|
||||
#for cards like 2080ti that aren't Ampere structure
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
q = q.to(half(k).dtype)
|
||||
|
||||
# 转置维度,保证形状为 [B, N, L, C]
|
||||
q = q.view(b, lq, q.size(1), q.size(2)).transpose(1, 2)
|
||||
k = k.view(b, lk, k.size(1), k.size(2)).transpose(1, 2)
|
||||
v = v.view(b, lk, v.size(1), v.size(2)).transpose(1, 2)
|
||||
|
||||
# 计算注意力
|
||||
# 注意:确保 Q、K、V 的形状为 [B, N, L, C]
|
||||
# 设置默认缩放因子
|
||||
if softmax_scale is None:
|
||||
softmax_scale = 1.0 / q.size(-1) ** 0.5
|
||||
|
||||
# 如果 q_scale 存在,则应用缩放
|
||||
if q_scale is not None:
|
||||
q = q * q_scale
|
||||
|
||||
# 创建掩码
|
||||
if causal:
|
||||
attn_mask = torch.triu(torch.full((q.size(2), k.size(2)), -torch.inf), diagonal=1).to(q.device)
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
# 计算注意力
|
||||
# 使用 scaled_dot_product_attention
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=causal,
|
||||
)
|
||||
|
||||
# 转换回原形状 [B, L, N, C]
|
||||
x = x.transpose(1, 2).contiguous()
|
||||
|
||||
# 对输出应用 Dropout
|
||||
dropout = nn.Dropout(dropout_p)
|
||||
x = dropout(x)
|
||||
else:
|
||||
raise
|
||||
|
||||
# output
|
||||
return x.type(out_dtype)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user