diff --git a/tests/unit/engine/adaptive_bs/test_bs_search_algo.py b/tests/unit/engine/adaptive_bs/test_bs_search_algo.py index 7ca347fe6c6..f59225e3b8a 100644 --- a/tests/unit/engine/adaptive_bs/test_bs_search_algo.py +++ b/tests/unit/engine/adaptive_bs/test_bs_search_algo.py @@ -143,7 +143,7 @@ def test_find_big_enough_batch_size_bs2_not_oom_but_most_mem(self): """Batch size 2 doesn't make oom but use most of memory.""" mock_train_func = self.get_mock_train_func(cuda_oom_bound=2, max_runnable_bs=1) - bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000) + bs_search_algo = BsSearchAlgo(mock_train_func, 2, 1000) assert bs_search_algo.find_big_enough_batch_size() == 2 def test_find_big_enough_batch_size_gradient_zero(self):