"""
Methods for merging partial result files (FITS / HDF).
"""
# -----------------------------------------------------------------------------
# IMPORTS
# -----------------------------------------------------------------------------
from pathlib import Path
from typing import Dict, List, Sequence
from warnings import warn, catch_warnings, filterwarnings
import os
from tqdm.auto import tqdm
import numpy as np
from hsr4hci.fits import read_fits
from hsr4hci.hdf import load_dict_from_hdf
# -----------------------------------------------------------------------------
# FUNCTION DEFINITIONS
# -----------------------------------------------------------------------------
[docs]def get_list_of_fits_file_paths(fits_dir: Path, prefix: str) -> List[Path]:
"""
Get a list of all FITS files in a given ``fits_dir`` whose file
name begins with the given ``prefix``.
Args:
fits_dir: Path to directory in which to look for FITS files.
prefix: Only consider FITS files whose names begin with this.
For example: `"hypotheses"` or `"mean_mf"`.
Returns:
A list of Paths to the matching FITS files in ``fits_dir``.
"""
# Get a list of the paths to all FITS files in the given FITS directory
# that start with the given prefix (e.g., "hypotheses" or "mean_mf")
fits_file_names = filter(
lambda _: _.endswith('.fits') and _.startswith(prefix),
os.listdir(fits_dir),
)
fits_file_paths = sorted([fits_dir / _ for _ in fits_file_names])
# Perform a quick sanity check: Does the number of FITS files we found
# match the number that we would expect based on the naming convention?
# Reminder: The naming convention is "<prefix>_<split>-<n_splits>.fits".
expected_number = int(fits_file_paths[0].name.split('-')[1].split('.')[0])
actual_number = len(fits_file_paths)
if expected_number != actual_number:
warn(
f'Naming convention suggests there should be {expected_number} '
f'FITS files, but {actual_number} were found!'
)
return sorted(fits_file_paths)
[docs]def get_list_of_hdf_file_paths(
hdf_dir: Path, prefix: str = 'residuals'
) -> List[Path]:
"""
Get a list of all HDF files in a given ``hdf_dir`` whose file name
begins with the given ``prefix``.
Args:
hdf_dir: Path to directory in which to look for HDF files.
prefix: Only consider HDF files whose names begin with this.
Usually, we only need HDF files starting with `"residuals"`.
Returns:
A list of Paths to the matching HDF files in ``hdf_dir``.
"""
# Get a list of the paths to all HDF files in the given HDF directory
hdf_file_names = filter(
lambda _: _.endswith('.hdf') and _.startswith(prefix),
os.listdir(hdf_dir),
)
hdf_file_paths = sorted([hdf_dir / _ for _ in hdf_file_names])
# Perform a quick sanity check: Does the number of HDF files we found
# match the number that we would expect based on the naming convention?
# Reminder: The naming convention is "results_<split>-<n_splits>.hdf".
expected_number = int(hdf_file_paths[0].name.split('-')[1].split('.')[0])
actual_number = len(hdf_file_paths)
if expected_number != actual_number:
warn(
f'Naming convention suggests there should be {expected_number} '
f'HDF files, but {actual_number} were found!'
)
return sorted(hdf_file_paths)
[docs]def merge_hdf_files(
hdf_file_paths: Sequence[Path],
) -> Dict[str, np.ndarray]:
"""
Take a list of HDF files and merge all of them into a single dict.
This function is intended to merge the (partial) results files that
are produced by :func:`hsr4hci.training.train_all_models`; see there
for more details on the expected internal structure of the files.
Args:
hdf_file_paths: A list of paths to the HDF files to be merged.
Returns:
A dictionary containing the "full" (i.e., merged) results from
all HDF files.
"""
# Instantiate the dictionary which will hold the final results
residuals: Dict[str, np.ndarray] = {}
# Loop over all HDF files that we need to merge
for hdf_file_path in tqdm(sorted(hdf_file_paths), ncols=80):
# Load the HDF file to be merged
hdf_file = load_dict_from_hdf(file_path=hdf_file_path)
# Get the expected dimensions of the stack and the ROI mask
stack_shape = tuple(hdf_file['stack_shape'])
roi_mask = np.asarray(hdf_file['roi_mask'])
# Loop over the actual results in the HDF file:
# The `key` is going to be either "default", or "0", ... "N" (i.e.,
# the different signal_times for which we have trained a model); the
# `value` is going to a numpy array containing (partial) residuals.
for key, value in hdf_file['residuals'].items():
# If necessary, create a new sub-dictionary in the results dict
if key not in residuals.keys():
residuals[key] = np.full(stack_shape, np.nan, dtype=np.float32)
# If the residuals are 2D (return_format == "partial"), we need to
# use the (partial) ROI mask to store them at the correct location
if value.ndim == 2:
residuals[key][:, roi_mask] = value
# If the residuals are 3D (return_format == "full"), we basically
# need to take the "NaN union" of all HDF files
elif value.ndim == 3:
with catch_warnings():
filterwarnings('ignore', r'Mean of empty slice')
residuals[key] = np.nanmean(
[residuals[key], value], axis=0
)
# Any other case will raise an error (the residuals in the HDF
# files should *always* be either 2D or 3D)
else: # pragma: no cover
raise RuntimeError('ndim must be either 2 or 3!')
return residuals
[docs]def merge_fits_files(fits_file_paths: List[Path]) -> np.ndarray:
"""
Take a list of FITS files and merge all of them into a single array.
This function can merge the partial result files that are obtained
in parallel with :func:`hsr4hci.hypotheses.get_all_hypotheses` and
:func:`hsr4hci.match_fraction.get_all_match_fractions`.
Merging works by stacking the arrays from the FITS files along a new
axis and then taking the :func:`numpy.nanmean` along this axis.
This, of course, assumes that each pixel only takes on a non- `NaN`
value in at most one of the FITS files.
Args:
fits_file_paths: List of FITS files to be merged.
Returns:
A numpy array containing the merged arrays from all FITS files.
"""
# Initialize the result with the first file
merged = read_fits(fits_file_paths[0], return_header=False)
# Loop over the remaining FITS files and merge them sequentially
for file_path in fits_file_paths[1:]:
array = read_fits(file_path, return_header=False)
with catch_warnings():
filterwarnings('ignore', r'Mean of empty slice')
merged = np.nanmean([merged, array], axis=0)
return merged