-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(interpolate): tasks to do generic and delay-specific dpss interp…
…olation
- Loading branch information
Showing
1 changed file
with
282 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,282 @@ | ||
"""Tasks to do data interpolation/inpainting.""" | ||
|
||
import numpy as np | ||
from caput import config | ||
from cora.util import units | ||
|
||
from ..core import io, task | ||
from ..util import dpss | ||
|
||
|
||
class DPSSInpaint(task.SingleTask): | ||
"""Fill data gaps using DPSS inpainting. | ||
Discrete prolate spheroidal sequence (DPSS) inpainting involves | ||
projecting a partially-masked data series onto a basis which | ||
maximally concentrates spectral power within a defined window. | ||
This basis, called the nth order discrete prolate spheroidal | ||
sequence or Slepian sequence consists of the large eigenvectors | ||
of a covariance matrix defined as a sum of `sinc` functions, | ||
which represent top-hats in the spectral inverse of the data. | ||
Attributes | ||
---------- | ||
axis : str | ||
Name of the axis over which to inpaint. Only one-dimensional | ||
inpainting is currently supported. Must be either "freq" or | ||
"ra". Default is `freq`. | ||
iter_axes : list[str] | ||
List of independent axes over which to iterate. This can | ||
include axes not in the dataset, but at least one of these | ||
axes should be present. Default is ["stack", "el"]. | ||
centres : list | ||
List of top-hat window centres. If all windows are centred | ||
about zero, the covariance matrix will be real, which provides | ||
significant performance improvements. | ||
halfwidths : list | ||
List of window half-widths. Must be the same length as `centres`. | ||
snr_cov : float | ||
Wiener filter inverse signal covariance. Default is 1.0e-3. | ||
copy : bool | ||
If true, copy the container instead of inpainting in-place. | ||
""" | ||
|
||
axis = config.enum(["freq", "ra"], default="freq") | ||
iter_axes = config.Property(proptype=list, default=["stack", "el"]) | ||
centres = config.Property(proptype=list) | ||
halfwidths = config.Property(proptype=list) | ||
snr_cov = config.Property(proptype=float, default=1.0e-3) | ||
copy = config.Property(proptype=bool, default=True) | ||
|
||
def process(self, data): | ||
"""Inpaint visibility data. | ||
Parameters | ||
---------- | ||
data : containers.VisContainer | ||
Container with a visibility dataset | ||
Returns | ||
------- | ||
data : containers.VisContainer | ||
Input container with masked values filled | ||
""" | ||
try: | ||
# Get the axis samples | ||
samples = getattr(data, self.axis) | ||
except AttributeError as exc: | ||
raise ValueError(f"Could not get axis `{self.axis}`.") from exc | ||
|
||
# Redistribute over an independent axis | ||
data.redistribute(self.iter_axes) | ||
# Set the local selection over the distributed axis | ||
self._set_sel(data) | ||
|
||
vinp, _ = self.inpaint(data.vis, data.weight, samples) | ||
# TODO: need a weight estimate | ||
|
||
# Make the output container | ||
out = data.copy() if self.copy else data | ||
out.redistribute(self.iter_axes) | ||
|
||
out.vis[:].local_array[:] = vinp | ||
|
||
return out | ||
|
||
def inpaint(self, vis, weight, samples): | ||
"""Inpaint visibilities using a wiener filter. | ||
Use a single sequence for the entire dataset. | ||
""" | ||
# Move the iteration and interpolation axes | ||
# to the front and flatten the other axes | ||
vobs, vaxind = _flatten_axes(vis, (*self.iter_axes, self.axis)) | ||
wobs, waxind = _flatten_axes(weight, (*self.iter_axes, self.axis)) | ||
|
||
# Pre-allocate the full output array | ||
vinp = np.zeros_like(vobs) | ||
winp = np.zeros_like(wobs) | ||
|
||
# Construct the covariance matrix and get dpss modes | ||
modes, amap = self._get_basis(samples) | ||
|
||
# Iterate over the variable axis | ||
for ii in range(vobs.shape[0]): | ||
# Get the correct basis for each slice | ||
A = modes[amap[ii]] | ||
# Write to the preallocated output array | ||
vinp[ii] = self.inpaint_single(vobs[ii], wobs[ii], A) | ||
|
||
# Reshape and move the interpolation axis back | ||
vinp = _inv_move_front(vinp, vaxind, vis.local_shape) | ||
winp = _inv_move_front(winp, waxind, weight.local_shape) | ||
|
||
return vinp, winp | ||
|
||
def inpaint_single(self, vobs, wobs, A): | ||
"""Inpaint a data slice.""" | ||
# Project visibilities into the dpss basis | ||
vproj = dpss.project(vobs, wobs, A) | ||
# Solve for basis coefficients | ||
b = dpss.solve(vproj, wobs, A, Si=self.snr_cov) | ||
# Get the inpainted visibilities | ||
return dpss.inpaint(A, b, vobs, wobs > 0) | ||
|
||
def _set_sel(self, data): | ||
"""Extract selection along local axis.""" | ||
self._local_sel = data.vis[:].local_bounds | ||
|
||
def _get_basis(self, samples): | ||
"""Make the DPSS basis. | ||
Returns a list of bases and a map. | ||
""" | ||
# Construct the covariance matrix and get dpss modes | ||
cov = dpss.make_covariance(samples, self.halfwidths, self.centres) | ||
modes = dpss.get_sequence(cov) | ||
# All iterations map to the same basis | ||
amap = [0] * (self._local_sel.stop - self._local_sel.start) | ||
|
||
return [modes], amap | ||
|
||
|
||
class DPSSInpaintDelay(DPSSInpaint): | ||
"""Inpaint with baseline-dependent delay cut. | ||
Attributes | ||
---------- | ||
axis : str | ||
Name of axis over which to inpaint. `freq` is the only | ||
accepted argument. | ||
za_cut : float | ||
Sine of the maximum zenith angle included in baseline-dependent delay | ||
filtering. Default is 1 which corresponds to the horizon (ie: filters out all | ||
zenith angles). Setting to zero turns off baseline dependent cut. | ||
extra_cut : float | ||
Increase the delay threshold beyond the baseline dependent term. | ||
telescope_orientation : one of ('NS', 'EW', 'none') | ||
Determines if the baseline-dependent delay cut is based on the north-south | ||
component, the east-west component or the full baseline length. For | ||
cylindrical telescopes oriented in the NS direction (like CHIME) use 'NS'. | ||
The default is 'NS'. | ||
""" | ||
|
||
axis = config.enum(["freq"], default="freq") | ||
za_cut = config.Property(proptype=float, default=1.0) | ||
extra_cut = config.Property(proptype=float, default=0.0) | ||
telescope_orientation = config.enum(["NS", "EW", "none"], default="NS") | ||
|
||
def setup(self, telescope): | ||
"""Load a telescope object. | ||
This is required to establish baseline-dependent | ||
delay cuts. | ||
Parameters | ||
---------- | ||
telescope : TransitTelescope | ||
Telescope object with baseline information. | ||
""" | ||
self.telescope = io.get_telescope(telescope) | ||
|
||
def _set_sel(self, data): | ||
"""Set the local baselines.""" | ||
prod = data.prodstack | ||
sel = self.telescope.feedmap[(prod["input_a"], prod["input_b"])] | ||
|
||
self._baselines = self.telescope.baselines[sel] | ||
|
||
def _get_basis(self, samples): | ||
"""Make the DPSS basis for each unique delay cut. | ||
Returns a list of bases and a map. | ||
""" | ||
# Calculate delay cuts based on telescope orientation | ||
if self.telescope_orientation == "NS": | ||
blen = abs(self._baselines[:, 1]) | ||
elif self.telescope_orientation == "EW": | ||
blen = abs(self._baselines[:, 0]) | ||
else: | ||
blen = np.linalg.norm(self._baselines, axis=1) | ||
|
||
# Get the delay cut for each baseline. Round delay cuts | ||
# to four decimal places to reduce repeat calculations | ||
delay_cut = self.za_cut * blen / units.c * 1.0e6 + self.extra_cut | ||
delay_cut = np.maximum(delay_cut, self.halfwidths[0]) | ||
delay_cut = np.round(delay_cut, decimals=3) | ||
|
||
# Compute covariances for each unique baseline and | ||
# map to each individual baseline. | ||
delay_cut, amap = np.unique(delay_cut, return_inverse=True) | ||
|
||
modes = [] | ||
|
||
for ii, cut in enumerate(delay_cut): | ||
self.log.debug(f"Making unique covariance {ii}/{len(delay_cut)}.") | ||
cov = dpss.make_covariance(samples, cut, 0.0) | ||
modes.append(dpss.get_sequence(cov)) | ||
|
||
return modes, amap | ||
|
||
|
||
class DPSSInpaintDelayStokesI(DPSSInpaintDelay): | ||
"""Inpaint with baseline-dependent delay cut. | ||
The input container must contain Stokes I visibilities. | ||
""" | ||
|
||
def _set_sel(self, data): | ||
"""Set the local baselines.""" | ||
# Baseline lengths extracted from the stack axis | ||
self._baselines = data.stack[data.vis[:].local_bounds] | ||
|
||
|
||
def _flatten_axes(data, axes): | ||
"""Move the specified axes to the front of a dataset. | ||
Not all the axes in `axes` need to be present, but at | ||
least one must exist | ||
""" | ||
dax = list(data.attrs["axis"]) | ||
|
||
axind = [dax.index(axis) for axis in axes if axis in dax] | ||
|
||
if not axind: | ||
raise ValueError( | ||
f"No matching axes. Dataset has axes {dax}, " | ||
f"but axes {axes} were requested." | ||
) | ||
|
||
return _move_front(data[:].local_array, axind, data.local_shape), axind | ||
|
||
|
||
def _move_front(arr: np.ndarray, axis: int | list, shape: tuple) -> np.ndarray: | ||
"""Move specified axes to the front and flatten remaining axes.""" | ||
if np.isscalar(axis): | ||
axis = [axis] | ||
|
||
new_shape = [shape[i] for i in axis] | ||
# Move the N specified axes to the first N positions | ||
inds = list(range(len(axis))) | ||
# Move the specified axes to the front and flatten | ||
# the remaining axes | ||
arr = np.moveaxis(arr, axis, inds) | ||
|
||
return arr.reshape(*new_shape, -1) | ||
|
||
|
||
def _inv_move_front(arr: np.ndarray, axis: int | list, shape: tuple) -> np.ndarray: | ||
"""Move axes back to their original position and expand.""" | ||
if np.isscalar(axis): | ||
axis = [axis] | ||
|
||
new_shape = [shape[i] for i in axis] | ||
new_shape += [sh for sh in shape if sh not in new_shape] | ||
inds = list(range(len(axis))) | ||
|
||
# Undo the flattening process | ||
arr = arr.reshape(new_shape) | ||
# Move axes back to their original positions | ||
arr = np.moveaxis(arr, inds, axis) | ||
|
||
return arr.reshape(shape) |