import logging
import warnings
from dataclasses import dataclass
import numpy as np
from scipy.optimize import minimize
from stdatamodels.jwst import datamodels
log = logging.getLogger(__name__)
__all__ = ["PixelReplaceArrays", "PixelReplacement"]
[docs]
@dataclass
class PixelReplaceArrays:
"""
Container for data arrays and dispersion direction.
Algorithms operate on this dataclass rather than on a
`~stdatamodels.jwst.datamodels.JwstDataModel`.
This avoids the overhead of constructing intermediate DataModel objects,
which was slowing runtime for TSO data with thousands of integrations,
and provides a consistent interface for all pixel replace algorithms.
"""
data: np.ndarray
"""Science array."""
dq: np.ndarray
"""Data quality array."""
err: np.ndarray
"""Total error array."""
var_poisson: np.ndarray | None
"""Poisson variance array."""
var_rnoise: np.ndarray | None
"""Read-noise variance array."""
var_flat: np.ndarray | None
"""Flat-field variance array."""
trace_model: np.ndarray | None
"""Trace model array."""
dispersion_direction: int
"""Dispersion direction."""
[docs]
class PixelReplacement:
"""
Main class for performing pixel replacement.
This class controls loading the input data model, selecting the
method for pixel replacement, and executing each step. This class
should provide modularization to allow for multiple options and possible
future reference files.
Parameters
----------
input_model : `~stdatamodels.jwst.datamodels.JwstDataModel`
Datamodel with bad pixels to replace. Updated in-place.
algorithm : str, optional
Replacement algorithm. Options are "mingrad", "fit_profile", or "trace_model".
n_adjacent_cols : int, optional
The number of adjacent columns to consider in building a spatial
profile. Used only if ``algorithm`` is "fit_profile".
"""
# Shortcuts for DQ Flags
DO_NOT_USE = datamodels.dqflags.pixel["DO_NOT_USE"]
FLUX_ESTIMATED = datamodels.dqflags.pixel["FLUX_ESTIMATED"]
NON_SCIENCE = datamodels.dqflags.pixel["NON_SCIENCE"]
# Shortcuts for dispersion direction for ease of reading
HORIZONTAL = 1
VERTICAL = 2
LOG_SLICE = ["column", "row"]
def __init__(self, input_model, algorithm="mingrad", n_adjacent_cols=3):
self.input = input_model
self.algorithm_name = algorithm
self.n_adjacent_cols = n_adjacent_cols
# Store algorithm options here.
self.algorithm_dict = {
"fit_profile": self.fit_profile,
"mingrad": self.mingrad,
"trace_model": self.trace_model,
}
# Choose algorithm from dict using input par.
try:
self.algorithm = self.algorithm_dict[self.algorithm_name]
except KeyError as err:
log.critical(
f"Algorithm name '{self.algorithm_name}' provided does "
"not match an implemented algorithm."
)
raise KeyError from err
@staticmethod
def _arrays_from_model(model):
"""Extract PixelReplaceArrays from DataModel, copying arrays.""" # numpydoc ignore: RT01
# Ensure variance arrays exist
var_dict = {
"var_poisson": None,
"var_rnoise": None,
"var_flat": None,
}
for key in var_dict.keys():
if model[key] is not None:
var_dict[key] = model[key].copy()
return PixelReplaceArrays(
data=model.data.copy(),
dq=model.dq.copy(),
err=model.err.copy(),
var_poisson=var_dict["var_poisson"],
var_rnoise=var_dict["var_rnoise"],
var_flat=var_dict["var_flat"],
trace_model=getattr(model, "trace_model", None), # info-only, no need to copy
dispersion_direction=model.meta.wcsinfo.dispersion_direction,
)
@staticmethod
def _model_from_arrays(arrays, model):
"""Write PixelReplaceArrays back into a DataModel in place.""" # numpydoc ignore: RT01
model.data = arrays.data
model.dq = arrays.dq
model.err = arrays.err
model.var_poisson = arrays.var_poisson
model.var_rnoise = arrays.var_rnoise
model.var_flat = arrays.var_flat
def _is_estimated(self, data, dq):
"""
Make a mask of estimated pixels.
Parameters
----------
data : ndarray
Data array.
dq : ndarray
DQ array.
Returns
-------
ndarray
Boolean array. True where pixels are flagged as FLUX_ESTIMATED and finite.
"""
# This could be a one-liner, but the bitwise operator precedence is finicky,
# so this helper function exists to keep the logic in one place.
is_flagged = (dq & self.FLUX_ESTIMATED) > 0
return np.isfinite(data) & is_flagged
[docs]
def replace(self):
"""
Unpack model and apply pixel replacement algorithm.
Process the input `~stdatamodels.jwst.datamodels.JwstDataModel`,
unpack any model that holds
more than one 2D spectrum, then apply selected algorithm
to each 2D spectrum in input.
"""
# ImageModel inputs (MIR_LRS-FIXEDSLIT)
# or 2D SlitModel inputs (e.g. NRS_FIXEDSLIT in spec3)
if isinstance(self.input, datamodels.ImageModel) or (
isinstance(self.input, datamodels.SlitModel) and self.input.data.ndim == 2
):
# Count pixels previously estimated first
previous_flag = self._is_estimated(self.input.data, self.input.dq)
arrays = self._arrays_from_model(self.input)
arrays = self.algorithm(arrays)
self._model_from_arrays(arrays, self.input)
# Count newly estimated pixels
n_replaced = np.count_nonzero(
self._is_estimated(self.input.data, self.input.dq) & ~previous_flag
)
log.info(f"Input model had {n_replaced} pixels replaced.")
elif isinstance(self.input, datamodels.IFUImageModel):
previous_flag = self._is_estimated(self.input.data, self.input.dq)
arrays = self._arrays_from_model(self.input)
if self.algorithm_name == "mingrad" or self.algorithm_name == "trace_model":
arrays = self.algorithm(arrays)
self._model_from_arrays(arrays, self.input)
elif self.algorithm_name == "fit_profile":
# Attempt to run pixel replacement on each throw of the IFU slicer
# individually.
region_map = None
if hasattr(self.input, "regions") and self.input.regions is not None:
region_map = self.input.regions
elif self.input.meta.exposure.type == "MIR_MRS":
det2ab = self.input.meta.wcs.get_transform(
self.input.meta.wcs.available_frames[0], "alpha_beta"
)
region_map = det2ab.label_mapper.mapper.copy()
if region_map is None:
raise ValueError(
"Cannot use algorithm='fit_profile' for IFU data with missing region map"
)
region_numbers = np.unique(region_map[region_map > 0])
for slice_num in region_numbers:
log.info(f"Replacing pixels for slice {slice_num}")
# Define a mask that is True where this trace is located
trace_mask = region_map == slice_num
arrays = self._arrays_from_model(self.input)
arrays.dq = np.where(
# When not in this trace, set NON_SCIENCE and DO_NOT_USE
~trace_mask,
arrays.dq | self.DO_NOT_USE | self.NON_SCIENCE,
arrays.dq,
)
arrays = self.algorithm(arrays)
self.input.data = np.where(trace_mask, arrays.data, self.input.data)
self.input.dq = np.where(trace_mask, arrays.dq, self.input.dq)
self.input.err = np.where(trace_mask, arrays.err, self.input.err)
for var in ["var_poisson", "var_rnoise", "var_flat"]:
input_var = getattr(self.input, var)
update_var = getattr(arrays, var)
if input_var is not None and update_var is not None:
input_var = np.where(trace_mask, update_var, input_var)
setattr(self.input, var, input_var)
n_replaced = np.count_nonzero(
self._is_estimated(self.input.data, self.input.dq) & ~previous_flag
)
log.info(f"Input IFU frame had {n_replaced} total pixels replaced.")
# MultiSlitModel inputs (WFSS, NRS_FIXEDSLIT, ?)
elif isinstance(self.input, datamodels.MultiSlitModel):
for i, _slit in enumerate(self.input.slits):
slit_model = datamodels.SlitModel(self.input.slits[i].instance)
arrays = self._arrays_from_model(slit_model)
slit_model.close()
previous_flag = self._is_estimated(arrays.data, arrays.dq)
arrays = self.algorithm(arrays)
n_replaced = np.count_nonzero(
self._is_estimated(arrays.data, arrays.dq) & ~previous_flag
)
log.info(f"Slit {i} had {n_replaced} pixels replaced.")
self._model_from_arrays(arrays, self.input.slits[i])
# CubeModel inputs are TSO (so far?); SlitModel may be NRS_BRIGHTOBJ,
# also requiring a re-packaging of the data into 2D inputs for the algorithm
elif isinstance(self.input, datamodels.CubeModel | datamodels.SlitModel):
dispaxis = self.input.meta.wcsinfo.dispersion_direction
for i in range(len(self.input.data)):
# Ensure variance arrays exist
var_dict = {
"var_poisson": None,
"var_rnoise": None,
"var_flat": None,
}
for key in var_dict.keys():
if self.input[key] is not None:
var_dict[key] = self.input[key][i].copy()
arrays = PixelReplaceArrays(
data=self.input.data[i].copy(),
dq=self.input.dq[i].copy(),
err=self.input.err[i].copy(),
var_poisson=var_dict["var_poisson"],
var_rnoise=var_dict["var_rnoise"],
var_flat=var_dict["var_flat"],
trace_model=getattr(self.input, "trace_model", None),
dispersion_direction=dispaxis,
)
previous_flag = self._is_estimated(arrays.data, arrays.dq)
arrays = self.algorithm(arrays)
n_replaced = np.count_nonzero(
self._is_estimated(arrays.data, arrays.dq) & ~previous_flag
)
log.info(f"Input TSO integration {i} had {n_replaced} pixels replaced.")
self.input.data[i] = arrays.data
self.input.dq[i] = arrays.dq
self.input.err[i] = arrays.err
for key in var_dict.keys():
if self.input[key] is not None:
self.input[key][i] = getattr(arrays, key)
else:
# This should never happen, as these should be caught in the step code.
log.critical(f"Input model {self.input} is not supported - skipping step.")
return
[docs]
def fit_profile(self, arrays):
"""
Replace pixels with the profile fit method.
Fit a profile to adjacent columns, scale profile to
column with missing pixel(s), and find flux estimate
from scaled profile.
Error and variance values for the replaced pixels
are similarly estimated, using the scales from the
profile fit to the data.
Parameters
----------
arrays : `PixelReplaceArrays`
Pixel arrays and dispersion direction for the 2D spectrum to process.
Arrays are modified in place.
Returns
-------
arrays : `PixelReplaceArrays`
The input with bad pixels now flagged with FLUX_ESTIMATED
and holding a flux value estimated from the spatial profile.
"""
dispaxis = arrays.dispersion_direction
# Make a copy of the input DQ, before replacement
input_dq = arrays.dq.copy()
# Truncate array to region where good pixels exist
good_pixels = np.where(~input_dq & self.DO_NOT_USE)
if np.any(0 in np.shape(good_pixels)):
log.warning(
"No good pixels in at least one dimension of "
"data array - skipping pixel replacement."
)
return arrays
x_range = [np.min(good_pixels[0]), np.max(good_pixels[0]) + 1]
y_range = [np.min(good_pixels[1]), np.max(good_pixels[1]) + 1]
valid_shape = [x_range, y_range]
profile_cut = valid_shape[dispaxis - 1]
# COMMENTS NOTE:
# In comments and parameter naming, I will try to be consistent in using
# "profile" to describe vectors in the spatial, i.e. cross-dispersion direction,
# and "slice" to describe vectors in the spectral, i.e. dispersion direction.
# Create set of slice indices which we can later use for profile creation
valid_profiles = set(range(*valid_shape[2 - dispaxis]))
profiles_to_replace = set()
# Loop over axis of data array corresponding to cross-
# dispersion direction by indexing data shape with
# strange dispaxis argument. Keep indices in full-frame numbering scheme,
# but only iterate through slices with valid data.
for ind in range(*valid_shape[2 - dispaxis]):
# Exclude regions with no data for dq slice.
dq_slice = input_dq[self.custom_slice(dispaxis, ind)][profile_cut[0] : profile_cut[1]]
# Exclude regions with NON_SCIENCE flag
dq_slice = np.where(dq_slice & self.NON_SCIENCE, self.NON_SCIENCE, dq_slice)
# Find bad pixels in region containing valid data.
n_bad = np.count_nonzero(dq_slice & self.DO_NOT_USE)
n_nonscience = np.count_nonzero(dq_slice & self.NON_SCIENCE)
if n_bad + n_nonscience == len(dq_slice):
log.debug(f"Slice {ind} contains no good pixels. Skipping replacement.")
valid_profiles.discard(ind)
elif n_bad == 0:
log.debug(f"Slice {ind} contains no bad pixels.")
else:
log.debug(f"Slice {ind} contains {n_bad} bad pixels.")
profiles_to_replace.add(ind)
log.debug(f"Number of profiles with at least one bad pixel: {len(profiles_to_replace)}")
for ind in profiles_to_replace:
# Use sets for convenient finding of neighboring slices to use in profile creation
adjacent_inds = set(range(ind - self.n_adjacent_cols, ind + self.n_adjacent_cols + 1))
adjacent_inds.discard(ind)
valid_adjacent_inds = list(adjacent_inds.intersection(valid_profiles))
# Cut out valid neighboring profiles
adjacent_condition = self.custom_slice(dispaxis, valid_adjacent_inds)
profile_data = arrays.data[adjacent_condition]
profile_err = arrays.err[adjacent_condition]
if profile_data.size == 0:
log.info(
f"Profile in {self.LOG_SLICE[dispaxis - 1]} {ind} "
f"has no valid adjacent values - skipping."
)
continue
# Mask out bad pixels
invalid_condition = (input_dq[adjacent_condition] & self.DO_NOT_USE).astype(bool)
profile_data[invalid_condition] = np.nan
profile_err[invalid_condition] = np.nan
# Add additional cut to pull only from region with valid data
# for convenience (may not be necessary)
region_condition = self.custom_slice(3 - dispaxis, range(*profile_cut))
profile_data = profile_data[region_condition]
profile_snr = np.abs(profile_data / profile_err[region_condition])
# Normalize profile data
# TODO: check on signs here - absolute max sometimes picks up
# large negative outliers
with warnings.catch_warnings():
warnings.filterwarnings(action="ignore", message="All-NaN slice encountered")
profile_norm_scale = np.nanmax(
np.abs(profile_data), axis=(dispaxis - 1), keepdims=True
)
# If profile data has SNR < 5 everywhere just use unity scaling
# (so we don't normalize to noise)
if np.nanmax(profile_snr) < 5:
profile_norm_scale[:] = 1.0
normalized = profile_data / profile_norm_scale
# Get corresponding error and variance data and scale and mask to match
# Handle the variance arrays as errors, so the scales match.
err_names = ["err", "var_poisson", "var_rnoise", "var_flat"]
norm_errors = {}
for err_name in err_names:
if err_name.startswith("var"):
if (err_arr := getattr(arrays, err_name)) is None:
continue
err = np.sqrt(err_arr)
else:
err = getattr(arrays, err_name)
norm_err = err[adjacent_condition]
norm_err[invalid_condition] = np.nan
norm_errors[err_name] = norm_err[region_condition] / profile_norm_scale
# Pull median for each pixel across profile.
# Profile entry full of NaN values would produce a numpy
# warning (despite well-defined behavior - return a NaN)
# so we suppress that
with warnings.catch_warnings():
warnings.filterwarnings(action="ignore", message="All-NaN slice encountered")
median_profile = np.nanmedian(normalized, axis=(2 - dispaxis))
# Do the same for the errors
for err_name in norm_errors:
norm_errors[err_name] = np.nanmedian(norm_errors[err_name], axis=(2 - dispaxis))
# Clean current profile of values flagged as bad
current_condition = self.custom_slice(dispaxis, ind)
current_profile = arrays.data[current_condition]
cleaned_current = np.where(
input_dq[current_condition] & self.DO_NOT_USE, np.nan, current_profile
)[range(*profile_cut)]
replace_mask = np.where(~np.isnan(cleaned_current))[0]
if len(replace_mask) == 0:
log.info(
f"Profile in {self.LOG_SLICE[dispaxis - 1]} {ind} "
f"has no valid values - skipping."
)
continue
min_median = median_profile[replace_mask]
min_current = cleaned_current[replace_mask]
norm_current = min_current / np.max(min_current)
# Scale median profile to current profile with bad pixel - minimize mse?
# Only do this scaling if we didn't default to all-unity scaling above,
# and require input values below 1e20 so that we don't overflow the
# minimization routine with extremely bad noise.
with warnings.catch_warnings():
warnings.filterwarnings(action="ignore", message="All-NaN slice encountered")
if (
(np.nanmedian(profile_norm_scale) != 1.0)
& (np.nanmax(np.abs(min_median)) < 1e20)
& (np.nanmax(np.abs(norm_current)) < 1e20)
):
# TODO: check on signs here - absolute max sometimes picks up
# large negative outliers
norm_scale = minimize(
self.profile_mse,
x0=np.abs(np.nanmax(norm_current)),
args=(np.abs(min_median), np.abs(norm_current)),
method="Nelder-Mead",
).x
scale = np.max(min_current)
else:
norm_scale = 1.0
scale = 1.0
# Replace pixels that are do-not-use but not non-science
current_dq = input_dq[current_condition][range(*profile_cut)]
replace_condition = (current_dq & self.DO_NOT_USE ^ current_dq & self.NON_SCIENCE) == 1
replaced_current = np.where(
replace_condition, median_profile * norm_scale * scale, cleaned_current
)
# Change the dq bits where old flag was DO_NOT_USE and new value is not nan
replaced_dq = np.where(
replace_condition & ~(np.isnan(replaced_current)),
current_dq ^ self.DO_NOT_USE ^ self.FLUX_ESTIMATED,
current_dq,
)
# Update data and DQ in the output model
arrays.data[current_condition][range(*profile_cut)] = replaced_current
arrays.dq[current_condition][range(*profile_cut)] = replaced_dq
# Also update the errors and variances
current_err = arrays.err[current_condition][range(*profile_cut)]
replaced_err = np.where(
replace_condition, norm_errors["err"] * norm_scale * scale, current_err
)
arrays.err[current_condition][range(*profile_cut)] = replaced_err
# Some values in NIRSpec variances may overflow in the squares - ignore the warning.
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "overflow encountered", RuntimeWarning)
for var in ["var_poisson", "var_rnoise", "var_flat"]:
if (var_arr := getattr(arrays, var)) is not None:
current_var = var_arr[current_condition][range(*profile_cut)]
replaced_var = np.where(
replace_condition,
(norm_errors[var] * norm_scale * scale) ** 2,
current_var,
)
var_arr[current_condition][range(*profile_cut)] = replaced_var
setattr(arrays, var, var_arr)
return arrays
@staticmethod
def _interp_neighbors(arr, yindx, xindx):
"""
Interpolate using neighboring pixels in both horizontal and vertical directions.
Parameters
----------
arr : ndarray
2-D input array.
yindx, xindx : ndarray
1-D arrays, each length N, of row/column indices of the bad pixels.
Returns
-------
ndarray
Interpolations with shape of ``(2, N)`` in the horizontal (0th index)
and vertical (1st index) directions.
"""
horiz = (arr[yindx, xindx - 1] + arr[yindx, xindx + 1]) / 2.0
vert = (arr[yindx - 1, xindx] + arr[yindx + 1, xindx]) / 2.0
return np.array([horiz, vert])
def _mingrad_interp(self, data, arrays_to_correct, xindx, yindx):
"""
Use mingrad interpolation to interpolate arrays at selected indices.
Parameters
----------
data : ndarray
Flux array, used to get the minimum gradient.
arrays_to_correct : dict
Specific arrays to correct. Updated in place.
xindx : ndarray
X-index values for pixels to interpolate.
yindx : ndarray
Y-index values for pixels to interpolate.
"""
# Absolute gradient along each axis from indata, shape (2, N), used to choose direction
diffs = np.array(
[
np.abs(data[yindx, xindx - 1] - data[yindx, xindx + 1]),
np.abs(data[yindx - 1, xindx] - data[yindx + 1, xindx]),
]
)
# Replace NaN diffs with inf so argmin naturally prefers the valid direction.
# Mask is True where at least one valid direction, False elsewhere,
# such that pixels where both diffs are inf have no usable neighbor pair and are skipped.
diffs_with_infs = np.where(np.isnan(diffs), np.inf, diffs) # (2, N)
mask = ~np.all(np.isinf(diffs_with_infs), axis=0) # (N,)
# Per-pixel direction index: 0 = horizontal, 1 = vertical
indmin = np.argmin(diffs_with_infs, axis=0) # (N,)
col_idx = np.arange(len(yindx))
# Select the minimium-gradient interpolated values and update model with them
indmin = indmin[mask]
col_idx = col_idx[mask]
# Interpolate arrays to correct in both directions then select the minimum-gradient
# direction
for array_name, array_data in arrays_to_correct.items():
if array_data is None:
continue
if array_name == "dq":
# Update DQ flags for pixels that were replaced.
orig_dq = array_data[yindx, xindx] # (N,)
remove_dnu = (
mask
& (orig_dq & self.DO_NOT_USE).astype(bool)
& ~(orig_dq & self.NON_SCIENCE).astype(bool)
)
array_data[yindx[remove_dnu], xindx[remove_dnu]] -= self.DO_NOT_USE
array_data[yindx[mask], xindx[mask]] |= self.FLUX_ESTIMATED
else:
# Interpolate over data
if array_name.startswith("var"):
# Take the square root first to get the scales right
interp_data = self._interp_neighbors(np.sqrt(array_data), yindx, xindx) ** 2
else:
interp_data = self._interp_neighbors(array_data, yindx, xindx)
array_data[yindx[mask], xindx[mask]] = interp_data[indmin, col_idx]
[docs]
def mingrad(self, arrays):
"""
Replace pixels with the minimum gradient replacement method.
Test the gradient along the spatial and spectral axes using
immediately adjacent pixels. Pick whichever dimension has the minimum
absolute gradient and replace the missing pixel with the average
of the two adjacent pixels along that dimension.
This aims to make the process extremely local; near point sources it should do
the replacement along the spectral axis avoiding sampling issues, while near bright
extended emission line the replacement should be along the spatial axis. May still
be suboptimal near bright emission lines from unresolved point sources.
Does not attempt any replacement if a NaN value is bordered by another NaN value
along a given axis.
Parameters
----------
arrays : `PixelReplaceArrays`
Pixel arrays and dispersion direction for the 2D spectrum to process.
Arrays are modified in-place.
Returns
-------
arrays : `PixelReplaceArrays`
The input with flagged bad pixels now flagged with FLUX_ESTIMATED
and holding a flux value estimated from adjacent pixels.
"""
log.info("Using minimum gradient method.")
arrays_to_correct = {
"data": None,
"err": None,
"var_poisson": None,
"var_rnoise": None,
"var_flat": None,
"dq": None,
}
for key in arrays_to_correct.keys():
arrays_to_correct[key] = getattr(arrays, key)
# Make an array of x/y values on the detector
(ysize, xsize) = arrays.data.shape
basex, basey = np.meshgrid(np.arange(xsize), np.arange(ysize))
# Padding around edge of array to ensure we don't look for neighbors outside array
pad = 1
# Find NaN-valued pixels
indx = np.where(
(~np.isfinite(arrays.data))
& (basey > pad)
& (basey < ysize - pad)
& (basex > pad)
& (basex < xsize - pad)
)
# X and Y indices
yindx, xindx = indx[0], indx[1]
# Interpolate arrays in place
self._mingrad_interp(arrays.data, arrays_to_correct, xindx, yindx)
return arrays
[docs]
def custom_slice(self, dispaxis, index):
"""
Construct slice for ease of use with varying dispersion axis.
Parameters
----------
dispaxis : int
Using module-defined:
* 1 = HORIZONTAL
* 2 = VERTICAL
index : int or list
Index or indices of cross-dispersion
vectors to slice
Returns
-------
tuple
Slice constructed using numpy
"""
if dispaxis == self.HORIZONTAL:
return np.s_[:, index]
elif dispaxis == self.VERTICAL:
return np.s_[index, :]
else:
raise IndexError("Custom slice requires valid dispersion axis specification!")
[docs]
def profile_mse(self, scale, median, current):
"""
Calculate mean-squared error of fitted profile.
Parameters
----------
scale : float
Initial estimate of scale factor to bring
normalized median profile up to current profile
median : ndarray
Median profile constructed from neighboring profile slices
current : ndarray
Current profile with bad pixels to be replaced
Returns
-------
float
Mean-squared error for minimization purposes
"""
return np.nansum((current - (median * scale)) ** 2.0) / (
len(median) - np.count_nonzero(np.isnan(current))
)
def _interp_along_wavelength(self, data, dispersion_direction, xindx, yindx):
"""
Interpolate missing values along the wavelength dimension.
Parameters
----------
data : ndarray
The array with missing values. Updated in place.
dispersion_direction : int
Axis containing the dispersion coordinates.
xindx : ndarray
X-index for missing values.
yindx : ndarray
Y-index for missing values.
"""
if dispersion_direction == self.HORIZONTAL:
bad_xd = yindx
bad_wvlen = xindx
else:
bad_xd = xindx
bad_wvlen = yindx
shape = data.shape
yy, xx = np.mgrid[: shape[0], : shape[1]]
# make sure bad pixels are NaN to start so they are not included in interpolation
data[yindx, xindx] = np.nan
for xd in np.unique(bad_xd):
wl_to_fix = bad_wvlen[bad_xd == xd]
if dispersion_direction == self.HORIZONTAL:
xdisp_values = data[xd, :]
wl_values = xx[xd, :]
else:
xdisp_values = data[:, xd]
wl_values = yy[:, xd]
valid = np.isfinite(xdisp_values)
if not np.any(valid):
continue
interp_wl = np.interp(wl_to_fix, wl_values[valid], xdisp_values[valid])
if dispersion_direction == self.HORIZONTAL:
data[xd, wl_to_fix] = interp_wl
else:
data[wl_to_fix, xd] = interp_wl
[docs]
def trace_model(self, arrays):
"""
Replace bad pixels from the trace model if available.
Any remaining bad pixels not available from the trace model are replaced
with the ``mingrad`` algorithm.
Parameters
----------
arrays : `PixelReplaceArrays`
Pixel arrays and dispersion direction for the 2D spectrum to process.
Arrays are modified in-place.
Returns
-------
arrays : `PixelReplaceArrays`
The input with flagged bad pixels now flagged with FLUX_ESTIMATED
and holding a flux value estimated from adjacent pixels.
"""
trace_model = arrays.trace_model
if trace_model is None:
# No trace model to use: just call mingrad and return
log.info("No trace model to use")
return self.mingrad(arrays)
replaceable = ~np.isfinite(arrays.data) & np.isfinite(arrays.trace_model)
if not np.any(replaceable):
log.info("No replaceable pixels in the trace model")
return self.mingrad(arrays)
log.info(f"Replacing {np.sum(replaceable)} bad pixels from trace model")
arrays.data[replaceable] = arrays.trace_model[replaceable]
# Update DQ flags for pixels that were replaced.
remove_dnu = replaceable & (arrays.dq & self.DO_NOT_USE > 0)
arrays.dq[remove_dnu] ^= self.DO_NOT_USE
arrays.dq[replaceable] |= self.FLUX_ESTIMATED
# Interpolate along wavelengths to fill in error values for missing data
yindx, xindx = np.where(replaceable)
for error_ext in ["err", "var_poisson", "var_rnoise", "var_flat"]:
err = getattr(arrays, error_ext)
if err is None:
continue
# interpolate variance as error
if error_ext.startswith("var"):
err = np.sqrt(err)
# interpolate the array in place
self._interp_along_wavelength(err, arrays.dispersion_direction, xindx, yindx)
# re-square the variance
if error_ext.startswith("var"):
err = err**2
# store the updated array
setattr(arrays, error_ext, err)
# Replace any remaining bad pixels via mingrad
return self.mingrad(arrays)