mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-05 06:29:14 +00:00
Fixed Flash attention
This commit is contained in:
parent
18f3a31daf
commit
28f19586a5
@ -1161,7 +1161,7 @@ def create_demo():
|
|||||||
video_to_continue = gr.Video(label= "Video to continue", visible= use_image2video and False) #######
|
video_to_continue = gr.Video(label= "Video to continue", visible= use_image2video and False) #######
|
||||||
if args.multiple_images:
|
if args.multiple_images:
|
||||||
image_to_continue = gr.Gallery(
|
image_to_continue = gr.Gallery(
|
||||||
label="Images as a starting point for new videos", type ="pil", #file_types= "image",
|
label="Images as a starting point for new videos", type ="numpy", #file_types= "image",
|
||||||
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=use_image2video)
|
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=use_image2video)
|
||||||
else:
|
else:
|
||||||
image_to_continue = gr.Image(label= "Image as a starting point for a new video", visible=use_image2video)
|
image_to_continue = gr.Image(label= "Image as a starting point for a new video", visible=use_image2video)
|
||||||
|
|||||||
@ -201,7 +201,7 @@ def pay_attention(
|
|||||||
qkv_list = [q, k, v]
|
qkv_list = [q, k, v]
|
||||||
del q, k , v
|
del q, k , v
|
||||||
x = sdpa_wrapper( qkv_list, lq).unsqueeze(0)
|
x = sdpa_wrapper( qkv_list, lq).unsqueeze(0)
|
||||||
elif attn=="flash" and (version is None or version == 3):
|
elif attn=="flash" and version == 3:
|
||||||
# Note: dropout_p, window_size are not supported in FA3 now.
|
# Note: dropout_p, window_size are not supported in FA3 now.
|
||||||
x = flash_attn_interface.flash_attn_varlen_func(
|
x = flash_attn_interface.flash_attn_varlen_func(
|
||||||
q=q,
|
q=q,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user