From 6b9153a4e6f17cd3db1d8b8bbe1a65aa017d0f7f Mon Sep 17 00:00:00 2001 From: "Shin, Eunwoo" Date: Thu, 10 Oct 2024 15:13:37 +0900 Subject: [PATCH] update unit test --- tests/unit/engine/adaptive_bs/test_bs_search_algo.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/unit/engine/adaptive_bs/test_bs_search_algo.py b/tests/unit/engine/adaptive_bs/test_bs_search_algo.py index 8e1078029cf..7ca347fe6c6 100644 --- a/tests/unit/engine/adaptive_bs/test_bs_search_algo.py +++ b/tests/unit/engine/adaptive_bs/test_bs_search_algo.py @@ -107,10 +107,10 @@ def test_find_max_usable_bs_gpu_memory_too_small(self): def test_auto_decrease_batch_size_bs2_not_oom_but_most_mem(self): """Batch size 2 doesn't make oom but use most of memory.""" - mock_train_func = self.get_mock_train_func(cuda_oom_bound=4, max_runnable_bs=1) + mock_train_func = self.get_mock_train_func(cuda_oom_bound=2, max_runnable_bs=1) bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000) - bs_search_algo.auto_decrease_batch_size() + assert bs_search_algo.auto_decrease_batch_size() == 2 @pytest.mark.parametrize( ("max_runnable_bs", "max_bs", "expected_bs"), @@ -141,11 +141,10 @@ def test_find_big_enough_batch_size_gpu_memory_too_small(self): def test_find_big_enough_batch_size_bs2_not_oom_but_most_mem(self): """Batch size 2 doesn't make oom but use most of memory.""" - mock_train_func = self.get_mock_train_func(cuda_oom_bound=1, max_runnable_bs=1) + mock_train_func = self.get_mock_train_func(cuda_oom_bound=2, max_runnable_bs=1) bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000) - with pytest.raises(RuntimeError): - bs_search_algo.find_big_enough_batch_size() + assert bs_search_algo.find_big_enough_batch_size() == 2 def test_find_big_enough_batch_size_gradient_zero(self): def mock_train_func(batch_size) -> int: