From 0be33acb57ac525a824a7318598e74b6b04c9e06 Mon Sep 17 00:00:00 2001 From: deepbeepmeep <84379123+deepbeepmeep@users.noreply.github.com> Date: Thu, 17 Jul 2025 13:22:36 +0200 Subject: [PATCH] Fixed multitalk crash --- wan/multitalk/multitalk_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wan/multitalk/multitalk_utils.py b/wan/multitalk/multitalk_utils.py index 4722dae..8462390 100644 --- a/wan/multitalk/multitalk_utils.py +++ b/wan/multitalk/multitalk_utils.py @@ -128,7 +128,7 @@ def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, spli _, seq_lens, heads, _ = visual_q.shape class_num, _ = ref_target_masks.shape - x_ref_attn_maps = torch.zeros(class_num, seq_lens, dtype=visual_q.device, device=visual_q.dtype) + x_ref_attn_maps = torch.zeros(class_num, seq_lens, dtype=visual_q.dtype, device=visual_q.device) split_chunk = heads // split_num @@ -350,4 +350,4 @@ def adaptive_projected_guidance( diff = diff * scale_factor diff_parallel, diff_orthogonal = project(diff, pred_cond) normalized_update = diff_orthogonal + eta * diff_parallel - return normalized_update \ No newline at end of file + return normalized_update