Source code for jwst.refpix.irs2_subtract_reference

import logging

import numpy as np
from astropy.stats import sigma_clipped_stats
from scipy.ndimage import convolve1d
from stdatamodels.jwst.datamodels import dqflags

log = logging.getLogger(__name__)

__all__ = [
    "correct_model",
    "float_to_complex",
    "make_irs2_mask",
    "strip_ref_pixels",
    "clobber_ref",
    "decode_mask",
    "replace_refpix",
    "flag_bad_refpix",
    "subtract_reference",
    "fft_interp_norm",
    "ols_line",
    "remove_slopes",
    "replace_bad_pixels",
    "fill_bad_regions",
]


[docs] def correct_model( output_model, irs2_model, scipix_n_default=16, refpix_r_default=4, pad=8, preserve_refpix=False ): """ Correct an input NIRSpec IRS2 datamodel using reference pixels. Parameters ---------- output_model : `~stdatamodels.jwst.datamodels.RampModel` The input science data model. irs2_model : `~stdatamodels.jwst.datamodels.IRS2Model` The reference file model for IRS2 correction. scipix_n_default : int Number of regular samples before stepping out to collect reference samples. refpix_r_default : int Number of reference samples before stepping back in to collect regular samples. pad : int The effective number of pixels sampled during the pause at the end of each row (new-row overhead). The padding is needed to preserve the phase of temporally periodic signals. preserve_refpix : bool If `True`, reference pixels will be preserved in the output. This is not used in the science pipeline, but is necessary to create new bias files for IRS2 mode. Returns ------- output_model : `~stdatamodels.jwst.datamodels.RampModel` The science data with reference output and reference pixels subtracted. """ # # Readout parameters # scipix_n 16 Number of regular samples before stepping out # to collect reference samples # refpix_r 4 Number of reference samples before stepping back # in to collect regular samples # NFOH 1 row New frame overhead (714?) # NROH 8 pixels New row overhead (`pad`) # JOH 1 pixel Jump overhead for stepping out to or in from # reference pixels # TPIX 10 microseconds Pixel dwell time # tframe 14.5889 s Frame readout time # The image and reference data will be rearranged into a 1-D array # containing the values in time order, i.e. the element number * TPIX # is the relative time at which a pixel was read out. This array will # have elements corresponding to the gaps (overheads) when no pixel was # being read. This 1-D array has length 1,458,176, which is equal to # 712 * 2048: # ((scipix_n + refpix_r + 2) * (512 // scipix_n) + NROH) * 2048 # The total frame readout time is: # (((scipix_n + refpix_r + 2) * (512 // scipix_n) + NROH) * 2048 + NFOH) # * TPIX # This agrees with the above value of tframe (14.5889 s) if NFOH = 714. # Get SCI and PIXELDQ arrays for now; that's all we need. data = output_model.data pixeldq = output_model.pixeldq # Load the reference file data. # The reference file data are complex, but they're stored as float, with # alternating real and imaginary parts. We therefore check for twice # as many rows as we actually want, and we'll divide that number by two # when allocating the arrays alpha and beta. nrows = len(irs2_model.irs2_table.field("alpha_0")) expected_nrows = 2 * 712 * 2048 if nrows != expected_nrows: log.error(f"Number of rows in reference file = {nrows}, but it should be {expected_nrows}.") output_model.meta.cal_step.refpix = "SKIPPED" return output_model alpha = np.ones((4, nrows // 2), dtype=np.complex64) beta = np.zeros((4, nrows // 2), dtype=np.complex64) alpha[0, :] = float_to_complex(irs2_model.irs2_table.field("alpha_0")) alpha[1, :] = float_to_complex(irs2_model.irs2_table.field("alpha_1")) alpha[2, :] = float_to_complex(irs2_model.irs2_table.field("alpha_2")) alpha[3, :] = float_to_complex(irs2_model.irs2_table.field("alpha_3")) beta[0, :] = float_to_complex(irs2_model.irs2_table.field("beta_0")) beta[1, :] = float_to_complex(irs2_model.irs2_table.field("beta_1")) beta[2, :] = float_to_complex(irs2_model.irs2_table.field("beta_2")) beta[3, :] = float_to_complex(irs2_model.irs2_table.field("beta_3")) scipix_n = output_model.meta.exposure.nrs_normal if scipix_n is None: log.warning(f"Keyword NRS_NORM not found; using default value {scipix_n_default}") scipix_n = scipix_n_default refpix_r = output_model.meta.exposure.nrs_reference if refpix_r is None: log.warning(f"Keyword NRS_REF not found; using default value {refpix_r_default}") refpix_r = refpix_r_default # Convert from sky (DMS) orientation to detector orientation. detector = output_model.meta.instrument.detector if detector == "NRS1": data = np.swapaxes(data, 2, 3) pixeldq = np.swapaxes(pixeldq, 0, 1) elif detector == "NRS2": data = np.swapaxes(data, 2, 3)[:, :, ::-1, ::-1] pixeldq = np.swapaxes(pixeldq, 0, 1)[::-1, ::-1] else: log.warning(f"Detector {detector}; not changing orientation (sky vs detector)") n_int = data.shape[0] # number of integrations in file ny = data.shape[-2] # 2048 nx = data.shape[-1] # 3200 # Create a mask that indicates the locations of normal vs interspersed # reference pixels. True flags normal pixels, False is reference pixels. irs2_mask = make_irs2_mask(nx, ny, scipix_n, refpix_r) # Get bad ref pixel flags from the pixeldq, collapsed along rows ref_flags = pixeldq & dqflags.pixel["BAD_REF_PIXEL"] ref_flags = np.any(ref_flags, axis=0) # If the IRS2 reference file includes data quality info, use that to # set bad reference pixel values to zero. if hasattr(irs2_model, "dq_table") and len(irs2_model.dq_table) > 0: output = irs2_model.dq_table.field("output") odd_even = irs2_model.dq_table.field("odd_even") mask = irs2_model.dq_table.field("mask") # Set interleaved reference pixel values to zero if they are flagged # as bad in the DQ extension of the CRDS reference file and not yet handled is_irs2 = ~irs2_mask.copy() # treat the refout like the other sections amplifier = nx // 5 is_irs2[:amplifier] = is_irs2[2 * amplifier : 3 * amplifier] clobber_ref( data, output, odd_even, mask, ref_flags, is_irs2, scipix_n=scipix_n, refpix_r=refpix_r ) else: log.warning("DQ extension not found in reference file") # Compute and apply the correction to one integration at a time for integ in range(n_int): log.info(f"Working on integration {integ + 1} out of {n_int}") # The input data have a length of 3200 for the last axis (X), while # the output data have an X axis with length 2048, the same as the # Y axis. This is the reason for the slice `nx-ny:` that is used # below. The last axis of output_model.data should be 2048. data0 = data[integ, :, :, :] data0 = subtract_reference( data0, alpha, beta, irs2_mask, scipix_n, refpix_r, pad, preserve_refpix=preserve_refpix ) if not preserve_refpix: data[integ, :, :, nx - ny :] = data0 else: data[integ, :, :, :] = data0 # Convert corrected data back to sky orientation if not preserve_refpix: temp_data = data[:, :, :, nx - ny :] else: temp_data = data if detector == "NRS1": output_model.data = np.swapaxes(temp_data, 2, 3) elif detector == "NRS2": output_model.data = np.swapaxes(temp_data[:, :, ::-1, ::-1], 2, 3) else: # don't change orientation output_model.data = temp_data # Strip interleaved ref pixels from the PIXELDQ and GROUPDQ extensions. if not preserve_refpix: strip_ref_pixels(output_model, irs2_mask) return output_model
[docs] def float_to_complex(data): """ Convert real and imaginary parts to complex. Parameters ---------- data : ndarray Data array with interleaved real and imaginary parts Returns ------- data : ndarray Complex array made from real and imaginary parts """ nelem = len(data) return data[0:-1:2] + 1j * data[1:nelem:2]
[docs] def make_irs2_mask(nx, ny, scipix_n, refpix_r): """ Make IRS2 mask. Parameters ---------- nx : int Number of columns in input data ny : int Number of rows in input data scipix_n : int Number of regular samples before stepping out to collect reference samples refpix_r : int Number of reference samples before stepping back in to collect regular samples Returns ------- irs2_mask : ndarray The IRS2 mask """ # Number of (scipix_n + refpix_r) per output, assuming four amplifier # outputs and one reference output. irs2_nx = max((ny, nx)) # Length of the reference output section. refout = irs2_nx // 5 part = refout - (scipix_n // 2 + refpix_r) k = part // (scipix_n + refpix_r) # `part` consists of k * (scipix_n + refpix_r) + stuff_at_end stuff_at_end = part - k * (scipix_n + refpix_r) # Create the mask that flags normal pixels as True. irs2_mask = np.ones(irs2_nx, dtype=bool) irs2_mask[0:refout] = False # Check whether the interspersed reference pixels are in the same # locations regardless of readout direction. if stuff_at_end == scipix_n // 2: # Yes, they are in the same locations. for i in range(refout + scipix_n // 2, irs2_nx + 1, scipix_n + refpix_r): irs2_mask[i : i + refpix_r] = False else: # Set the flags for each readout direction separately. nelem = refout # number of elements per output temp = np.ones(nelem, dtype=bool) for i in range(scipix_n // 2, nelem + 1, scipix_n + refpix_r): temp[i : i + refpix_r] = False j = refout irs2_mask[j : j + nelem] = temp.copy() j += nelem irs2_mask[j : j + nelem] = temp[::-1].copy() j += nelem irs2_mask[j : j + nelem] = temp.copy() j += nelem irs2_mask[j : j + nelem] = temp[::-1].copy() return irs2_mask
[docs] def strip_ref_pixels(output_model, irs2_mask): """ Copy out the normal pixels from PIXELDQ and GROUPDQ arrays. Parameters ---------- output_model : `~stdatamodels.jwst.datamodels.RampModel` The output science data model, to be modified in-place irs2_mask : ndarray of bool 1D array of length 3200. `True` means the element corresponds to a normal pixel in the raw, IRS2-format data. `False` corresponds either to a reference output pixel or to one of the interspersed reference pixel values. """ detector = output_model.meta.instrument.detector if detector == "NRS1": # Select rows. temp_array = output_model.pixeldq output_model.pixeldq = temp_array[..., irs2_mask, :] temp_array = output_model.groupdq output_model.groupdq = temp_array[..., irs2_mask, :] elif detector == "NRS2": # Reverse the direction of the mask, and select rows. temp_mask = irs2_mask[::-1] temp_array = output_model.pixeldq output_model.pixeldq = temp_array[..., temp_mask, :] temp_array = output_model.groupdq output_model.groupdq = temp_array[..., temp_mask, :] else: # Select columns. temp_array = output_model.pixeldq output_model.pixeldq = temp_array[..., irs2_mask] temp_array = output_model.groupdq output_model.groupdq = temp_array[..., irs2_mask]
[docs] def clobber_ref(data, output, odd_even, mask, ref_flags, is_irs2, scipix_n=16, refpix_r=4): """ Set some interleaved reference pixel values to zero. This is an explanation of the arithmetic for computing ``ref`` in the loop over the list of bit numbers that is returned by :func:`decode_mask`. Reads of reference pixels are interleaved with reads of science data. The pattern of science pixels (S) and reference pixels (r) looks like this:: SSSSSSSSrrrrSSSSSSSSSSSSSSSSrrrrSSSSSSSSSSSSSSSSrrrr ... rrrrSSSSSSSS Within each amplifier output, a row starts and ends with 8 (scipix_n / 2) science pixels, and the row contains 32 blocks of 4 reference pixels. There are 20 (scipix_n + refpix_r) pixels from the start of one block of reference pixels to the start of the next. ``k`` is an integer between 0 and 31, inclusive, an index to identify the block of reference pixels that we need to modify (we'll set two of the pixels to zero). ``odd_even`` is either 1 or 2, indicating that we should set either the first or the second pair of reference pixels to 0. The same set of interleaved reference pixels will be set to 0 regardless of integration number, group number, or image line number. Parameters ---------- data : ndarray The 4-D data array in detector orientation. This includes both the science and interleaved reference pixel values. ``data`` will be modified in-place to set some of the reference pixel values to zero. The science data values will not be modified. output : ndarray of int16 A 1-D array of amplifier output numbers, 1, 2, 3, or 4, read from the OUTPUT column in the DQ extension of the CRDS reference file. odd_even : ndarray of int16 A 1-D array of integer values, which may be either 1 or 2, read from the ODD_EVEN column in the DQ extension of the CRDS reference file. mask : ndarray of uint32 The MASK column (1-D) read from the CRDS reference file. ref_flags : ndarray of bool Bad reference pixel flags (1-D) , matching the data row size in detector orientation. `True` indicates a bad reference pixel. is_irs2 : ndarray of bool A 1-D array matching the data row size in detector orientation. `True` indicates an interleaved reference pixel. scipix_n : int, optional Number of regular (science) samples before stepping out to collect reference samples. refpix_r : int, optional Number of reference samples before stepping back in to collect regular samples. """ nx = data.shape[-1] # 3200 nrows = len(output) for row in range(nrows): # `offset` is the offset in pixels from the beginning of the row # to the start of the current amp output. `offset` starts with # 640 in order to skip over the reference output. offset = output[row] * (nx // 5) # nx // 5 is 640 # The readout direction alternates from one amp output to the next. if output[row] // 2 * 2 == output[row]: odd_even_row = 3 - odd_even[row] # 1 --> 2; 2 --> 1 else: odd_even_row = odd_even[row] bits = decode_mask(mask[row]) log.debug( f"output {output[row]} odd_even {odd_even[row]} mask {mask[row]} DQ bits {bits}" ) new_bad_pix = [] for k in bits: ref = offset + scipix_n // 2 + k * (scipix_n + refpix_r) + 2 * (odd_even_row - 1) log.debug(f"bad interleaved reference at pixels {ref} {ref + 1}") # track new bad pixel if not already handled for bad_pix in (ref, ref + 1): if not ref_flags[bad_pix]: new_bad_pix.append(bad_pix) ref_flags[bad_pix] = True # replace new bad pixels for bad_pix in new_bad_pix: replace_refpix( bad_pix, data, ref_flags, is_irs2, offset, offset + nx // 5, scipix_n, refpix_r, axis=-1, )
[docs] def decode_mask(mask): """ Interpret the MASK column of the DQ table. As per the ESA CDP3 document: "There is also a DQ extension that holds a binary table with three columns (OUTPUT, ODD_EVEN, and MASK) and eight rows. In the current IRS2 implementation, one jumps 32 times to odd and 32 times to even reference pixels, which are then read twice consecutively. Therefore, the masks are 32 bit unsigned integers that encode bad interleaved reference pixels/columns from left to right (increasing column index) in the native detector frame. When a bit is set, the corresponding reference data should not be used for the correction." Parameters ---------- mask : uint32 A mask value. Returns ------- bits : list A list of the indices of bits set in the ``mask`` value. """ # The bit number corresponds to a count of groups of reads of the # interleaved reference pixels. The 32-bit unsigned integer encoding # has increasing index, from left to right. flags = np.array([2**n for n in range(32)], dtype=np.uint32) temp = np.bitwise_and(flags, mask) bits = np.where(temp > 0)[0] bits = list(bits) bits.sort() return bits
[docs] def replace_refpix( bad_pix, data, bad_mask, is_irs2, low_limit, high_limit, scipix_n, refpix_r, axis=-2 ): """ Replace a bad reference pixel with its nearest neighboring value. The nearest reference group above and below the bad pixel are checked for good values in pixels with the same parity as the bad pixel. If both are good, they are averaged to determine the replacement value. If only one is good, it is directly used. If neither are good, but a neighboring readout with opposite parity is good, that value is used. If none of these options are available, the value is set to 0.0 and will be interpolated over during the IRS2 correction. The data array is modified in place. Parameters ---------- bad_pix : int The bad pixel index to replace. data : ndarray The 4-D data array containing reference and science pixels. If in science orientation, ``axis`` should be -2. If in detector orientation, ``axis`` should be set to -1. bad_mask : ndarray A 1-D boolean mask, where `True` indicates a bad reference value. Should match the shape of the data along ``axis``. is_irs2 : ndarray A 1-D boolean mask, where `True` indicates an interleaved reference pixel. Should match the shape of the data along ``axis``. low_limit : int The lower limit of the data indices along ``axis`` to check for replacement values, usually set to the bottom of the amplifier. high_limit : int The upper limit of the data indices along ``axis`` to check for replacement values, usually set to the bottom of the amplifier. scipix_n : int Number of regular (science) samples before stepping out to collect reference samples. refpix_r : int Number of reference samples before stepping back in to collect regular samples. axis : int, optional Indicates the axis containing the reference pixel values. Set to -2 for science orientation, -1 for detector orientation. """ # nearest reference pixel group, respecting parity ref_period = scipix_n + refpix_r nearest_low = bad_pix - ref_period nearest_high = bad_pix + ref_period # check for neighboring good ref pixels # to use as a fallback value neighbor = None lower_neighbor = bad_pix - 2 upper_neighbor = bad_pix + 2 if lower_neighbor >= low_limit and is_irs2[lower_neighbor] and ~bad_mask[lower_neighbor]: neighbor = lower_neighbor elif upper_neighbor < high_limit and is_irs2[upper_neighbor] and ~bad_mask[upper_neighbor]: neighbor = upper_neighbor if neighbor is not None: if axis == -1: replace_value = data[:, :, :, neighbor] else: replace_value = data[:, :, neighbor, :] else: # last resort: set to zero and allow cosine interpolation replace_value = 0.0 # try to average upper and lower v1, v2 = None, None if nearest_low >= low_limit and ~bad_mask[nearest_low]: if axis == -1: v1 = data[:, :, :, nearest_low] else: v1 = data[:, :, nearest_low, :] if nearest_high < high_limit and ~bad_mask[nearest_high]: if axis == -1: v2 = data[:, :, :, nearest_high] else: v2 = data[:, :, nearest_high, :] if v1 is not None and v2 is not None: log.debug( f" Pixel {bad_pix} replaced with value averaged from {nearest_low},{nearest_high}" ) replace_value = np.mean([v1, v2], axis=0) elif v1 is not None: log.debug(f" Pixel {bad_pix} replaced with value at {nearest_low}") replace_value = v1 elif v2 is not None: log.debug(f" Pixel {bad_pix} replaced with value at {nearest_high}") replace_value = v2 elif neighbor is not None: log.debug(f" Pixel {bad_pix} replaced with value at neighbor {neighbor}") else: log.debug(f" Pixel {bad_pix} replaced with 0.0") if axis == -1: data[:, :, :, bad_pix] = replace_value else: data[:, :, bad_pix, :] = replace_value
[docs] def flag_bad_refpix(datamodel, n_sigma=3.0, flag_only=False, replace_only=False): """ Flag bad reference pixels and replace with nearest good values. Parameters ---------- datamodel : `~stdatamodels.jwst.datamodels.JwstDataModel` The data, in science orientation. This includes both the science and interleaved reference pixel values. Data and pixeldq will be modified in-place. The science data values will not be modified. n_sigma : float, optional Flagging threshold, expressed as a factor times the standard deviation. flag_only : bool, optional If set, bad values will be flagged in the pixeldq extension but not replaced. replace_only : bool, optional If set, previously flagged bad values will be replaced, but new outliers will not be flagged. """ data = datamodel.data pixeldq = datamodel.pixeldq scipix_n = datamodel.meta.exposure.nrs_normal refpix_r = datamodel.meta.exposure.nrs_reference log.debug(f"Using flagging threshold n_sigma = {n_sigma}") # bad pixels will be replaced for all integrations and all groups nints, ngroups, ny, nx = np.shape(data) # initialize the mask with any previously marked bad pixels ref_flags = pixeldq & dqflags.pixel["BAD_REF_PIXEL"] mask_bad = np.any(ref_flags, axis=1) is_irs2 = np.full(ny, False) # calculate differences of readout pairs per amplifier amplifier = ny // 5 # 640 ref_period = scipix_n + refpix_r initial_mask = mask_bad.copy() for k in range(5): offset = int(k * amplifier) # get statistics for each integration individually, but # apply flags to all integrations for j in range(nints): ref_pix, rp_diffs, rp_means, rp_stds = [], [], [], [] int_bad = initial_mask.copy() # jump from the start of the reference pixel sequence to the next # starting pixel is from 8 to 640 by 20 for rpstart in range(scipix_n // 2, amplifier, ref_period): # amplifier offset rpstart += offset # go through the reference pixels by pairs for ri in range(0, refpix_r, 2): ri = rpstart + ri rp_d = np.mean(np.abs(data[j, :, ri + 1, :] - data[j, :, ri, :])) rp_m = np.mean(data[j, :, ri : ri + 2, :]) rp_s = np.std(data[j, :, ri : ri + 2, :]) is_irs2[ri : ri + 2] = True # exclude ref pix already flagged good = ~np.any(int_bad[ri : ri + 2]) if good and not replace_only: ref_pix.append(ri) rp_means.append(rp_m) rp_stds.append(rp_s) rp_diffs.append(rp_d) if not replace_only: ref_pix = np.array(ref_pix, dtype=int) rp_diffs = np.array(rp_diffs) rp_means = np.array(rp_means) rp_stds = np.array(rp_stds) pair_pixel = ref_pix + 1 # clipped stats for all tests mean_of_diffs, _, std_of_diffs = sigma_clipped_stats(rp_diffs, sigma=n_sigma) mean_of_means, _, std_of_means = sigma_clipped_stats(rp_means, sigma=n_sigma) mean_of_stds, _, std_of_stds = sigma_clipped_stats(rp_stds, sigma=n_sigma) # find the additional intermittent bad pixels, marking both readouts high_diffs = (rp_diffs - mean_of_diffs) > (n_sigma * std_of_diffs) high_means = (rp_means - mean_of_means) > (n_sigma * std_of_means) high_stds = (rp_stds - mean_of_stds) > (n_sigma * std_of_stds) log.debug( f"High diffs={np.sum(high_diffs)}, " f"high means={np.sum(high_means)}, " f"high stds={np.sum(high_stds)}" ) int_bad[ref_pix[high_diffs]] = True int_bad[pair_pixel[high_diffs]] = True int_bad[ref_pix[high_means]] = True int_bad[pair_pixel[high_means]] = True int_bad[ref_pix[high_stds]] = True int_bad[pair_pixel[high_stds]] = True log.debug( f"{np.sum(int_bad[offset : offset + amplifier])} " f"suspicious bad reference pixels in " f"amplifier {k}, integration {j}" ) mask_bad |= int_bad # replace any flagged pixels if desired if not flag_only: # list of all bad pixels all_bad = np.arange(offset, offset + amplifier)[mask_bad[offset : offset + amplifier]] for bad_pix in all_bad: replace_refpix( bad_pix, data, mask_bad, is_irs2, offset, offset + amplifier, scipix_n, refpix_r ) if flag_only: log.info(f"Total bad reference pixels flagged: {np.sum(mask_bad)}") else: log.info(f"Total bad reference pixels replaced: {np.sum(mask_bad)}") if pixeldq is not None: pixeldq[mask_bad] |= dqflags.pixel["BAD_REF_PIXEL"] | dqflags.pixel["DO_NOT_USE"]
[docs] def subtract_reference( data0, alpha, beta, irs2_mask, scipix_n, refpix_r, pad, preserve_refpix=False ): """ Subtract reference output and pixels for the current integration. Parameters ---------- data0 : ndarray The science data for the current integration. The shape is expected to be (ngroups, ny, 3200), where ngroups is the number of groups, and ny is the pixel height of the image. The width 3200 of the image includes the "normal" pixel data, plus the embedded reference pixels, and the reference output. alpha : ndarray This is a 2-D array of values read from the reference file. The first axis is the sector number (but only for the normal pixel data and reference pixels, not the reference output). The second axis has length 2048 * 712, corresponding to the time-ordered arrangement of the data. For each sector, the correction is applied as follows:: data * alpha[i] + reference_output * beta[i] beta : ndarray Data read from the reference file. See ``alpha`` for details. irs2_mask : ndarray A 1-D boolean array, where `True` means the element corresponds to a normal pixel in the raw, IRS2-format data; `False` corresponds either to a reference output pixel or to one of the interspersed reference pixel values. scipix_n : int Number of regular samples before stepping out to collect reference samples. refpix_r : int Number of reference samples before stepping back in to collect regular samples. pad : int The effective number of pixels sampled during the pause at the end of each row (new-row overhead). preserve_refpix : bool If `True`, reference pixels will be preserved in the output. This is not used in the science pipeline, but is necessary to create new bias files for IRS2 mode. Returns ------- data0 : ndarray The science data for the current integration, with reference output and embedded reference pixels subtracted and also removed, leaving only the normal pixel data (including the reference pixels on each edge). The shape is expected to be (ngroups, ny, nx), where ``nx = ny = 2048``. """ shape = data0.shape ngroups = shape[0] ny = shape[1] nx = shape[2] # See expression in equation 1 in IRS2_Handoff.pdf. # row = 712, if scipix_n = 16, refpix_r = 4, pad = 8. row = (scipix_n + refpix_r + 2) * 512 // scipix_n + pad # s = size(data0) # If data0 is the data for one integration, then: # s[0] would be 3 # s[1] = shape[2] = nx, the length of the X axis # s[2] = shape[1] = ny, the length of the Y axis # s[3] = shape[0] = ngroups, the number of groups (or frames) ind_n = np.arange(512, dtype=np.intp) ind_ref = np.arange(512 // scipix_n * refpix_r, dtype=np.intp) # hnorm is an array of column indices of normal pixels. # len(hnorm) = 512; len(href) = 128 # len(hnorm1) = 512; len(href1) = 128 hnorm = ind_n + refpix_r * ((ind_n + scipix_n // 2) // scipix_n) # href is an array of column indices of reference pixels. href = ind_ref + scipix_n * (ind_ref // refpix_r) + scipix_n // 2 hnorm1 = ind_n + (refpix_r + 2) * ((ind_n + scipix_n // 2) // scipix_n) href1 = ind_ref + (scipix_n + 2) * (ind_ref // refpix_r) + scipix_n // 2 + 1 unpad = np.sort(np.hstack([hnorm1, href1])) # Subtract the average over the ramp for each pixel. # b_offset is saved so that it can be added back in at the end. b_offset = data0.sum(axis=0, dtype=np.float64) / float(ngroups) data0 -= b_offset # IDL: data0 = reform(data0, s[1]/5, 5, s[2], s[3], /over) # nx/5, 5, ny, ngroups (IDL) data0 = data0.reshape((ngroups, ny, 5, nx // 5)) # current order: nx/5, 5, ny, ngroups (IDL) # current order: ngroups, ny, 5, nx/5 (numpy) # 0 1 2 3 current numpy indices # transpose to: nx/5, ny, ngroups, 5 (IDL) # transpose to: 5, ngroups, ny, nx/5 (numpy) # 2 0 1 3 transpose order for numpy # Therefore: 0 1 2 3 --> 2 0 1 3 transpose order for numpy # Here is another way to look at it: # IDL: 0 1 2 3 --> 0 2 3 1 # 3 2 1 0 1 3 2 0 (IDL indices, but reversed to numpy order) # numpy: 0 1 2 3 --> 2 0 1 3 # IDL: data0 = transpose(data0, [0,2,3,1]) data0 = np.transpose(data0, (2, 0, 1, 3)) # Flip the direction of the X axis for every other output, so the readout # direction in data0 will be the same for every output. data0[0, :, :, :] = data0[0, :, :, ::-1] data0[2, :, :, :] = data0[2, :, :, ::-1] data0[4, :, :, :] = data0[4, :, :, ::-1] # convert to time sequences of normal pixels and reference pixels. # IDL: d0 = fltarr(s[1] / 5 + pad + 2 * (512 / scipix_n), s[2], s[3], 5) # Note: nx // 5 + pad + 2 * (512 // scipix_n) = 640 + 64 + 8 = 712. # hnorm1[-1] = 703, and hnorm[-1] = 639, so 703 - 639 = 64. # 8 is the pad value. d0 = np.zeros((5, ngroups, ny, row), dtype=np.float32) # (5, ngroups, 2048, 712) # IDL: d0[hnorm1,*,*,*] = data0[hnorm,*,*,*] # IDL: d0[href1,*,*,*] = data0[href,*,*,*] # IDL: data0 = temporary(d0) d0[:, :, :, hnorm1] = data0[:, :, :, hnorm] d0[:, :, :, href1] = data0[:, :, :, href] del data0 data0 = d0.copy() del d0 # Fitting and removal of slopes per frame to remove issues at frame boundaries remove_slopes(data0, ngroups, ny, row) # Use cosine weighted interpolation to replace 0.0 values and bad # pixels and gaps. (initial guess) replace_bad_pixels(data0, ngroups, ny, row) # Fill in bad pixels, gaps, and reference data locations in the normal # data, using Fourier filtering/interpolation fill_bad_regions(data0, ngroups, ny, nx, row, scipix_n, refpix_r, pad, hnorm, hnorm1) # Setup various lists of indices that will be used in subsequent # sections for keeping/shuffling reference pixels in various arrays # # The comments are for scipix_n = 16, refpix_r = 4 n0 = 512 // scipix_n n1 = scipix_n + refpix_r + 2 ht = np.arange(n0 * n1, dtype=np.int32).reshape((n0, n1)) # (32, 22) ht[:, 0 : (scipix_n - refpix_r) // 2 + 1] = -1 ht[:, scipix_n // 2 + 1 + 3 * refpix_r // 2 :] = -1 hs = ht.copy() # ht is like href1, but extended over gaps and first and last norm pix mask = ht >= 0 ht = ht[mask] # 1-D, length = 2 * refpix_r * 512 / scipix_n # IDL: hs[scipix_n/2 + 1-refpix_r/2:scipix_n/2 + refpix_r + refpix_r/2,*] = # hs[reform([transpose(reform(indgen(refpix_r),refpix_r/2,2)), # transpose(reform(indgen(refpix_r),refpix_r/2,2))],refpix_r * 2) # + scipix_n/2 + 1,*] ; WIRED for R=2^(int) indr = np.arange(refpix_r, dtype=np.intp).reshape((2, refpix_r // 2)) # indr_t = # [[0 2] # [1 3]] indr_t = indr.transpose() # Before flattening, two_indr_t = # [[0 2 0 2] # [1 3 1 3]] # After flattening, two_indr_t = [0 2 0 2 1 3 1 3]. two_indr_t = np.concatenate((indr_t, indr_t), axis=1).flatten() two_indr_t += scipix_n // 2 + 1 # [9 11 9 11 10 12 10 12] hs[:, scipix_n // 2 + 1 - refpix_r // 2 : scipix_n // 2 + 1 + refpix_r // 2 + refpix_r] = hs[ :, two_indr_t ] mask = hs >= 0 hs = hs[mask] # hs is now 1-D if refpix_r % 4 == 2: len_hs = len(hs) temp_hs = hs.reshape(len_hs // 2, 2) temp_hs = temp_hs[:, ::-1] hs = temp_hs.flatten() # Construct the reference data: this is done in a big loop over the # four "sectors" of data in the image, corresponding to the amp regions. # Data from each sector is operated on independently and ultimately # the corrections are subtracted from each sector independently. shape_d = data0.shape for k in range(1, 5): log.debug(f"processing sector {k}") # At this point in the processing data0 has shape (5, ngroups, 2048, 712), # assuming normal IRS2 readout settings. r0k contains a subset of the # data from 1 sector of data0, with shape (ngroups, 2048, 256) r0k = np.zeros((shape_d[1], shape_d[2], shape_d[3]), dtype=np.float32) temp = data0[k, :, :, hs].copy() temp = np.transpose(temp, (1, 2, 0)) r0k[:, :, ht] = temp del temp # data0 has shape (5, ngroups, ny, row). See the section above where # d0 was created, then copied (moved) to data0. # sd[1] = shape_d[3] row (712) # sd[2] = shape_d[2] ny (2048) # sd[3] = shape_d[1] ngroups # sd[4] = shape_d[0] 5 # s is used below, so for convenience, here are the values again: # s[1] = shape[2] = nx # s[2] = shape[1] = ny # s[3] = shape[0] = ngroups # IDL and numpy differ in where they apply the normalization for the # FFT. This really shouldn't matter. normalization = float(shape_d[2] * shape_d[3]) # Set up refout if alpha was provided refout0 = None if alpha is not None: # IDL: refout0 = reform(data0[*,*,*,0], sd[1] * sd[2], sd[3]) refout0 = data0[0, :, :, :].reshape((shape_d[1], shape_d[2] * shape_d[3])) # IDL: refout0 = fft(refout0, dim=1, /over) # Divide by the length of the axis to be consistent with IDL. refout0 = np.fft.fft(refout0, axis=1) / normalization # IDL: r0 = reform(r0, sd[1] * sd[2], sd[3], 5, /over) r0k = r0k.reshape((shape_d[1], shape_d[2] * shape_d[3])) r0k = r0k.astype(np.complex64) r0k_fft = np.fft.fft(r0k, axis=1) / normalization # Note that where the IDL code uses alpha, we use beta, and vice versa. # IDL: for k=0,3 do oBridge[k]->Execute, # "for i=0, s3-1 do r0[*,i] *= alpha" r0k_fft *= beta[k - 1] # IDL: for k=0,3 do oBridge[k]->Execute, # "for i=0, s3-1 do r0[*,i] += beta * refout0[*,i]" if alpha is not None: r0k_fft += alpha[k - 1] * refout0 del refout0 # IDL: for k=0,3 do oBridge[k]->Execute, # "r0 = fft(r0, 1, dim=1, /overwrite)", /nowait r0k = np.fft.ifft(r0k_fft, axis=1) * normalization del r0k_fft # sd[1] = shape_d[3] row (712) # sd[2] = shape_d[2] ny (2048) # sd[3] = shape_d[1] ngroups # sd[4] = shape_d[0] 5 # IDL: r0 = reform(r0, sd[1], sd[2], sd[3], 5, /over) r0k = r0k.reshape(shape_d[1], shape_d[2], shape_d[3]) r0k = r0k.real if not preserve_refpix: r0k = r0k[:, :, hnorm1] else: r0k = r0k[:, :, unpad] # Subtract the correction from the data in this sector if not preserve_refpix: data0[k, :, :, hnorm1] -= np.transpose(r0k, (2, 0, 1)) else: data0[k, :, :, unpad] -= np.transpose(r0k, (2, 0, 1)) del r0k # End of loop over 4 sectors # Original data0 array has shape (5, ngroups, 2048, 712). Now that # correction has been applied, remove the interleaved reference pixels. # This leaves data0 with shape (5, ngroups, 2048, 512). if not preserve_refpix: data0 = data0[:, :, :, hnorm1] else: data0 = data0[:, :, :, unpad] # Unflip the data in the sectors that have opposite readout direction if preserve_refpix: data0[0, :, :, :] = data0[0, :, :, ::-1] data0[2, :, :, :] = data0[2, :, :, ::-1] data0[4, :, :, :] = data0[4, :, :, ::-1] # IDL: data0 = transpose(data0, [0,3,1,2]) 0, 1, 2, 3 --> 0, 3, 1, 2 # current order: 512, ny, ngroups, 5 (IDL) # current order: 5, ngroups, ny, 512 (numpy) # 0 1 2 3 current numpy indices # transpose to: 512, 5, ny, ngroups (IDL) # transpose to: ngroups, ny, 5, 512 (numpy) # 1 2 0 3 transpose order for numpy # Therefore: 0 1 2 3 --> 1 2 0 3 transpose order for numpy # After transposing, data0 will have shape (ngroups, 2048, 5, 512). data0 = np.transpose(data0, (1, 2, 0, 3)) # Reshape data0 back to its normal (ngroups, 2048, 2048), which has # the interleaved reference pixels stripped out. # IDL: data0 = reform(data0[*, 1:*, *, *], s[2], s[2], s[3], /over) # Note: ny x ny, not ny x nx. if not preserve_refpix: data0 = data0[:, :, 1:, :].reshape((ngroups, ny, ny)) else: data0 = data0.reshape((ngroups, ny, nx)) # b_offset is the average over the ramp that we subtracted near the # beginning; add it back in. # Shape of b_offset is (2048, 3200), but data0 is (ngroups, 2048, 2048), # so a mask is applied to b_offset to remove the reference pix locations. if not preserve_refpix: data0 += b_offset[..., irs2_mask] else: # add in only data value - # reference mean should be subtracted if not stripped, # except in reference sector data0[..., irs2_mask] += b_offset[..., irs2_mask] data0[..., : nx // 5] += b_offset[..., : nx // 5] return data0
[docs] def fft_interp_norm(dd0, mask0, row, hnorm, hnorm1, ny, ngroups, aa, n_iter_norm): """ Filter iteratively in FFT space of the normal pixels in each group. Parameters ---------- dd0 : ndarray Data array containing all groups, updated in place. mask0 : ndarray Mask for pixels to filter, with dimensions ny x nrow. 1 means use the pixel, 0 means do not use it. row : int Row size computed from the number of science pixels, reference pixels, and padding in an amplifier. hnorm : ndarray Array of column indices for normal pixels. hnorm1 : ndarray Shifted index values for normal pixels. ny : int Y size of data array. ngroups : int Number of groups. aa : ndarray Filter to apply. n_iter_norm : int Number of filtering iterations. """ mm = np.zeros((ny, row), dtype=np.int8) mm[:, hnorm1] = mask0[:, hnorm] hm = mm != 0 # 2-D boolean mask for j in range(ngroups): dd = dd0[j, :, :].copy() # make a copy, not a view p = dd.flatten() for _it in range(n_iter_norm): pp = np.fft.fft(p) pp *= aa p[:] = np.fft.ifft(pp).real p[hm.ravel()] = dd[hm] dd0[j, :, :] = p.reshape((ny, row))
[docs] def ols_line(x, y): """ Fit a straight line using ordinary least squares. Parameters ---------- x : ndarray Array of independent variables y : ndarray Array of dependent variables Returns ------- intercept : float Intercept of straight line fit slope : float Slope of straight line fit """ xf = x.ravel() yf = y.ravel() if len(xf) < 1 or len(yf) < 1: return 0.0, 0.0 groups = float(len(xf)) mean_x = xf.mean() mean_y = yf.mean() sum_x2 = (xf**2).sum() sum_xy = (xf * yf).sum() slope = (sum_xy - groups * mean_x * mean_y) / (sum_x2 - groups * mean_x**2) intercept = mean_y - slope * mean_x return intercept, slope
[docs] def remove_slopes(data0, ngroups, ny, row): """ Remove slopes. Fitting and removal of slopes per frame to remove issues at frame boundaries. Parameters ---------- data0 : ndarray Input data array ngroups : int Number of groups in input data ny : int Number of rows in input data row : int Row size """ time_arr = np.arange(ny * row, dtype=np.float32).reshape((ny, row)) time_arr -= time_arr.mean(dtype=np.float64) row4plus4 = np.array([0, 1, 2, 3, 2044, 2045, 2046, 2047], dtype=np.intp) # For ab_3, it should be OK to use the same index order as the IDL code. ab_3 = np.zeros((2, ngroups, 5), dtype=np.float32) for i in range(5): for k in range(ngroups): # mask is 2-D, since both row4plus4 and : have more than one element. mask = data0[i, k, row4plus4, :] != 0.0 (intercept, slope) = ols_line( time_arr[row4plus4, :][mask], data0[i, k, row4plus4, :][mask] ) ab_3[0, k, i] = intercept ab_3[1, k, i] = slope for i in range(5): for k in range(ngroups): # weight is 0 where data0 is 0, else 1. weight = (data0[i, k, :, :] != 0.0).astype(np.int8) data0[i, k, :, :] -= (ab_3[0, k, i] + time_arr * ab_3[1, k, i]) * weight
[docs] def replace_bad_pixels(data0, ngroups, ny, row): """ Replace bad pixels. Use cosine weighted interpolation to replace 0.0 values and bad pixels and gaps:: s[1] = nx s[2] = ny s[3] = ngroups Parameters ---------- data0 : ndarray Input data array ngroups : int Number of groups in input data array ny : int Number of rows in input data array row : int Row definition:: row = (scipix_n + refpix_r + 2) * 512 // scipix_n + pad # This means that if scipix_n = 16 refpix_r = 4 pad = 8 # then row = 712 """ w_ind = np.arange(1, 32, dtype=np.float32) / 32.0 w = np.sin(w_ind * np.pi) kk = 0 for jj in range(ngroups): dat = data0[kk, jj, :, :].reshape(row * ny) mask = (dat != 0.0).astype(np.float32) numerator = convolve1d(dat, w, mode="wrap") denominator = convolve1d(mask, w, mode="wrap") div_zero = denominator == 0.0 # check for divide by zero numerator = np.where(div_zero, 0.0, numerator) denominator = np.where(div_zero, 1.0, denominator) dat = numerator / denominator dat = dat.reshape(ny, row) mask = mask.reshape(ny, row) data0[kk, jj, :, :] += dat * (1.0 - mask)
[docs] def fill_bad_regions(data0, ngroups, ny, nx, row, scipix_n, refpix_r, pad, hnorm, hnorm1): """ Fill the bad regions in the data. Use Fourier filter/interpolation to replace: a. bad pixel, gaps, and reference data in the time-ordered normal data b. gaps and normal data in the time-ordered reference data This "improves" upon the cosine interpolation performed above. Parameters ---------- data0 : ndarray Input data array. Modified in place. ngroups : int Number of groups in input data. ny : int Number of rows in input data. nx : int Number of columns in input data. row : int See :func:`replace_bad_pixels`. scipix_n : int Number of regular samples before stepping out to collect reference samples. refpix_r : int Number of reference samples before stepping back in to collect regular samples. pad : int The effective number of pixels sampled during the pause at the end of each row (new-row overhead). The padding is needed to preserve the phase of temporally periodic signals. hnorm : ndarray Array of column indices for normal pixels. hnorm1 : ndarray Shifted index values for normal pixels. """ # Parameters for the filter to be used: # length of apodization cosine filter elen = 110000 // (scipix_n + refpix_r + 2) # max unfiltered frequency blen = (512 + 512 // scipix_n * (refpix_r + 2) + pad) // ( scipix_n + refpix_r + 2 ) * ny // 2 - elen // 2 # Construct the filter [1, cos, 0, cos, 1]. temp_a1 = (np.cos(np.arange(elen, dtype=np.float64) * np.pi / float(elen)) + 1.0) / 2.0 # elen = 5000 # blen = 30268 # row * ny // 2 - 2 * blen - 2 * elen = 658552 # len(temp_a2) = 729088 temp_a2 = np.concatenate( ( np.ones(blen, dtype=np.float64), temp_a1.copy(), np.zeros(row * ny // 2 - 2 * blen - 2 * elen, dtype=np.float64), temp_a1[::-1].copy(), np.ones(blen, dtype=np.float64), ) ) roll_a2 = np.roll(temp_a2, -1) aa = np.concatenate((temp_a2, roll_a2[::-1])) del temp_a1, temp_a2, roll_a2 # IDL: aa = a # replicate(1, s[3]) ; for application to the data # In IDL, aa is a 2-D array with one column of `a` for each group. In # Python, numpy broadcasting takes care of this. n_iter_norm = 3 dd0 = data0[0, :, :, :] # IDL: fft_interp_norm, dd0, 2, replicate(1, s[1] / 4, s[2], 4), # row, hnorm, hnorm1, s, aa , n_iter_norm fft_interp_norm( dd0, np.ones((ny, nx // 4), dtype=np.int64), row, hnorm, hnorm1, ny, ngroups, aa, n_iter_norm, ) data0[0, :, :, :] = dd0.copy() del aa, dd0