Fixed multitalk crash

This commit is contained in:
deepbeepmeep 2025-07-17 13:22:36 +02:00 committed by GitHub
parent a356c6af4b
commit 0be33acb57
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -128,7 +128,7 @@ def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, spli
_, seq_lens, heads, _ = visual_q.shape
class_num, _ = ref_target_masks.shape
x_ref_attn_maps = torch.zeros(class_num, seq_lens, dtype=visual_q.device, device=visual_q.dtype)
x_ref_attn_maps = torch.zeros(class_num, seq_lens, dtype=visual_q.dtype, device=visual_q.device)
split_chunk = heads // split_num
@ -350,4 +350,4 @@ def adaptive_projected_guidance(
diff = diff * scale_factor
diff_parallel, diff_orthogonal = project(diff, pred_cond)
normalized_update = diff_orthogonal + eta * diff_parallel
return normalized_update
return normalized_update