mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-05 14:33:15 +00:00
15 lines
510 B
Python
15 lines
510 B
Python
import torch
|
|
|
|
|
|
def distribute_into_histogram(loss: torch.Tensor,
|
|
t: torch.Tensor,
|
|
num_bins: int = 25) -> tuple[torch.Tensor, torch.Tensor]:
|
|
loss = loss.detach().flatten()
|
|
t = t.detach().flatten()
|
|
t = (t * num_bins).long()
|
|
hist = torch.zeros(num_bins, device=loss.device)
|
|
count = torch.zeros(num_bins, device=loss.device)
|
|
hist.scatter_add_(0, t, loss)
|
|
count.scatter_add_(0, t, torch.ones_like(loss))
|
|
return hist, count
|