Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Use fixtures #65

Merged
merged 23 commits into from
Jul 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion cascade/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ def __getitem__(self, item):
def __iter__(self):
for item in self._data:
yield item

def get_meta(self):
meta = super().get_meta()
meta[0]['obj_type'] = str(type(self._data))
return meta


class Wrapper(Dataset):
Expand All @@ -95,7 +100,7 @@ def __len__(self) -> int:
def get_meta(self):
meta = super().get_meta()
meta[0]['len'] = len(self)
meta[0]['obj_type'] = type(self._data)
meta[0]['obj_type'] = str(type(self._data))
return meta


Expand Down
6 changes: 6 additions & 0 deletions cascade/models/model_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,9 @@ def get_meta(self) -> List[Dict]:
'type': 'repo'
})
return meta

def __del__(self):
# Release all files on desctruction
for handler in self.logger.handlers:
handler.close()
self.logger.removeHandler(handler)
60 changes: 58 additions & 2 deletions cascade/tests/dummy_model.py → cascade/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,19 @@
limitations under the License.
"""


import os
import sys
import random
import shutil
import numpy as np
import pytest

MODULE_PATH = os.path.dirname(os.path.abspath(os.path.dirname(__file__)))
sys.path.append(os.path.dirname(MODULE_PATH))

sys.path.append(os.path.abspath('../..'))
from cascade.models import Model
from cascade.data import Wrapper, Iterator
from cascade.models import Model, ModelRepo


class DummyModel(Model):
Expand Down Expand Up @@ -47,3 +54,52 @@ def save(self, path):
path += '.bin'
with open(path, 'wb') as f:
f.write(b'model')


class EmptyModel(DummyModel):
def __init__(self):
pass


@pytest.fixture(params=[
[1, 2, 3, 4, 5],
[0],
[0, 0, 0, 0],
[-i for i in range(0, 100)]
])
def number_dataset(request):
return Wrapper(request.param)


@pytest.fixture(params=[
[1, 2, 3, 4, 5],
[0],
[0, 0, 0, 0],
[-i for i in range(100, 0)]
])
def number_iterator(request):
return Iterator(request.param)


@pytest.fixture(params=[
{'a': 0},
{'b': 1},
{'a': 0, 'b': 'alala'},
{'c': np.array([1, 2]), 'd': {'a': 0}}])
def dummy_model(request):
return DummyModel(**request.param)


@pytest.fixture
def empty_model():
return EmptyModel()


@pytest.fixture
def model_repo(tmp_path):
repo = ModelRepo(str(tmp_path), lines=[
dict(
name=str(num),
cls=DummyModel) for num in range(10)
])
yield repo
38 changes: 0 additions & 38 deletions cascade/tests/number_dataset.py

This file was deleted.

30 changes: 0 additions & 30 deletions cascade/tests/number_iterator.py

This file was deleted.

8 changes: 4 additions & 4 deletions cascade/tests/test_apply_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@
SCRIPT_DIR = os.path.dirname(os.path.abspath(os.path.dirname(__file__)))
sys.path.append(os.path.dirname(SCRIPT_DIR))

from cascade.tests.number_dataset import NumberDataset
from cascade.data import Wrapper


@pytest.mark.parametrize(
'arr, func', [
([1, 2, 3, 4, 5], lambda x: x * 2),
([1, 2, 3, 4, 5], lambda x: x ** 2),
([1, 2, 3, 4, 5], lambda x: x)
([1], lambda x: x ** 2),
([1, 2, -3], lambda x: x)
]
)
def test_apply_modifier(arr, func):
ds = NumberDataset(arr)
ds = Wrapper(arr)
ds = ApplyModifier(ds, func)
assert(list(map(func, arr)) == [item for item in ds])
22 changes: 10 additions & 12 deletions cascade/tests/test_bruteforce_cacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,23 @@
MODULE_PATH = os.path.dirname(os.path.abspath(os.path.dirname(__file__)))
sys.path.append(os.path.dirname(MODULE_PATH))

from cascade.tests.number_dataset import NumberDataset
from cascade.tests.number_iterator import NumberIterator
from cascade.data import BruteforceCacher
from cascade.data import BruteforceCacher, Wrapper


def test_ds():
ds = NumberDataset([1, 2, 3, 4, 5])
ds = BruteforceCacher(ds)
assert([1, 2, 3, 4, 5] == [item for item in ds])
def test_ds(number_dataset):
ds = BruteforceCacher(number_dataset)
assert([number_dataset[i] for i in range(len(number_dataset))] \
== [item for item in ds])


def test_it():
ds = NumberIterator(6)
ds = BruteforceCacher(ds)
assert([0, 1, 2, 3, 4, 5] == [item for item in ds])
def test_it(number_iterator):
ds = BruteforceCacher(number_iterator)
assert([item for item in number_iterator] \
== [item for item in ds])


def test_meta():
ds = NumberDataset([1, 2, 3, 4, 5])
ds = Wrapper([1, 2, 3, 4, 5])
ds = BruteforceCacher(ds)
meta = ds.get_meta()

Expand Down
30 changes: 20 additions & 10 deletions cascade/tests/test_concatenator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,36 @@

import os
import sys
import pytest

MODULE_PATH = os.path.dirname(os.path.abspath(os.path.dirname(__file__)))
sys.path.append(os.path.dirname(MODULE_PATH))

from cascade.tests.number_dataset import NumberDataset
from cascade.data import Wrapper
from cascade.data import Concatenator


def test_meta():
n1 = NumberDataset([0, 1])
n2 = NumberDataset([2, 3, 4, 5])
n1 = Wrapper([0, 1])
n2 = Wrapper([2, 3, 4, 5])

c = Concatenator([n1, n2], meta_prefix={'num': 1})
assert(c.get_meta()[0]['num'] == 1)

def test_concatenation():
n1 = NumberDataset([0, 1])
n2 = NumberDataset([2, 3, 4, 5])
n3 = NumberDataset([6, 7, 8])
n4 = NumberDataset([1])

c = Concatenator([n1, n2, n4, n3, n4])
assert([c[i] for i in range(len(c))] == [0, 1, 2, 3, 4, 5, 1, 6, 7, 8, 1])
@pytest.mark.parametrize(
'arrs', [
([0],[0],[0]),
([1,2,3,4], [11]),
([1],),
([1,2,3,4], [])
]
)
def test_concatenation(arrs):
c = Concatenator([*arrs])

res = []
for arr in arrs:
res += arr

assert([c[i] for i in range(len(c))] == res)
4 changes: 2 additions & 2 deletions cascade/tests/test_cyclic_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
sys.path.append(os.path.dirname(MODULE_PATH))

from cascade.data import CyclicSampler
from cascade.tests.number_dataset import NumberDataset
from cascade.data import Wrapper


def test_cycle():
ds = NumberDataset([0, 1, 2, 3, 4])
ds = Wrapper([0, 1, 2, 3, 4])
ds = CyclicSampler(ds, 16)

assert([ds[i] for i in range(len(ds))] ==
Expand Down
13 changes: 6 additions & 7 deletions cascade/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
sys.path.append(os.path.dirname(MODULE_PATH))

from cascade.data import Dataset, Iterator, Wrapper, Modifier, Sampler
from cascade.tests.number_dataset import NumberDataset


def test_getitem():
Expand Down Expand Up @@ -53,11 +52,11 @@ def test_update_meta():
assert(meta[0]['b'] == 3)


def test_meta_from_file():
with open('test_meta_from_file.json', 'w') as f:
def test_meta_from_file(tmp_path):
with open(os.path.join(tmp_path, 'test_meta_from_file.json'), 'w') as f:
json.dump({'a': 1}, f)

ds = Dataset(meta_prefix='test_meta_from_file.json')
ds = Dataset(meta_prefix=os.path.join(tmp_path, 'test_meta_from_file.json'))
meta = ds.get_meta()

assert('a' in meta[0])
Expand All @@ -81,7 +80,7 @@ def test_wrapper():


def test_modifier():
ds = NumberDataset([1, 2, 3, 4])
ds = Wrapper([1, 2, 3, 4])
ds = Modifier(ds)

res = []
Expand All @@ -96,7 +95,7 @@ def test_modifier():


def test_modifier_meta():
ds = NumberDataset([1, 2, 3, 4])
ds = Wrapper([1, 2, 3, 4])
ds = Modifier(ds)

meta = ds.get_meta()
Expand All @@ -105,5 +104,5 @@ def test_modifier_meta():


def test_sampler():
ds = NumberDataset([1, 2, 3, 4])
ds = Wrapper([1, 2, 3, 4])
ds = Sampler(ds, 10)
9 changes: 4 additions & 5 deletions cascade/tests/test_folder_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@
from cascade.data import FolderDataset


def test():
folder = './folder_dataset'
os.mkdir(folder)
with open('./folder_dataset/0.txt', 'w') as w:
def test(tmp_path):
tmp_path = str(tmp_path)
with open(os.path.join(tmp_path, '0.txt'), 'w') as w:
w.write('hello')

ds = FolderDataset(folder)
ds = FolderDataset(tmp_path)
meta = ds.get_meta()[0]

assert(len(ds) == 1)
Expand Down
Loading