import logging
import warnings
from dataclasses import dataclass
import numpy as np
from scipy.optimize import minimize
from stdatamodels.jwst import datamodels
from jwst.assign_wcs import nirspec
log = logging.getLogger(__name__)
__all__ = ["PixelReplaceArrays", "PixelReplacement"]
[docs]
@dataclass
class PixelReplaceArrays:
"""
Container for data arrays and dispersion direction.
Algorithms operate on this dataclass rather than on a
`~stdatamodels.jwst.datamodels.JwstDataModel`.
This avoids the overhead of constructing intermediate DataModel objects,
which was slowing runtime for TSO data with thousands of integrations,
and provides a consistent interface for :meth:`PixelReplacement.mingrad`
and :meth:`PixelReplacement.fit_profile`.
Attributes
----------
data : ndarray
Science array.
dq : ndarray
Data quality array.
err : ndarray
Total error array.
var_poisson : ndarray or None
Poisson variance array.
var_rnoise : ndarray or None
Read-noise variance array.
var_flat : ndarray or None
Flat-field variance array.
dispersion_direction : int
Dispersion direction.
"""
data: np.ndarray
dq: np.ndarray
err: np.ndarray
var_poisson: np.ndarray | None
var_rnoise: np.ndarray | None
var_flat: np.ndarray | None
dispersion_direction: int
[docs]
class PixelReplacement:
"""
Main class for performing pixel replacement.
This class controls loading the input data model, selecting the
method for pixel replacement, and executing each step. This class
should provide modularization to allow for multiple options and possible
future reference files.
Parameters
----------
input_model : `~stdatamodels.jwst.datamodels.JwstDataModel`
Datamodel with bad pixels to replace. Updated in-place.
**pars : dict, optional
Optional parameters to modify how pixel replacement
will execute.
"""
# Shortcuts for DQ Flags
DO_NOT_USE = datamodels.dqflags.pixel["DO_NOT_USE"]
FLUX_ESTIMATED = datamodels.dqflags.pixel["FLUX_ESTIMATED"]
NON_SCIENCE = datamodels.dqflags.pixel["NON_SCIENCE"]
# Shortcuts for dispersion direction for ease of reading
HORIZONTAL = 1
VERTICAL = 2
LOG_SLICE = ["column", "row"]
def __init__(self, input_model, **pars):
self.input = input_model
self.pars = {}
self.pars.update(pars)
# Store algorithm options here.
self.algorithm_dict = {
"fit_profile": self.fit_profile,
"mingrad": self.mingrad,
}
# Choose algorithm from dict using input par.
try:
self.algorithm = self.algorithm_dict[self.pars["algorithm"]]
except KeyError as err:
log.critical(
f"Algorithm name {self.pars['algorithm']} provided does "
"not match an implemented algorithm!"
)
raise KeyError from err
@staticmethod
def _arrays_from_model(model):
"""Extract PixelReplaceArrays from DataModel, copying arrays.""" # numpydoc ignore: RT01
return PixelReplaceArrays(
data=model.data.copy(),
dq=model.dq.copy(),
err=model.err.copy(),
var_poisson=model.var_poisson.copy(),
var_rnoise=model.var_rnoise.copy(),
var_flat=model.var_flat.copy(),
dispersion_direction=model.meta.wcsinfo.dispersion_direction,
)
@staticmethod
def _model_from_arrays(arrays, model):
"""Write PixelReplaceArrays back into a DataModel in place.""" # numpydoc ignore: RT01
model.data = arrays.data
model.dq = arrays.dq
model.err = arrays.err
model.var_poisson = arrays.var_poisson
model.var_rnoise = arrays.var_rnoise
model.var_flat = arrays.var_flat
[docs]
def replace(self):
"""
Unpack model and apply pixel replacement algorithm.
Process the input `~stdatamodels.jwst.datamodels.JwstDataModel`,
unpack any model that holds
more than one 2D spectrum, then apply selected algorithm
to each 2D spectrum in input.
"""
# ImageModel inputs (MIR_LRS-FIXEDSLIT)
# or 2D SlitModel inputs (e.g. NRS_FIXEDSLIT in spec3)
if isinstance(self.input, datamodels.ImageModel) or (
isinstance(self.input, datamodels.SlitModel) and self.input.data.ndim == 2
):
arrays = self._arrays_from_model(self.input)
arrays = self.algorithm(arrays)
self._model_from_arrays(arrays, self.input)
n_replaced = np.count_nonzero(self.input.dq & self.FLUX_ESTIMATED)
log.info(f"Input model had {n_replaced} pixels replaced.")
elif isinstance(self.input, datamodels.IFUImageModel):
# Attempt to run pixel replacement on each throw of the IFU slicer
# individually.
xx, yy = np.indices(self.input.data.shape)
if self.input.meta.exposure.type == "MIR_MRS":
if self.pars["algorithm"] == "mingrad":
# mingrad method
arrays = self._arrays_from_model(self.input)
arrays = self.algorithm(arrays)
self._model_from_arrays(arrays, self.input)
else:
# fit_profile method
det2ab = self.input.meta.wcs.get_transform(
self.input.meta.wcs.available_frames[0], "alpha_beta"
)
_, beta_array, _ = det2ab(yy, xx)
unique_beta = np.unique(beta_array)
unique_beta = unique_beta[~np.isnan(unique_beta)]
for i, beta in enumerate(unique_beta):
# Define a mask that is True where this trace is located
trace_mask = beta_array == beta
arrays = self._arrays_from_model(self.input)
arrays.dq = np.where(
# When not in this trace, set NON_SCIENCE and DO_NOT_USE
~trace_mask,
arrays.dq | self.DO_NOT_USE | self.NON_SCIENCE,
arrays.dq,
)
arrays = self.algorithm(arrays)
self.input.data = np.where(trace_mask, arrays.data, self.input.data)
self.input.dq = np.where(trace_mask, arrays.dq, self.input.dq)
self.input.err = np.where(trace_mask, arrays.err, self.input.err)
self.input.var_poisson = np.where(
trace_mask, arrays.var_poisson, self.input.var_poisson
)
self.input.var_rnoise = np.where(
trace_mask, arrays.var_rnoise, self.input.var_rnoise
)
self.input.var_flat = np.where(
trace_mask, arrays.var_flat, self.input.var_flat
)
n_replaced = np.count_nonzero(arrays.dq & self.FLUX_ESTIMATED)
log.info(
f"Input MRS frame had {n_replaced} pixels replaced "
f"in IFU slice {i + 1}."
)
n_replaced = np.count_nonzero(self.input.dq & self.FLUX_ESTIMATED)
log.info(f"Input MRS frame had {n_replaced} total pixels replaced.")
else:
if self.pars["algorithm"] == "mingrad":
# mingrad method
arrays = self._arrays_from_model(self.input)
arrays = self.algorithm(arrays)
self._model_from_arrays(arrays, self.input)
else:
# fit_profile method - iterate over IFU slices
for i in range(30):
slice_wcs = nirspec.nrs_wcs_set_input(self.input, i)
det2slicer = slice_wcs.get_transform(
self.input.meta.wcs.available_frames[0], "slicer"
)
_, _, wave = det2slicer(yy, xx)
# Define a mask that is True where this trace is located
trace_mask = wave > 0
arrays = self._arrays_from_model(self.input)
arrays.dq = np.where(
# When not in this trace, set NON_SCIENCE and DO_NOT_USE
~trace_mask,
arrays.dq | self.DO_NOT_USE | self.NON_SCIENCE,
arrays.dq,
)
arrays = self.algorithm(arrays)
self.input.data = np.where(trace_mask, arrays.data, self.input.data)
self.input.dq = np.where(trace_mask, arrays.dq, self.input.dq)
self.input.err = np.where(trace_mask, arrays.err, self.input.err)
self.input.var_poisson = np.where(
trace_mask, arrays.var_poisson, self.input.var_poisson
)
self.input.var_rnoise = np.where(
trace_mask, arrays.var_rnoise, self.input.var_rnoise
)
self.input.var_flat = np.where(
trace_mask, arrays.var_flat, self.input.var_flat
)
n_replaced = np.count_nonzero(arrays.dq & self.FLUX_ESTIMATED)
log.info(
f"Input NRS_IFU frame had {n_replaced} pixels "
f"replaced in IFU slice {i + 1}."
)
n_replaced = np.count_nonzero(self.input.dq & self.FLUX_ESTIMATED)
log.info(f"Input NRS_IFU frame had {n_replaced} total pixels replaced.")
# MultiSlitModel inputs (WFSS, NRS_FIXEDSLIT, ?)
elif isinstance(self.input, datamodels.MultiSlitModel):
for i, _slit in enumerate(self.input.slits):
slit_model = datamodels.SlitModel(self.input.slits[i].instance)
arrays = self._arrays_from_model(slit_model)
slit_model.close()
arrays = self.algorithm(arrays)
n_replaced = np.count_nonzero(arrays.dq & self.FLUX_ESTIMATED)
log.info(f"Slit {i} had {n_replaced} pixels replaced.")
self._model_from_arrays(arrays, self.input.slits[i])
# CubeModel inputs are TSO (so far?); SlitModel may be NRS_BRIGHTOBJ,
# also requiring a re-packaging of the data into 2D inputs for the algorithm
elif isinstance(self.input, datamodels.CubeModel | datamodels.SlitModel):
dispaxis = self.input.meta.wcsinfo.dispersion_direction
for i in range(len(self.input.data)):
# Ensure variance arrays exist
var_dict = {
"var_poisson": None,
"var_rnoise": None,
"var_flat": None,
}
for key in var_dict.keys():
if self.input[key] is not None:
var_dict[key] = self.input[key][i].copy()
arrays = PixelReplaceArrays(
data=self.input.data[i].copy(),
dq=self.input.dq[i].copy(),
err=self.input.err[i].copy(),
var_poisson=var_dict["var_poisson"],
var_rnoise=var_dict["var_rnoise"],
var_flat=var_dict["var_flat"],
dispersion_direction=dispaxis,
)
arrays = self.algorithm(arrays)
n_replaced = np.count_nonzero(arrays.dq & self.FLUX_ESTIMATED)
log.info(f"Input TSO integration {i} had {n_replaced} pixels replaced.")
self.input.data[i] = arrays.data
self.input.dq[i] = arrays.dq
self.input.err[i] = arrays.err
for key in var_dict.keys():
if self.input[key] is not None:
self.input[key][i] = getattr(arrays, key)
else:
# This should never happen, as these should be caught in the step code.
log.critical(
"Pixel replacement code did not filter this input correctly - skipping step."
)
return
[docs]
def fit_profile(self, arrays):
"""
Replace pixels with the profile fit method.
Fit a profile to adjacent columns, scale profile to
column with missing pixel(s), and find flux estimate
from scaled profile.
Error and variance values for the replaced pixels
are similarly estimated, using the scales from the
profile fit to the data.
Parameters
----------
arrays : `PixelReplaceArrays`
Pixel arrays and dispersion direction for the 2D spectrum to process.
Arrays are modified in place.
Returns
-------
arrays : `PixelReplaceArrays`
The input with bad pixels now flagged with FLUX_ESTIMATED
and holding a flux value estimated from the spatial profile.
"""
# np.nanmedian() entry full of NaN values would produce a numpy
# warning (despite well-defined behavior - return a NaN)
# so we suppress that here.
warnings.filterwarnings(action="ignore", message="All-NaN slice encountered")
dispaxis = arrays.dispersion_direction
# Make a copy of the input DQ, before replacement
input_dq = arrays.dq.copy()
# Truncate array to region where good pixels exist
good_pixels = np.where(~input_dq & self.DO_NOT_USE)
if np.any(0 in np.shape(good_pixels)):
log.warning(
"No good pixels in at least one dimension of "
"data array - skipping pixel replacement."
)
return arrays
x_range = [np.min(good_pixels[0]), np.max(good_pixels[0]) + 1]
y_range = [np.min(good_pixels[1]), np.max(good_pixels[1]) + 1]
valid_shape = [x_range, y_range]
profile_cut = valid_shape[dispaxis - 1]
# COMMENTS NOTE:
# In comments and parameter naming, I will try to be consistent in using
# "profile" to describe vectors in the spatial, i.e. cross-dispersion direction,
# and "slice" to describe vectors in the spectral, i.e. dispersion direction.
# Create set of slice indices which we can later use for profile creation
valid_profiles = set(range(*valid_shape[2 - dispaxis]))
profiles_to_replace = set()
# Loop over axis of data array corresponding to cross-
# dispersion direction by indexing data shape with
# strange dispaxis argument. Keep indices in full-frame numbering scheme,
# but only iterate through slices with valid data.
for ind in range(*valid_shape[2 - dispaxis]):
# Exclude regions with no data for dq slice.
dq_slice = input_dq[self.custom_slice(dispaxis, ind)][profile_cut[0] : profile_cut[1]]
# Exclude regions with NON_SCIENCE flag
dq_slice = np.where(dq_slice & self.NON_SCIENCE, self.NON_SCIENCE, dq_slice)
# Find bad pixels in region containing valid data.
n_bad = np.count_nonzero(dq_slice & self.DO_NOT_USE)
n_nonscience = np.count_nonzero(dq_slice & self.NON_SCIENCE)
if n_bad + n_nonscience == len(dq_slice):
log.debug(f"Slice {ind} contains no good pixels. Skipping replacement.")
valid_profiles.discard(ind)
elif n_bad == 0:
log.debug(f"Slice {ind} contains no bad pixels.")
else:
log.debug(f"Slice {ind} contains {n_bad} bad pixels.")
profiles_to_replace.add(ind)
log.debug(f"Number of profiles with at least one bad pixel: {len(profiles_to_replace)}")
for ind in profiles_to_replace:
# Use sets for convenient finding of neighboring slices to use in profile creation
adjacent_inds = set(
range(ind - self.pars["n_adjacent_cols"], ind + self.pars["n_adjacent_cols"] + 1)
)
adjacent_inds.discard(ind)
valid_adjacent_inds = list(adjacent_inds.intersection(valid_profiles))
# Cut out valid neighboring profiles
adjacent_condition = self.custom_slice(dispaxis, valid_adjacent_inds)
profile_data = arrays.data[adjacent_condition]
profile_err = arrays.err[adjacent_condition]
if profile_data.size == 0:
log.info(
f"Profile in {self.LOG_SLICE[dispaxis - 1]} {ind} "
f"has no valid adjacent values - skipping."
)
continue
# Mask out bad pixels
invalid_condition = (input_dq[adjacent_condition] & self.DO_NOT_USE).astype(bool)
profile_data[invalid_condition] = np.nan
profile_err[invalid_condition] = np.nan
# Add additional cut to pull only from region with valid data
# for convenience (may not be necessary)
region_condition = self.custom_slice(3 - dispaxis, range(*profile_cut))
profile_data = profile_data[region_condition]
profile_snr = np.abs(profile_data / profile_err[region_condition])
# Normalize profile data
# TODO: check on signs here - absolute max sometimes picks up
# large negative outliers
profile_norm_scale = np.nanmax(np.abs(profile_data), axis=(dispaxis - 1), keepdims=True)
# If profile data has SNR < 5 everywhere just use unity scaling
# (so we don't normalize to noise)
if np.nanmax(profile_snr) < 5:
profile_norm_scale[:] = 1.0
normalized = profile_data / profile_norm_scale
# Get corresponding error and variance data and scale and mask to match
# Handle the variance arrays as errors, so the scales match.
err_names = ["err", "var_poisson", "var_rnoise", "var_flat"]
norm_errors = {}
for err_name in err_names:
if err_name.startswith("var"):
if (err_arr := getattr(arrays, err_name)) is None:
continue
err = np.sqrt(err_arr)
else:
err = getattr(arrays, err_name)
norm_err = err[adjacent_condition]
norm_err[invalid_condition] = np.nan
norm_errors[err_name] = norm_err[region_condition] / profile_norm_scale
# Pull median for each pixel across profile.
# Profile entry full of NaN values would produce a numpy
# warning (despite well-defined behavior - return a NaN)
# so we suppress that above.
median_profile = np.nanmedian(normalized, axis=(2 - dispaxis))
# Do the same for the errors
for err_name in norm_errors:
norm_errors[err_name] = np.nanmedian(norm_errors[err_name], axis=(2 - dispaxis))
# Clean current profile of values flagged as bad
current_condition = self.custom_slice(dispaxis, ind)
current_profile = arrays.data[current_condition]
cleaned_current = np.where(
input_dq[current_condition] & self.DO_NOT_USE, np.nan, current_profile
)[range(*profile_cut)]
replace_mask = np.where(~np.isnan(cleaned_current))[0]
if len(replace_mask) == 0:
log.info(
f"Profile in {self.LOG_SLICE[dispaxis - 1]} {ind} "
f"has no valid values - skipping."
)
continue
min_median = median_profile[replace_mask]
min_current = cleaned_current[replace_mask]
norm_current = min_current / np.max(min_current)
# Scale median profile to current profile with bad pixel - minimize mse?
# Only do this scaling if we didn't default to all-unity scaling above,
# and require input values below 1e20 so that we don't overflow the
# minimization routine with extremely bad noise.
if (
(np.nanmedian(profile_norm_scale) != 1.0)
& (np.nanmax(np.abs(min_median)) < 1e20)
& (np.nanmax(np.abs(norm_current)) < 1e20)
):
# TODO: check on signs here - absolute max sometimes picks up
# large negative outliers
norm_scale = minimize(
self.profile_mse,
x0=np.abs(np.nanmax(norm_current)),
args=(np.abs(min_median), np.abs(norm_current)),
method="Nelder-Mead",
).x
scale = np.max(min_current)
else:
norm_scale = 1.0
scale = 1.0
# Replace pixels that are do-not-use but not non-science
current_dq = input_dq[current_condition][range(*profile_cut)]
replace_condition = (current_dq & self.DO_NOT_USE ^ current_dq & self.NON_SCIENCE) == 1
replaced_current = np.where(
replace_condition, median_profile * norm_scale * scale, cleaned_current
)
# Change the dq bits where old flag was DO_NOT_USE and new value is not nan
replaced_dq = np.where(
replace_condition & ~(np.isnan(replaced_current)),
current_dq ^ self.DO_NOT_USE ^ self.FLUX_ESTIMATED,
current_dq,
)
# Update data and DQ in the output model
arrays.data[current_condition][range(*profile_cut)] = replaced_current
arrays.dq[current_condition][range(*profile_cut)] = replaced_dq
# Also update the errors and variances
current_err = arrays.err[current_condition][range(*profile_cut)]
replaced_err = np.where(
replace_condition, norm_errors["err"] * norm_scale * scale, current_err
)
arrays.err[current_condition][range(*profile_cut)] = replaced_err
# Some values in NIRSpec variances may overflow in the squares - ignore the warning.
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "overflow encountered", RuntimeWarning)
for var in ["var_poisson", "var_rnoise", "var_flat"]:
if (var_arr := getattr(arrays, var)) is not None:
current_var = var_arr[current_condition][range(*profile_cut)]
replaced_var = np.where(
replace_condition,
(norm_errors[var] * norm_scale * scale) ** 2,
current_var,
)
var_arr[current_condition][range(*profile_cut)] = replaced_var
setattr(arrays, var, var_arr)
return arrays
@staticmethod
def _interp_neighbors(arr, yindx, xindx):
"""
Interpolate using neighboring pixels in both horizontal and vertical directions.
Parameters
----------
arr : ndarray
2-D input array.
yindx, xindx : ndarray
1-D arrays, each length N, of row/column indices of the bad pixels.
Returns
-------
ndarray
Interpolations with shape of ``(2, N)`` in the horizontal (0th index)
and vertical (1st index) directions.
"""
horiz = (arr[yindx, xindx - 1] + arr[yindx, xindx + 1]) / 2.0
vert = (arr[yindx - 1, xindx] + arr[yindx + 1, xindx]) / 2.0
return np.array([horiz, vert])
[docs]
def mingrad(self, arrays):
"""
Replace pixels with the minimum gradient replacement method.
Test the gradient along the spatial and spectral axes using
immediately adjacent pixels. Pick whichever dimension has the minimum
absolute gradient and replace the missing pixel with the average
of the two adjacent pixels along that dimension.
This aims to make the process extremely local; near point sources it should do
the replacement along the spectral axis avoiding sampling issues, while near bright
extended emission line the replacement should be along the spatial axis. May still
be suboptimal near bright emission lines from unresolved point sources.
Does not attempt any replacement if a NaN value is bordered by another NaN value
along a given axis.
Parameters
----------
arrays : `PixelReplaceArrays`
Pixel arrays and dispersion direction for the 2D spectrum to process.
Arrays are modified in-place.
Returns
-------
arrays : `PixelReplaceArrays`
The input with flagged bad pixels now flagged with FLUX_ESTIMATED
and holding a flux value estimated from adjacent pixels.
"""
# np.nanmedian() entry full of NaN values would produce a numpy
# warning (despite well-defined behavior - return a NaN)
# so we suppress that here.
warnings.filterwarnings(action="ignore", message="All-NaN slice encountered")
log.info("Using minimum gradient method.")
in_var_dict = {
"var_poisson": None,
"var_rnoise": None,
"var_flat": None,
}
interp_rootvar_dict = {
"var_poisson": None,
"var_rnoise": None,
"var_flat": None,
}
for key in in_var_dict.keys():
if getattr(arrays, key) is not None:
in_var_dict[key] = getattr(arrays, key)
# Make an array of x/y values on the detector
(ysize, xsize) = arrays.data.shape
basex, basey = np.meshgrid(np.arange(xsize), np.arange(ysize))
pad = 1 # Padding around edge of array to ensure we don't look for neighbors outside array
# Find NaN-valued pixels
indx = np.where(
(~np.isfinite(arrays.data))
& (basey > pad)
& (basey < ysize - pad)
& (basex > pad)
& (basex < xsize - pad)
)
# X and Y indices
yindx, xindx = indx[0], indx[1]
# Absolute gradient along each axis from indata, shape (2, N), used to choose direction
diffs = np.array(
[
np.abs(arrays.data[yindx, xindx - 1] - arrays.data[yindx, xindx + 1]),
np.abs(arrays.data[yindx - 1, xindx] - arrays.data[yindx + 1, xindx]),
]
)
# Interpolated values for each quantity in both directions, shape (2, N)
interp_data = self._interp_neighbors(arrays.data, yindx, xindx)
interp_err = self._interp_neighbors(arrays.err, yindx, xindx)
# Propagate variance components as errors to get the scales right
for key in in_var_dict.keys():
if in_var_dict[key] is not None:
interp_rootvar_dict[key] = self._interp_neighbors(
np.sqrt(in_var_dict[key]), yindx, xindx
)
# Replace NaN diffs with inf so argmin naturally prefers the valid direction.
# Mask is True where at least one valid direction, False elsewhere,
# such that pixels where both diffs are inf have no usable neighbor pair and are skipped.
diffs_with_infs = np.where(np.isnan(diffs), np.inf, diffs) # (2, N)
mask = ~np.all(np.isinf(diffs_with_infs), axis=0) # (N,)
# Per-pixel direction index: 0 = horizontal, 1 = vertical
indmin = np.argmin(diffs_with_infs, axis=0) # (N,)
col_idx = np.arange(len(yindx))
# Select the minimium-gradient interpolated values and update model with them
indmin = indmin[mask]
col_idx = col_idx[mask]
arrays.data[yindx[mask], xindx[mask]] = interp_data[indmin, col_idx]
arrays.err[yindx[mask], xindx[mask]] = interp_err[indmin, col_idx]
# Square the interpolated errors back to variances for insertion
for key in interp_rootvar_dict.keys():
if interp_rootvar_dict[key] is not None:
in_var_dict[key][yindx[mask], xindx[mask]] = (
interp_rootvar_dict[key][indmin, col_idx] ** 2
)
setattr(arrays, key, in_var_dict[key])
# Update DQ flags for pixels that were replaced.
orig_dq = arrays.dq[yindx, xindx] # (N,)
remove_dnu = (
mask
& (orig_dq & self.DO_NOT_USE).astype(bool)
& ~(orig_dq & self.NON_SCIENCE).astype(bool)
)
arrays.dq[yindx[remove_dnu], xindx[remove_dnu]] -= self.DO_NOT_USE
arrays.dq[yindx[mask], xindx[mask]] |= self.FLUX_ESTIMATED
return arrays
[docs]
def custom_slice(self, dispaxis, index):
"""
Construct slice for ease of use with varying dispersion axis.
Parameters
----------
dispaxis : int
Using module-defined:
* 1 = HORIZONTAL
* 2 = VERTICAL
index : int or list
Index or indices of cross-dispersion
vectors to slice
Returns
-------
tuple
Slice constructed using numpy
"""
if dispaxis == self.HORIZONTAL:
return np.s_[:, index]
elif dispaxis == self.VERTICAL:
return np.s_[index, :]
else:
raise IndexError("Custom slice requires valid dispersion axis specification!")
[docs]
def profile_mse(self, scale, median, current):
"""
Calculate mean-squared error of fitted profile.
Parameters
----------
scale : float
Initial estimate of scale factor to bring
normalized median profile up to current profile
median : ndarray
Median profile constructed from neighboring profile slices
current : ndarray
Current profile with bad pixels to be replaced
Returns
-------
float
Mean-squared error for minimization purposes
"""
return np.nansum((current - (median * scale)) ** 2.0) / (
len(median) - np.count_nonzero(np.isnan(current))
)