Source code for hsr4hci.pca

"""
Methods for performing PCA-based PSF subtraction.
"""

# -----------------------------------------------------------------------------
# IMPORTS
# -----------------------------------------------------------------------------

from copy import deepcopy
from typing import List, Optional, Tuple, Union, overload
from typing_extensions import Literal

from warnings import warn

from sklearn.decomposition import PCA

import numpy as np

from hsr4hci.derotating import derotate_combine
from hsr4hci.utils import check_consistent_size


# -----------------------------------------------------------------------------
# FUNCTION DEFINITIONS
# -----------------------------------------------------------------------------

@overload
def get_pca_signal_estimates(
    stack: np.ndarray,
    parang: np.ndarray,
    n_components: Union[int, List[int], Tuple[int], np.ndarray],
    return_components: Literal[True],
    roi_mask: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, np.ndarray]:
    ...  # pragma: no cover


@overload
def get_pca_signal_estimates(
    stack: np.ndarray,
    parang: np.ndarray,
    n_components: Union[int, List[int], Tuple[int], np.ndarray],
    return_components: Literal[False],
    roi_mask: Optional[np.ndarray] = None,
) -> np.ndarray:
    ...  # pragma: no cover


[docs]def get_pca_signal_estimates( stack: np.ndarray, parang: np.ndarray, n_components: Union[int, List[int], Tuple[int], np.ndarray], return_components: bool = True, roi_mask: Optional[np.ndarray] = None, ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: """ Get the signal estimate (i.e., the derotated and combined stack that has been denoised) using PCA-based PSF subtraction for different numbers of principal components. .. note:: This function essentially provides a rather minimalistic implementation of PynPoint's :class:`pynpoint.processing.psfsubtraction.PcaPsfSubtractionModule`. Args: stack: A 3D numpy array of shape `(n_frames, width, height)` containing the stack for which to estimate the systematic noise using PCA. parang: A numpy array of shape `(n_frames,)` which contains the respective parallactic angle for each frame (necessary for derotating the stack). n_components: An iterable of integers, containing the values for the numbers of principal components for which to run PCA. return_components: Whether to return the principal components. roi_mask: A 2D binary mask of shape `(x_size, y_size)` that can be used to select the region of interest. If a ROI mask is given, only the pixels inside the ROI will be used to find the PCA basis. Returns: Depending on the value of ``return_components``, either * A 3D numpy array of shape `(N, width, height)` (where N is the number of elements in ``pca_numbers``), which contains the signal estimates for different numbers of principal components (ordered from lowest to highest number of PCs), or * A 2-tuple containing * A 3D numpy array with the signal estimates (see above), and * A 3D numpy array with the corresponding principal components. """ # ------------------------------------------------------------------------- # Preparations # ------------------------------------------------------------------------- # Check that stack and parallactic angles have matching sizes check_consistent_size(stack, parang, axis=0) # Define shortcuts n_frames, x_size, y_size = stack.shape frame_size = (x_size, y_size) # Turn a single integer into a list of integers if isinstance(n_components, int): n_components = [n_components] # Check condition_1 = isinstance(n_components, (list, tuple, np.ndarray)) condition_2 = all(isinstance(_, int) for _ in n_components) if not condition_1 or not condition_2: raise ValueError('n_components must be a (sequence of) integer(s)!') # Convert the component numbers into a sorted list n_components = sorted(list(n_components), reverse=True) # Find the maximum number of PCA components to use: This number cannot be # higher than the number of frames in the stack! max_n_components = max(n_components) if max_n_components > len(stack): warn( UserWarning( '(Maximum of) n_components cannot be larger than n_frames! ' 'max(n_components) was assumed as n_frames.' ) ) max_n_components = len(stack) # ------------------------------------------------------------------------- # Compute the PCA for the maximum number of components that we need # ------------------------------------------------------------------------- # Reshape stack from 3D to 2D (each frame is turned into a single 1D # vector). If a ROI mask is given, only the pixels inside the ROI are # used; otherwise, all pixels are used. if roi_mask is not None: reshaped_stack = stack[:, roi_mask] else: reshaped_stack = stack[:, np.full(frame_size, True)] # Instantiate new PCA instance with maximum number of principal # components, and fit it to the reshaped stack original_pca = PCA(n_components=max_n_components) original_pca.fit(reshaped_stack) truncated_pca = deepcopy(original_pca) # ------------------------------------------------------------------------- # Loop over the requested component numbers and compute signal estimates # ------------------------------------------------------------------------- # Initialize an empty array for the signal estimates signal_estimates = np.full((len(n_components), x_size, y_size), np.nan) # Loop over different numbers of principal components and compute the # signal estimates for that number of PCs for i, n in enumerate(n_components): # Only keep the first n principal components truncated_pca.components_ = truncated_pca.components_[:n] # Apply the dimensionality-reducing transformation transformed_stack = truncated_pca.transform(reshaped_stack) # Use inverse transform to map the dimensionality-reduced frame vectors # back into the original space so that we can interpret them as frames noise_estimate_ = truncated_pca.inverse_transform(transformed_stack) if roi_mask is not None: noise_estimate = np.full(stack.shape, np.nan) noise_estimate[:, roi_mask] = noise_estimate_ else: noise_estimate = noise_estimate_.reshape(stack.shape) # Subtract the noise estimate to compute the residual stack residual_stack = np.nan_to_num(stack - noise_estimate) # Derotate and combine the residuals to compute the signal estimate signal_estimate = derotate_combine(stack=residual_stack, parang=parang) # Restore ROI mask (if applicable) if roi_mask is not None: signal_estimate[~roi_mask] = np.nan # Store the signal estimate for the current number of PCs signal_estimates[len(n_components) - i - 1] = signal_estimate # ------------------------------------------------------------------------- # Return the signal estimates and, optionally, also the PCs # ------------------------------------------------------------------------- # Either return the signal estimates directly... if not return_components: return signal_estimates # ...or prepare the PCs first, and return them with the signal estimates else: # The reshaping of principal components into 2D frames depends on # whether we have used an ROI mask if roi_mask is not None: components = np.full((len(n_components), x_size, y_size), np.nan) for i in range(len(n_components)): components[i, roi_mask] = original_pca.components_[i] else: components = deepcopy(original_pca.components_) components = components.reshape((-1, x_size, y_size)) return signal_estimates, components