forked from libffcv/ffcv
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_partial_batches.py
80 lines (63 loc) · 2.29 KB
/
test_partial_batches.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from dataclasses import replace
import torch as ch
from ffcv.pipeline.allocation_query import AllocationQuery
from ffcv.pipeline.compiler import Compiler
import numpy as np
from typing import Callable
from assertpy import assert_that
from torch.utils.data import Dataset
import logging
import os
from assertpy import assert_that
from tempfile import NamedTemporaryFile
from ffcv.pipeline.operation import Operation
from ffcv.transforms.ops import ToTensor
from multiprocessing import cpu_count
from ffcv.writer import DatasetWriter
from ffcv.reader import Reader
from ffcv.loader import Loader
from ffcv.fields import IntField, FloatField, BytesField
from ffcv.fields.basics import FloatDecoder
from ffcv.pipeline.state import State
from test_writer import DummyDataset
numba_logger = logging.getLogger('numba')
numba_logger.setLevel(logging.WARNING)
class Doubler(Operation):
def generate_code(self) -> Callable:
def code(x, dst):
dst[:x.shape[0]] = x * 2
return dst
return code
def declare_state_and_memory(self, previous_state: State):
return (previous_state, AllocationQuery(previous_state.shape, previous_state.dtype, previous_state.device))
def run_test(bs, exp_length, drop_last=True):
length = 600
batch_size = bs
with NamedTemporaryFile() as handle:
file_name = handle.name
dataset = DummyDataset(length)
writer = DatasetWriter(file_name, {
'index': IntField(),
'value': FloatField()
})
writer.from_indexed_dataset(dataset)
Compiler.set_enabled(True)
loader = Loader(file_name, batch_size, num_workers=min(5, cpu_count()), seed=17,
drop_last=drop_last,
pipelines={
'value': [FloatDecoder(), Doubler(), ToTensor()]
})
assert_that(loader).is_length(exp_length)
another_partial = drop_last
for (batch, _) in loader:
if batch.shape[0] != bs:
assert_that(another_partial).is_false()
another_partial = True
def test_partial():
run_test(7, 85, True)
def test_not_partial():
run_test(7, 86, False)
def test_not_partial_multiple():
run_test(60, 10, False)
def test_partial_multiple():
run_test(60, 10, True)