Skip to content

Commit

Permalink
Subset refactor part 1: new get_data API (#1984)
Browse files Browse the repository at this point in the history
* First draft of refactor planning

* Add new get_data method to ConfigHelper

* Cover case where subsets_to_apply is None

* Update comments and raise valueerror if valid data_label not provided

* Address review comments

* Remove unused imports

* Fix f-string issue

* Move if statement outside loop

* Move cls code

* Change cls to be None by default

* Rearrange cls and subset_to_apply blocks

* Fix docstrings

* Add tests in imviz and specviz2d

* Cover specviz2d specific line with test

* Add more tests

* Remove concept notebook

* Remove changes to notebooks

* Add check for statistic values

* Remove percentile from statistic options

* Remove percentile from docstring too

* Fix bug with imviz data and Subset 1 not returning an NDData object

* Fix code after last commit and add tests

* Return spectrum1d obj in specviz2d with data of shape 2

* Only allow statistic to be set in cubeviz

* Update changes file
  • Loading branch information
javerbukh authored Feb 14, 2023
1 parent 9dcca5f commit 2d767af
Show file tree
Hide file tree
Showing 7 changed files with 220 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ Specviz2d
API Changes
-----------

- Add ``get_data()`` method to base helper class to centralize data retrieval. [#1984]

Cubeviz
^^^^^^^

Expand Down
30 changes: 30 additions & 0 deletions jdaviz/configs/cubeviz/plugins/tests/test_cubeviz_helper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
import pytest

from astropy import units as u
from astropy.tests.helper import assert_quantity_allclose
from specutils import Spectrum1D


def test_nested_helper(cubeviz_helper):
'''Ensures the Cubeviz helper is always returned, even after the Specviz helper is called'''
# Force Specviz helper to instantiate
Expand All @@ -15,3 +22,26 @@ def test_plugin_user_apis(cubeviz_helper):
plugin = plugin_api._obj
for attr in plugin_api._expose:
assert hasattr(plugin, attr)


def test_invalid_statistic(cubeviz_helper, spectrum1d_cube):
cubeviz_helper.load_data(spectrum1d_cube, "test")
cubeviz_helper._apply_interactive_region('bqplot:ellipse', (0, 0), (9, 8))

with pytest.raises(ValueError, match='statistic 42 not in list of valid '):
cubeviz_helper.get_data(data_label="test[FLUX]", subset_to_apply='Subset 1', statistic=42)


def test_valid_statistic(cubeviz_helper, spectrum1d_cube):
cubeviz_helper.load_data(spectrum1d_cube, "test")
cubeviz_helper._apply_interactive_region('bqplot:ellipse', (0, 0), (9, 8))

results_min = cubeviz_helper.get_data(data_label="test[FLUX]",
subset_to_apply='Subset 1', statistic="minimum")
results_max = cubeviz_helper.get_data(data_label="test[FLUX]",
subset_to_apply='Subset 1', statistic="maximum")
assert isinstance(results_min, Spectrum1D)
assert_quantity_allclose(results_min.flux,
[6., 14.] * u.Jy, atol=1e-5 * u.Jy)
assert_quantity_allclose(results_max.flux,
[7., 15.] * u.Jy, atol=1e-5 * u.Jy)
3 changes: 3 additions & 0 deletions jdaviz/configs/imviz/tests/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ def test_create_new_viewer(imviz_helper, image_2d_wcs):
imviz_helper.load_data(arr, data_label=data_label, show_in_viewer=False)
imviz_helper.create_image_viewer(viewer_name=viewer_name)

returned_data = imviz_helper.get_data(data_label)
assert len(returned_data.shape) == 2

# new image viewer created
assert len(imviz_helper.app.get_viewer_ids()) == 2

Expand Down
6 changes: 6 additions & 0 deletions jdaviz/configs/imviz/tests/test_regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from astropy import units as u
from astropy.coordinates import SkyCoord, Angle
from astropy.utils.data import get_pkg_data_filename
from astropy.nddata import NDData
from photutils.aperture import CircularAperture, SkyCircularAperture
from regions import (PixCoord, CircleSkyRegion, RectanglePixelRegion, CirclePixelRegion,
EllipsePixelRegion, PointPixelRegion, PointSkyRegion, PolygonPixelRegion,
Expand Down Expand Up @@ -259,6 +260,11 @@ def test_photutils_sky_has_wcs(self):
self.verify_region_loaded('my_aper_sky_1')
assert self.imviz.get_interactive_regions() == {}

def test_get_data_with_region(self):
self.imviz._apply_interactive_region('bqplot:rectangle', (0, 0), (10, 10))
results = self.imviz.get_data('has_wcs[SCI,1]', subset_to_apply='Subset 1')
assert isinstance(results, NDData)


class TestLoadRegionsFromFile(BaseRegionHandler):

Expand Down
11 changes: 9 additions & 2 deletions jdaviz/configs/specviz2d/tests/test_helper.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from specutils import Spectrum1D
from jdaviz import Specviz


def test_helper(specviz2d_helper, spectrum1d):
specviz2d_helper.load_data(spectrum_1d=spectrum1d)
def test_helper(specviz2d_helper, mos_spectrum2d):
specviz2d_helper.load_data(spectrum_2d=mos_spectrum2d)
assert isinstance(specviz2d_helper.specviz, Specviz)

specviz2d_helper.app.data_collection[0].meta['Trace'] = "Test"

returned_data = specviz2d_helper.get_data("Spectrum 2D")
assert len(returned_data.shape) == 1
assert isinstance(returned_data, Spectrum1D)


def test_plugin_user_apis(specviz2d_helper):
for plugin_name, plugin_api in specviz2d_helper.plugins.items():
Expand Down
97 changes: 97 additions & 0 deletions jdaviz/core/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,20 @@
import re
import warnings
from contextlib import contextmanager
from inspect import isclass

import numpy as np
import astropy.units as u
from astropy.wcs.wcsapi import BaseHighLevelWCS
from astropy.nddata import CCDData
from glue.core import HubListener
from glue.core.edit_subset_mode import NewMode
from glue.core.message import SubsetCreateMessage, SubsetDeleteMessage
from glue.core.subset import Subset, MaskSubsetState
from glue.config import data_translator
from ipywidgets.widgets import widget_serialization
from specutils import Spectrum1D


from jdaviz.app import Application
from jdaviz.core.events import SnackbarMessage, ExitBatchLoadMessage
Expand Down Expand Up @@ -402,6 +407,98 @@ def show_in_new_tab(self, title=None): # pragma: no cover
DeprecationWarning)
return self.show(loc="sidecar:tab-after", title=title)

def get_data(self, data_label=None, cls=None, subset_to_apply=None, statistic=None):
"""
Returns data with name equal to data_label of type cls with subsets applied from
subset_to_apply.
Parameters
----------
data_label : str, optional
Provide a label to retrieve a specific data set from data_collection.
cls : `~specutils.Spectrum1D`, `~astropy.nddata.CCDData`, optional
The type that data will be returned as.
subset_to_apply : str, optional
Subset that is to be applied to data before it is returned.
statistic : {'minimum', 'maximum', 'mean', 'median', 'sum'}, optional
The statistic to use to collapse the dataset.
Returns
-------
data : cls
Data is returned as type cls with subsets applied.
"""
if self.app.config != "cubeviz" and statistic:
raise AttributeError(f"{self.app.config} does not need the statistic parameter set.")

list_of_valid_statistic_values = ['minimum', 'maximum', 'mean',
'median', 'sum']
if statistic and statistic not in list_of_valid_statistic_values:
raise ValueError(f"statistic {statistic} not in list of valid"
f" statistic values {list_of_valid_statistic_values}")

list_of_valid_subset_names = [x.label for x in self.app.data_collection.subset_groups]
if subset_to_apply and subset_to_apply not in list_of_valid_subset_names:
raise ValueError(f"Subset {statistic} not in list of valid"
f" subset names {list_of_valid_subset_names}")

if data_label and data_label not in self.app.data_collection.labels:
raise ValueError(f'{data_label} not in {self.app.data_collection.labels}.')
elif not data_label and len(self.app.data_collection) > 1:
raise ValueError('data_label must be set if more than'
' one data exists in data_collection.')
elif not data_label and len(self.app.data_collection) == 1:
data_label = self.app.data_collection[0].label

if cls is not None and not isclass(cls):
raise TypeError(
"cls in get_data must be a class or None.")
data = self.app.data_collection[data_label]

if not cls:
if len(data.shape) == 2 and self.app.config == "specviz2d":
cls = Spectrum1D
elif len(data.shape) == 2:
cls = CCDData
elif len(data.shape) in [1, 3]:
cls = Spectrum1D
if not subset_to_apply:
if 'Trace' in data.meta:
data = data.get_object()
elif cls == Spectrum1D:
data = data.get_object(cls=cls, statistic=statistic)
else:
data = data.get_object(cls=cls)

return data

if not cls and subset_to_apply:
raise AttributeError(f"A valid cls must be provided to"
f" apply subset {subset_to_apply} to data. "
f"Instead, {cls} was given.")

# Loop through each subset
for subsets in self.app.data_collection.subset_groups:
# If name matches the name in subsets_to_apply, continue
if subsets.label.lower() == subset_to_apply.lower():
# Loop through each data a subset applies to
for subset in subsets.subsets:
# If the subset applies to data with the same name as data_label, continue
if subset.data.label == data_label:

handler, _ = data_translator.get_handler_for(cls)
try:
if cls == Spectrum1D:
data = handler.to_object(subset, statistic=statistic)
else:
data = handler.to_object(subset)
except Exception as e:
warnings.warn(f"Not able to get {data_label} returned with"
f" subset {subsets.label} applied of type {cls}."
f" Exception: {e}")
return data


class ImageConfigHelper(ConfigHelper):
"""`ConfigHelper` that uses an image viewer as its primary viewer.
Expand Down
73 changes: 73 additions & 0 deletions jdaviz/core/tests/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import pytest

from astropy import units as u
from astropy.tests.helper import assert_quantity_allclose
from glue.core.edit_subset_mode import NewMode
from glue.core.roi import XRangeROI

from jdaviz.core.helpers import _next_subset_num


Expand All @@ -21,3 +26,71 @@ def __init__(self, label):
def test_next_subset_num(label, prefix, answer):
mocked_group = [MockGroupItem(label)]
assert _next_subset_num(prefix, mocked_group) == answer


class TestConfigHelper:
@pytest.fixture(autouse=True)
def setup_class(self, specviz_helper, spectrum1d, multi_order_spectrum_list):
self.spec_app = specviz_helper
self.spec = spectrum1d
self.label = "Test 1D Spectrum"

self.spec2 = spectrum1d._copy(spectral_axis=spectrum1d.spectral_axis+1000*u.AA)
self.label2 = "Test 1D Spectrum 2"
self.spec_app.load_spectrum(spectrum1d, data_label=self.label)
self.spec_app.load_spectrum(self.spec2, data_label=self.label2)

# Add 3 subsets to cover different parts of spec and spec2
self.spec_app.app.get_viewer("spectrum-viewer").apply_roi(XRangeROI(6000, 6500))
self.spec_app.app.session.edit_subset_mode.mode = NewMode

self.spec_app.app.get_viewer("spectrum-viewer").apply_roi(XRangeROI(6700, 7200))
self.spec_app.app.get_viewer("spectrum-viewer").apply_roi(XRangeROI(8200, 8800))

@pytest.mark.parametrize(
('label', 'subset_name', 'answer'),
[('Test 1D Spectrum', 'Subset 1',
[False, False, False, True, True, True, True, True, True, True]),
('Test 1D Spectrum', 'Subset 2',
[True, True, True, True, False, False, True, True, True, True]),
('Test 1D Spectrum', 'Subset 3',
[True, True, True, True, True, True, True, True, True, True]),
('Test 1D Spectrum 2', 'Subset 1',
[True, True, True, True, True, True, True, True, True, True]),
('Test 1D Spectrum 2', 'Subset 2',
[False, True, True, True, True, True, True, True, True, True]),
('Test 1D Spectrum 2', 'Subset 3',
[True, True, True, True, True, True, False, False, False, True])])
def test_get_data_with_one_subset_per_data(self, specviz_helper, label, subset_name, answer):

results = specviz_helper.get_data(data_label=label,
subset_to_apply=subset_name,
statistic=None)
assert list(results.mask) == answer

def test_get_data_no_label_multiple_in_dc(self, specviz_helper):
with pytest.raises(ValueError, match='data_label must be set if more'):
specviz_helper.get_data()

def test_get_data_label_not_in_dc(self, specviz_helper):
with pytest.raises(ValueError, match='Blah not in '):
specviz_helper.get_data(data_label="Blah")

def test_get_data_no_label_one_in_dc(self, specviz_helper):
specviz_helper.app.data_collection.remove(specviz_helper.app.data_collection[self.label2])
results = specviz_helper.get_data()
assert_quantity_allclose(results.flux,
self.spec.flux, atol=1e-5 * u.Unit(self.spec.flux.unit))

def test_get_data_invald_cls_class(self, specviz_helper):
specviz_helper.app.data_collection.remove(specviz_helper.app.data_collection[self.label2])
with pytest.raises(TypeError, match="cls in get_data must be a class or None."):
specviz_helper.get_data('Test 1D Spectrum', cls=42)

def test_get_data_invald_subset_name(self, specviz_helper):
with pytest.raises(ValueError, match="not in list of valid subset names"):
specviz_helper.get_data('Test 1D Spectrum', subset_to_apply="Fail")

def test_get_data_not_needed_statistic(self, specviz_helper):
with pytest.raises(AttributeError, match="does not need the statistic parameter set"):
specviz_helper.get_data('Test 1D Spectrum', statistic="mean")

0 comments on commit 2d767af

Please # to comment.