From 7c38cb8a23173da667cf3b108095b5fa597312a7 Mon Sep 17 00:00:00 2001 From: "Kang, Harim" Date: Fri, 11 Oct 2024 09:31:17 +0900 Subject: [PATCH] Fix balanced sampler --- src/otx/algo/samplers/balanced_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/otx/algo/samplers/balanced_sampler.py b/src/otx/algo/samplers/balanced_sampler.py index 287bbf1dcf4..4b6cfb56caa 100644 --- a/src/otx/algo/samplers/balanced_sampler.py +++ b/src/otx/algo/samplers/balanced_sampler.py @@ -65,7 +65,7 @@ def __init__( self.img_indices = {k: torch.tensor(v, dtype=torch.int64) for k, v in ann_stats.items() if len(v) > 0} self.num_cls = len(self.img_indices.keys()) self.data_length = len(self.dataset) - self.num_trials = int(self.data_length / self.num_cls) + self.num_trials = max(int(self.data_length / self.num_cls), 1) if efficient_mode: # Reduce the # of sampling (sampling data for a single epoch)