Source code for jwst.pixel_replace.pixel_replace

import logging
import warnings
from dataclasses import dataclass

import numpy as np
from scipy.optimize import minimize
from stdatamodels.jwst import datamodels

log = logging.getLogger(__name__)

__all__ = ["PixelReplaceArrays", "PixelReplacement"]


[docs] @dataclass class PixelReplaceArrays: """ Container for data arrays and dispersion direction. Algorithms operate on this dataclass rather than on a `~stdatamodels.jwst.datamodels.JwstDataModel`. This avoids the overhead of constructing intermediate DataModel objects, which was slowing runtime for TSO data with thousands of integrations, and provides a consistent interface for all pixel replace algorithms. """ data: np.ndarray """Science array.""" dq: np.ndarray """Data quality array.""" err: np.ndarray """Total error array.""" var_poisson: np.ndarray | None """Poisson variance array.""" var_rnoise: np.ndarray | None """Read-noise variance array.""" var_flat: np.ndarray | None """Flat-field variance array.""" trace_model: np.ndarray | None """Trace model array.""" dispersion_direction: int """Dispersion direction."""
[docs] class PixelReplacement: """ Main class for performing pixel replacement. This class controls loading the input data model, selecting the method for pixel replacement, and executing each step. This class should provide modularization to allow for multiple options and possible future reference files. Parameters ---------- input_model : `~stdatamodels.jwst.datamodels.JwstDataModel` Datamodel with bad pixels to replace. Updated in-place. algorithm : str, optional Replacement algorithm. Options are "mingrad", "fit_profile", or "trace_model". n_adjacent_cols : int, optional The number of adjacent columns to consider in building a spatial profile. Used only if ``algorithm`` is "fit_profile". """ # Shortcuts for DQ Flags DO_NOT_USE = datamodels.dqflags.pixel["DO_NOT_USE"] FLUX_ESTIMATED = datamodels.dqflags.pixel["FLUX_ESTIMATED"] NON_SCIENCE = datamodels.dqflags.pixel["NON_SCIENCE"] # Shortcuts for dispersion direction for ease of reading HORIZONTAL = 1 VERTICAL = 2 LOG_SLICE = ["column", "row"] def __init__(self, input_model, algorithm="mingrad", n_adjacent_cols=3): self.input = input_model self.algorithm_name = algorithm self.n_adjacent_cols = n_adjacent_cols # Store algorithm options here. self.algorithm_dict = { "fit_profile": self.fit_profile, "mingrad": self.mingrad, "trace_model": self.trace_model, } # Choose algorithm from dict using input par. try: self.algorithm = self.algorithm_dict[self.algorithm_name] except KeyError as err: log.critical( f"Algorithm name '{self.algorithm_name}' provided does " "not match an implemented algorithm." ) raise KeyError from err @staticmethod def _arrays_from_model(model): """Extract PixelReplaceArrays from DataModel, copying arrays.""" # numpydoc ignore: RT01 # Ensure variance arrays exist var_dict = { "var_poisson": None, "var_rnoise": None, "var_flat": None, } for key in var_dict.keys(): if model[key] is not None: var_dict[key] = model[key].copy() return PixelReplaceArrays( data=model.data.copy(), dq=model.dq.copy(), err=model.err.copy(), var_poisson=var_dict["var_poisson"], var_rnoise=var_dict["var_rnoise"], var_flat=var_dict["var_flat"], trace_model=getattr(model, "trace_model", None), # info-only, no need to copy dispersion_direction=model.meta.wcsinfo.dispersion_direction, ) @staticmethod def _model_from_arrays(arrays, model): """Write PixelReplaceArrays back into a DataModel in place.""" # numpydoc ignore: RT01 model.data = arrays.data model.dq = arrays.dq model.err = arrays.err model.var_poisson = arrays.var_poisson model.var_rnoise = arrays.var_rnoise model.var_flat = arrays.var_flat def _is_estimated(self, data, dq): """ Make a mask of estimated pixels. Parameters ---------- data : ndarray Data array. dq : ndarray DQ array. Returns ------- ndarray Boolean array. True where pixels are flagged as FLUX_ESTIMATED and finite. """ # This could be a one-liner, but the bitwise operator precedence is finicky, # so this helper function exists to keep the logic in one place. is_flagged = (dq & self.FLUX_ESTIMATED) > 0 return np.isfinite(data) & is_flagged
[docs] def replace(self): """ Unpack model and apply pixel replacement algorithm. Process the input `~stdatamodels.jwst.datamodels.JwstDataModel`, unpack any model that holds more than one 2D spectrum, then apply selected algorithm to each 2D spectrum in input. """ # ImageModel inputs (MIR_LRS-FIXEDSLIT) # or 2D SlitModel inputs (e.g. NRS_FIXEDSLIT in spec3) if isinstance(self.input, datamodels.ImageModel) or ( isinstance(self.input, datamodels.SlitModel) and self.input.data.ndim == 2 ): # Count pixels previously estimated first previous_flag = self._is_estimated(self.input.data, self.input.dq) arrays = self._arrays_from_model(self.input) arrays = self.algorithm(arrays) self._model_from_arrays(arrays, self.input) # Count newly estimated pixels n_replaced = np.count_nonzero( self._is_estimated(self.input.data, self.input.dq) & ~previous_flag ) log.info(f"Input model had {n_replaced} pixels replaced.") elif isinstance(self.input, datamodels.IFUImageModel): previous_flag = self._is_estimated(self.input.data, self.input.dq) arrays = self._arrays_from_model(self.input) if self.algorithm_name == "mingrad" or self.algorithm_name == "trace_model": arrays = self.algorithm(arrays) self._model_from_arrays(arrays, self.input) elif self.algorithm_name == "fit_profile": # Attempt to run pixel replacement on each throw of the IFU slicer # individually. region_map = None if hasattr(self.input, "regions") and self.input.regions is not None: region_map = self.input.regions elif self.input.meta.exposure.type == "MIR_MRS": det2ab = self.input.meta.wcs.get_transform( self.input.meta.wcs.available_frames[0], "alpha_beta" ) region_map = det2ab.label_mapper.mapper.copy() if region_map is None: raise ValueError( "Cannot use algorithm='fit_profile' for IFU data with missing region map" ) region_numbers = np.unique(region_map[region_map > 0]) for slice_num in region_numbers: log.info(f"Replacing pixels for slice {slice_num}") # Define a mask that is True where this trace is located trace_mask = region_map == slice_num arrays = self._arrays_from_model(self.input) arrays.dq = np.where( # When not in this trace, set NON_SCIENCE and DO_NOT_USE ~trace_mask, arrays.dq | self.DO_NOT_USE | self.NON_SCIENCE, arrays.dq, ) arrays = self.algorithm(arrays) self.input.data = np.where(trace_mask, arrays.data, self.input.data) self.input.dq = np.where(trace_mask, arrays.dq, self.input.dq) self.input.err = np.where(trace_mask, arrays.err, self.input.err) for var in ["var_poisson", "var_rnoise", "var_flat"]: input_var = getattr(self.input, var) update_var = getattr(arrays, var) if input_var is not None and update_var is not None: input_var = np.where(trace_mask, update_var, input_var) setattr(self.input, var, input_var) n_replaced = np.count_nonzero( self._is_estimated(self.input.data, self.input.dq) & ~previous_flag ) log.info(f"Input IFU frame had {n_replaced} total pixels replaced.") # MultiSlitModel inputs (WFSS, NRS_FIXEDSLIT, ?) elif isinstance(self.input, datamodels.MultiSlitModel): for i, _slit in enumerate(self.input.slits): slit_model = datamodels.SlitModel(self.input.slits[i].instance) arrays = self._arrays_from_model(slit_model) slit_model.close() previous_flag = self._is_estimated(arrays.data, arrays.dq) arrays = self.algorithm(arrays) n_replaced = np.count_nonzero( self._is_estimated(arrays.data, arrays.dq) & ~previous_flag ) log.info(f"Slit {i} had {n_replaced} pixels replaced.") self._model_from_arrays(arrays, self.input.slits[i]) # CubeModel inputs are TSO (so far?); SlitModel may be NRS_BRIGHTOBJ, # also requiring a re-packaging of the data into 2D inputs for the algorithm elif isinstance(self.input, datamodels.CubeModel | datamodels.SlitModel): dispaxis = self.input.meta.wcsinfo.dispersion_direction for i in range(len(self.input.data)): # Ensure variance arrays exist var_dict = { "var_poisson": None, "var_rnoise": None, "var_flat": None, } for key in var_dict.keys(): if self.input[key] is not None: var_dict[key] = self.input[key][i].copy() arrays = PixelReplaceArrays( data=self.input.data[i].copy(), dq=self.input.dq[i].copy(), err=self.input.err[i].copy(), var_poisson=var_dict["var_poisson"], var_rnoise=var_dict["var_rnoise"], var_flat=var_dict["var_flat"], trace_model=getattr(self.input, "trace_model", None), dispersion_direction=dispaxis, ) previous_flag = self._is_estimated(arrays.data, arrays.dq) arrays = self.algorithm(arrays) n_replaced = np.count_nonzero( self._is_estimated(arrays.data, arrays.dq) & ~previous_flag ) log.info(f"Input TSO integration {i} had {n_replaced} pixels replaced.") self.input.data[i] = arrays.data self.input.dq[i] = arrays.dq self.input.err[i] = arrays.err for key in var_dict.keys(): if self.input[key] is not None: self.input[key][i] = getattr(arrays, key) else: # This should never happen, as these should be caught in the step code. log.critical(f"Input model {self.input} is not supported - skipping step.") return
[docs] def fit_profile(self, arrays): """ Replace pixels with the profile fit method. Fit a profile to adjacent columns, scale profile to column with missing pixel(s), and find flux estimate from scaled profile. Error and variance values for the replaced pixels are similarly estimated, using the scales from the profile fit to the data. Parameters ---------- arrays : `PixelReplaceArrays` Pixel arrays and dispersion direction for the 2D spectrum to process. Arrays are modified in place. Returns ------- arrays : `PixelReplaceArrays` The input with bad pixels now flagged with FLUX_ESTIMATED and holding a flux value estimated from the spatial profile. """ dispaxis = arrays.dispersion_direction # Make a copy of the input DQ, before replacement input_dq = arrays.dq.copy() # Truncate array to region where good pixels exist good_pixels = np.where(~input_dq & self.DO_NOT_USE) if np.any(0 in np.shape(good_pixels)): log.warning( "No good pixels in at least one dimension of " "data array - skipping pixel replacement." ) return arrays x_range = [np.min(good_pixels[0]), np.max(good_pixels[0]) + 1] y_range = [np.min(good_pixels[1]), np.max(good_pixels[1]) + 1] valid_shape = [x_range, y_range] profile_cut = valid_shape[dispaxis - 1] # COMMENTS NOTE: # In comments and parameter naming, I will try to be consistent in using # "profile" to describe vectors in the spatial, i.e. cross-dispersion direction, # and "slice" to describe vectors in the spectral, i.e. dispersion direction. # Create set of slice indices which we can later use for profile creation valid_profiles = set(range(*valid_shape[2 - dispaxis])) profiles_to_replace = set() # Loop over axis of data array corresponding to cross- # dispersion direction by indexing data shape with # strange dispaxis argument. Keep indices in full-frame numbering scheme, # but only iterate through slices with valid data. for ind in range(*valid_shape[2 - dispaxis]): # Exclude regions with no data for dq slice. dq_slice = input_dq[self.custom_slice(dispaxis, ind)][profile_cut[0] : profile_cut[1]] # Exclude regions with NON_SCIENCE flag dq_slice = np.where(dq_slice & self.NON_SCIENCE, self.NON_SCIENCE, dq_slice) # Find bad pixels in region containing valid data. n_bad = np.count_nonzero(dq_slice & self.DO_NOT_USE) n_nonscience = np.count_nonzero(dq_slice & self.NON_SCIENCE) if n_bad + n_nonscience == len(dq_slice): log.debug(f"Slice {ind} contains no good pixels. Skipping replacement.") valid_profiles.discard(ind) elif n_bad == 0: log.debug(f"Slice {ind} contains no bad pixels.") else: log.debug(f"Slice {ind} contains {n_bad} bad pixels.") profiles_to_replace.add(ind) log.debug(f"Number of profiles with at least one bad pixel: {len(profiles_to_replace)}") for ind in profiles_to_replace: # Use sets for convenient finding of neighboring slices to use in profile creation adjacent_inds = set(range(ind - self.n_adjacent_cols, ind + self.n_adjacent_cols + 1)) adjacent_inds.discard(ind) valid_adjacent_inds = list(adjacent_inds.intersection(valid_profiles)) # Cut out valid neighboring profiles adjacent_condition = self.custom_slice(dispaxis, valid_adjacent_inds) profile_data = arrays.data[adjacent_condition] profile_err = arrays.err[adjacent_condition] if profile_data.size == 0: log.info( f"Profile in {self.LOG_SLICE[dispaxis - 1]} {ind} " f"has no valid adjacent values - skipping." ) continue # Mask out bad pixels invalid_condition = (input_dq[adjacent_condition] & self.DO_NOT_USE).astype(bool) profile_data[invalid_condition] = np.nan profile_err[invalid_condition] = np.nan # Add additional cut to pull only from region with valid data # for convenience (may not be necessary) region_condition = self.custom_slice(3 - dispaxis, range(*profile_cut)) profile_data = profile_data[region_condition] profile_snr = np.abs(profile_data / profile_err[region_condition]) # Normalize profile data # TODO: check on signs here - absolute max sometimes picks up # large negative outliers with warnings.catch_warnings(): warnings.filterwarnings(action="ignore", message="All-NaN slice encountered") profile_norm_scale = np.nanmax( np.abs(profile_data), axis=(dispaxis - 1), keepdims=True ) # If profile data has SNR < 5 everywhere just use unity scaling # (so we don't normalize to noise) if np.nanmax(profile_snr) < 5: profile_norm_scale[:] = 1.0 normalized = profile_data / profile_norm_scale # Get corresponding error and variance data and scale and mask to match # Handle the variance arrays as errors, so the scales match. err_names = ["err", "var_poisson", "var_rnoise", "var_flat"] norm_errors = {} for err_name in err_names: if err_name.startswith("var"): if (err_arr := getattr(arrays, err_name)) is None: continue err = np.sqrt(err_arr) else: err = getattr(arrays, err_name) norm_err = err[adjacent_condition] norm_err[invalid_condition] = np.nan norm_errors[err_name] = norm_err[region_condition] / profile_norm_scale # Pull median for each pixel across profile. # Profile entry full of NaN values would produce a numpy # warning (despite well-defined behavior - return a NaN) # so we suppress that with warnings.catch_warnings(): warnings.filterwarnings(action="ignore", message="All-NaN slice encountered") median_profile = np.nanmedian(normalized, axis=(2 - dispaxis)) # Do the same for the errors for err_name in norm_errors: norm_errors[err_name] = np.nanmedian(norm_errors[err_name], axis=(2 - dispaxis)) # Clean current profile of values flagged as bad current_condition = self.custom_slice(dispaxis, ind) current_profile = arrays.data[current_condition] cleaned_current = np.where( input_dq[current_condition] & self.DO_NOT_USE, np.nan, current_profile )[range(*profile_cut)] replace_mask = np.where(~np.isnan(cleaned_current))[0] if len(replace_mask) == 0: log.info( f"Profile in {self.LOG_SLICE[dispaxis - 1]} {ind} " f"has no valid values - skipping." ) continue min_median = median_profile[replace_mask] min_current = cleaned_current[replace_mask] norm_current = min_current / np.max(min_current) # Scale median profile to current profile with bad pixel - minimize mse? # Only do this scaling if we didn't default to all-unity scaling above, # and require input values below 1e20 so that we don't overflow the # minimization routine with extremely bad noise. with warnings.catch_warnings(): warnings.filterwarnings(action="ignore", message="All-NaN slice encountered") if ( (np.nanmedian(profile_norm_scale) != 1.0) & (np.nanmax(np.abs(min_median)) < 1e20) & (np.nanmax(np.abs(norm_current)) < 1e20) ): # TODO: check on signs here - absolute max sometimes picks up # large negative outliers norm_scale = minimize( self.profile_mse, x0=np.abs(np.nanmax(norm_current)), args=(np.abs(min_median), np.abs(norm_current)), method="Nelder-Mead", ).x scale = np.max(min_current) else: norm_scale = 1.0 scale = 1.0 # Replace pixels that are do-not-use but not non-science current_dq = input_dq[current_condition][range(*profile_cut)] replace_condition = (current_dq & self.DO_NOT_USE ^ current_dq & self.NON_SCIENCE) == 1 replaced_current = np.where( replace_condition, median_profile * norm_scale * scale, cleaned_current ) # Change the dq bits where old flag was DO_NOT_USE and new value is not nan replaced_dq = np.where( replace_condition & ~(np.isnan(replaced_current)), current_dq ^ self.DO_NOT_USE ^ self.FLUX_ESTIMATED, current_dq, ) # Update data and DQ in the output model arrays.data[current_condition][range(*profile_cut)] = replaced_current arrays.dq[current_condition][range(*profile_cut)] = replaced_dq # Also update the errors and variances current_err = arrays.err[current_condition][range(*profile_cut)] replaced_err = np.where( replace_condition, norm_errors["err"] * norm_scale * scale, current_err ) arrays.err[current_condition][range(*profile_cut)] = replaced_err # Some values in NIRSpec variances may overflow in the squares - ignore the warning. with warnings.catch_warnings(): warnings.filterwarnings("ignore", "overflow encountered", RuntimeWarning) for var in ["var_poisson", "var_rnoise", "var_flat"]: if (var_arr := getattr(arrays, var)) is not None: current_var = var_arr[current_condition][range(*profile_cut)] replaced_var = np.where( replace_condition, (norm_errors[var] * norm_scale * scale) ** 2, current_var, ) var_arr[current_condition][range(*profile_cut)] = replaced_var setattr(arrays, var, var_arr) return arrays
@staticmethod def _interp_neighbors(arr, yindx, xindx): """ Interpolate using neighboring pixels in both horizontal and vertical directions. Parameters ---------- arr : ndarray 2-D input array. yindx, xindx : ndarray 1-D arrays, each length N, of row/column indices of the bad pixels. Returns ------- ndarray Interpolations with shape of ``(2, N)`` in the horizontal (0th index) and vertical (1st index) directions. """ horiz = (arr[yindx, xindx - 1] + arr[yindx, xindx + 1]) / 2.0 vert = (arr[yindx - 1, xindx] + arr[yindx + 1, xindx]) / 2.0 return np.array([horiz, vert]) def _mingrad_interp(self, data, arrays_to_correct, xindx, yindx): """ Use mingrad interpolation to interpolate arrays at selected indices. Parameters ---------- data : ndarray Flux array, used to get the minimum gradient. arrays_to_correct : dict Specific arrays to correct. Updated in place. xindx : ndarray X-index values for pixels to interpolate. yindx : ndarray Y-index values for pixels to interpolate. """ # Absolute gradient along each axis from indata, shape (2, N), used to choose direction diffs = np.array( [ np.abs(data[yindx, xindx - 1] - data[yindx, xindx + 1]), np.abs(data[yindx - 1, xindx] - data[yindx + 1, xindx]), ] ) # Replace NaN diffs with inf so argmin naturally prefers the valid direction. # Mask is True where at least one valid direction, False elsewhere, # such that pixels where both diffs are inf have no usable neighbor pair and are skipped. diffs_with_infs = np.where(np.isnan(diffs), np.inf, diffs) # (2, N) mask = ~np.all(np.isinf(diffs_with_infs), axis=0) # (N,) # Per-pixel direction index: 0 = horizontal, 1 = vertical indmin = np.argmin(diffs_with_infs, axis=0) # (N,) col_idx = np.arange(len(yindx)) # Select the minimium-gradient interpolated values and update model with them indmin = indmin[mask] col_idx = col_idx[mask] # Interpolate arrays to correct in both directions then select the minimum-gradient # direction for array_name, array_data in arrays_to_correct.items(): if array_data is None: continue if array_name == "dq": # Update DQ flags for pixels that were replaced. orig_dq = array_data[yindx, xindx] # (N,) remove_dnu = ( mask & (orig_dq & self.DO_NOT_USE).astype(bool) & ~(orig_dq & self.NON_SCIENCE).astype(bool) ) array_data[yindx[remove_dnu], xindx[remove_dnu]] -= self.DO_NOT_USE array_data[yindx[mask], xindx[mask]] |= self.FLUX_ESTIMATED else: # Interpolate over data if array_name.startswith("var"): # Take the square root first to get the scales right interp_data = self._interp_neighbors(np.sqrt(array_data), yindx, xindx) ** 2 else: interp_data = self._interp_neighbors(array_data, yindx, xindx) array_data[yindx[mask], xindx[mask]] = interp_data[indmin, col_idx]
[docs] def mingrad(self, arrays): """ Replace pixels with the minimum gradient replacement method. Test the gradient along the spatial and spectral axes using immediately adjacent pixels. Pick whichever dimension has the minimum absolute gradient and replace the missing pixel with the average of the two adjacent pixels along that dimension. This aims to make the process extremely local; near point sources it should do the replacement along the spectral axis avoiding sampling issues, while near bright extended emission line the replacement should be along the spatial axis. May still be suboptimal near bright emission lines from unresolved point sources. Does not attempt any replacement if a NaN value is bordered by another NaN value along a given axis. Parameters ---------- arrays : `PixelReplaceArrays` Pixel arrays and dispersion direction for the 2D spectrum to process. Arrays are modified in-place. Returns ------- arrays : `PixelReplaceArrays` The input with flagged bad pixels now flagged with FLUX_ESTIMATED and holding a flux value estimated from adjacent pixels. """ log.info("Using minimum gradient method.") arrays_to_correct = { "data": None, "err": None, "var_poisson": None, "var_rnoise": None, "var_flat": None, "dq": None, } for key in arrays_to_correct.keys(): arrays_to_correct[key] = getattr(arrays, key) # Make an array of x/y values on the detector (ysize, xsize) = arrays.data.shape basex, basey = np.meshgrid(np.arange(xsize), np.arange(ysize)) # Padding around edge of array to ensure we don't look for neighbors outside array pad = 1 # Find NaN-valued pixels indx = np.where( (~np.isfinite(arrays.data)) & (basey > pad) & (basey < ysize - pad) & (basex > pad) & (basex < xsize - pad) ) # X and Y indices yindx, xindx = indx[0], indx[1] # Interpolate arrays in place self._mingrad_interp(arrays.data, arrays_to_correct, xindx, yindx) return arrays
[docs] def custom_slice(self, dispaxis, index): """ Construct slice for ease of use with varying dispersion axis. Parameters ---------- dispaxis : int Using module-defined: * 1 = HORIZONTAL * 2 = VERTICAL index : int or list Index or indices of cross-dispersion vectors to slice Returns ------- tuple Slice constructed using numpy """ if dispaxis == self.HORIZONTAL: return np.s_[:, index] elif dispaxis == self.VERTICAL: return np.s_[index, :] else: raise IndexError("Custom slice requires valid dispersion axis specification!")
[docs] def profile_mse(self, scale, median, current): """ Calculate mean-squared error of fitted profile. Parameters ---------- scale : float Initial estimate of scale factor to bring normalized median profile up to current profile median : ndarray Median profile constructed from neighboring profile slices current : ndarray Current profile with bad pixels to be replaced Returns ------- float Mean-squared error for minimization purposes """ return np.nansum((current - (median * scale)) ** 2.0) / ( len(median) - np.count_nonzero(np.isnan(current)) )
def _interp_along_wavelength(self, data, dispersion_direction, xindx, yindx): """ Interpolate missing values along the wavelength dimension. Parameters ---------- data : ndarray The array with missing values. Updated in place. dispersion_direction : int Axis containing the dispersion coordinates. xindx : ndarray X-index for missing values. yindx : ndarray Y-index for missing values. """ if dispersion_direction == self.HORIZONTAL: bad_xd = yindx bad_wvlen = xindx else: bad_xd = xindx bad_wvlen = yindx shape = data.shape yy, xx = np.mgrid[: shape[0], : shape[1]] # make sure bad pixels are NaN to start so they are not included in interpolation data[yindx, xindx] = np.nan for xd in np.unique(bad_xd): wl_to_fix = bad_wvlen[bad_xd == xd] if dispersion_direction == self.HORIZONTAL: xdisp_values = data[xd, :] wl_values = xx[xd, :] else: xdisp_values = data[:, xd] wl_values = yy[:, xd] valid = np.isfinite(xdisp_values) if not np.any(valid): continue interp_wl = np.interp(wl_to_fix, wl_values[valid], xdisp_values[valid]) if dispersion_direction == self.HORIZONTAL: data[xd, wl_to_fix] = interp_wl else: data[wl_to_fix, xd] = interp_wl
[docs] def trace_model(self, arrays): """ Replace bad pixels from the trace model if available. Any remaining bad pixels not available from the trace model are replaced with the ``mingrad`` algorithm. Parameters ---------- arrays : `PixelReplaceArrays` Pixel arrays and dispersion direction for the 2D spectrum to process. Arrays are modified in-place. Returns ------- arrays : `PixelReplaceArrays` The input with flagged bad pixels now flagged with FLUX_ESTIMATED and holding a flux value estimated from adjacent pixels. """ trace_model = arrays.trace_model if trace_model is None: # No trace model to use: just call mingrad and return log.info("No trace model to use") return self.mingrad(arrays) replaceable = ~np.isfinite(arrays.data) & np.isfinite(arrays.trace_model) if not np.any(replaceable): log.info("No replaceable pixels in the trace model") return self.mingrad(arrays) log.info(f"Replacing {np.sum(replaceable)} bad pixels from trace model") arrays.data[replaceable] = arrays.trace_model[replaceable] # Update DQ flags for pixels that were replaced. remove_dnu = replaceable & (arrays.dq & self.DO_NOT_USE > 0) arrays.dq[remove_dnu] ^= self.DO_NOT_USE arrays.dq[replaceable] |= self.FLUX_ESTIMATED # Interpolate along wavelengths to fill in error values for missing data yindx, xindx = np.where(replaceable) for error_ext in ["err", "var_poisson", "var_rnoise", "var_flat"]: err = getattr(arrays, error_ext) if err is None: continue # interpolate variance as error if error_ext.startswith("var"): err = np.sqrt(err) # interpolate the array in place self._interp_along_wavelength(err, arrays.dispersion_direction, xindx, yindx) # re-square the variance if error_ext.startswith("var"): err = err**2 # store the updated array setattr(arrays, error_ext, err) # Replace any remaining bad pixels via mingrad return self.mingrad(arrays)