"""
Utility function for assign_wcs.
"""
import logging
import functools
import numpy as np
from astropy.coordinates import SkyCoord
from astropy.utils.misc import isiterable
from astropy.io import fits
from astropy.modeling import models as astmodels
from astropy.table import QTable
from astropy.constants import c
from typing import Union, List
from gwcs import WCS
from gwcs.wcstools import wcs_from_fiducial, grid_from_bounding_box
from gwcs import utils as gwutils
from stpipe.exceptions import StpipeExitException
from stdatamodels.jwst.datamodels import JwstDataModel
from stdatamodels.jwst.datamodels import WavelengthrangeModel
from stdatamodels.jwst.transforms.models import GrismObject
from . import pointing
from ..lib.catalog_utils import SkyObject
log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)
_MAX_SIP_DEGREE = 6
__all__ = ["reproject", "wcs_from_footprints", "velocity_correction",
"MSAFileError", "NoDataOnDetectorError", "compute_scale",
"calc_rotation_matrix", "wrap_ra", "update_fits_wcsinfo"]
class MSAFileError(Exception):
def __init__(self, message):
super(MSAFileError, self).__init__(message)
class NoDataOnDetectorError(StpipeExitException):
"""WCS solution indicates no data on detector
When WCS solutions are available, the solutions indicate that no data
will be present, raise this exception.
Specific example is for NIRSpec and the NRS2 detector. For various
configurations of the MSA, it is possible that no dispersed spectra will
appear on NRS2. This is not a failure of calibration, but needs to be
called out in order for the calling architecture to be aware of this.
"""
def __init__(self, message=None):
if message is None:
message = 'WCS solution indicate that no science is in the data.'
# The first argument instructs stpipe CLI tools to exit with status
# 64 when this exception is raised.
super().__init__(64, message)
def _domain_to_bounding_box(domain):
# TODO: remove this when domain is completely removed
bb = tuple([(item['lower'], item['upper']) for item in domain])
if len(bb) == 1:
bb = bb[0]
return bb
def reproject(wcs1, wcs2):
"""
Given two WCSs return a function which takes pixel coordinates in
the first WCS and computes their location in the second one.
It performs the forward transformation of ``wcs1`` followed by the
inverse of ``wcs2``.
Parameters
----------
wcs1, wcs2 : `~gwcs.wcs.WCS`
WCS objects.
Returns
-------
_reproject : func
Function to compute the transformations. It takes x, y
positions in ``wcs1`` and returns x, y positions in ``wcs2``.
"""
def _reproject(x, y):
sky = wcs1.forward_transform(x, y)
return wcs2.backward_transform(*sky)
return _reproject
def compute_scale(wcs: WCS, fiducial: Union[tuple, np.ndarray],
disp_axis: int = None, pscale_ratio: float = None) -> float:
"""Compute scaling transform.
Parameters
----------
wcs : `~gwcs.wcs.WCS`
Reference WCS object from which to compute a scaling factor.
fiducial : tuple
Input fiducial of (RA, DEC) or (RA, DEC, Wavelength) used in calculating reference points.
disp_axis : int
Dispersion axis integer. Assumes the same convention as `wcsinfo.dispersion_direction`
pscale_ratio : int
Ratio of input to output pixel scale
Returns
-------
scale : float
Scaling factor for x and y or cross-dispersion direction.
"""
spectral = 'SPECTRAL' in wcs.output_frame.axes_type
if spectral and disp_axis is None:
raise ValueError('If input WCS is spectral, a disp_axis must be given')
crpix = np.array(wcs.invert(*fiducial))
delta = np.zeros_like(crpix)
spatial_idx = np.where(np.array(wcs.output_frame.axes_type) == 'SPATIAL')[0]
delta[spatial_idx[0]] = 1
crpix_with_offsets = np.vstack((crpix, crpix + delta, crpix + np.roll(delta, 1))).T
crval_with_offsets = wcs(*crpix_with_offsets, with_bounding_box=False)
coords = SkyCoord(ra=crval_with_offsets[spatial_idx[0]], dec=crval_with_offsets[spatial_idx[1]], unit="deg")
xscale = np.abs(coords[0].separation(coords[1]).value)
yscale = np.abs(coords[0].separation(coords[2]).value)
if pscale_ratio is not None:
xscale *= pscale_ratio
yscale *= pscale_ratio
if spectral:
# Assuming scale doesn't change with wavelength
# Assuming disp_axis is consistent with DataModel.meta.wcsinfo.dispersion.direction
return yscale if disp_axis == 1 else xscale
return np.sqrt(xscale * yscale)
def calc_rotation_matrix(roll_ref: float, v3i_yang: float, vparity: int = 1) -> List[float]:
"""Calculate the rotation matrix.
Parameters
----------
roll_ref : float
Telescope roll angle of V3 North over East at the ref. point in radians
v3i_yang : float
The angle between ideal Y-axis and V3 in radians.
vparity : int
The x-axis parity, usually taken from the JWST SIAF parameter VIdlParity.
Value should be "1" or "-1".
Returns
-------
matrix: [pc1_1, pc1_2, pc2_1, pc2_2]
The rotation matrix
Notes
-----
The rotation is
----------------
| pc1_1 pc2_1 |
| pc1_2 pc2_2 |
----------------
"""
if vparity not in (1, -1):
raise ValueError(f'vparity should be 1 or -1. Input was: {vparity}')
rel_angle = roll_ref - (vparity * v3i_yang)
pc1_1 = vparity * np.cos(rel_angle)
pc1_2 = np.sin(rel_angle)
pc2_1 = vparity * -np.sin(rel_angle)
pc2_2 = np.cos(rel_angle)
return [pc1_1, pc1_2, pc2_1, pc2_2]
def wcs_from_footprints(dmodels, refmodel=None, transform=None, bounding_box=None,
pscale_ratio=None, pscale=None, rotation=None,
shape=None, crpix=None, crval=None):
"""
Create a WCS from a list of input data models.
A fiducial point in the output coordinate frame is created from the
footprints of all WCS objects. For a spatial frame this is the center
of the union of the footprints. For a spectral frame the fiducial is in
the beginning of the footprint range.
If ``refmodel`` is None, the first WCS object in the list is considered
a reference. The output coordinate frame and projection (for celestial frames)
is taken from ``refmodel``.
If ``transform`` is not supplied, a compound transform is created using
CDELTs and PC.
If ``bounding_box`` is not supplied, the bounding_box of the new WCS is computed
from bounding_box of all input WCSs.
Parameters
----------
dmodels : list of `~jwst.datamodels.JwstDataModel`
A list of data models.
refmodel : `~jwst.datamodels.JwstDataModel`, optional
This model's WCS is used as a reference.
WCS. The output coordinate frame, the projection and a
scaling and rotation transform is created from it. If not supplied
the first model in the list is used as ``refmodel``.
transform : `~astropy.modeling.core.Model`, optional
A transform, passed to :meth:`~gwcs.wcstools.wcs_from_fiducial`
If not supplied Scaling | Rotation is computed from ``refmodel``.
bounding_box : tuple, optional
Bounding_box of the new WCS.
If not supplied it is computed from the bounding_box of all inputs.
pscale_ratio : float, None, optional
Ratio of input to output pixel scale. Ignored when either
``transform`` or ``pscale`` are provided.
pscale : float, None, optional
Absolute pixel scale in degrees. When provided, overrides
``pscale_ratio``. Ignored when ``transform`` is provided.
rotation : float, None, optional
Position angle of output image’s Y-axis relative to North.
A value of 0.0 would orient the final output image to be North up.
The default of `None` specifies that the images will not be rotated,
but will instead be resampled in the default orientation for the camera
with the x and y axes of the resampled image corresponding
approximately to the detector axes. Ignored when ``transform`` is
provided.
shape : tuple of int, None, optional
Shape of the image (data array) using ``numpy.ndarray`` convention
(``ny`` first and ``nx`` second). This value will be assigned to
``pixel_shape`` and ``array_shape`` properties of the returned
WCS object.
crpix : tuple of float, None, optional
Position of the reference pixel in the image array. If ``crpix`` is not
specified, it will be set to the center of the bounding box of the
returned WCS object.
crval : tuple of float, None, optional
Right ascension and declination of the reference pixel. Automatically
computed if not provided.
"""
bb = bounding_box
wcslist = [im.meta.wcs for im in dmodels]
if not isiterable(wcslist):
raise ValueError("Expected 'wcslist' to be an iterable of WCS objects.")
if not all([isinstance(w, WCS) for w in wcslist]):
raise TypeError("All items in wcslist are to be instances of gwcs.WCS.")
if refmodel is None:
refmodel = dmodels[0]
else:
if not isinstance(refmodel, JwstDataModel):
raise TypeError("Expected refmodel to be an instance of DataModel.")
fiducial = compute_fiducial(wcslist, bb)
if crval is not None:
# overwrite spatial axes with user-provided CRVAL:
i = 0
for k, axt in enumerate(wcslist[0].output_frame.axes_type):
if axt == 'SPATIAL':
fiducial[k] = crval[i]
i += 1
ref_fiducial = np.array([refmodel.meta.wcsinfo.ra_ref, refmodel.meta.wcsinfo.dec_ref])
prj = astmodels.Pix2Sky_TAN()
if transform is None:
transform = []
wcsinfo = pointing.wcsinfo_from_model(refmodel)
sky_axes, spec, other = gwutils.get_axes(wcsinfo)
# Need to put the rotation matrix (List[float, float, float, float])
# returned from calc_rotation_matrix into the correct shape for
# constructing the transformation
v3yangle = np.deg2rad(refmodel.meta.wcsinfo.v3yangle)
vparity = refmodel.meta.wcsinfo.vparity
if rotation is None:
roll_ref = np.deg2rad(refmodel.meta.wcsinfo.roll_ref)
else:
roll_ref = np.deg2rad(rotation) + (vparity * v3yangle)
pc = np.reshape(
calc_rotation_matrix(roll_ref, v3yangle, vparity=vparity),
(2, 2)
)
rotation = astmodels.AffineTransformation2D(pc, name='pc_rotation_matrix')
transform.append(rotation)
if sky_axes:
if not pscale:
pscale = compute_scale(refmodel.meta.wcs, ref_fiducial,
pscale_ratio=pscale_ratio)
transform.append(astmodels.Scale(pscale, name='cdelt1') & astmodels.Scale(pscale, name='cdelt2'))
if transform:
transform = functools.reduce(lambda x, y: x | y, transform)
out_frame = refmodel.meta.wcs.output_frame
input_frame = refmodel.meta.wcs.input_frame
wnew = wcs_from_fiducial(fiducial, coordinate_frame=out_frame, projection=prj,
transform=transform, input_frame=input_frame)
footprints = [w.footprint().T for w in wcslist]
domain_bounds = np.hstack([wnew.backward_transform(*f) for f in footprints])
axis_min_values = np.min(domain_bounds, axis=1)
domain_bounds = (domain_bounds.T - axis_min_values).T
output_bounding_box = []
for axis in out_frame.axes_order:
axis_min, axis_max = domain_bounds[axis].min(), domain_bounds[axis].max()
output_bounding_box.append((axis_min, axis_max))
output_bounding_box = tuple(output_bounding_box)
if crpix is None:
offset1, offset2 = wnew.backward_transform(*fiducial)
offset1 -= axis_min_values[0]
offset2 -= axis_min_values[1]
else:
offset1, offset2 = crpix
offsets = astmodels.Shift(-offset1, name='crpix1') & astmodels.Shift(-offset2, name='crpix2')
wnew.insert_transform('detector', offsets, after=True)
wnew.bounding_box = output_bounding_box
if shape is None:
shape = [int(axs[1] - axs[0] + 0.5) for axs in output_bounding_box[::-1]]
wnew.pixel_shape = shape[::-1]
wnew.array_shape = shape
return wnew
def compute_fiducial(wcslist, bounding_box=None):
"""
For a celestial footprint this is the center.
For a spectral footprint, it is the beginning of the range.
This function assumes all WCSs have the same output coordinate frame.
"""
axes_types = wcslist[0].output_frame.axes_type
spatial_axes = np.array(axes_types) == 'SPATIAL'
spectral_axes = np.array(axes_types) == 'SPECTRAL'
footprints = np.hstack([w.footprint(bounding_box=bounding_box).T for w in wcslist])
spatial_footprint = footprints[spatial_axes]
spectral_footprint = footprints[spectral_axes]
fiducial = np.empty(len(axes_types))
if spatial_footprint.any():
lon, lat = spatial_footprint
lon, lat = np.deg2rad(lon), np.deg2rad(lat)
x = np.cos(lat) * np.cos(lon)
y = np.cos(lat) * np.sin(lon)
z = np.sin(lat)
x_mid = (np.max(x) + np.min(x)) / 2.
y_mid = (np.max(y) + np.min(y)) / 2.
z_mid = (np.max(z) + np.min(z)) / 2.
lon_fiducial = np.rad2deg(np.arctan2(y_mid, x_mid)) % 360.0
lat_fiducial = np.rad2deg(np.arctan2(z_mid, np.sqrt(x_mid ** 2 + y_mid ** 2)))
fiducial[spatial_axes] = lon_fiducial, lat_fiducial
if spectral_footprint.any():
fiducial[spectral_axes] = spectral_footprint.min()
return fiducial
def is_fits(input_img):
"""
Returns
--------
isFits: tuple
An ``(isfits, fitstype)`` tuple. The values of ``isfits`` and
``fitstype`` are specified as:
- ``isfits``: True|False
- ``fitstype``: if True, one of 'waiver', 'mef', 'simple'; if False, None
Notes
-----
Input images which do not have a valid FITS filename will automatically
result in a return of (False, None).
In the case that the input has a valid FITS filename but runs into some
error upon opening, this routine will raise that exception for the calling
routine/user to handle.
"""
isfits = False
fitstype = None
names = ['fits', 'fit', 'FITS', 'FIT']
# determine if input is a fits file based on extension
# Only check type of FITS file if filename ends in valid FITS string
f = None
fileclose = False
if isinstance(input_img, fits.HDUList):
isfits = True
f = input_img
else:
isfits = True in [input_img.endswith(suffix) for suffix in names]
# if input is a fits file determine what kind of fits it is
# waiver fits len(shape) == 3
if isfits:
if not f:
try:
f = fits.open(input_img, mode='readonly')
fileclose = True
except Exception:
if f is not None:
f.close()
raise
data0 = f[0].data
if data0 is not None:
try:
if isinstance(f[1], fits.TableHDU):
fitstype = 'waiver'
except IndexError:
fitstype = 'simple'
else:
fitstype = 'mef'
if fileclose:
f.close()
return isfits, fitstype
def subarray_transform(input_model):
"""
Return an offset model if the observation uses a subarray.
Parameters
----------
input_model : `~jwst.datamodels.JwstDataModel`
Data model.
Returns
-------
subarray2full : `~astropy.modeling.core.Model` or ``None``
Returns a (combination of ) ``Shift`` models if a subarray is used.
Returns ``None`` if a full frame observation.
"""
tr_xstart = astmodels.Identity(1)
tr_ystart = astmodels.Identity(1)
# These quantities are 1-based
xstart = input_model.meta.subarray.xstart
ystart = input_model.meta.subarray.ystart
if xstart is not None and xstart != 1:
tr_xstart = astmodels.Shift(xstart - 1)
if ystart is not None and ystart != 1:
tr_ystart = astmodels.Shift(ystart - 1)
if (isinstance(tr_xstart, astmodels.Identity) and
isinstance(tr_ystart, astmodels.Identity)):
# the case of a full frame observation
return None
else:
subarray2full = tr_xstart & tr_ystart
return subarray2full
def not_implemented_mode(input_model, ref, slit_y_range=None):
"""
Return ``None`` if assign_wcs has not been implemented for a mode.
"""
exp_type = input_model.meta.exposure.type
message = "WCS for EXP_TYPE of {0} is not implemented.".format(exp_type)
log.critical(message)
return None
def get_object_info(catalog_name=None):
"""Return a list of SkyObjects from the direct image
The source_catalog step catalog items are read into a list
of SkyObjects which can be referenced by catalog id. Only
the columns needed by the WFSS code are saved.
Parameters
----------
catalog_name : str, astropy.table.table.Qtable
The name of the photutils catalog or its quantities table
Returns
-------
objects : list[jwst.transforms.models.SkyObject]
A list of SkyObject tuples
Notes
-----
"""
if isinstance(catalog_name, str):
if len(catalog_name) == 0:
err_text = "Empty catalog filename"
log.error(err_text)
raise ValueError(err_text)
try:
catalog = QTable.read(catalog_name, format='ascii.ecsv')
except FileNotFoundError as e:
log.error("Could not find catalog file: {0}".format(e))
raise FileNotFoundError("Could not find catalog: {0}".format(e))
elif isinstance(catalog_name, QTable):
catalog = catalog_name
else:
err_text = "Need to input string name of catalog or astropy.table.table.QTable instance"
log.error(err_text)
raise TypeError(err_text)
objects = []
# validate that the expected columns are there
required_fields = set(SkyObject()._fields)
try:
if not set(required_fields).issubset(set(catalog.colnames)):
difference = set(required_fields).difference(set(catalog.colnames))
err_text = "Missing required columns in source catalog: {0}".format(difference)
log.error(err_text)
raise KeyError(err_text)
except AttributeError as e:
err_text = "Problem validating object catalog columns: {0}".format(e)
log.error(err_text)
raise AttributeError
# The columns are named sky_bbox_ll, sky_bbox_ul, sky_bbox_lr,
# and sky_bbox_ur, each of which is a SkyCoord (i.e. RA & Dec & frame) at
# one corner of the minimal bounding box. There will also be a sky_bbox
# property as a 4-tuple of SkyCoord, but that is not serializable
# (hence, the four separate columns).
for row in catalog:
objects.append(SkyObject(label=row['label'],
xcentroid=row['xcentroid'],
ycentroid=row['ycentroid'],
sky_centroid=row['sky_centroid'],
isophotal_abmag=row['isophotal_abmag'],
isophotal_abmag_err=row['isophotal_abmag_err'],
sky_bbox_ll=row['sky_bbox_ll'],
sky_bbox_lr=row['sky_bbox_lr'],
sky_bbox_ul=row['sky_bbox_ul'],
sky_bbox_ur=row['sky_bbox_ur'],
is_extended=row['is_extended']
)
)
return objects
def create_grism_bbox(input_model,
reference_files=None,
mmag_extract=None,
extract_orders=None,
wfss_extract_half_height=None,
wavelength_range=None,
nbright=None):
"""Create bounding boxes for each object in the catalog
The sky coordinates in the catalog image are first related
to the grism image. They need to go through the WCS object
in order to find the "direct image" pixel location, which is
also in a detector pixel coordinate frame. This "direct image"
location can then be sent through the trace polynomials to find
the spectral location on the grism image for that wavelength and order.
Parameters
----------
input_model : `jwst.datamodels.ImagingModel`
Data model which holds the grism image
reference_files : dict, optional
Dictionary of reference file names.
If ``None``, ``wavelength_range`` must be supplied to specify
the orders and corresponding wavelength ranges to be used in extraction.
mmag_extract : float, optional
The faintest magnitude to extract from the catalog.
extract_orders : list, optional
The list of orders to extract, if specified this will
override the orders listed in the wavelengthrange reference file.
If ``None``, the default one in the wavelengthrange reference file is used.
wfss_extract_half_height : int, optional
Cross-dispersion extraction half height in pixels, WFSS mode.
Overwrites the computed extraction height in ``GrismObject.order_bounding.``
If ``None``, it's computed from the segmentation map,
using the min and max wavelength for each of the orders that
are available.
wavelength_range : dict, optional
Pairs of {spectral_order: (wave_min, wave_max)} for each order.
If ``None``, the default one in the wavelengthrange reference file is used.
nbright : int, optional
The number of brightest objects to extract from the catalog.
Returns
-------
grism_objects : list
A list of GrismObject(s) for every source in the catalog.
Each grism object contains information about its
spectral extent.
Notes
-----
The wavelengthrange reference file is used to govern
the extent of the bounding box for each object. The name of the
catalog has been stored in the input models meta information under
the source_catalog key.
It's left to the calling routine to cut the bounding boxes at the
extent of the detector (for example, extract 2d would only extract
the on-detector portion of the bounding box)
Bounding box dispersion direction is dependent on the filter and
module for NIRCAM and changes for GRISMR, but is consistent for GRISMC,
see https://jwst-docs.stsci.edu/display/JTI/NIRCam+Wide+Field+Slitless+Spectroscopy
NIRISS has one detector. GRISMC disperses along rows and
GRISMR disperses along columns.
If ``wfss_extract_half_height`` is specified it is used to compute the extent in
the cross-dispersion direction, which becomes ``2 * wfss_extract_half_height + 1``.
``wfss_extract_half_height`` can only be applied to point source objects.
"""
instr_name = input_model.meta.instrument.name
if instr_name == "NIRCAM":
filter_name = input_model.meta.instrument.filter
elif instr_name == "NIRISS":
filter_name = input_model.meta.instrument.pupil
else:
raise ValueError("create_grism_object works with NIRCAM and NIRISS WFSS exposures only.")
if reference_files is None:
# Get the list of extract_orders and lmin, lmax from wavelength_range.
if wavelength_range is None:
message = "If reference files are not supplied, ``wavelength_range`` must be provided."
raise TypeError(message)
else:
# Get the list of extract_orders and lmin, lmax from the ``wavelengthrange`` reference file.
with WavelengthrangeModel(reference_files['wavelengthrange']) as f:
if 'WFSS' not in f.meta.exposure.type:
err_text = "Wavelengthrange reference file not for WFSS"
log.error(err_text)
raise ValueError(err_text)
ref_extract_orders = f.extract_orders
if extract_orders is None:
# ref_extract_orders = extract_orders
extract_orders = [x[1] for x in ref_extract_orders if x[0] == filter_name].pop()
wavelength_range = f.get_wfss_wavelength_range(filter_name, extract_orders)
if mmag_extract is None:
mmag_extract = 999. # extract all objects, regardless of magnitude
else:
log.info("Extracting objects < abmag = {0}".format(mmag_extract))
if not isinstance(mmag_extract, (int, float)):
raise TypeError(f"Expected mmag_extract to be a number, got {mmag_extract}")
# extract the catalog objects
if input_model.meta.source_catalog is None:
err_text = "No source catalog listed in datamodel."
log.error(err_text)
raise ValueError(err_text)
log.info(f"Getting objects from {input_model.meta.source_catalog}")
return _create_grism_bbox(input_model, mmag_extract, wfss_extract_half_height, wavelength_range,
nbright)
def _create_grism_bbox(input_model, mmag_extract=None, wfss_extract_half_height=None,
wavelength_range=None, nbright=None):
log.debug(f'Extracting with wavelength_range {wavelength_range}')
# this contains the pure information from the catalog with no translations
skyobject_list = get_object_info(input_model.meta.source_catalog)
# get the imaging transform to record the center of the object in the image
# here, image is in the imaging reference frame, before going through the
# dispersion coefficients
sky_to_detector = input_model.meta.wcs.get_transform('world', 'detector')
sky_to_grism = input_model.meta.wcs.backward_transform
grism_objects = [] # the return list of GrismObjects
for obj in skyobject_list:
if obj.isophotal_abmag is not None:
if obj.isophotal_abmag < mmag_extract:
# could add logic to ignore object if too far off image,
# save the image frame center of the object
# takes in ra, dec, wavelength, order but wave and order
# don't get used until the detector->grism_detector transform
xcenter, ycenter, _, _ = sky_to_detector(obj.sky_centroid.icrs.ra.value,
obj.sky_centroid.icrs.dec.value,
1, 1)
order_bounding = {}
waverange = {}
partial_order = {}
for order in wavelength_range:
# range_select = [(x[2], x[3]) for x in wavelengthrange if (x[0] == order and x[1] == filter_name)]
# The orders of the bounding box in the non-dispersed image
# drive the extraction extent. The location of the min and
# max wavelengths for each order are used to get the
# location of the +/- sides of the bounding box in the
# grism image
lmin, lmax = wavelength_range[order]
ra = np.array([obj.sky_bbox_ll.ra.value, obj.sky_bbox_lr.ra.value,
obj.sky_bbox_ul.ra.value, obj.sky_bbox_ur.ra.value])
dec = np.array([obj.sky_bbox_ll.dec.value, obj.sky_bbox_lr.dec.value,
obj.sky_bbox_ul.dec.value, obj.sky_bbox_ur.dec.value])
x1, y1, _, _, _ = sky_to_grism(ra, dec, [lmin] * 4, [order] * 4)
x2, y2, _, _, _ = sky_to_grism(ra, dec, [lmax] * 4, [order] * 4)
xstack = np.hstack([x1, x2])
ystack = np.hstack([y1, y2])
# Subarrays are only allowed in nircam tsgrism mode. The polynomial transforms
# only work with the full frame coordinates. The code here is called during extract_2d,
# and is creating bounding boxes which should be in the full frame coordinates, it just
# uses the input catalog and the magnitude to limit the objects that need bounding boxes.
# Tsgrism is always supposed to have the source object at the same pixel, and that is
# hardcoded into the transforms. At least a while ago, the 2d extraction for tsgrism mode
# didn't call this bounding box code. So I think it's safe to leave the subarray
# subtraction out, i.e. do not subtract x/ystart.
xmin = np.nanmin(xstack)
xmax = np.nanmax(xstack)
ymin = np.nanmin(ystack)
ymax = np.nanmax(ystack)
if wfss_extract_half_height is not None and not obj.is_extended:
if input_model.meta.wcsinfo.dispersion_direction == 2:
ra_center, dec_center = obj.sky_centroid.ra.value, obj.sky_centroid.dec.value
center, _, _, _, _ = sky_to_grism(ra_center, dec_center, (lmin + lmax) / 2, order)
xmin = center - wfss_extract_half_height
xmax = center + wfss_extract_half_height
elif input_model.meta.wcsinfo.dispersion_direction == 1:
ra_center, dec_center = obj.sky_centroid.ra.value, obj.sky_centroid.dec.value
_, center, _, _, _ = sky_to_grism(ra_center, dec_center, (lmin + lmax) / 2, order)
ymin = center - wfss_extract_half_height
ymax = center + wfss_extract_half_height
else:
raise ValueError("Cannot determine dispersion direction.")
# Convert floating-point corner values to whole pixel indexes
xmin = gwutils._toindex(xmin)
xmax = gwutils._toindex(xmax)
ymin = gwutils._toindex(ymin)
ymax = gwutils._toindex(ymax)
# Don't add objects and orders that are entirely off the detector.
# "partial_order" marks objects that are near enough to the detector
# edge to have some spectrum on the detector.
# This is useful because the catalog often is created from a resampled direct
# image that is bigger than the detector FOV for a single grism exposure.
exclude = False
ispartial = False
# Here we check to ensure that the extraction region `pts`
# has at least two pixels of width in the dispersion
# direction, and one in the cross-dispersed direction when
# placed into the subarray extent.
pts = np.array([[ymin, xmin], [ymax, xmax]])
subarr_extent = np.array([[0, 0],
[input_model.meta.subarray.ysize - 1,
input_model.meta.subarray.xsize - 1]])
if input_model.meta.wcsinfo.dispersion_direction == 1:
# X-axis is dispersion direction
disp_col = 1
xdisp_col = 0
else:
# Y-axis is dispersion direction
disp_col = 0
xdisp_col = 1
dispaxis_check = (pts[1, disp_col] - subarr_extent[0, disp_col] > 0) and \
(subarr_extent[1, disp_col] - pts[0, disp_col] > 0)
xdispaxis_check = (pts[1, xdisp_col] - subarr_extent[0, xdisp_col] >= 0) and \
(subarr_extent[1, xdisp_col] - pts[0, xdisp_col] >= 0)
contained = dispaxis_check and xdispaxis_check
inidx = np.all(np.logical_and(subarr_extent[0] <= pts, pts <= subarr_extent[1]), axis=1)
if not contained:
exclude = True
log.info("Excluding off-image object: {}, order {}".format(obj.label, order))
elif contained >= 1:
outbox = pts[np.logical_not(inidx)]
if len(outbox) > 0:
ispartial = True
log.info("Partial order on detector for obj: {} order: {}".format(obj.label, order))
if not exclude:
order_bounding[order] = ((ymin, ymax), (xmin, xmax))
waverange[order] = ((lmin, lmax))
partial_order[order] = ispartial
if len(order_bounding) > 0:
grism_objects.append(GrismObject(sid=obj.label,
order_bounding=order_bounding,
sky_centroid=obj.sky_centroid,
partial_order=partial_order,
waverange=waverange,
sky_bbox_ll=obj.sky_bbox_ll,
sky_bbox_lr=obj.sky_bbox_lr,
sky_bbox_ul=obj.sky_bbox_ul,
sky_bbox_ur=obj.sky_bbox_ur,
xcentroid=xcenter,
ycentroid=ycenter,
is_extended=obj.is_extended,
isophotal_abmag=obj.isophotal_abmag))
# At this point we have a list of grism objects limited to
# isophotal_abmag < mmag_extract. We now need to further restrict
# the list to the N brightest objects, as given by nbright.
if nbright is None:
# Include all objects, regardless of brightness
final_objects = grism_objects
else:
# grism_objects is a list of objects, so it's not easy or practical
# to sort it directly. So create a list of the isophotal_abmags, which
# we'll then use to find the N brightest objects.
indxs = np.argsort([obj.isophotal_abmag for obj in grism_objects])
# Create a final grism object list containing only the N brightest objects
final_objects = []
final_objects = [grism_objects[i] for i in indxs[:nbright]]
del grism_objects
log.info(f"Total of {len(final_objects)} grism objects defined")
if len(final_objects) == 0:
log.warning("No grism objects saved; check catalog or step params")
return final_objects
def get_num_msa_open_shutters(shutter_state):
"""
Return the number of open shutters in a slitlet.
Parameters
----------
shutter_state : str
``Slit.shutter_state`` attribute - a combination of
``1`` - open shutter, ``0`` - closed shutter, ``x`` - main shutter.
"""
num = shutter_state.count('1')
if 'x' in shutter_state:
num += 1
return num
def transform_bbox_from_shape(shape):
"""Create a bounding box from the shape of the data.
This is appropriate to attached to a transform.
Parameters
----------
shape : tuple
The shape attribute from a `numpy.ndarray` array
Returns
-------
bbox : tuple
Bounding box in y, x order.
"""
bbox = ((-0.5, shape[-2] - 0.5),
(-0.5, shape[-1] - 0.5))
return bbox
def wcs_bbox_from_shape(shape):
"""Create a bounding box from the shape of the data.
This is appropriate to attach to a wcs object
Parameters
----------
shape : tuple
The shape attribute from a `numpy.ndarray` array
Returns
-------
bbox : tuple
Bounding box in x, y order.
"""
bbox = ((-0.5, shape[-1] - 0.5),
(-0.5, shape[-2] - 0.5))
return bbox
def bounding_box_from_subarray(input_model):
"""Create a bounding box from the subarray size.
Note: The bounding_box assumes full frame coordinates.
It is set to ((ystart, ystart + xsize), (xstart, xstart + xsize)).
It is in 0-based coordinates.
Parameters
----------
input_model : `~jwst.datamodels.JwstDataModel`
The data model.
Returns
-------
bbox : tuple
Bounding box in y, x order.
"""
bb_xstart = -0.5
bb_xend = -0.5
bb_ystart = -0.5
bb_yend = -0.5
if input_model.meta.subarray.xsize is not None:
bb_xend = input_model.meta.subarray.xsize - 0.5
if input_model.meta.subarray.ysize is not None:
bb_yend = input_model.meta.subarray.ysize - 0.5
bbox = ((bb_ystart, bb_yend), (bb_xstart, bb_xend))
return bbox
def update_s_region_imaging(model):
"""
Update the ``S_REGION`` keyword using ``WCS.footprint``.
"""
bbox = model.meta.wcs.bounding_box
if bbox is None:
bbox = wcs_bbox_from_shape(model.data.shape)
model.meta.wcs.bounding_box = bbox
# footprint is an array of shape (2, 4) as we
# are interested only in the footprint on the sky
footprint = model.meta.wcs.footprint(bbox, center=True, axis_type="spatial").T
# take only imaging footprint
footprint = footprint[:2, :]
# Make sure RA values are all positive
negative_ind = footprint[0] < 0
if negative_ind.any():
footprint[0][negative_ind] = 360 + footprint[0][negative_ind]
footprint = footprint.T
update_s_region_keyword(model, footprint)
def compute_footprint_spectral(model):
"""
Determine spatial footprint for spectral observations using the instrument model.
Parameters
----------
model : `~jwst.datamodels.IFUImageModel`
The output of assign_wcs.
"""
swcs = model.meta.wcs
bbox = swcs.bounding_box
if bbox is None:
bbox = wcs_bbox_from_shape(model.data.shape)
x, y = grid_from_bounding_box(bbox)
ra, dec, lam = swcs(x, y)
# the wrapped ra values are forced to be on one side of ra-border
# the wrapped ra are used to determine the correct min and max ra
ra = wrap_ra(ra)
min_ra = np.nanmin(ra)
max_ra = np.nanmax(ra)
# for the footprint we want the ra values to fall between 0 to 360
if min_ra < 0:
min_ra = min_ra + 360.0
if max_ra >= 360.0:
max_ra = max_ra - 360.0
footprint = np.array([[min_ra, np.nanmin(dec)],
[max_ra, np.nanmin(dec)],
[max_ra, np.nanmax(dec)],
[min_ra, np.nanmax(dec)]])
lam_min = np.nanmin(lam)
lam_max = np.nanmax(lam)
return footprint, (lam_min, lam_max)
def update_s_region_spectral(model):
""" Update the S_REGION keyword.
"""
footprint, spectral_region = compute_footprint_spectral(model)
update_s_region_keyword(model, footprint)
model.meta.wcsinfo.spectral_region = spectral_region
def compute_footprint_nrs_slit(slit):
""" Compute the footprint of a Nirspec slit using the instrument model.
Parameters
----------
slit : `~jwst.datamodels.SlitModel`
"""
slit2world = slit.meta.wcs.get_transform("slit_frame", "world")
# Define the corners of a virtual slit. The center of the slit is (0, 0).
virtual_corners_x = [-.5, -.5, .5, .5]
virtual_corners_y = [slit.slit_ymin, slit.slit_ymax, slit.slit_ymax, slit.slit_ymin]
# Use a default wavelength or 2 microns as input to the transform.
input_lam = [2e-6] * 4
ra, dec, lam = slit2world(virtual_corners_x,
virtual_corners_y,
input_lam)
footprint = np.array([ra, dec]).T
lam_min = np.nanmin(lam)
lam_max = np.nanmax(lam)
return footprint, (lam_min, lam_max)
def update_s_region_nrs_slit(slit):
footprint, spectral_region = compute_footprint_nrs_slit(slit)
update_s_region_keyword(slit, footprint)
slit.meta.wcsinfo.spectral_region = spectral_region
def update_s_region_keyword(model, footprint):
""" Update the S_REGION keyword.
"""
s_region = (
"POLYGON ICRS "
" {0:.9f} {1:.9f}"
" {2:.9f} {3:.9f}"
" {4:.9f} {5:.9f}"
" {6:.9f} {7:.9f}".format(*footprint.flatten()))
if "nan" in s_region:
# do not update s_region if there are NaNs.
log.info("There are NaNs in s_region, S_REGION not updated.")
else:
model.meta.wcsinfo.s_region = s_region
log.info("Update S_REGION to {}".format(model.meta.wcsinfo.s_region))
def compute_footprint_nrs_ifu(dmodel, mod):
"""
Determine NIRSPEC IFU footprint using the instrument model.
For efficiency this function uses the transforms directly,
instead of the WCS object. The common transforms in the WCS
model chain are referenced and reused; only the slice specific
transforms are computed.
If the transforms change this function should be revised.
Parameters
----------
output_model : `~jwst.datamodels.IFUImageModel`
The output of assign_wcs.
mod : module
The imported ``nirspec`` module.
Returns
-------
footprint : ndarray
The spatial footprint
spectral_region : tuple
The wavelength range for the observation.
"""
ra_total = []
dec_total = []
lam_total = []
_, wrange = mod.spectral_order_wrange_from_model(dmodel)
pipe = dmodel.meta.wcs.pipeline
# Get the GWA to slit_frame transform
g2s = pipe[2].transform
# Construct a list of the transforms between coordinate frames.
# Set a place holder ``Identity`` transform at index 2 and 3.
# Update them with slice specific transforms.
transforms = [pipe[0].transform]
transforms.append(pipe[1].transform[1:])
transforms.append(astmodels.Identity(1))
transforms.append(astmodels.Identity(1))
transforms.extend([step.transform for step in pipe[4:-1]])
for sl in range(30):
transforms[2] = g2s.get_model(sl)
# Create the full transform from ``slit_frame`` to ``detector``.
# It is used to compute the bounding box.
m = functools.reduce(lambda x, y: x | y, [tr.inverse for tr in transforms[:3][::-1]])
bbox = mod.compute_bounding_box(m, wrange)
# Add the remaining transforms - from ``sli_frame`` to ``world``
transforms[3] = pipe[3].transform.get_model(sl) & astmodels.Identity(1)
mforw = functools.reduce(lambda x, y: x | y, transforms)
x1, y1 = grid_from_bounding_box(bbox)
ra, dec, lam = mforw(x1, y1)
ra_total.extend(np.ravel(ra))
dec_total.extend(np.ravel(dec))
lam_total.extend(np.ravel(lam))
# the wrapped ra values are forced to be on one side of ra-border
# the wrapped ra are used to determine the correct min and max ra
ra_total = wrap_ra(ra_total)
ra_max = np.nanmax(ra_total)
ra_min = np.nanmin(ra_total)
# for the footprint we want ra to be between 0 to 360
if ra_min < 0:
ra_min = ra_min + 360.0
if ra_max >= 360.0:
ra_max = ra_max - 360.0
dec_max = np.nanmax(dec_total)
dec_min = np.nanmin(dec_total)
lam_max = np.nanmax(lam_total)
lam_min = np.nanmin(lam_total)
footprint = np.array([ra_min, dec_min, ra_max, dec_min, ra_max, dec_max, ra_min, dec_max])
return footprint, (lam_min, lam_max)
def update_s_region_nrs_ifu(output_model, mod):
"""
Update S_REGION for NRS_IFU observations using calculated footprint.
Parameters
----------
output_model : `~jwst.datamodels.IFUImageModel`
The output of assign_wcs.
mod : module
The imported ``nirspec`` module.
"""
footprint, spectral_region = compute_footprint_nrs_ifu(output_model, mod)
update_s_region_keyword(output_model, footprint)
output_model.meta.wcsinfo.spectral_region = spectral_region
def update_s_region_mrs(output_model):
"""
Update S_REGION for MIRI_MRS observations using the WCS transforms.
Parameters
----------
output_model : `~jwst.datamodels.IFUImageModel`
The output of assign_wcs.
"""
footprint, spectral_region = compute_footprint_spectral(output_model)
update_s_region_keyword(output_model, footprint)
output_model.meta.wcsinfo.spectral_region = spectral_region
def velocity_correction(velosys):
"""
Compute wavelength correction to Barycentric reference frame.
Parameters
----------
velosys : float
Radial velocity wrt Barycenter [m / s].
"""
correction = (1 / (1 + velosys / c.value))
model = astmodels.Identity(1) * astmodels.Const1D(correction, name="velocity_correction")
model.inverse = astmodels.Identity(1) / astmodels.Const1D(correction, name="inv_vel_correction")
return model
def wrap_ra(ravalues):
"""Test for 0/360 wrapping in ra values.
If exists it makes it difficult to determine
ra range of a region on the sky. This problem is solved by putting them all
on "one side" of 0/360 border
Parameters
----------
ravalues : numpy.ndarray
input RA values
Returns
------
a numpy array of ra values all on "same side" of 0/360 border
"""
ravalues_array = np.array(ravalues)
index_good = np.where(np.isfinite(ravalues_array))
ravalues_wrap = ravalues_array[index_good].copy()
median_ra = np.nanmedian(ravalues_wrap)
# using median to test if there is any wrapping going on
wrap_index = np.where(np.fabs(ravalues_wrap - median_ra) > 180.0)
nwrap = wrap_index[0].size
# get all the ra on the same "side" of 0/360
if nwrap != 0 and median_ra < 180:
ravalues_wrap[wrap_index] = ravalues_wrap[wrap_index] - 360.0
if nwrap != 0 and median_ra > 180:
ravalues_wrap[wrap_index] = ravalues_wrap[wrap_index] + 360.0
# if the input ravaules are a list - return a list
if isinstance(ravalues, list):
ravalues = ravalues_wrap.tolist()
return ravalues_wrap
def in_ifu_slice(slice_wcs, ra, dec, lam):
"""
Given RA, DEC and LAM return the x, y positions within a slice.
Parameters
----------
slice_wcs : `~gwcs.WCS`
Slice WCS object.
ra, dec, lam : float, ndarray
Physical Coordinates.
Returns
-------
x, y : float, ndarray
x, y locations within the slice.
"""
slicer2world = slice_wcs.get_transform('slicer', 'world')
slx, sly, sllam = slicer2world.inverse(ra, dec, lam)
# Compute the slice X coordinate using the center of the slit.
SLX, _, _ = slice_wcs.get_transform('slit_frame', 'slicer')(0, 0, 2e-6)
onslice_ind = np.isclose(slx, SLX, atol=5e-4)
return onslice_ind
[docs]
def update_fits_wcsinfo(datamodel, max_pix_error=0.01, degree=None, npoints=32,
crpix=None, projection='TAN', imwcs=None, **kwargs):
"""
Update ``datamodel.meta.wcsinfo`` based on a FITS WCS + SIP approximation
of a GWCS object. By default, this function will approximate
the datamodel's GWCS object stored in ``datamodel.meta.wcs`` but it can
also approximate a user-supplied GWCS object when provided via
the ``imwcs`` parameter.
The default mode in using this attempts to achieve roughly 0.01 pixel
accuracy over the entire image.
This function uses the :py:meth:`~gwcs.wcs.WCS.to_fits_sip` to
create FITS WCS representations of GWCS objects. Only most important
:py:meth:`~gwcs.wcs.WCS.to_fits_sip` parameters are exposed here. Other
arguments to :py:meth:`~gwcs.wcs.WCS.to_fits_sip` can be passed via
``kwargs`` - see "Other Parameters" section below.
Please refer to the documentation of :py:meth:`~gwcs.wcs.WCS.to_fits_sip`
for more details.
.. warning::
This function modifies input data model's ``datamodel.meta.wcsinfo``
members.
Parameters
----------
datamodel : `ImageModel`
The input data model for imaging or WFSS mode whose ``meta.wcsinfo``
field should be updated from GWCS. By default, ``datamodel.meta.wcs``
is used to compute FITS WCS + SIP approximation. When ``imwcs`` is
not `None` then computed FITS WCS will be an approximation of the WCS
provided through the ``imwcs`` parameter.
max_pix_error : float, optional
Maximum allowed error over the domain of the pixel array. This
error is the equivalent pixel error that corresponds to the maximum
error in the output coordinate resulting from the fit based on
a nominal plate scale.
degree : int, iterable, None, optional
Degree of the SIP polynomial. Default value `None` indicates that
all allowed degree values (``[1...6]``) will be considered and
the lowest degree that meets accuracy requerements set by
``max_pix_error`` will be returned. Alternatively, ``degree`` can be
an iterable containing allowed values for the SIP polynomial degree.
This option is similar to default `None` but it allows caller to
restrict the range of allowed SIP degrees used for fitting.
Finally, ``degree`` can be an integer indicating the exact SIP degree
to be fit to the WCS transformation. In this case
``max_pixel_error`` is ignored.
npoints : int, optional
The number of points in each dimension to sample the bounding box
for use in the SIP fit. Minimum number of points is 3.
crpix : list of float, None, optional
Coordinates (1-based) of the reference point for the new FITS WCS.
When not provided, i.e., when set to `None` (default) the reference
pixel already specified in ``wcsinfo`` will be re-used. If
``wcsinfo`` does not contain ``crpix`` information, then the
reference pixel will be chosen near the center of the bounding box
for axes corresponding to the celestial frame.
projection : str, `~astropy.modeling.projections.Pix2SkyProjection`, optional
Projection to be used for the created FITS WCS. It can be specified
as a string of three characters specifying a FITS projection code
from Table 13 in
`Representations of World Coordinates in FITS \
<https://doi.org/10.1051/0004-6361:20021326>`_
(Paper I), Greisen, E. W., and Calabretta, M. R., A & A, 395,
1061-1075, 2002. Alternatively, it can be an instance of one of the
`astropy's Pix2Sky_* <https://docs.astropy.org/en/stable/modeling/\
reference_api.html#module-astropy.modeling.projections>`_
projection models inherited from
:py:class:`~astropy.modeling.projections.Pix2SkyProjection`.
imwcs : `gwcs.WCS`, None, optional
Imaging GWCS object for WFSS mode whose FITS WCS approximation should
be computed and stored in the ``datamodel.meta.wcsinfo`` field.
When ``imwcs`` is `None` then WCS from ``datamodel.meta.wcs``
will be used.
.. warning::
Used with WFSS modes only. For other modes, supplying a different
WCS from ``datamodel.meta.wcs`` will result in the GWCS and
FITS WCS descriptions to diverge.
Other Parameters
----------------
max_inv_pix_error : float, None, optional
Maximum allowed inverse error over the domain of the pixel array
in pixel units. With the default value of `None` no inverse
is generated.
inv_degree : int, iterable, None, optional
Degree of the SIP polynomial. Default value `None` indicates that
all allowed degree values (``[1...6]``) will be considered and
the lowest degree that meets accuracy requerements set by
``max_pix_error`` will be returned. Alternatively, ``degree`` can be
an iterable containing allowed values for the SIP polynomial degree.
This option is similar to default `None` but it allows caller to
restrict the range of allowed SIP degrees used for fitting.
Finally, ``degree`` can be an integer indicating the exact SIP degree
to be fit to the WCS transformation. In this case
``max_inv_pixel_error`` is ignored.
bounding_box : tuple, None, optional
A pair of tuples, each consisting of two numbers
Represents the range of pixel values in both dimensions
((xmin, xmax), (ymin, ymax))
verbose : bool, optional
Print progress of fits.
Returns
-------
FITS header with all SIP WCS keywords
Raises
------
ValueError
If the WCS is not at least 2D, an exception will be raised. If the
specified accuracy (both forward and inverse, both rms and maximum)
is not achieved an exception will be raised.
Notes
-----
Use of this requires a judicious choice of required accuracies.
Attempts to use higher degrees (~7 or higher) will typically fail due
to floating point problems that arise with high powers.
For more details, see :py:meth:`~gwcs.wcs.WCS.to_fits_sip`.
"""
if crpix is None:
crpix = [datamodel.meta.wcsinfo.crpix1, datamodel.meta.wcsinfo.crpix2]
if None in crpix:
crpix = None
# For WFSS modes the imaging WCS is passed as an argument.
# For imaging modes it is retrieved from the datamodel.
if imwcs is None:
imwcs = datamodel.meta.wcs
# make a copy of kwargs:
kwargs = {k: v for k, v in kwargs.items()}
# override default values for "other parameters":
max_inv_pix_error = kwargs.pop('max_inv_pix_error', None)
inv_degree = kwargs.pop('inv_degree', None)
if inv_degree is None:
inv_degree = range(1, _MAX_SIP_DEGREE)
# limit default 'degree' range to _MAX_SIP_DEGREE:
if degree is None:
degree = range(1, _MAX_SIP_DEGREE)
hdr = imwcs.to_fits_sip(
max_pix_error=max_pix_error,
degree=degree,
max_inv_pix_error=max_inv_pix_error,
inv_degree=inv_degree,
npoints=npoints,
crpix=crpix,
**kwargs
)
# update meta.wcsinfo with FITS keywords except for naxis*
del hdr['naxis*']
# maintain convention of lowercase keys
hdr_dict = {k.lower(): v for k, v in hdr.items()}
# delete naxis, cdelt, pc from wcsinfo
rm_keys = ['naxis', 'cdelt1', 'cdelt2',
'pc1_1', 'pc1_2', 'pc2_1', 'pc2_2',
'a_order', 'b_order', 'ap_order', 'bp_order']
rm_keys.extend(f"{s}_{i}_{j}" for i in range(10) for j in range(10)
for s in ['a', 'b', 'ap', 'bp'])
for key in rm_keys:
if key in datamodel.meta.wcsinfo.instance:
del datamodel.meta.wcsinfo.instance[key]
# update meta.wcs_info with fit keywords
datamodel.meta.wcsinfo.instance.update(hdr_dict)
return hdr
def wfss_imaging_wcs(wfss_model, imaging, bbox=None, **kwargs):
""" Add a FITS WCS approximation for imaging mode to WFSS headers.
Parameters
----------
wfss_model : `~ImageModel`
Input WFSS model (NRC or NIS).
imaging : func, callable
The ``imaging`` function in the ``niriss`` or ``nircam`` modules.
bbox : tuple or None
The bounding box over which to approximate the distortion solution.
Typically this is based on the shape of the direct image.
"""
xstart = wfss_model.meta.subarray.xstart
ystart = wfss_model.meta.subarray.ystart
reference_files = get_wcs_reference_files(wfss_model)
image_pipeline = imaging(wfss_model, reference_files)
imwcs = WCS(image_pipeline)
if bbox is not None:
imwcs.bounding_box = bbox
elif xstart is not None and ystart is not None and (xstart != 1 or ystart != 1):
imwcs.bounding_box = bounding_box_from_subarray(wfss_model)
else:
imwcs.bounding_box = wcs_bbox_from_shape(wfss_model.data.shape)
_ = update_fits_wcsinfo(wfss_model, projection='TAN', imwcs=imwcs, bounding_box=None, **kwargs)
def get_wcs_reference_files(datamodel):
"""Retrieve names of WCS reference files for NIS_WFSS and NRC_WFSS modes.
Parameters
----------
datamodel : `~ImageModel`
Input WFSS file (NRC or NIS).
"""
from jwst.assign_wcs import AssignWcsStep
refs = {}
step = AssignWcsStep()
for reftype in AssignWcsStep.reference_file_types:
val = step.get_reference_file(datamodel, reftype)
if val.strip() == 'N/A':
refs[reftype] = None
else:
refs[reftype] = val
return refs