diff --git a/jdaviz/app.py b/jdaviz/app.py index 7cabcd1200..4e2b89fca0 100644 --- a/jdaviz/app.py +++ b/jdaviz/app.py @@ -2065,13 +2065,12 @@ def set_data_visibility(self, viewer_reference, data_label, visible=True, replac # if Data has children, update their visibilities to match Data: assoc_children = self._get_assoc_data_children(data_label) for layer in viewer.layers: - for data_label in assoc_children: - if layer.layer.data.label == data_label: - if visible and not layer.visible: - layer.visible = True - layer.update() - else: - layer.visible = visible + if layer.layer.data.label in assoc_children: + if visible and not layer.visible: + layer.visible = True + layer.update() + else: + layer.visible = visible # update data menu - selected_data_items should be READ ONLY, not modified by the user/UI selected_items = viewer_item['selected_data_items'] diff --git a/jdaviz/components/tooltip.vue b/jdaviz/components/tooltip.vue index 35baf3818e..00c5a73730 100644 --- a/jdaviz/components/tooltip.vue +++ b/jdaviz/components/tooltip.vue @@ -105,6 +105,8 @@ const tooltips = { 'plugin-collapse-save-fits': 'Save collapsed cube as FITS file', 'plugin-link-apply': 'Apply linking to data', 'plugin-footprints-color-picker': 'Change the color of the footprint overlay', + 'plugin-dq-show-all': 'Show all quality flags', + 'plugin-dq-hide-all': 'Hide all quality flags', } diff --git a/jdaviz/configs/default/default.yaml b/jdaviz/configs/default/default.yaml index 4dda8cbb79..ebf9e8449f 100644 --- a/jdaviz/configs/default/default.yaml +++ b/jdaviz/configs/default/default.yaml @@ -15,4 +15,4 @@ toolbar: tray: - g-subset-plugin - g-gaussian-smooth - - export + - export \ No newline at end of file diff --git a/jdaviz/configs/default/plugins/__init__.py b/jdaviz/configs/default/plugins/__init__.py index e516c2dd82..a79315363a 100644 --- a/jdaviz/configs/default/plugins/__init__.py +++ b/jdaviz/configs/default/plugins/__init__.py @@ -11,3 +11,4 @@ from .export.export import * # noqa from .plot_options.plot_options import * # noqa from .markers.markers import * # noqa +from .data_quality.data_quality import * # noqa diff --git a/jdaviz/configs/default/plugins/data_quality/__init__.py b/jdaviz/configs/default/plugins/data_quality/__init__.py index 8e834ef147..9fd5bcb8ff 100644 --- a/jdaviz/configs/default/plugins/data_quality/__init__.py +++ b/jdaviz/configs/default/plugins/data_quality/__init__.py @@ -1 +1 @@ -from .dq_utils import * # noqa +from .data_quality import * # noqa diff --git a/jdaviz/configs/default/plugins/data_quality/data_quality.py b/jdaviz/configs/default/plugins/data_quality/data_quality.py new file mode 100644 index 0000000000..25180c89a3 --- /dev/null +++ b/jdaviz/configs/default/plugins/data_quality/data_quality.py @@ -0,0 +1,277 @@ +import os +from traitlets import Any, Dict, Bool, List, Unicode, Float, observe + +import numpy as np +from glue_jupyter.common.toolbar_vuetify import read_icon +from echo import delay_callback +from matplotlib.colors import hex2color + +from jdaviz.core.registries import tray_registry +from jdaviz.core.template_mixin import ( + PluginTemplateMixin, LayerSelect, ViewerSelectMixin +) +from jdaviz.core.user_api import PluginUserApi +from jdaviz.core.tools import ICON_DIR +from jdaviz.configs.default.plugins.data_quality.dq_utils import ( + decode_flags, generate_listed_colormap, dq_flag_map_paths, load_flag_map +) + +__all__ = ['DataQuality'] + +telescope_names = { + "jwst": "JWST", + "roman": "Roman" +} + + +@tray_registry('g-data-quality', label="Data Quality", viewer_requirements="image") +class DataQuality(PluginTemplateMixin, ViewerSelectMixin): + template_file = __file__, "data_quality.vue" + + irrelevant_msg = Unicode("Data Quality plugin is in development.").tag(sync=True) + + # `layer` is the science data layer + science_layer_multiselect = Bool(False).tag(sync=True) + science_layer_items = List().tag(sync=True) + science_layer_selected = Any().tag(sync=True) # Any needed for multiselect + + # `dq_layer` is the data quality layer corresponding to the + # science data in `layer` + dq_layer_multiselect = Bool(False).tag(sync=True) + dq_layer_items = List().tag(sync=True) + dq_layer_selected = Any().tag(sync=True) # Any needed for multiselect + dq_layer_opacity = Float(0.9).tag(sync=True) # Any needed for multiselect + + flag_map_definitions = Dict().tag(sync=True) + flag_map_selected = Any().tag(sync=True) + flag_map_definitions_selected = Dict().tag(sync=True) + flag_map_items = List().tag(sync=True) + decoded_flags = List().tag(sync=True) + flags_filter = List().tag(sync=True) + + icons = Dict().tag(sync=True) + icon_radialtocheck = Unicode(read_icon(os.path.join(ICON_DIR, 'radialtocheck.svg'), 'svg+xml')).tag(sync=True) # noqa + icon_checktoradial = Unicode(read_icon(os.path.join(ICON_DIR, 'checktoradial.svg'), 'svg+xml')).tag(sync=True) # noqa + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.icons = {k: v for k, v in self.app.state.icons.items()} + + self.science_layer = LayerSelect( + self, 'science_layer_items', 'science_layer_selected', + 'viewer_selected', 'science_layer_multiselect', is_root=True + ) + + self.dq_layer = LayerSelect( + self, 'dq_layer_items', 'dq_layer_selected', + 'viewer_selected', 'dq_layer_multiselect', is_root=False, + is_child_of=self.science_layer.selected + ) + + self.load_default_flag_maps() + self.init_decoding() + + @observe('science_layer_selected') + def update_dq_layer(self, *args): + if not hasattr(self, 'dq_layer'): + return + + self.dq_layer.filter_is_child_of = self.science_layer_selected + self.dq_layer._update_layer_items() + + # listen for changes on the image opacity, and update the + # data quality layer opacity on changes to the science layer opacity + plot_options = self.app.get_tray_item_from_name('g-plot-options') + plot_options.observe(self.update_opacity, 'image_opacity_value') + + def load_default_flag_maps(self): + for name in dq_flag_map_paths: + self.flag_map_definitions[name] = load_flag_map(name) + self.flag_map_items = self.flag_map_items + [telescope_names[name]] + + @property + def unique_flags(self): + selected_dq = self.dq_layer.selected_obj + if not len(selected_dq): + return [] + + dq = selected_dq[0].get_image_data() + return np.unique(dq[~np.isnan(dq)]) + + @property + def validate_flag_decode_possible(self): + return ( + self.flag_map_selected is not None and + len(self.dq_layer.selected_obj) > 0 + ) + + @observe('flag_map_selected') + def update_flag_map_definitions_selected(self, event): + selected = self.flag_map_definitions[self.flag_map_selected.lower()] + self.flag_map_definitions_selected = selected + + @observe('dq_layer_selected') + def init_decoding(self, event={}): + if not self.validate_flag_decode_possible: + return + + unique_flags = self.unique_flags + cmap, rgba_colors = generate_listed_colormap(n_flags=len(unique_flags)) + self.decoded_flags = decode_flags( + flag_map=self.flag_map_definitions_selected, + unique_flags=unique_flags, + rgba_colors=rgba_colors + ) + dq_layer = self.get_dq_layer() + dq_layer.composite._allow_bad_alpha = True + + flag_bits = np.array([flag['flag'] for flag in self.decoded_flags]) + + dq_layer.state.stretch = 'lookup' + stretch_object = dq_layer.state.stretch_object + stretch_object.flags = flag_bits + + with delay_callback(dq_layer.state, 'alpha', 'cmap', 'v_min', 'v_max'): + if len(flag_bits): + dq_layer.state.v_min = min(flag_bits) + dq_layer.state.v_max = max(flag_bits) + + dq_layer.state.alpha = self.dq_layer_opacity + dq_layer.state.cmap = cmap + + def get_dq_layer(self): + if self.dq_layer_selected == '': + return + + viewer = self.viewer.selected_obj + [dq_layer] = [ + layer for layer in viewer.layers if + layer.layer.label == self.dq_layer_selected + ] + return dq_layer + + def get_science_layer(self): + viewer = self.viewer.selected_obj + [science_layer] = [ + layer for layer in viewer.layers if + layer.layer.label == self.science_layer_selected + ] + return science_layer + + @observe('dq_layer_opacity') + def update_opacity(self, event={}): + science_layer = self.get_science_layer() + dq_layer = self.get_dq_layer() + + if dq_layer is not None: + # DQ opacity is a fraction of the science layer's opacity: + dq_layer.state.alpha = self.dq_layer_opacity * science_layer.state.alpha + + @observe('decoded_flags', 'flags_filter') + def _update_cmap(self, event={}): + dq_layer = self.get_dq_layer() + flag_bits = np.array([flag['flag'] for flag in self.decoded_flags]) + rgb_colors = [hex2color(flag['color']) for flag in self.decoded_flags] + + hidden_flags = np.array([ + flag['flag'] for flag in self.decoded_flags + + # hide the flag if the visibility toggle is False: + if not flag['show'] or + + # hide the flag if `flags_filter` has entries but not this one: + ( + len(self.flags_filter) and + not np.isin( + list(map(int, flag['decomposed'].keys())), + list(self.flags_filter) + ).any() + ) + ]) + + with delay_callback(dq_layer.state, 'v_min', 'v_max', 'alpha', 'stretch', 'cmap'): + # set correct stretch and limits: + # dq_layer.state.stretch = 'lookup' + stretch_object = dq_layer.state.stretch_object + stretch_object.flags = flag_bits + stretch_object.dq_array = dq_layer.get_image_data() + stretch_object.hidden_flags = hidden_flags + + # update the colors of the listed colormap without + # reassigning the layer.state.cmap object + cmap = dq_layer.state.cmap + cmap.colors = rgb_colors + cmap._init() + + # trigger updates to cmap in viewer: + dq_layer.update() + + if len(flag_bits): + dq_layer.state.v_min = min(flag_bits) + dq_layer.state.v_max = max(flag_bits) + + dq_layer.state.alpha = self.dq_layer_opacity + + def update_visibility(self, index): + self.decoded_flags[index]['show'] = not self.decoded_flags[index]['show'] + self.vue_update_cmap() + + def vue_update_cmap(self): + self.send_state('decoded_flags') + self._update_cmap() + + def vue_update_visibility(self, index): + self.update_visibility(index) + + def update_color(self, index, color): + self.decoded_flags[index]['color'] = color + self.vue_update_cmap() + + def vue_update_color(self, args): + index, color = args + self.update_color(index, color) + + @observe('science_layer_selected') + def mission_or_instrument_from_meta(self, event): + if not hasattr(self, 'science_layer'): + return + + layer = self.science_layer.selected_obj + if not len(layer): + return + + # this is defined for JWST and ROMAN, should be upper case: + telescope = layer[0].layer.meta.get('telescope', None) + + if telescope is not None: + self.flag_map_selected = telescope_names[telescope.lower()] + + def vue_hide_all_flags(self, event): + for flag in self.decoded_flags: + flag['show'] = False + + self.vue_update_cmap() + + def vue_clear_flags_filter(self, event): + self.flags_filter = [] + self.vue_update_cmap() + + def vue_show_all_flags(self, event): + for flag in self.decoded_flags: + flag['show'] = True + + self.flags_filter = [] + self.vue_update_cmap() + + @property + def user_api(self): + return PluginUserApi( + self, + expose=( + 'science_layer', 'dq_layer', + 'decoded_flags', 'flags_filter', + 'viewer', 'dq_layer_opacity', + 'flag_map_definitions_selected', + ) + ) diff --git a/jdaviz/configs/default/plugins/data_quality/data_quality.vue b/jdaviz/configs/default/plugins/data_quality/data_quality.vue new file mode 100644 index 0000000000..4a6b710b40 --- /dev/null +++ b/jdaviz/configs/default/plugins/data_quality/data_quality.vue @@ -0,0 +1,224 @@ + + + + + + + + + + + + + + + + + Data quality relative opacity + + + + + Quality Flags + + + + + mdi-eye + Show All + + + + + + + mdi-eye-off + Hide All + + + + + + + + + + + + + + + + {{item + ': ' + flag_map_definitions_selected[item].name}} + + + + + + + + Clear Filter + + + + + + + Color + + + Flag (Decomposed) + + + + + + + + + + + + + + + + + + + + > + + + + + + {{item.flag}} ({{Object.keys(item.decomposed).join(', ')}}) + + + + + + + + {{item.show ? "mdi-eye" : "mdi-eye-off"}} + + + + + {{item.name}} ({{key}}): {{item.description}} + + + + + + + + + + + + + + + diff --git a/jdaviz/configs/default/plugins/data_quality/dq_utils.py b/jdaviz/configs/default/plugins/data_quality/dq_utils.py index e1ea1619a0..f9bacac5b2 100644 --- a/jdaviz/configs/default/plugins/data_quality/dq_utils.py +++ b/jdaviz/configs/default/plugins/data_quality/dq_utils.py @@ -1,13 +1,96 @@ from importlib import resources from pathlib import Path + +import numpy as np +from matplotlib.colors import ListedColormap, rgb2hex +from glue.config import stretches from astropy.table import Table +# paths to CSV files with DQ flag mappings: dq_flag_map_paths = { 'jwst': Path('data', 'data_quality', 'jwst.csv'), 'roman': Path('data', 'data_quality', 'roman.csv'), } +class LookupStretch: + """ + Stretch class specific to DQ arrays. + + Attributes + ---------- + flags : array-like + DQ flags. + """ + + def __init__(self, flags=None, hidden_flags=None): + # Default x, y values(0-1) range chosen for a typical initial spline shape. + # Can be modified if required. + if flags is None: + flags = np.linspace(0, 1, 5) + if hidden_flags is None: + hidden_flags = [] + + self.flags = np.asarray(flags) + self.hidden_flags = np.asarray(hidden_flags).astype(int) + + @property + def flag_range(self): + return np.max(self.flags) - np.min(self.flags) + + @property + def scaled_flags(self): + # renormalize the flags on range (0, 1): + return (self.flags - np.min(self.flags)) / self.flag_range + + def dq_array_to_flag_index(self, values): + # Find the index of the closest entry in `scaled_flags` + # for each of `values` using array broadcasting: + return np.argmin( + np.abs( + np.nan_to_num(values, nan=-10).flatten()[None, :] - + self.scaled_flags[:, None] + ), axis=0 + ).reshape(values.shape) + + def __call__(self, values, out=None, clip=False): + # For our uses, we can ignore `out` and `clip`, but those would need + # to be implemented before contributing this class upstream. + + # find closest index in `self.flags` for each value in `values`: + if hasattr(values, 'squeeze'): + values = values.squeeze() + + # `values` will have already been passed through + # astropy.visualization.ManualInterval and normalized on (0, 1) + # before they arrive here. First, remove that interval and get + # back the integer values: + values_integer = np.round(values * self.flag_range + np.min(self.flags)) + + # normalize by the number of flags, onto interval (0, 1): + renormed = self.dq_array_to_flag_index(values) / len(self.flags) + + if len(self.hidden_flags): + # mask that is True for `values` in the hidden flags list: + value_is_hidden = np.isin( + np.nan_to_num(values_integer, nan=-10).astype(int), + self.hidden_flags.astype(int) + ) + else: + value_is_hidden = False + + # preserve NaNs in values, and make hidden flags NaNs: + return np.where( + np.isnan(values) | value_is_hidden, + np.nan, + renormed + ) + + +if "lookup" not in stretches: + stretches.add("lookup", LookupStretch, display="DQ") + + def load_flag_map(mission_or_instrument=None, path=None): """ Load a flag map from disk. @@ -37,11 +120,15 @@ def load_flag_map(mission_or_instrument=None, path=None): flag_map_path = path with resources.as_file(resources.files('jdaviz').joinpath(flag_map_path)) as path: - flag_table = Table.read(path, format='ascii.csv') + # by default astropy's Table.read returns "masked" for empty table data, so + # here we supply a fill value (empty string) + fill_values = [''] + + flag_table = Table.read(path, format='ascii.csv', fill_values=fill_values) flag_mapping = {} for flag, name, desc in flag_table.iterrows(): - flag_mapping[flag] = dict(name=name, description=desc) + flag_mapping[int(flag)] = dict(name=name, description=desc) return flag_mapping @@ -71,3 +158,97 @@ def write_flag_map(flag_mapping, csv_path, **kwargs): table.add_row(row) table.write(csv_path, format='ascii.csv', **kwargs) + + +def generate_listed_colormap(n_flags=None, rgba_colors=None, seed=3): + """ + Generate a list of random "light" colors of length ``n_flags``. + + Parameters + ---------- + n_flags : int + Number of colors in the listed colormap, should match the + number of unique DQ flags (before they're decomposed). + rgba_colors : list of tuples + List of RGBA tuples for each color in the colormap. + seed : int + Seed for the random number generator used to + draw random colors. + + Returns + ------- + cmap : `~matplotlib.pyplot.colors.ListedColormap` + Colormap constructed with ``n_flags`` colors. + rgba_colors : list of tuples + Random light colors of length ``n_flags``. + """ + rng = np.random.default_rng(seed) + default_alpha = 1 + + if rgba_colors is None: + # Generate random colors that are generally "light", i.e. with + # RGB values in the upper half of the interval (0, 1): + rgba_colors = [ + tuple(np.insert(rng.uniform(size=2), rng.integers(0, 3), 1).tolist() + [default_alpha]) + for _ in range(n_flags) + ] + + cmap = ListedColormap(rgba_colors) + + # setting `bad` alpha=0 will make NaNs transparent: + cmap.set_bad(color='k', alpha=0) + return cmap, rgba_colors + + +def decompose_bit(bit): + """ + For an integer ``bit``, return a list of the powers of + two that sum up to ``bit``. + + Parameters + ---------- + bit : int + Sum of powers of two. + + Returns + ------- + powers : list of integers + Powers of two which sum to ``bit``. + """ + bit = int(bit) + powers = [] + i = 1 + while i <= bit: + if i & bit: + powers.append(int(np.log2(i))) + i <<= 1 + return sorted(powers) + + +def decode_flags(flag_map, unique_flags, rgba_colors): + """ + For a list of unique bits in ``unique_flags``, return a list of + dictionaries of the decomposed bits with their names, definitions, and + colors defined in ``rgba_colors``. + + Parameters + ---------- + flag_map : dict + Flag mapping, such as the ones produced by ``load_flag_map``. + unique_flags : list or array + Sequence of unique flags which occur in a data quality array. + rgba_colors : list of tuples + RGBA color tuples, one per unique flag. + """ + decoded_flags = [] + + for i, (bit, color) in enumerate(zip(unique_flags, rgba_colors)): + decoded_bits = decompose_bit(bit) + decoded_flags.append({ + 'flag': int(bit), + 'decomposed': {bit: flag_map[bit] for bit in decoded_bits}, + 'color': rgb2hex(color), + 'show': True, + }) + + return decoded_flags diff --git a/jdaviz/configs/default/plugins/data_quality/tests/test_data_quality.py b/jdaviz/configs/default/plugins/data_quality/tests/test_data_quality.py index c06573b73b..10971f63e1 100644 --- a/jdaviz/configs/default/plugins/data_quality/tests/test_data_quality.py +++ b/jdaviz/configs/default/plugins/data_quality/tests/test_data_quality.py @@ -1,6 +1,9 @@ +import warnings +from pathlib import Path import pytest -import numpy as np +import numpy as np +from astroquery.mast import Observations from stdatamodels.jwst.datamodels.dqflags import pixel as pixel_jwst from jdaviz.configs.imviz.plugins.parsers import HAS_ROMAN_DATAMODELS @@ -70,3 +73,100 @@ def test_roman_against_rdm(): for flag in flag_map_expected: assert flag_map_loaded[flag]['name'] == flag_map_expected[flag]['name'] + + +@pytest.mark.remote_data +def test_data_quality_plugin(imviz_helper, tmp_path): + # while the DQ plugin is in development, we are making it + # irrelevant by default. This fixture allows us to use DQ + # plugin in the tests until it's out of development. + dq_plugin = imviz_helper.app.get_tray_item_from_name('g-data-quality') + dq_plugin.irrelevant_msg = "" + + uri = "mast:JWST/product/jw01895001004_07101_00001_nrca3_cal.fits" + download_path = str(tmp_path / Path(uri).name) + Observations.download_file(uri, local_path=download_path) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + imviz_helper.load_data(download_path, ext=('SCI', 'DQ')) + + assert len(imviz_helper.app.data_collection) == 2 + + # this assumption is made in the DQ plugin (for now) + assert imviz_helper.app.data_collection[-1].label.endswith('[DQ]') + + dq_plugin = imviz_helper.plugins['Data Quality']._obj + + # sci+dq layers are correctly identified + expected_science_data, expected_dq_data = imviz_helper.app.data_collection + assert dq_plugin.science_layer_selected == expected_science_data.label + assert dq_plugin.dq_layer_selected == expected_dq_data.label + + # JWST data product identified as such: + assert dq_plugin.flag_map_selected == 'JWST' + + # check flag 0 is a bad pixel in the JWST flag map: + flag_map_selected = dq_plugin.flag_map_definitions_selected + assert flag_map_selected[0]['name'] == 'DO_NOT_USE' + assert flag_map_selected[0]['description'] == 'Bad pixel. Do not use.' + + # check default dq opacity is a fraction of sci data: + sci_alpha = imviz_helper.default_viewer._obj.layers[0].state.alpha + dq_alpha = imviz_helper.default_viewer._obj.layers[1].state.alpha + assert dq_alpha == sci_alpha * dq_plugin.dq_layer_opacity + + plot_opts = imviz_helper.plugins['Plot Options']._obj + + # only the sci data appears in Plot Options: + assert len(plot_opts.layer_items) == 1 + + # check changes to sci opacity affect dq opacity + new_sci_opacity = 0.5 + plot_opts.image_opacity_value = new_sci_opacity + dq_alpha = imviz_helper.default_viewer._obj.layers[1].state.alpha + assert dq_alpha == new_sci_opacity * dq_plugin.dq_layer_opacity + + # check that mouseover shows dq values on bad pixels (flag == 0): + # check that mouseover shows dq values on bad pixels (flag == 0): + viewer = imviz_helper.default_viewer._obj + label_mouseover = imviz_helper.app.session.application._tools['g-coords-info'] + label_mouseover._viewer_mouse_event(viewer, + {'event': 'mousemove', 'domain': {'x': 1366, 'y': 708}}) + + # DQ features are labeled in the first line: + label_mouseover_text = label_mouseover.as_text()[0] + + # bad pixels with flag == 0 have flux == NaN + expected_flux_label = '+nan MJy/sr' + assert expected_flux_label in label_mouseover_text + + # check that the decomposed DQ flag is at the end of the flux label's line: + flux_label_idx = label_mouseover_text.index(expected_flux_label) + assert label_mouseover_text[flux_label_idx + len(expected_flux_label) + 1:] == '(DQ: 1)' + + # check that a flagged pixel that is not marked with the bit 0 has a flux in mouseover label: + label_mouseover._viewer_mouse_event(viewer, + {'event': 'mousemove', 'domain': {'x': 1371, 'y': 715}}) + label_mouseover_text = label_mouseover.as_text()[0] + assert label_mouseover_text.split('+')[1] == '2.94744e-01 MJy/sr (DQ: 4)' + + # check that a pixel without a DQ flag has no DQ mouseover label: + label_mouseover._viewer_mouse_event(viewer, + {'event': 'mousemove', 'domain': {'x': 1361, 'y': 684}}) + label_mouseover_text = label_mouseover.as_text()[0] + assert 'DQ' not in label_mouseover_text + + # set a bit filter, then clear it: + assert len(dq_plugin.flags_filter) == 0 + dq_plugin.flags_filter = [0, 1] + dq_plugin.vue_clear_flags_filter({}) + assert len(dq_plugin.flags_filter) == 0 + + # hide all: + dq_plugin.vue_hide_all_flags({}) + assert not any([flag['show'] for flag in dq_plugin.decoded_flags]) + + # now show all: + dq_plugin.vue_show_all_flags({}) + assert all([flag['show'] for flag in dq_plugin.decoded_flags]) diff --git a/jdaviz/configs/default/plugins/viewers.py b/jdaviz/configs/default/plugins/viewers.py index 415d844c91..b1ee5a0535 100644 --- a/jdaviz/configs/default/plugins/viewers.py +++ b/jdaviz/configs/default/plugins/viewers.py @@ -11,7 +11,7 @@ from jdaviz.core.astrowidgets_api import AstrowidgetsImageViewerMixin from jdaviz.core.registries import viewer_registry from jdaviz.core.user_api import ViewerUserApi -from jdaviz.utils import ColorCycler, get_subset_type, _wcs_only_label +from jdaviz.utils import ColorCycler, get_subset_type, _wcs_only_label, layer_is_not_dq __all__ = ['JdavizViewerMixin'] @@ -314,8 +314,9 @@ def active_image_layer(self): """Active image layer in the viewer, if available.""" # Find visible layers visible_layers = [layer for layer in self.state.layers - if (layer.visible and layer_is_image_data(layer.layer))] - + if (layer.visible and + layer_is_image_data(layer.layer) and + layer_is_not_dq(layer.layer))] if len(visible_layers) == 0: return None diff --git a/jdaviz/configs/imviz/helper.py b/jdaviz/configs/imviz/helper.py index 6375818558..11e04c6161 100644 --- a/jdaviz/configs/imviz/helper.py +++ b/jdaviz/configs/imviz/helper.py @@ -416,9 +416,17 @@ def get_top_layer_index(viewer): This is because when blinked, first layer might not be top visible layer. """ + # exclude children of layer associations + associations = viewer.jdaviz_app._data_associations + visible_image_layers = [ i for i, lyr in enumerate(viewer.layers) - if lyr.visible and layer_is_image_data(lyr.layer) + if ( + lyr.visible and + layer_is_image_data(lyr.layer) and + # check that this layer is a root, without parents: + associations[lyr.layer.label]['parent'] is None + ) ] if len(visible_image_layers): diff --git a/jdaviz/configs/imviz/imviz.yaml b/jdaviz/configs/imviz/imviz.yaml index 4e7de049bb..2934cebabf 100644 --- a/jdaviz/configs/imviz/imviz.yaml +++ b/jdaviz/configs/imviz/imviz.yaml @@ -24,6 +24,7 @@ tray: - g-plot-options - g-subset-plugin - g-markers + - g-data-quality - imviz-compass - imviz-line-profile-xy - imviz-aper-phot-simple diff --git a/jdaviz/configs/imviz/plugins/coords_info/coords_info.py b/jdaviz/configs/imviz/plugins/coords_info/coords_info.py index d2bcdd45d3..1392d052b2 100644 --- a/jdaviz/configs/imviz/plugins/coords_info/coords_info.py +++ b/jdaviz/configs/imviz/plugins/coords_info/coords_info.py @@ -66,7 +66,8 @@ def __init__(self, *args, **kwargs): viewer_refs.append(viewer.reference_id) self.dataset._manual_options = ['auto', 'none'] - self.dataset.filters = ['layer_in_viewers', 'is_not_wcs_only'] + + self.dataset.filters = ['layer_in_viewers', 'is_not_wcs_only', 'layer_is_not_dq'] if self.app.config == 'imviz': # filter out scatter-plot entries (from add_markers API, for example) self.dataset.add_filter('is_image') @@ -274,6 +275,15 @@ def _image_viewer_update(self, viewer, x, y): else: image = None + # If there is one, get the associated DQ layer for the active layer: + associated_dq_layer = None + available_plugins = [tray_item['name'] for tray_item in self.app.state.tray_items] + if 'g-data-quality' in available_plugins: + assoc_children = self.app._get_assoc_data_children(active_layer.layer.label) + if assoc_children: + data_quality_plugin = self.app.get_tray_item_from_name('g-data-quality') + associated_dq_layer = data_quality_plugin.get_dq_layer() + unreliable_pixel, unreliable_world = False, False self._dict['axes_x'] = x @@ -429,16 +439,30 @@ def _image_viewer_update(self, viewer, x, y): if (-0.5 < x < image.shape[ix_shape] - 0.5 and -0.5 < y < image.shape[iy_shape] - 0.5 and hasattr(active_layer, 'attribute')): + attribute = active_layer.attribute + if isinstance(viewer, (ImvizImageView, MosvizImageView, MosvizProfile2DView)): value = image.get_data(attribute)[int(round(y)), int(round(x))] + if associated_dq_layer is not None: + dq_attribute = associated_dq_layer.state.attribute + dq_data = associated_dq_layer.layer.get_data(dq_attribute) + dq_value = dq_data[int(round(y)), int(round(x))] unit = image.get_component(attribute).units elif isinstance(viewer, CubevizImageView): arr = image.get_component(attribute).data unit = image.get_component(attribute).units value = self._get_cube_value(image, arr, x, y, viewer) self.row1b_title = 'Value' - self.row1b_text = f'{value:+10.5e} {unit}' + + if associated_dq_layer is not None: + if np.isnan(dq_value): + dq_text = '' + else: + dq_text = f' (DQ: {int(dq_value):d})' + else: + dq_text = '' + self.row1b_text = f'{value:+10.5e} {unit}{dq_text}' self._dict['value'] = float(value) self._dict['value:unit'] = unit self._dict['value:unreliable'] = unreliable_pixel diff --git a/jdaviz/configs/imviz/plugins/parsers.py b/jdaviz/configs/imviz/plugins/parsers.py index b6ea47938d..96e4eae94a 100644 --- a/jdaviz/configs/imviz/plugins/parsers.py +++ b/jdaviz/configs/imviz/plugins/parsers.py @@ -29,6 +29,15 @@ " to the file name to load all of them.") +def prep_data_layer_as_dq(data, component_id='DQ'): + # nans are used to mark "good" flags in the DQ colormap, so + # convert DQ array to float to support nans: + cid = data.get_component(component_id) + data_arr = np.float32(cid.data) + data_arr[data_arr == 0] = np.nan + data.update_components({cid: data_arr}) + + @data_parser_registry("imviz-data-parser") def parse_data(app, file_obj, ext=None, data_label=None, parent=None): """Parse a data file into Imviz. @@ -92,6 +101,19 @@ def parse_data(app, file_obj, ext=None, data_label=None, parent=None): else: # Assume FITS with fits.open(file_obj) as pf: + available_extensions = [hdu.name for hdu in pf] + + # if FITS file contains SCI and DQ extensions, assume the + # parent for the DQ is SCI: + if ( + 'SCI' in available_extensions and + ext == 'DQ' and parent is None + ): + loaded_data_labels = [data.label for data in app.data_collection] + latest_sci_extension = [label for label in loaded_data_labels + if label.endswith('[DATA]')][-1] + parent = latest_sci_extension + _parse_image(app, pf, data_label, ext=ext, parent=parent) else: _parse_image(app, file_obj, data_label, ext=ext, parent=parent) @@ -101,12 +123,11 @@ def get_image_data_iterator(app, file_obj, data_label, ext=None): """This function is for internal use, so other viz can also extract image data like Imviz does. """ - if isinstance(file_obj, fits.HDUList): if 'ASDF' in file_obj: # JWST ASDF-in-FITS - # Load all extensions - if ext == '*': - data_iter = _jwst_all_to_glue_data(file_obj, data_label) + # Load multiple extensions + if ext == '*' or isinstance(ext, (tuple, list)): + data_iter = _jwst_all_to_glue_data(file_obj, data_label, load_extensions=ext) # Load only specified extension else: @@ -116,8 +137,8 @@ def get_image_data_iterator(app, file_obj, data_label, ext=None): # issue info message. _info_nextensions(app, file_obj) - elif ext == '*': # Load all extensions - data_iter = _hdus_to_glue_data(file_obj, data_label) + elif ext == '*' or isinstance(ext, (tuple, list)): # Load multiple extensions + data_iter = _hdus_to_glue_data(file_obj, data_label, ext=ext) elif ext is not None: # Load just the EXT user wants hdu = file_obj[ext] @@ -175,7 +196,15 @@ def _parse_image(app, file_obj, data_label, ext=None, parent=None): data_label = app.return_data_label(file_obj, ext, alt_name="image_data") data_iter = get_image_data_iterator(app, file_obj, data_label, ext=ext) + # Save the SCI extension to this list: + sci_ext = None + for data, data_label in data_iter: + + # if the science extension hasn't been identified yet, do so here: + if sci_ext is None and data_label.endswith(('[DATA]', '[SCI]')): + sci_ext = data_label + if isinstance(data.coords, GWCS) and (data.coords.bounding_box is not None): # keep a copy of the original bounding box so we can detect # when extrapolating beyond, but then remove the bounding box @@ -188,14 +217,17 @@ def _parse_image(app, file_obj, data_label, ext=None, parent=None): data_label = app.return_data_label(data_label, alt_name="image_data") # TODO: generalize/centralize this for use in other configs too - if parent is not None and ext == 'DQ': - # nans are used to mark "good" flags in the DQ colormap, so - # convert DQ array to float to support nans: - cid = data.get_component("DQ") - data_arr = np.float32(cid.data) - data_arr[data_arr == 0] = np.nan - data.update_components({cid: data_arr}) - app.add_data(data, data_label, parent=parent) + if data_label.endswith('[DQ]'): + prep_data_layer_as_dq(data) + + if parent is not None: + parent_data_label = parent + elif data_label.endswith('[DQ]'): + parent_data_label = sci_ext + else: + parent_data_label = None + + app.add_data(data, data_label, parent=parent_data_label) # Do not link image data here. We do it at the end in Imviz.load_data() @@ -248,9 +280,12 @@ def _validate_bunit(bunit, raise_error=True): # ---- Functions that handle input from JWST FITS files ----- -def _jwst_all_to_glue_data(file_obj, data_label): +def _jwst_all_to_glue_data(file_obj, data_label, load_extensions='*'): for hdu in file_obj: - if _validate_fits_image2d(hdu, raise_error=False): + if ( + _validate_fits_image2d(hdu, raise_error=False) and + (load_extensions == '*' or hdu.name in load_extensions) + ): ext = hdu.name.lower() if ext == 'sci': @@ -329,9 +364,11 @@ def _jwst2data(file_obj, ext, data_label): def _roman_2d_to_glue_data(file_obj, data_label, ext=None): - if ext == '*' or ext is None: + if ext == '*': # NOTE: Update as needed. Should cover all the image extensions available. ext_list = ('data', 'dq', 'err', 'var_poisson', 'var_rnoise') + elif ext is None: + ext_list = ('data', ) elif isinstance(ext, (list, tuple)): ext_list = ext else: @@ -352,13 +389,18 @@ def _roman_2d_to_glue_data(file_obj, data_label, ext=None): data.add_component(component=component, label=comp_label) data.meta.update(standardize_metadata(dict(meta))) + if comp_label == 'dq': + prep_data_layer_as_dq(data, component_id=comp_label) + yield data, new_data_label def _roman_asdf_2d_to_glue_data(file_obj, data_label, ext=None): - if ext == '*' or ext is None: + if ext == '*': # NOTE: Update as needed. Should cover all the image extensions available. ext_list = ('data', 'dq', 'err', 'var_poisson', 'var_rnoise') + elif ext is None: + ext_list = ('data', ) elif isinstance(ext, (list, tuple)): ext_list = ext else: @@ -380,6 +422,8 @@ def _roman_asdf_2d_to_glue_data(file_obj, data_label, ext=None): component = Component(np.array(ext_values), units=bunit) data.add_component(component=component, label=comp_label) data.meta.update(standardize_metadata(dict(meta))) + if comp_label == 'DQ': + prep_data_layer_as_dq(data, component_id=comp_label) yield data, new_data_label @@ -391,11 +435,12 @@ def _hdu_to_glue_data(hdu, data_label, hdulist=None): yield data, data_label -def _hdus_to_glue_data(file_obj, data_label): +def _hdus_to_glue_data(file_obj, data_label, ext=None): for hdu in file_obj: - if _validate_fits_image2d(hdu, raise_error=False): - data, new_data_label = _hdu2data(hdu, data_label, file_obj) - yield data, new_data_label + if ext is None or ext == '*' or hdu.name in ext: + if _validate_fits_image2d(hdu, raise_error=False): + data, new_data_label = _hdu2data(hdu, data_label, file_obj) + yield data, new_data_label def _hdu2data(hdu, data_label, hdulist, include_wcs=True): diff --git a/jdaviz/core/template_mixin.py b/jdaviz/core/template_mixin.py index 8647b4c6b1..30eccf7b4a 100644 --- a/jdaviz/core/template_mixin.py +++ b/jdaviz/core/template_mixin.py @@ -56,7 +56,10 @@ from jdaviz.core.tools import ICON_DIR from jdaviz.core.user_api import UserApiWrapper, PluginUserApi from jdaviz.style_registry import PopoutStyleWrapper -from jdaviz.utils import get_subset_type, is_wcs_only, is_not_wcs_only, _wcs_only_label +from jdaviz.utils import ( + get_subset_type, is_wcs_only, is_not_wcs_only, + _wcs_only_label, layer_is_not_dq as layer_is_not_dq_global +) __all__ = ['show_widget', 'TemplateMixin', 'PluginTemplateMixin', @@ -1341,8 +1344,9 @@ def __init__(self, plugin, items, selected, viewer, multiselect=None, default_text=None, manual_options=[], default_mode='first', - filters=['not_child_layer'], - only_wcs_layers=False): + only_wcs_layers=False, + is_root=True, + is_child_of=None): """ Parameters ---------- @@ -1395,6 +1399,25 @@ def __init__(self, plugin, items, selected, viewer, self._update_layer_items() self.update_wcs_only_filter(only_wcs_layers) + self.filter_is_root = is_root + self.filter_is_child_of = is_child_of + + if self.filter_is_root: + # ignore layers that are children in associations: + def filter_is_root(data): + return self.app._get_assoc_data_parent(data.label) is None + + self.add_filter(filter_is_root) + + elif not self.filter_is_root and self.filter_is_child_of is not None: + # only offer layers that are children of the correct parent: + def has_correct_parent(data): + if self.filter_is_child_of == '': + return False + return self.app._get_assoc_data_parent(data.label) == self.filter_is_child_of + + self.add_filter(has_correct_parent) + def _get_viewer(self, viewer): # newer will likely be the viewer name in most cases, but viewer id in the case # of additional viewers in imviz. @@ -3334,6 +3357,8 @@ def not_child_layer(data): # ignore layers that are children in associations: return self.app._get_assoc_data_parent(data.label) is None + layer_is_not_dq = layer_is_not_dq_global + return super()._is_valid_item(data, locals()) @observe('filters') diff --git a/jdaviz/utils.py b/jdaviz/utils.py index f6f5a437c3..bcc141b574 100644 --- a/jdaviz/utils.py +++ b/jdaviz/utils.py @@ -249,6 +249,10 @@ def is_not_wcs_only(layer): return not is_wcs_only(layer) +def layer_is_not_dq(data): + return not data.label.endswith('[DQ]') + + def standardize_metadata(metadata): """Standardize given metadata so it can be viewed in Metadata Viewer plugin. The input can be plain diff --git a/notebooks/concepts/imviz_dq_concept.ipynb b/notebooks/concepts/imviz_dq_concept.ipynb new file mode 100644 index 0000000000..f1b489f292 --- /dev/null +++ b/notebooks/concepts/imviz_dq_concept.ipynb @@ -0,0 +1,132 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "67e77fe2-df29-4224-9989-685aa9755e0a", + "metadata": {}, + "source": [ + "# DQ plugin concept" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "82dca7aa-fc11-4c6e-987b-2ad1df1c9dc7", + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "from astroquery.mast import Observations\n", + "import matplotlib.pyplot as plt\n", + "from jdaviz import Imviz" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dd42275f-ee86-47c0-bf7e-c87ec82c72c3", + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "from astroquery.mast import Observations\n", + "from jdaviz import Imviz\n", + "\n", + "imviz = Imviz()\n", + "\n", + "data_dir = \"../example_files\"\n", + "\n", + "uris = [\n", + " 'mast:JWST/product/jw01895001004_07101_00001_nrca3_cal.fits',\n", + "\n", + " # # Bonus: try with a second science data+DQ pair:\n", + " # 'mast:JWST/product/jw01895001004_04101_00001_nrca3_cal.fits'\n", + "]\n", + "\n", + "for uri in uris:\n", + " fn = uri.split('/')[-1]\n", + " path = f'{data_dir}/{fn}'\n", + " result = Observations.download_file(uri, local_path=path)\n", + " with warnings.catch_warnings():\n", + " warnings.simplefilter('ignore')\n", + " imviz.load_data(path, ext=('SCI', 'DQ'))\n", + "\n", + " ## Must also support:\n", + " # imviz.load_data(path, ext=\"*\")\n", + "\n", + "imviz.load_data(path, data_label='Roman L2', ext=('data', 'dq'))\n", + "\n", + "imviz.show(height=900)\n", + "\n", + "dq_plugin = imviz.plugins['Data Quality']._obj\n", + "\n", + "dq_plugin.open_in_tray()" + ] + }, + { + "cell_type": "markdown", + "id": "2cb12a5c-f6b8-4496-87c0-e61416fb2be5", + "metadata": {}, + "source": [ + "Download Roman example data:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cf968192-2bd5-4b6c-a905-5e4cc74d5b9a", + "metadata": {}, + "outputs": [], + "source": [ + "from astropy.utils.data import download_file\n", + "\n", + "# With the following optional dependencies, load a Roman L2 file:\n", + "# roman_datamodels==0.18 rad==0.18\n", + "roman_l2_url = 'https://stsci.box.com/shared/static/ktpt4li627kq4mipi3er5yd4qw6hq7ll.asdf'\n", + "path = download_file(roman_l2_url, cache=True)" + ] + }, + { + "cell_type": "markdown", + "id": "e03209fa-b14e-444a-a631-0e5b64809cbb", + "metadata": {}, + "source": [ + "Load into Imviz:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b52bac5c-eceb-4f0c-b30a-44b711681e7c", + "metadata": {}, + "outputs": [], + "source": [ + "imviz.load_data(path, data_label='Roman L2', ext=('data', 'dq'))\n", + "\n", + "# Select the science data layer in the DQ plugin:\n", + "dq_plugin.science_layer.selected = 'Roman L2[DATA]'" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyproject.toml b/pyproject.toml index 17591ae389..5c10ad09d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ dependencies = [ "traitlets>=5.0.5", "bqplot>=0.12.37", "bqplot-image-gl>=1.4.11", - "glue-core>=1.17.1,!=1.19.0", + "glue-core>=1.18.0,!=1.19.0", "glue-jupyter>=0.20", "echo>=0.5.0", "ipykernel>=6.19.4",