diff --git a/tests/unit/engine/adaptive_bs/test_bs_search_algo.py b/tests/unit/engine/adaptive_bs/test_bs_search_algo.py index 8e1078029cf..7ca347fe6c6 100644 --- a/tests/unit/engine/adaptive_bs/test_bs_search_algo.py +++ b/tests/unit/engine/adaptive_bs/test_bs_search_algo.py @@ -107,10 +107,10 @@ def test_find_max_usable_bs_gpu_memory_too_small(self): def test_auto_decrease_batch_size_bs2_not_oom_but_most_mem(self): """Batch size 2 doesn't make oom but use most of memory.""" - mock_train_func = self.get_mock_train_func(cuda_oom_bound=4, max_runnable_bs=1) + mock_train_func = self.get_mock_train_func(cuda_oom_bound=2, max_runnable_bs=1) bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000) - bs_search_algo.auto_decrease_batch_size() + assert bs_search_algo.auto_decrease_batch_size() == 2 @pytest.mark.parametrize( ("max_runnable_bs", "max_bs", "expected_bs"), @@ -141,11 +141,10 @@ def test_find_big_enough_batch_size_gpu_memory_too_small(self): def test_find_big_enough_batch_size_bs2_not_oom_but_most_mem(self): """Batch size 2 doesn't make oom but use most of memory.""" - mock_train_func = self.get_mock_train_func(cuda_oom_bound=1, max_runnable_bs=1) + mock_train_func = self.get_mock_train_func(cuda_oom_bound=2, max_runnable_bs=1) bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000) - with pytest.raises(RuntimeError): - bs_search_algo.find_big_enough_batch_size() + assert bs_search_algo.find_big_enough_batch_size() == 2 def test_find_big_enough_batch_size_gradient_zero(self): def mock_train_func(batch_size) -> int: