Skip to content

Commit

Permalink
feat(interpolate): tasks to do generic and delay-specific dpss interp…
Browse files Browse the repository at this point in the history
…olation
  • Loading branch information
ljgray committed Jan 9, 2025
1 parent 8230fc6 commit 80f66ed
Showing 1 changed file with 282 additions and 0 deletions.
282 changes: 282 additions & 0 deletions draco/analysis/interpolate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
"""Tasks to do data interpolation/inpainting."""

import numpy as np
from caput import config
from cora.util import units

from ..core import io, task
from ..util import dpss


class DPSSInpaint(task.SingleTask):
"""Fill data gaps using DPSS inpainting.
Discrete prolate spheroidal sequence (DPSS) inpainting involves
projecting a partially-masked data series onto a basis which
maximally concentrates spectral power within a defined window.
This basis, called the nth order discrete prolate spheroidal
sequence or Slepian sequence consists of the large eigenvectors
of a covariance matrix defined as a sum of `sinc` functions,
which represent top-hats in the spectral inverse of the data.
Attributes
----------
axis : str
Name of the axis over which to inpaint. Only one-dimensional
inpainting is currently supported. Must be either "freq" or
"ra". Default is `freq`.
iter_axes : list[str]
List of independent axes over which to iterate. This can
include axes not in the dataset, but at least one of these
axes should be present. Default is ["stack", "el"].
centres : list
List of top-hat window centres. If all windows are centred
about zero, the covariance matrix will be real, which provides
significant performance improvements.
halfwidths : list
List of window half-widths. Must be the same length as `centres`.
snr_cov : float
Wiener filter inverse signal covariance. Default is 1.0e-3.
copy : bool
If true, copy the container instead of inpainting in-place.
"""

axis = config.enum(["freq", "ra"], default="freq")
iter_axes = config.Property(proptype=list, default=["stack", "el"])
centres = config.Property(proptype=list)
halfwidths = config.Property(proptype=list)
snr_cov = config.Property(proptype=float, default=1.0e-3)
copy = config.Property(proptype=bool, default=True)

def process(self, data):
"""Inpaint visibility data.
Parameters
----------
data : containers.VisContainer
Container with a visibility dataset
Returns
-------
data : containers.VisContainer
Input container with masked values filled
"""
try:
# Get the axis samples
samples = getattr(data, self.axis)
except AttributeError as exc:
raise ValueError(f"Could not get axis `{self.axis}`.") from exc

# Redistribute over an independent axis
data.redistribute(self.iter_axes)
# Set the local selection over the distributed axis
self._set_sel(data)

vinp, _ = self.inpaint(data.vis, data.weight, samples)
# TODO: need a weight estimate

# Make the output container
out = data.copy() if self.copy else data
out.redistribute(self.iter_axes)

out.vis[:].local_array[:] = vinp

return out

def inpaint(self, vis, weight, samples):
"""Inpaint visibilities using a wiener filter.
Use a single sequence for the entire dataset.
"""
# Move the iteration and interpolation axes
# to the front and flatten the other axes
vobs, vaxind = _flatten_axes(vis, (*self.iter_axes, self.axis))
wobs, waxind = _flatten_axes(weight, (*self.iter_axes, self.axis))

# Pre-allocate the full output array
vinp = np.zeros_like(vobs)
winp = np.zeros_like(wobs)

# Construct the covariance matrix and get dpss modes
modes, amap = self._get_basis(samples)

# Iterate over the variable axis
for ii in range(vobs.shape[0]):
# Get the correct basis for each slice
A = modes[amap[ii]]
# Write to the preallocated output array
vinp[ii] = self.inpaint_single(vobs[ii], wobs[ii], A)

# Reshape and move the interpolation axis back
vinp = _inv_move_front(vinp, vaxind, vis.local_shape)
winp = _inv_move_front(winp, waxind, weight.local_shape)

return vinp, winp

def inpaint_single(self, vobs, wobs, A):
"""Inpaint a data slice."""
# Project visibilities into the dpss basis
vproj = dpss.project(vobs, wobs, A)
# Solve for basis coefficients
b = dpss.solve(vproj, wobs, A, Si=self.snr_cov)
# Get the inpainted visibilities
return dpss.inpaint(A, b, vobs, wobs > 0)

def _set_sel(self, data):
"""Extract selection along local axis."""
self._local_sel = data.vis[:].local_bounds

def _get_basis(self, samples):
"""Make the DPSS basis.
Returns a list of bases and a map.
"""
# Construct the covariance matrix and get dpss modes
cov = dpss.make_covariance(samples, self.halfwidths, self.centres)
modes = dpss.get_sequence(cov)
# All iterations map to the same basis
amap = [0] * (self._local_sel.stop - self._local_sel.start)

return [modes], amap


class DPSSInpaintDelay(DPSSInpaint):
"""Inpaint with baseline-dependent delay cut.
Attributes
----------
axis : str
Name of axis over which to inpaint. `freq` is the only
accepted argument.
za_cut : float
Sine of the maximum zenith angle included in baseline-dependent delay
filtering. Default is 1 which corresponds to the horizon (ie: filters out all
zenith angles). Setting to zero turns off baseline dependent cut.
extra_cut : float
Increase the delay threshold beyond the baseline dependent term.
telescope_orientation : one of ('NS', 'EW', 'none')
Determines if the baseline-dependent delay cut is based on the north-south
component, the east-west component or the full baseline length. For
cylindrical telescopes oriented in the NS direction (like CHIME) use 'NS'.
The default is 'NS'.
"""

axis = config.enum(["freq"], default="freq")
za_cut = config.Property(proptype=float, default=1.0)
extra_cut = config.Property(proptype=float, default=0.0)
telescope_orientation = config.enum(["NS", "EW", "none"], default="NS")

def setup(self, telescope):
"""Load a telescope object.
This is required to establish baseline-dependent
delay cuts.
Parameters
----------
telescope : TransitTelescope
Telescope object with baseline information.
"""
self.telescope = io.get_telescope(telescope)

def _set_sel(self, data):
"""Set the local baselines."""
prod = data.prodstack
sel = self.telescope.feedmap[(prod["input_a"], prod["input_b"])]

self._baselines = self.telescope.baselines[sel]

def _get_basis(self, samples):
"""Make the DPSS basis for each unique delay cut.
Returns a list of bases and a map.
"""
# Calculate delay cuts based on telescope orientation
if self.telescope_orientation == "NS":
blen = abs(self._baselines[:, 1])
elif self.telescope_orientation == "EW":
blen = abs(self._baselines[:, 0])
else:
blen = np.linalg.norm(self._baselines, axis=1)

# Get the delay cut for each baseline. Round delay cuts
# to four decimal places to reduce repeat calculations
delay_cut = self.za_cut * blen / units.c * 1.0e6 + self.extra_cut
delay_cut = np.maximum(delay_cut, self.halfwidths[0])
delay_cut = np.round(delay_cut, decimals=3)

# Compute covariances for each unique baseline and
# map to each individual baseline.
delay_cut, amap = np.unique(delay_cut, return_inverse=True)

modes = []

for ii, cut in enumerate(delay_cut):
self.log.debug(f"Making unique covariance {ii}/{len(delay_cut)}.")
cov = dpss.make_covariance(samples, cut, 0.0)
modes.append(dpss.get_sequence(cov))

return modes, amap


class DPSSInpaintDelayStokesI(DPSSInpaintDelay):
"""Inpaint with baseline-dependent delay cut.
The input container must contain Stokes I visibilities.
"""

def _set_sel(self, data):
"""Set the local baselines."""
# Baseline lengths extracted from the stack axis
self._baselines = data.stack[data.vis[:].local_bounds]


def _flatten_axes(data, axes):
"""Move the specified axes to the front of a dataset.
Not all the axes in `axes` need to be present, but at
least one must exist
"""
dax = list(data.attrs["axis"])

axind = [dax.index(axis) for axis in axes if axis in dax]

if not axind:
raise ValueError(
f"No matching axes. Dataset has axes {dax}, "
f"but axes {axes} were requested."
)

return _move_front(data[:].local_array, axind, data.local_shape), axind


def _move_front(arr: np.ndarray, axis: int | list, shape: tuple) -> np.ndarray:
"""Move specified axes to the front and flatten remaining axes."""
if np.isscalar(axis):
axis = [axis]

new_shape = [shape[i] for i in axis]
# Move the N specified axes to the first N positions
inds = list(range(len(axis)))
# Move the specified axes to the front and flatten
# the remaining axes
arr = np.moveaxis(arr, axis, inds)

return arr.reshape(*new_shape, -1)


def _inv_move_front(arr: np.ndarray, axis: int | list, shape: tuple) -> np.ndarray:
"""Move axes back to their original position and expand."""
if np.isscalar(axis):
axis = [axis]

new_shape = [shape[i] for i in axis]
new_shape += [sh for sh in shape if sh not in new_shape]
inds = list(range(len(axis)))

# Undo the flattening process
arr = arr.reshape(new_shape)
# Move axes back to their original positions
arr = np.moveaxis(arr, inds, axis)

return arr.reshape(shape)

0 comments on commit 80f66ed

Please # to comment.