diff --git a/draco/analysis/interpolate.py b/draco/analysis/interpolate.py new file mode 100644 index 00000000..de4674c1 --- /dev/null +++ b/draco/analysis/interpolate.py @@ -0,0 +1,282 @@ +"""Tasks to do data interpolation/inpainting.""" + +import numpy as np +from caput import config +from cora.util import units + +from ..core import io, task +from ..util import dpss + + +class DPSSInpaint(task.SingleTask): + """Fill data gaps using DPSS inpainting. + + Discrete prolate spheroidal sequence (DPSS) inpainting involves + projecting a partially-masked data series onto a basis which + maximally concentrates spectral power within a defined window. + This basis, called the nth order discrete prolate spheroidal + sequence or Slepian sequence consists of the large eigenvectors + of a covariance matrix defined as a sum of `sinc` functions, + which represent top-hats in the spectral inverse of the data. + + Attributes + ---------- + axis : str + Name of the axis over which to inpaint. Only one-dimensional + inpainting is currently supported. Must be either "freq" or + "ra". Default is `freq`. + iter_axes : list[str] + List of independent axes over which to iterate. This can + include axes not in the dataset, but at least one of these + axes should be present. Default is ["stack", "el"]. + centres : list + List of top-hat window centres. If all windows are centred + about zero, the covariance matrix will be real, which provides + significant performance improvements. + halfwidths : list + List of window half-widths. Must be the same length as `centres`. + snr_cov : float + Wiener filter inverse signal covariance. Default is 1.0e-3. + copy : bool + If true, copy the container instead of inpainting in-place. + """ + + axis = config.enum(["freq", "ra"], default="freq") + iter_axes = config.Property(proptype=list, default=["stack", "el"]) + centres = config.Property(proptype=list) + halfwidths = config.Property(proptype=list) + snr_cov = config.Property(proptype=float, default=1.0e-3) + copy = config.Property(proptype=bool, default=True) + + def process(self, data): + """Inpaint visibility data. + + Parameters + ---------- + data : containers.VisContainer + Container with a visibility dataset + + Returns + ------- + data : containers.VisContainer + Input container with masked values filled + """ + try: + # Get the axis samples + samples = getattr(data, self.axis) + except AttributeError as exc: + raise ValueError(f"Could not get axis `{self.axis}`.") from exc + + # Redistribute over an independent axis + data.redistribute(self.iter_axes) + # Set the local selection over the distributed axis + self._set_sel(data) + + vinp, _ = self.inpaint(data.vis, data.weight, samples) + # TODO: need a weight estimate + + # Make the output container + out = data.copy() if self.copy else data + out.redistribute(self.iter_axes) + + out.vis[:].local_array[:] = vinp + + return out + + def inpaint(self, vis, weight, samples): + """Inpaint visibilities using a wiener filter. + + Use a single sequence for the entire dataset. + """ + # Move the iteration and interpolation axes + # to the front and flatten the other axes + vobs, vaxind = _flatten_axes(vis, (*self.iter_axes, self.axis)) + wobs, waxind = _flatten_axes(weight, (*self.iter_axes, self.axis)) + + # Pre-allocate the full output array + vinp = np.zeros_like(vobs) + winp = np.zeros_like(wobs) + + # Construct the covariance matrix and get dpss modes + modes, amap = self._get_basis(samples) + + # Iterate over the variable axis + for ii in range(vobs.shape[0]): + # Get the correct basis for each slice + A = modes[amap[ii]] + # Write to the preallocated output array + vinp[ii] = self.inpaint_single(vobs[ii], wobs[ii], A) + + # Reshape and move the interpolation axis back + vinp = _inv_move_front(vinp, vaxind, vis.local_shape) + winp = _inv_move_front(winp, waxind, weight.local_shape) + + return vinp, winp + + def inpaint_single(self, vobs, wobs, A): + """Inpaint a data slice.""" + # Project visibilities into the dpss basis + vproj = dpss.project(vobs, wobs, A) + # Solve for basis coefficients + b = dpss.solve(vproj, wobs, A, Si=self.snr_cov) + # Get the inpainted visibilities + return dpss.inpaint(A, b, vobs, wobs > 0) + + def _set_sel(self, data): + """Extract selection along local axis.""" + self._local_sel = data.vis[:].local_bounds + + def _get_basis(self, samples): + """Make the DPSS basis. + + Returns a list of bases and a map. + """ + # Construct the covariance matrix and get dpss modes + cov = dpss.make_covariance(samples, self.halfwidths, self.centres) + modes = dpss.get_sequence(cov) + # All iterations map to the same basis + amap = [0] * (self._local_sel.stop - self._local_sel.start) + + return [modes], amap + + +class DPSSInpaintDelay(DPSSInpaint): + """Inpaint with baseline-dependent delay cut. + + Attributes + ---------- + axis : str + Name of axis over which to inpaint. `freq` is the only + accepted argument. + za_cut : float + Sine of the maximum zenith angle included in baseline-dependent delay + filtering. Default is 1 which corresponds to the horizon (ie: filters out all + zenith angles). Setting to zero turns off baseline dependent cut. + extra_cut : float + Increase the delay threshold beyond the baseline dependent term. + telescope_orientation : one of ('NS', 'EW', 'none') + Determines if the baseline-dependent delay cut is based on the north-south + component, the east-west component or the full baseline length. For + cylindrical telescopes oriented in the NS direction (like CHIME) use 'NS'. + The default is 'NS'. + """ + + axis = config.enum(["freq"], default="freq") + za_cut = config.Property(proptype=float, default=1.0) + extra_cut = config.Property(proptype=float, default=0.0) + telescope_orientation = config.enum(["NS", "EW", "none"], default="NS") + + def setup(self, telescope): + """Load a telescope object. + + This is required to establish baseline-dependent + delay cuts. + + Parameters + ---------- + telescope : TransitTelescope + Telescope object with baseline information. + """ + self.telescope = io.get_telescope(telescope) + + def _set_sel(self, data): + """Set the local baselines.""" + prod = data.prodstack + sel = self.telescope.feedmap[(prod["input_a"], prod["input_b"])] + + self._baselines = self.telescope.baselines[sel] + + def _get_basis(self, samples): + """Make the DPSS basis for each unique delay cut. + + Returns a list of bases and a map. + """ + # Calculate delay cuts based on telescope orientation + if self.telescope_orientation == "NS": + blen = abs(self._baselines[:, 1]) + elif self.telescope_orientation == "EW": + blen = abs(self._baselines[:, 0]) + else: + blen = np.linalg.norm(self._baselines, axis=1) + + # Get the delay cut for each baseline. Round delay cuts + # to four decimal places to reduce repeat calculations + delay_cut = self.za_cut * blen / units.c * 1.0e6 + self.extra_cut + delay_cut = np.maximum(delay_cut, self.halfwidths[0]) + delay_cut = np.round(delay_cut, decimals=3) + + # Compute covariances for each unique baseline and + # map to each individual baseline. + delay_cut, amap = np.unique(delay_cut, return_inverse=True) + + modes = [] + + for ii, cut in enumerate(delay_cut): + self.log.debug(f"Making unique covariance {ii}/{len(delay_cut)}.") + cov = dpss.make_covariance(samples, cut, 0.0) + modes.append(dpss.get_sequence(cov)) + + return modes, amap + + +class DPSSInpaintDelayStokesI(DPSSInpaintDelay): + """Inpaint with baseline-dependent delay cut. + + The input container must contain Stokes I visibilities. + """ + + def _set_sel(self, data): + """Set the local baselines.""" + # Baseline lengths extracted from the stack axis + self._baselines = data.stack[data.vis[:].local_bounds] + + +def _flatten_axes(data, axes): + """Move the specified axes to the front of a dataset. + + Not all the axes in `axes` need to be present, but at + least one must exist + """ + dax = list(data.attrs["axis"]) + + axind = [dax.index(axis) for axis in axes if axis in dax] + + if not axind: + raise ValueError( + f"No matching axes. Dataset has axes {dax}, " + f"but axes {axes} were requested." + ) + + return _move_front(data[:].local_array, axind, data.local_shape), axind + + +def _move_front(arr: np.ndarray, axis: int | list, shape: tuple) -> np.ndarray: + """Move specified axes to the front and flatten remaining axes.""" + if np.isscalar(axis): + axis = [axis] + + new_shape = [shape[i] for i in axis] + # Move the N specified axes to the first N positions + inds = list(range(len(axis))) + # Move the specified axes to the front and flatten + # the remaining axes + arr = np.moveaxis(arr, axis, inds) + + return arr.reshape(*new_shape, -1) + + +def _inv_move_front(arr: np.ndarray, axis: int | list, shape: tuple) -> np.ndarray: + """Move axes back to their original position and expand.""" + if np.isscalar(axis): + axis = [axis] + + new_shape = [shape[i] for i in axis] + new_shape += [sh for sh in shape if sh not in new_shape] + inds = list(range(len(axis))) + + # Undo the flattening process + arr = arr.reshape(new_shape) + # Move axes back to their original positions + arr = np.moveaxis(arr, inds, axis) + + return arr.reshape(shape)