Source code for hsr4hci.typehinting

"""
Methods for additional custom type hinting.

.. tip::
    Methods in this module are really just for typehinting; this code
    is never actually executed.
"""

# -----------------------------------------------------------------------------
# IMPORTS
# -----------------------------------------------------------------------------

# This requires Python 3.8+
from typing import Protocol

import numpy as np


# -----------------------------------------------------------------------------
# TYPE DEFINITIONS
# -----------------------------------------------------------------------------

[docs]class RegressorModel(Protocol): """ Define a type hint for a generic regressor, that is, a class that follows the usual ``sklearn`` syntax (i.e., it provides a ``fit()`` and a ``predict()`` method) and can be used to learn a mapping from predictors `X` to targets `y`. """ # pylint: disable=missing-function-docstring def fit( self, X: np.ndarray, y: np.ndarray, ) -> 'RegressorModel': ... # pragma: no cover # pylint: disable=missing-function-docstring def predict( self, X: np.ndarray, ) -> np.ndarray: ... # pragma: no cover
[docs]class BaseLinearModel(RegressorModel, Protocol): """ Define a base class for linear models from sklearn. Linear models are characterized by the fact that they have a coefficient vector ``coef_`` and an intercept term ``intercept_``. """ coef_: np.ndarray intercept_: float
[docs]class BaseLinearModelCV(BaseLinearModel, Protocol): """ Define a base class for cross-validated linear models from sklearn such as, e.g., ``RidgeCV``. These models are characterized by the fact that they have an ``alpha_`` attribute which stores the value of the regularization parameter chosen by the cross-validation. """ alpha_: np.ndarray