Skip to content

Commit

Permalink
update unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
eunwoosh committed Oct 10, 2024
1 parent 360bd23 commit 6b9153a
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions tests/unit/engine/adaptive_bs/test_bs_search_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@ def test_find_max_usable_bs_gpu_memory_too_small(self):

def test_auto_decrease_batch_size_bs2_not_oom_but_most_mem(self):
"""Batch size 2 doesn't make oom but use most of memory."""
mock_train_func = self.get_mock_train_func(cuda_oom_bound=4, max_runnable_bs=1)
mock_train_func = self.get_mock_train_func(cuda_oom_bound=2, max_runnable_bs=1)

bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000)
bs_search_algo.auto_decrease_batch_size()
assert bs_search_algo.auto_decrease_batch_size() == 2

@pytest.mark.parametrize(
("max_runnable_bs", "max_bs", "expected_bs"),
Expand Down Expand Up @@ -141,11 +141,10 @@ def test_find_big_enough_batch_size_gpu_memory_too_small(self):

def test_find_big_enough_batch_size_bs2_not_oom_but_most_mem(self):
"""Batch size 2 doesn't make oom but use most of memory."""
mock_train_func = self.get_mock_train_func(cuda_oom_bound=1, max_runnable_bs=1)
mock_train_func = self.get_mock_train_func(cuda_oom_bound=2, max_runnable_bs=1)

bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000)
with pytest.raises(RuntimeError):
bs_search_algo.find_big_enough_batch_size()
assert bs_search_algo.find_big_enough_batch_size() == 2

def test_find_big_enough_batch_size_gradient_zero(self):
def mock_train_func(batch_size) -> int:
Expand Down

0 comments on commit 6b9153a

Please # to comment.