diff --git a/references/segmentation/utils.py b/references/segmentation/utils.py index cb200f23d76..92db1899851 100644 --- a/references/segmentation/utils.py +++ b/references/segmentation/utils.py @@ -88,7 +88,7 @@ def compute(self): return acc_global, acc, iu def reduce_from_all_processes(self): - reduce_across_processes(self.mat) + self.mat = reduce_across_processes(self.mat).to(torch.int64) def __str__(self): acc_global, acc, iu = self.compute()