diff --git a/wan/multitalk/multitalk_utils.py b/wan/multitalk/multitalk_utils.py index 4722dae..8462390 100644 --- a/wan/multitalk/multitalk_utils.py +++ b/wan/multitalk/multitalk_utils.py @@ -128,7 +128,7 @@ def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, spli _, seq_lens, heads, _ = visual_q.shape class_num, _ = ref_target_masks.shape - x_ref_attn_maps = torch.zeros(class_num, seq_lens, dtype=visual_q.device, device=visual_q.dtype) + x_ref_attn_maps = torch.zeros(class_num, seq_lens, dtype=visual_q.dtype, device=visual_q.device) split_chunk = heads // split_num @@ -350,4 +350,4 @@ def adaptive_projected_guidance( diff = diff * scale_factor diff_parallel, diff_orthogonal = project(diff, pred_cond) normalized_update = diff_orthogonal + eta * diff_parallel - return normalized_update \ No newline at end of file + return normalized_update