Skip to content

Commit

Permalink
Add tests with pandas and numpy and comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
andLaing committed Nov 26, 2020
1 parent 81cf00b commit 35a1c31
Showing 1 changed file with 51 additions and 0 deletions.
51 changes: 51 additions & 0 deletions invisible_cities/detsim/buffer_functions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,22 @@ def test_find_signal_start(binned_waveforms, signal_thresh):
assert np.all(pmt_sum[pulses] > signal_thresh)


@mark.parametrize("signal_thresh", (2, 10))
def test_find_signal_start_numpy(binned_waveforms, signal_thresh):

pmt_bins, pmt_wfs, *_ = binned_waveforms

buffer_length = 800 * units.mus
bin_width = np.diff(pmt_bins)[0]
stand_off = int(buffer_length / bin_width)

pmt_sum = pmt_wfs.sum()
pmt_wfs_np = np.asarray(pmt_wfs.tolist())
pulses = find_signal_start(pmt_wfs_np, signal_thresh, stand_off)

assert np.all(pmt_sum[pulses] > signal_thresh)


def test_find_signal_start_correct_index():

thresh = 5
Expand Down Expand Up @@ -109,3 +125,38 @@ def test_buffer_calculator(mc_waveforms, binned_waveforms,
assert evt_pmt .shape[1] == int(buffer_length / pmt_binwid)
assert evt_sipm.shape[1] == int(buffer_length / sipm_binwid)
assert np.sum(evt_pmt, axis=0)[pre_trg_samp] == pmt_sum[pulses[i]]


def test_buffer_calculator_pandas_numpy(mc_waveforms, binned_waveforms):

_, pmt_binwid, sipm_binwid, _ = mc_waveforms

pmt_bins, pmt_wfs, sipm_bins, sipm_wfs = binned_waveforms

buffer_length = 800 * units.mus
bin_width = np.diff(pmt_bins)[0]
stand_off = int(buffer_length / bin_width)
pre_trigger = 100 * units.mus
signal_thresh = 2

pulses_pd = find_signal_start(pmt_wfs, signal_thresh, stand_off)
pmt_nparr = np.asarray(pmt_wfs.tolist())
pulses_np = find_signal_start(pmt_nparr, signal_thresh, stand_off)

calculate_buffers = buffer_calculator(buffer_length,
pre_trigger ,
pmt_binwid ,
sipm_binwid )

buffers_pd = calculate_buffers(pulses_pd, *binned_waveforms)
buffers_np = calculate_buffers(pulses_np ,
pmt_bins ,
pmt_nparr ,
sipm_bins ,
np.asarray(sipm_wfs.tolist()))

assert len(buffers_pd) == len(buffers_np) == 1
evtpd_buffers = buffers_pd[0]
evtnp_buffers = buffers_np[0]
assert np.all([np.all(evtpd_buffers[0] == evtnp_buffers[0]),
np.all(evtpd_buffers[1] == evtnp_buffers[1])])

0 comments on commit 35a1c31

Please # to comment.