import inspect
from functools import wraps
from itertools import chain
import numpy as np
import astropy.nddata
from astropy.wcs.wcsapi import BaseHighLevelWCS, BaseLowLevelWCS, HighLevelWCSWrapper, SlicedLowLevelWCS
from ndcube.utils import wcs as wcs_utils
from ndcube.utils.exceptions import warn_user
__all__ = [
"get_crop_item_from_points",
"propagate_rebin_uncertainties",
"sanitize_crop_inputs",
"sanitize_wcs",
]
[docs]
def sanitize_wcs(func):
"""
A wrapper for NDCube methods to sanitise the wcs argument.
This decorator is only designed to be used on methods of NDCube.
It will find the wcs argument, keyword or positional and if it is `None`, set
it to ``self.wcs``.
It will then verify that the WCS has a matching number of pixel dimensions
to the dimensionality of the array. It will finally verify that the object
passed is a HighLevelWCS object, or an ExtraCoords object.
"""
# This needs to be here to prevent a circular import
from ndcube.extra_coords.extra_coords import ExtraCoords # noqa: PLC0415
@wraps(func)
def wcs_wrapper(*args, **kwargs):
sig = inspect.signature(func)
params = sig.bind(*args, **kwargs)
wcs = params.arguments.get('wcs', None)
self = params.arguments['self']
if wcs is None:
wcs = self.wcs
if not isinstance(wcs, ExtraCoords):
if not wcs.pixel_n_dim == self.data.ndim:
raise ValueError("The supplied WCS must have the same number of "
"pixel dimensions as the NDCube object. "
"If you specified `cube.extra_coords.wcs` "
"please just pass `cube.extra_coords`.")
if not isinstance(wcs, (BaseHighLevelWCS, ExtraCoords)):
raise TypeError("wcs argument must be a High Level WCS or an ExtraCoords object.")
params.arguments['wcs'] = wcs
return func(*params.args, **params.kwargs)
return wcs_wrapper
[docs]
def get_crop_item_from_points(points, wcs, crop_by_values, keepdims, original_shape):
"""
Find slice item that crops to minimum cube in array-space containing specified world points.
Parameters
----------
points : iterable of iterables
Each iterable represents a point in real world space.
Each element in a point gives the real world coordinate value of the point
in high-level coordinate objects or quantities.
(Must be consistently high or low level within and across points.)
Objects must be in the order required by
wcs.world_to_array_index/world_to_array_index_values.
wcs : `~astropy.wcs.wcsapi.BaseHighLevelWCS`, `~astropy.wcs.wcsapi.BaseLowLevelWCS`
The WCS to use to convert the world coordinates to array indices.
crop_by_values : `bool`
Denotes whether cropping is done using high-level objects or "values",
i.e. low-level objects.
keepdims : `bool`
If `False`, returned item will drop length-1 dimensions otherwise, item will keep length-1 dimensions.
original_shape: `tuple` of `int`
The shape of the data cube before cropping.
Returns
-------
item : `tuple` of `slice`
The slice item for each axis of the cube which, when applied to the cube,
will return the minimum cube in array-index-space that contains all the
input world points.
"""
# Define a list of lists to hold the pixel coordinates of the points
# where each inner list gives the pixel coordinates of all points for that pixel axis.
# Recall that pixel axis ordering is reversed compared to array axis ordering.
combined_points_pixel_idx = [[]] * wcs.pixel_n_dim
high_level_wcs = HighLevelWCSWrapper(wcs) if isinstance(wcs, BaseLowLevelWCS) else wcs
low_level_wcs = high_level_wcs.low_level_wcs
# For each point compute the corresponding array indices.
for point in points:
# Get the pixel axes associated with each element in point.
point_inputs_pixel_axes = (
tuple(wcs_utils.world_axis_to_pixel_axes(i, low_level_wcs.axis_correlation_matrix)
for i in range(low_level_wcs.world_n_dim)) if crop_by_values
else wcs_utils.pixel_indices_for_world_objects(high_level_wcs))
# Get indices of pixel axes which correspond to only None inputs in point
# as well as those that correspond to a coord.
point_indices_with_inputs = []
pixel_axes_with_input = []
for i, coord in enumerate(point):
if coord is not None:
point_indices_with_inputs.append(i)
pixel_axes_with_input.append(point_inputs_pixel_axes[i])
pixel_axes_with_input = set(chain.from_iterable(pixel_axes_with_input))
pixel_axes_without_input = set(range(low_level_wcs.pixel_n_dim)) - pixel_axes_with_input
pixel_axes_with_input = np.array(list(pixel_axes_with_input))
pixel_axes_without_input = np.array(list(pixel_axes_without_input))
# Slice out the axes that do not correspond to a coord
# from the WCS and the input point.
if len(pixel_axes_without_input) > 0:
array_axes_without_input = wcs_utils.convert_between_array_and_pixel_axes(
pixel_axes_without_input, low_level_wcs.pixel_n_dim)
wcs_slice = np.array([slice(None)] * low_level_wcs.pixel_n_dim)
wcs_slice[array_axes_without_input] = 0
sliced_wcs = SlicedLowLevelWCS(low_level_wcs, slices=tuple(wcs_slice))
sliced_point = np.array(point, dtype=object)[np.array(point_indices_with_inputs)]
else:
# Else, if all axes have at least one crop input, no need to slice the WCS.
sliced_wcs, sliced_point = low_level_wcs, np.array(point, dtype=object)
# Derive the pixel indices of the input point and place each index
# in the list corresponding to its axis.
# Use the to_pixel methods to preserve fractional indices for future rounding.
point_pixel_indices = (sliced_wcs.world_to_pixel_values(*sliced_point) if crop_by_values
else HighLevelWCSWrapper(sliced_wcs).world_to_pixel(*sliced_point))
# For each pixel axis associated with this point, place the pixel coords for
# that pixel axis into the corresponding list within combined_points_pixel_idx.
if sliced_wcs.pixel_n_dim == 1:
point_pixel_indices = (point_pixel_indices,)
for axis, index in zip(pixel_axes_with_input, point_pixel_indices):
combined_points_pixel_idx[axis] = combined_points_pixel_idx[axis] + [index]
# Iterate through each array axis to determine the min and max pixel coords
# and then convert to array indices. Note that combined_points_pixel_idx holds the
# pixel coords for each pixel axis. Therefore, to iterate in array axis order,
# combined_points_pixel_idx must be reversed.
item = []
ambiguous = False
message = ""
result_is_scalar = True
for array_axis, pixel_coords in enumerate(combined_points_pixel_idx[::-1]):
if pixel_coords == []:
result_is_scalar = False
item.append(slice(None))
else:
# Calculate the index of the array element containing the pixel coordinate.
# Note that integer pixel coordinates correspond to the pixel center,
# while integer array indices correspond to lower edge of desired array element.
# Therefore a shift of 0.5 is required in the conversion.
# The max idx conversion below will discard right-ward array element if
# max pixel coord corresponds to a pixel edge.
min_array_idx = int(np.floor(min(pixel_coords) + 0.5))
max_array_idx = int(np.ceil(max(pixel_coords) - 0.5)) + 1
# Raise error if indices all lie below or all lie above array axis's extent.
# Exception: min_array_idx == max_array_idx == 0 is allowed because max_array_idx
# will be later changed to 1.
if (min_array_idx < 0 and max_array_idx <= 0) or min_array_idx >= original_shape[array_axis]:
raise ValueError(f"All world points associated with array axis {array_axis}"
" are outside the range of the NDCube being cropped.")
# world_to_array_index uses negative indices to represent locations to the left
# of the 0th pixel, while python slicing uses them to count backwards from the
# last element in the array. Therefore, set negative indices to 0.
# Note that we've already checked that the max pixel_coord is >= 0.
# Also note that there's no need to clip the max array idx, as values above
# the array extent does not cause ambiguity in the slicing so long as the
# min array idx is below that upper extent, which has also already been checked
# by the above error.
if min_array_idx < 0:
min_array_idx = 0
# Due to the above calculation, the above min and max array indices can only be
# same if the original pixel coords correspond to the same pixel edge.
# If this is the case, increment the max array index by 1 so the rightward array
# element is kept. Also, build a warning message about this to be raised later.
if min_array_idx == max_array_idx:
ambiguous = True
max_array_idx += 1
if min_array_idx == 0:
message += (f"All input points corresponding to array axis {array_axis} lie on "
"the lower boundary of array element 0 (the first element). "
"The cropped NDCube will only include array element 0.\n")
else:
message += (f"All input points corresponding to array axis {array_axis} lie on "
f"the boundary between array elements {min_array_idx - 1} and "
f"{min_array_idx}. The cropped NDCube will only include array "
f"element {min_array_idx}.\n")
if max_array_idx - min_array_idx == 1 and not keepdims:
item.append(min_array_idx)
else:
item.append(slice(min_array_idx, max_array_idx))
result_is_scalar = False
# Raise warning if all world values for any array axes correspond to a pixel edge.
if ambiguous:
warn_user(message)
# If item will result in a scalar cube, raise an error as this is not currently supported.
if result_is_scalar:
raise ValueError("Input points causes cube to be cropped to a single pixel. "
"This is not supported when keepdims=False.")
return tuple(item)
[docs]
def propagate_rebin_uncertainties(uncertainty, data, mask, operation, operation_ignores_mask=False,
propagation_operation=None, correlation=0, **kwargs):
"""
Default algorithm for uncertainty propagation in :meth:`~ndcube.NDCube.rebin`.
First dimension of uncertainty, data and mask inputs represent the pixels
in the bin being aggregated by the rebin process while the latter dimensions
must have the same shape as the rebinned data. The operation input is the
function used to aggregate elements in the first dimension, e.g. `numpy.sum`.
Parameters
----------
uncertainty: `astropy.nddata.NDUncertainty`
Cannot be instance of `astropy.nddata.UnknownUncertainty`.
The uncertainties associated with the data. The first dimension represents
pixels in each bin being aggregated while trailing dimensions must have
the same shape as the rebinned data.
data: array-like or `None`
The data associated with the above uncertainties.
Must have same shape as above.
mask: array-like of `bool` or `None`
Indicates whether any uncertainty elements should be ignored in propagation.
If True, corresponding uncertainty element is ignored. If False, it is used.
Must have same shape as above.
operation: function
The function used to aggregate the data for which the uncertainties are being
propagated here.
operation_ignores_mask: `bool`
Determines whether masked values are used or excluded from calculation.
Default is False causing masked data and uncertainty to be excluded.
propagation_operation: function
The operation which defines how the uncertainties are propagated.
This can differ from operation, e.g. if operation is sum, then
propagation_operation should be add.
correlation: `int`
Passed to `astropy.nddata.NDUncertainty.propagate`. See that method's docstring.
Default=0.
Returns
-------
new_uncertainty: `astropy.nddata.NDUncertainty`
The propagated uncertainty.
Same shape as input uncertainty without its first dimension.
"""
flat_axis = 0
operation_is_mean = True if operation in {np.mean, np.nanmean} else False
operation_is_nantype = True if operation in {np.nansum, np.nanmean, np.nanprod} else False
# If propagation_operation kwarg not set manually, try to set it based on operation kwarg.
if not propagation_operation:
if operation in {np.sum, np.nansum, np.mean, np.nanmean}:
propagation_operation = np.add
# TODO: product was renamed to prod for numpy 2.0
elif operation in {np.prod, np.nanprod}:
propagation_operation = np.multiply
else:
raise ValueError("propagation_operation not recognized.")
# Build mask if not provided.
new_uncertainty = uncertainty[0] # Define uncertainty for initial iteration step.
if operation_ignores_mask or mask is None:
mask = False
if mask is False:
if operation_is_nantype:
nan_mask = np.isnan(data)
if nan_mask.any():
mask = nan_mask
idx = np.logical_not(mask)
mask1 = mask[1:]
else:
# If there is no mask and operation is not nan-type, build generator
# so non-mask can still be iterated.
n_pix_per_bin = data.shape[flat_axis]
mask1 = (False for i in range(1, n_pix_per_bin))
else:
# Mask uncertainties corresponding to nan data if operation is nantype.
if operation_is_nantype:
mask[np.isnan(data)] = True
# Set masked uncertainties in first mask to 0
# as they shouldn't count towards final uncertainty.
mask1 = mask[1:]
idx = np.logical_not(mask)
uncertainty.array[mask] = 0
new_uncertainty.array[mask[0]] = 0
# Propagate uncertainties.
# Note uncertainty must be associated with a parent nddata for some propagations.
cumul_data = data[0]
if mask is not False and operation_ignores_mask is False:
cumul_data[idx[0]] = 0
parent_nddata = astropy.nddata.NDData(cumul_data, uncertainty=new_uncertainty)
new_uncertainty.parent_nddata = parent_nddata
for j, mask_slice in enumerate(mask1):
i = j + 1
cumul_data = operation(data[:i+1]) if mask is False else operation(data[:i+1][idx[:i+1]])
data_slice = astropy.nddata.NDData(data=data[i], mask=mask_slice,
uncertainty=uncertainty[i])
new_uncertainty = new_uncertainty.propagate(propagation_operation, data_slice,
cumul_data, correlation)
parent_nddata = astropy.nddata.NDData(cumul_data, uncertainty=new_uncertainty)
new_uncertainty.parent_nddata = parent_nddata
# If aggregation operation is mean, uncertainties must be divided by
# number of unmasked pixels in each bin.
if operation_is_mean and propagation_operation is np.add:
if mask is False:
new_uncertainty.array /= n_pix_per_bin
else:
unmasked_per_bin = np.logical_not(mask).astype(int).sum(axis=flat_axis)
new_uncertainty.array /= np.clip(unmasked_per_bin, 1, None)
return new_uncertainty