Source code for hsr4hci.match_fraction

"""
Methods for computing match fractions.
"""

# -----------------------------------------------------------------------------
# IMPORTS
# -----------------------------------------------------------------------------

from typing import Dict, Tuple

from sklearn.metrics.pairwise import cosine_similarity
from tqdm.auto import tqdm

import numpy as np

from hsr4hci.coordinates import get_center, cartesian2polar
from hsr4hci.forward_modeling import add_fake_planet
from hsr4hci.general import find_closest, rotate_position
from hsr4hci.masking import get_positions_from_mask


# -----------------------------------------------------------------------------
# FUNCTION DEFINITIONS
# -----------------------------------------------------------------------------

[docs]def get_all_match_fractions( residuals: Dict[str, np.ndarray], roi_mask: np.ndarray, hypotheses: np.ndarray, parang: np.ndarray, psf_template: np.ndarray, frame_size: Tuple[int, int], n_roi_splits: int = 1, roi_split: int = 0, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ This is essentially a convenience function which wraps the loop over the ROI and calls :func:`get_match_fraction_for_position()` for every spatial pixel. Args: residuals: A dictionary containing the full residuals as they are produced by :func:`hsr4hci.training.train_all_models`. hypotheses: A 2D numpy array containing the hypotheses map. parang: A 1D numpy array of shape `(n_frames, )` containing the parallactic angle for every frame. psf_template: A 2D numpy array containing the unsaturated PSF template. frame_size: A tuple `(x_size, y_size)` containing the spatial size of the input stack in pixels. n_roi_splits: Total number of splits for the ROI if we want to compute the match fraction map in parallel. roi_split: Index of the ROI split that we want to process here. Returns: A 3-tuple consisting of 1. ``mean_mfs``: A 2D numpy array containing the match fraction map when using the mean to average. 2. ``median_mfs``: A 2D numpy array containing the match fraction map when using the median to average. 3. ``affected_pixels``: A 4D numpy array containing which, for each position `(x, y)` contains a 2D binary mask with the affected mask (see :func:`get_match_fraction_for_position`). """ # Initialize array for the match fractions (mean and median) mean_mfs = np.full(frame_size, np.nan) median_mfs = np.full(frame_size, np.nan) # Define an array in which we keep track of the "affected pixels" (i.e., # the planet traces for every hypothesis), mostly for debugging purposes affected_pixels = np.full(frame_size + frame_size, np.nan) # Define positions for which to run (= subset of the ROI) positions = get_positions_from_mask(roi_mask)[roi_split::n_roi_splits] # Get signal times based on the keys of the given results dictionary _digit_keys = filter(lambda _: _.isdigit(), residuals.keys()) signal_times = np.array(sorted(list(map(int, _digit_keys)))) # Loop over (subset of) ROI and compute match fractions for position in tqdm(positions, ncols=80): mean_mf, median_mf, affected_mask = get_match_fraction_for_position( position=position, hypothesis=hypotheses[position[0], position[1]], residuals=residuals, parang=parang, psf_template=psf_template, signal_times=signal_times, frame_size=frame_size, ) mean_mfs[position] = mean_mf median_mfs[position] = median_mf affected_pixels[position] = affected_mask return mean_mfs, median_mfs, affected_pixels
[docs]def get_match_fraction_for_position( position: Tuple[int, int], hypothesis: float, residuals: Dict[str, np.ndarray], parang: np.ndarray, psf_template: np.ndarray, signal_times: np.ndarray, frame_size: Tuple[int, int], ) -> Tuple[float, float, np.ndarray]: """ Compute the match fraction for a single given position. Args: position: A tuple `(x, y)` specifying the position for which to compute the match fraction. hypothesis: The hypothesis (= temporal index) for the given ``position``. In general, this should be an integer, but the type here has to be a ``float`` because the value may also be `NaN` (in case there is no hypothesis). residuals: A dictionary containing the full residuals as they are produced by :func:`hsr4hci.training.train_all_models`. parang: A 1D numpy array of shape `(n_frames, )` containing the parallactic angle for every frame. psf_template: A 2D numpy array containing the unsaturated PSF template. signal_times: A 1D numpy array of shape `(n_signal_times, )` containing the temporal grid. frame_size: A tuple `(x_size, y_size)` containing the spatial size of the input stack in pixels. Returns: A 3-tuple consisting of 1. ``match_fraction__mean``: The match fraction for the given target ``position`` when using the mean to average. 2. ``match_fraction__median``: The match fraction for the given target ``position`` when using the median to average. 3. ``affected_mask``: A 2D numpy array containing a binary mask that indicates the pixels from which the match fraction was computed (i.e., the pixels that are affected by the planet according to the ``hypothesis``). """ # Define shortcut for number of frames n_frames = len(parang) # If we do not have a hypothesis for the current position, we can # directly return the match fraction as 0 if np.isnan(hypothesis): return np.nan, np.nan, np.full(frame_size, False) # Compute the expect final position based on the hypothesis that the # signal is at `position` at time `signal_time` final_position = rotate_position( position=position[::-1], # position is in numpy coordinates center=get_center(frame_size), angle=float(parang[int(hypothesis)]), ) # Compute the *full* expected signal stack under this hypothesis and # normalize it (so that the thresholding below works reliably!) expected_stack = add_fake_planet( stack=np.zeros((n_frames,) + frame_size), parang=parang, psf_template=psf_template, polar_position=cartesian2polar( position=(final_position[0], final_position[1]), frame_size=frame_size, ), magnitude=0, extra_scaling=1, dit_stack=1, dit_psf_template=1, return_planet_positions=False, interpolation='bilinear', ) expected_stack = np.asarray(expected_stack / np.max(expected_stack)) # Find mask of all pixels that are affected by the planet trace, i.e., # all pixels that at some point in time contain planet signal. # The threshold value of 0.5 is a bit of a magic number: it serves to pick # only those pixels really affected by the central peak of the signal, and # not the secondary maxima. The secondary maxima are often too low to be # picked up by the HSR, and including them would lower the match fraction. affected_mask = np.max(expected_stack, axis=0).astype(float) >= 0.5 # Keep track of the matches (similarity scores) for affected positions matches = [] # Convert signal_times to list (to avoid mypy issue with find_closest()) signal_times_list = list(signal_times) # Loop over all affected positions and check how well the residuals # match the expected signals for (x, y) in get_positions_from_mask(affected_mask): # Skip the hypothesis itself: We know that is is a match, because # otherwise it would not be our hypothesis if (x, y) == position: continue # Find the time at which this pixel is affected the most, and find the # closest matching signal time for which we have trained a model and # therefore have a residual to compare with tmp_peak_time = int(np.argmax(expected_stack[:, x, y])) _, peak_time = find_closest(signal_times_list, tmp_peak_time) # Define shortcuts for the time series that we compare a = expected_stack[:, x, y] b = np.asarray(residuals[str(peak_time)][:, x, y]) # In case we do not have a (signal masking) residual for the # current affected position, we skip it if np.isnan(b).any(): continue # Compute the cosine similarity between the expected signal and the # "best" residual as a measure for how well the current pixel (x, y) # matches our hypothesis for `position`. similarity = cosine_similarity(X=a.reshape(1, -1), Y=b.reshape(1, -1)) matches.append(float(similarity)) # Compute mean and median match fraction for current position if matches: match_fraction__mean = float(np.nanmean(matches)) match_fraction__median = float(np.nanmedian(matches)) else: match_fraction__mean = np.nan match_fraction__median = np.nan return match_fraction__mean, match_fraction__median, affected_mask