Source code for hsr4hci.plotting

"""
Methods for plotting.
"""

# -----------------------------------------------------------------------------
# IMPORTS
# -----------------------------------------------------------------------------

from copy import copy
from pathlib import Path
from typing import Dict, Optional, Sequence, Tuple, Union

import colorsys

from astropy.modeling import models, fitting
from matplotlib.axes import Axes
from matplotlib.cm import get_cmap as original_get_cmap
from matplotlib.colorbar import Colorbar
from matplotlib.colors import Colormap, LinearSegmentedColormap, ListedColormap
from matplotlib.figure import Figure
from matplotlib.image import AxesImage
from mpl_toolkits.axes_grid1 import make_axes_locatable
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
from photutils import CircularAperture

import matplotlib.colors as mc
import matplotlib.font_manager as fm
import matplotlib.pyplot as plt
import numpy as np

from hsr4hci.coordinates import get_center
from hsr4hci.masking import mask_frame_around_position


# -----------------------------------------------------------------------------
# TYPE DEFINITIONS
# -----------------------------------------------------------------------------

# A valid matplotlib color can be one of the following three options:
# - A string specifying either the name of the color, or defining the color
#   as a HEX string. Examples: "red", "C3", "#FF0000".
# - An RGB tuple specifying the color. Example: (1, 0, 0) for red.
# - An RGBA tuple, specifying the color and the alpha channel (i.e., the
#   transparency). Example: (1, 0, 0, 0.5) for a semitransparent red.
MatplotlibColor = Union[
    str, Tuple[float, float, float], Tuple[float, float, float, float]
]


# -----------------------------------------------------------------------------
# FUNCTION DEFINITIONS
# -----------------------------------------------------------------------------

[docs]def get_cmap( cmap_name: str = 'RdBu_r', bad_color: str = '#212121', ) -> Union[Colormap, LinearSegmentedColormap, ListedColormap]: """ Convenience wrapper around :func:`matplotlib.cm.get_cmap` which allows to also set the ``bad_color`` (i.e., the color for `NaN` values). Args: cmap_name: The name of a matplotlib color map (e.g., `"RdBu_r"` or `"viridis"`). bad_color: A string specifying a color in HTML format: (e.g., `"#FF0000"`) which will be used as the ``bad_color`` of the color map; that is, the color used, for example, for `NaN` values in :func:`matplotlib.pyplot.imshow` plots. Returns: A matplotlib colormap with the desired ``bad_color``. """ # Get desired color map and set the desired bad_color cmap = copy(original_get_cmap(cmap_name)) cmap.set_bad(color=bad_color) return cmap
[docs]def get_transparent_cmap(color: MatplotlibColor = 'red') -> ListedColormap: """ Return a colormap that goes from transparent to the target color. Color maps of this type can be useful, for example, when plotting or overlaying masks, where only the selected pixels should receive a color, while everything else stays Args: color: A valid matplotlib color. Returns: A ``ListedColormap``, which gradually goes from transparent to the given target color. """ return ListedColormap([(0, 0, 0, 0), color])
[docs]def add_colorbar_to_ax( img: AxesImage, fig: Figure, ax: Axes, where: str = 'right', ) -> Colorbar: """ Add a "nice" colorbar to an imshow plot. Original source: https://stackoverflow.com/a/18195921/4100721 Args: img: The return of a :func:`matplotlib.pyplot.imshow` command. fig: The figure that the plot is part of. ax: The ax which contains the plot. where: Where to place the colorbar (`"left"`, `"right"`, `"top"` or `"bottom"`). Returns: The colorbar that was added to the axis. """ if where in ('left', 'right'): orientation = 'vertical' elif where in ('top', 'bottom'): orientation = 'horizontal' else: raise ValueError( f'Illegal value for `where`: "{where}". Must be one ' 'of ["left", "right", "top", "bottom"].' ) divider = make_axes_locatable(ax) cax = divider.append_axes(where, size='5%', pad=0.05) cbar = fig.colorbar( img, cax=cax, orientation=orientation, ticklocation=where ) return cbar
[docs]def adjust_luminosity( color: MatplotlibColor, amount: float = 1.4 ) -> Tuple[float, float, float]: """ Adjust the luminosity of the given ``color`` by the ``amount``. Original source: https://stackoverflow.com/a/49601444/4100721 Args: color: The input color. Can either be a hex string (e.g., `"#FF0000"`), matplotlib color string (e.g., `"C1"` or `"green"`), or an RGB tuple in float format (e.g., `(1.0, 0.0, 0.0)`). amount: The amount by how much the input color should be lightened. For ``amount`` > 1, the color gets brighter; for ``amount`` < 1, the color is darkened. By default, colors are lightened by 40%. Returns: An RGB tuple describing the luminosity-adjusted input color. """ # In case `color` is a proper color name, we can try to resolve it into # an RGB tuple using the lookup table (of HEX strings) in mc.cnames. if isinstance(color, str) and (color in mc.cnames.keys()): rgb: Tuple[float, float, float] = mc.to_rgb(mc.cnames[color]) # Otherwise, we try to convert the color to RGB; this will raise a value # error for invalid color formats. else: rgb = mc.to_rgb(color) # Convert color from RBG to HLS representation hue, luminosity, saturation = colorsys.rgb_to_hls(*rgb) # Multiply `1 - luminosity` by given `amount` and convert back to RGB luminosity = max(0.0, min(1.0, amount * luminosity)) rgb = colorsys.hls_to_rgb(hue, luminosity, saturation) return rgb
[docs]def disable_ticks(ax: Axes) -> None: """ Disable the ticks and labels on the given matplotlib ``ax``. This is similar to calling ``ax.axis('off')``, except that the frame around the plot is preserved. Args: ax: A matplotlib axis. """ ax.tick_params( axis='both', which='both', top=False, bottom=False, left=False, right=False, labelbottom=False, labelleft=False, )
[docs]def zerocenter_imshow(ax: Axes) -> None: """ Make sure that the `(vmin, vmax)` range of the ``imshow()`` plot in the given ``ax`` object is symmetric around zero. Args: ax: The ax which contains the plot. """ # Get plot and current limits img = ax.get_images()[0] vmin, vmax = img.get_clim() # Compute and set new limits limit = max(np.abs(vmin), np.abs(vmax)) img.set_clim((-limit, limit))
[docs]def zerocenter_plot(ax: Axes, which: str) -> None: """ Make sure that the `xlim` or `ylim` range of the plot object in the given ``ax`` object is symmetric around zero. Args: ax: The ax which contains the plot. which: Which axis to center around zero (`"x"` or `"y"`). """ if which == 'x': vmin, vmax = ax.get_xlim() limit = max(np.abs(vmin), np.abs(vmax)) ax.set_xlim(xmin=-limit, xmax=limit) elif which == 'y': vmin, vmax = ax.get_ylim() limit = max(np.abs(vmin), np.abs(vmax)) ax.set_ylim(ymin=-limit, ymax=limit) else: raise ValueError('Parameter which must be "x" or "y"!')
[docs]def set_fontsize(ax: Axes, fontsize: int) -> None: """ Set the ``fontsize`` for all labels (title, x- and y-label, and tick labels) of a target ``ax`` at once. Args: ax: The ax which contains the plot. fontsize: The target font size for the labels. """ for item in ( [ax.title, ax.xaxis.label, ax.yaxis.label] + ax.get_xticklabels() + ax.get_yticklabels() ): item.set_fontsize(fontsize)
# ----------------------------------------------------------------------------- # AUXILIARY FUNCTION DEFINITIONS AND PLOT_FRAME() # ----------------------------------------------------------------------------- def _determine_limit( frame: np.ndarray, positions: Optional[Sequence[Tuple[float, float]]], ) -> float: """ Auxiliary function to determine the plot limits for plot_frame(). """ # If no positions are given, simply use the 99.9th percentile of the # entire frame as the "global" limit if (positions is None) or (not positions): return float(np.nanpercentile(np.abs(frame), 99.9)) # Otherwise, loop over the positions, fit the frame at each position with # a 2D Gaussian, and set the limit to the maximum amplitude we find. # Define a grid for the fit x, y = np.meshgrid(np.arange(frame.shape[0]), np.arange(frame.shape[1])) # Keep track of the maximum amplitude (= the limit we will return). This # limit should always be positive! limit = float(np.nanmin(np.abs(frame))) # Loop over all given positions for position in positions: # Set up the model (and keep the mean = position fixed) model = models.Gaussian2D(x_mean=position[0], y_mean=position[1]) model.x_mean.fixed = True model.y_mean.fixed = True # Mask the frame (set everything to zero that is too far from position) masked_frame = mask_frame_around_position( frame=np.nan_to_num(frame), position=position, radius=5 ) # Fit the frame and update the limit fit_p = fitting.LevMarLSQFitter() model = fit_p(model=model, x=x, y=y, z=masked_frame) limit = max(limit, model.amplitude.value) return limit def _add_apertures_and_labels( ax: Axes, positions: Sequence[Tuple[float, float]], labels: Sequence[Union[str, float]], label_positions: Optional[Sequence[str]], aperture_radius: float, draw_color: MatplotlibColor, ) -> None: """ Auxiliary function for `plot_frame()` to add apertures and labels to mark planet positions and indicate the SNR / FPF / ... """ # Define default options for the label label_kwargs = dict( color='white', fontsize=6, bbox=dict( facecolor=draw_color, edgecolor='none', boxstyle='square,pad=0.2', ), ) # Draw apertures at positions (if positions are given) if positions: aperture = CircularAperture(positions=positions, r=aperture_radius) aperture.plot(axes=ax, **dict(lw=1, color=draw_color)) # If no label positions are given, assume all labels go to the right # of the position that they are annotation if label_positions is None: label_positions = len(labels) * ['right'] # Add labels for positions (if labels are given) if labels and positions and label_positions: for position, label, label_position in zip( positions, labels, label_positions ): # Determine positions for annotate() and the alignment of the # label based on `label_position` if label_position == 'right': xy = (position[0] + aperture_radius, position[1]) xytext = (8, 0) ha = 'left' va = 'center' elif label_position == 'left': xy = (position[0] - aperture_radius, position[1]) xytext = (-8, 0) ha = 'right' va = 'center' elif label_position == 'top': xy = (position[0], position[1] + aperture_radius) xytext = (0, 8) ha = 'center' va = 'bottom' elif label_position == 'bottom': xy = (position[0], position[1] - aperture_radius) xytext = (0, -8) ha = 'center' va = 'top' else: raise ValueError('Illegal value for label_position!') # Annotate the aperture with a label ax.annotate( text=label, xy=xy, xytext=xytext, textcoords='offset pixels', arrowprops=dict( arrowstyle='-', shrinkA=0, shrinkB=0, lw=1, color=draw_color, ), ha=ha, va=va, **label_kwargs, ) def _add_scalebar( ax: Axes, frame_size: Tuple[int, int], pixscale: float, color: MatplotlibColor = 'white', loc: str = 'upper right', ) -> float: """ Auxiliary function for `plot_frame()` to add a scale bar. """ # Compute size of the scale bar, and define its label accordingly scalebar_size = 1 / pixscale scalebar_label_value = 1.0 while scalebar_size > 0.3 * frame_size[0]: scalebar_size /= 2 scalebar_label_value /= 2 # Create the scale bar and add it to the frame scalebar = AnchoredSizeBar( transform=ax.transData, size=scalebar_size, label=f'{scalebar_label_value}"', loc=loc, pad=1, color=color, frameon=False, size_vertical=0, fontproperties=fm.FontProperties(size=6), ) ax.add_artist(scalebar) return scalebar_size def _add_ticks( ax: Axes, frame_size: Tuple[int, int], scalebar_size: float, color: MatplotlibColor = 'white', ) -> None: """ Auxiliary function for `plot_frame()` to add ticks to the frame. """ # Define shortcut for the center center = get_center(frame_size) # Define tick positions delta = scalebar_size / 2 xticks, yticks = [], [] for i in range(10): xticks += [center[0] - i * delta, center[0] + i * delta] yticks += [center[1] - i * delta, center[1] + i * delta] xticks = list(filter(lambda _: 0 < _ < frame_size[0], xticks)) yticks = list(filter(lambda _: 0 < _ < frame_size[1], yticks)) # Add ticks to the axis ax.set_xticks(xticks) ax.set_yticks(yticks) # Define which ticks to show ax.tick_params( axis='both', which='both', direction='in', color=color, length=1.25, top=True, bottom=True, left=True, right=True, labelleft=False, labelbottom=False, ) def _add_colorbar( img: AxesImage, limits: Tuple[float, float], fig: Figure, ax: Axes, use_logscale: bool, ) -> Colorbar: """ Auxiliary function for `plot_frame()` to add a colorbar. """ # Create a color bar at the bottom of the axis divider = make_axes_locatable(ax) cax = divider.append_axes('bottom', size='5%', pad=0.025) cbar = fig.colorbar(img, cax=cax, orientation='horizontal') # Unpack the limits vmin, vmax = limits # Set up the rest of the colorbar options if use_logscale: cbar.set_ticks([vmin / 2, vmin / 10, 0, vmax / 10, vmax / 2]) else: cbar.set_ticks([2 * vmin / 3, vmin / 3, 0, vmax / 3, 2 * vmax / 3]) cbar.ax.set_xticklabels(["{:.1f}".format(i) for i in cbar.get_ticks()]) cbar.ax.tick_params(labelsize=5) return cbar def _add_cardinal_directions( ax: Axes, color: MatplotlibColor = 'white', ) -> None: """ Auxiliary function for `plot_frame()` to add cardinal directions. """ # Define position (i.e., where do the arrows start) and length of arrows position = (0.95, 0.05) arrow_length = 0.075 # Define common parameters for annotate() params = dict( xycoords='axes fraction', textcoords='axes fraction', arrowprops=dict( arrowstyle='<-', lw=0.75, color=color, shrinkA=2.5, shrinkB=0, patchA=None, patchB=None, ), color=color, fontsize=6, bbox=dict(fc='none', ec='none', pad=0), ) # Plot an arrow for "North" and "East" ax.annotate( 'N', xy=position, xytext=(position[0], position[1] + arrow_length), ha='center', va='bottom', **params, ) ax.annotate( 'E', xy=position, xytext=(position[0] - arrow_length, position[1]), ha='right', va='center', **params, )
[docs]def plot_frame( frame: np.ndarray, positions: Sequence[Tuple[float, float]], labels: Sequence[Union[str, float]], pixscale: float, figsize: Tuple[float, float] = (4.3 / 2.54, 5.0 / 2.54), subplots_adjust: Optional[Dict[str, float]] = None, aperture_radius: float = 0, label_positions: Optional[Sequence[str]] = None, draw_color: MatplotlibColor = 'darkgreen', scalebar_color: MatplotlibColor = 'white', cmap: str = 'RdBu_r', limits: Optional[Tuple[float, float]] = None, use_logscale: bool = False, add_colorbar: bool = True, add_scalebar: bool = True, add_cardinal_directions: bool = True, scalebar_loc: str = 'upper right', file_path: Optional[Union[Path, str]] = None, ) -> Tuple[Figure, Axes, Optional[Colorbar]]: """ Plot a single frame (e.g., a signal estimate) with various options. This function was used to generate most of the result plots in the paper. Args: frame: A 2D numpy array of shape `(x_size, y_size)` containing the frame to be plotted (e.g., a signal estimate). positions: A list of positions (which may be empty). At each position, an aperture is drawn with the given radius. labels: A list of labels (which may be empty) that are placed next to the apertures drawn at the ``positions``. Can be used, for example, to add the SNR or FPF to the plot. pixscale: The pixel scale, in units of arcsecond / pixel. Only needed if ``add_scalebar`` is True. figsize: A two-tuple `(x_size, y_size)` containing the size of the figure in inches. subplots_adjust: Dictionary with parameters that will be passed to :func:`matplotlib.pyplot.subplots_adjust`. aperture_radius: The radius of the apertures to be drawn at the given ``positions``. If ``positions`` is empty, this value is never used. label_positions: A list of strings (either `"right`", `"left`", `"top"` or `"bottom"`) that indicates, for each label where this label should be placed relative to the position that it annotates. Default is `"right"` for all labels. draw_color: The color that is used for drawing the apertures and also labels. scalebar_color: The color that is used for the scale bar and the ticks. cmap: Name of the color map to be used for plotting. limits: A tuple `(vmin, vmax)` that is used for the plot limits. If None, the limits are estimated from the data. use_logscale: Whether to use a (symmetric) log scale. add_colorbar: Whether to add a colorbar at the bottom. add_scalebar: Whether to add a scale bar and a grid of ticks around the borders of the frame (to better understand the scale of the frame). add_cardinal_directions: Whether to add labeled arrows to indicate the cardinal directions (North and East). scalebar_loc: Location parameter for the scalebar. file_path: The path at which to save the resulting plot. The path should include the file name plus file ending. If None is given, the plot is not saved. Returns: A 3-tuple containing 1. the current matplotlib figure, 2. the current axis containing the plot of the frame, and 3. the colorbar object (or `None`, if no colorbar was added). """ # Define shortcuts frame_size = (frame.shape[0], frame.shape[1]) center = get_center(frame_size) # In case no explicit plot limit is specified, determine it from the data if limits is None: vmax = _determine_limit(frame=frame, positions=positions) vmin = -vmax else: vmin, vmax = limits # Set up the `norm`, which determines whether we use linear or log scale if use_logscale: norm = mc.SymLogNorm(linthresh=0.1 * vmax, vmin=vmin, vmax=vmax) else: norm = mc.PowerNorm(gamma=1, vmin=vmin, vmax=vmax) # Prepare grid for the pcolormesh() x, y = np.meshgrid(np.arange(frame.shape[0]), np.arange(frame.shape[1])) # Prepare parameters for the adjust_subplots() call if subplots_adjust is None: subplots_adjust = dict(left=0, top=1, right=1, bottom=0.075) # Set up a new figure and adjust margins fig, ax = plt.subplots(figsize=figsize) fig.subplots_adjust(**subplots_adjust) # Create the actual plot # Using pcolormesh() instead of imshow() avoids interpolation artifacts in # most PDF viewers (otherwise, the PDF version will often look blurry). img = ax.pcolormesh( x, y, frame, shading='nearest', cmap=get_cmap(cmap), snap=True, rasterized=True, norm=norm, ) ax.set_aspect('equal') # Place a "+"-marker at the center for the frame ax.plot(center[0], center[1], '+', ms=6, color='black') # Plot apertures and add labels if positions: _add_apertures_and_labels( ax, positions, labels, label_positions, aperture_radius, draw_color ) # If desired, add a scale bar and a grid of ticks if add_scalebar: scalebar_size = _add_scalebar( ax, frame_size, pixscale, scalebar_color, scalebar_loc ) _add_ticks(ax, frame_size, scalebar_size, scalebar_color) else: disable_ticks(ax) # If desired, add the cardinal directions if add_cardinal_directions: _add_cardinal_directions(ax, scalebar_color) # If desired, add a color bar if add_colorbar: cbar = _add_colorbar(img, (vmin, vmax), fig, ax, use_logscale) else: cbar = None # Save the results, if desired if file_path is not None: plt.savefig(file_path, pad_inches=0, dpi=600) return fig, ax, cbar