"""
Abstract inference class and static utilities.
"""
__author__ = "Janek Sendrowski"
__contact__ = "sendrowski.janek@gmail.com"
__date__ = "2023-03-12"
import logging
from abc import ABC, abstractmethod
from typing import List, Optional, Literal, Tuple, Dict, Sequence
import jsonpickle
import numpy as np
import pandas as pd
from typing_extensions import Self
from .bootstrap import Bootstrap
from .parametrization import Parametrization, _from_string
from .utils import Serializable
logger = logging.getLogger("fastdfe")
[docs]
class Inference:
"""
Static utility methods for inference objects.
"""
[docs]
@staticmethod
def plot_discretized(
inferences: List['AbstractInference'],
intervals: Sequence = np.array([-np.inf, -100, -10, -1, 0, 1, np.inf]),
confidence_intervals: bool = True,
ci_level: float = 0.05,
bootstrap_type: Literal['percentile', 'bca'] = 'percentile',
point_estimate: Literal['original', 'mean', 'median'] = 'mean',
file: str = None,
show: bool = True,
title: str = 'discretized DFEs',
labels: Sequence = None,
ax: 'plt.Axes' = None,
kwargs_legend: dict = dict(prop=dict(size=8)),
**kwargs
) -> 'plt.Axes':
"""
Visualize several discretized DFEs given by the list of inference objects.
:param inferences: List of inference objects.
:param intervals: Intervals over ``(-inf, inf)`` to use for discretization.
:param confidence_intervals: Whether to plot confidence intervals.
:param ci_level: Confidence level for confidence intervals.
:param bootstrap_type: Type of bootstrap to use for confidence intervals.
:param point_estimate: Whether to use 'original' MLE values, 'mean' or 'median' of bootstraps as point estimate.
:param file: Path to file to save the plot to.
:param show: Whether to show the plot.
:param title: Title of the plot.
:param labels: Labels for the DFEs.
:param kwargs: Additional arguments for the plot.
:param ax: Axes to plot on. Only for Python visualization backend.
:param kwargs_legend: Keyword arguments passed to :meth:`plt.legend`. Only for Python visualization backend.
:return: Axes of the plot.
"""
from .visualization import Visualization
# get data from inference objects
values = []
errors = []
for i, inference in enumerate(inferences):
val, errs = inference.get_discretized(
intervals=np.array(intervals),
confidence_intervals=confidence_intervals,
ci_level=ci_level,
bootstrap_type=bootstrap_type,
point_estimate=point_estimate
)
values.append(val)
errors.append(errs)
# plot DFEs
return Visualization.plot_discretized(
values=values,
errors=errors,
labels=labels,
file=file,
show=show,
intervals=np.array(intervals),
title=title,
ax=ax,
kwargs_legend=kwargs_legend
)
[docs]
@staticmethod
def plot_continuous(
inferences: List['AbstractInference'],
intervals: np.ndarray = np.array([-np.inf, -100, -10, -1, 0, 1, np.inf]),
confidence_intervals: bool = True,
ci_level: float = 0.05,
bootstrap_type: Literal['percentile', 'bca'] = 'percentile',
file: str = None,
show: bool = True,
title: str = 'continuous DFEs',
labels: Sequence = None,
scale: Literal['lin', 'log', 'symlog'] = 'lin',
scale_density: bool = False,
ax: 'plt.Axes' = None,
kwargs_legend: dict = dict(prop=dict(size=8)),
**kwargs
) -> 'plt.Axes':
"""
Visualize several DFEs given by the list of inference objects.
By default, the PDF is plotted as is. Due to the logarithmic scale on
the x-axis, we may get a wrong intuition on how the mass is distributed,
however. To get a better intuition, we can optionally scale the density
by the x-axis interval size using ``scale_density = True``. This has the
disadvantage that the density now changes for x, so that even a constant
density will look warped.
:param inferences: List of inference objects.
:param intervals: Intervals to use for discretization.
:param confidence_intervals: Whether to plot confidence intervals.
:param ci_level: Confidence level for confidence intervals.
:param bootstrap_type: Type of bootstrap to use for confidence intervals.
:param file: Path to file to save the plot to.
:param show: Whether to show the plot.
:param title: Title of the plot.
:param labels: Labels for the DFEs.
:param scale: y-scale of the plot.
:param scale_density: Whether to scale the density by the x-axis interval size.
:param ax: Axes to plot on. Only for Python visualization backend.
:param kwargs_legend: Keyword arguments passed to :meth:`plt.legend`. Only for Python visualization backend.
:param kwargs: Additional arguments for the plot.
:return: Axes of the plot.
"""
from .visualization import Visualization
# get data from inference objects
values = []
errors = []
for i, inf in enumerate(inferences):
val, errs = inf.get_discretized(
intervals=intervals,
confidence_intervals=confidence_intervals,
ci_level=ci_level,
bootstrap_type=bootstrap_type
)
values.append(val)
errors.append(errs)
# plot DFEs
return Visualization.plot_continuous(
bins=intervals,
**locals()
)
[docs]
@staticmethod
def plot_inferred_parameters(
inferences: List['AbstractInference'],
labels: Sequence,
confidence_intervals: bool = True,
ci_level: float = 0.05,
bootstrap_type: Literal['percentile', 'bca'] = 'percentile',
point_estimate: Literal['original', 'mean', 'median'] = 'mean',
file: str = None,
show: bool = True,
title: str = 'parameter estimates',
scale: Literal['lin', 'log', 'symlog'] = 'log',
ax: 'plt.Axes' = None,
kwargs_legend: dict = dict(prop=dict(size=8), loc='upper right'),
**kwargs
) -> 'plt.Axes':
"""
Visualize several discretized DFEs given by the list of inference objects.
Note that the DFE parametrization needs to be the same for all inference objects.
:param inferences: List of inference objects.
:param labels: Unique labels for the DFEs.
:param scale: y-scale of the plot.
:param confidence_intervals: Whether to plot confidence intervals.
:param ci_level: Confidence level for confidence intervals.
:param bootstrap_type: Type of bootstrap to use for confidence intervals.
:param point_estimate: Whether to use 'original' MLE values, 'mean' or 'median' of bootstraps as point estimate.
:param file: Path to file to save the plot to.
:param show: Whether to show the plot.
:param title: Title of the plot.
:param ax: Axes to plot on. Only for Python visualization backend.
:return: Axes of the plot.
:param kwargs_legend: Keyword arguments passed to :meth:`plt.legend`. Only for Python visualization backend.
:param kwargs: Additional arguments which are ignored.
:raises ValueError: If no inference objects are given.
"""
from .visualization import Visualization
if len(inferences) == 0:
raise ValueError('No inference objects given.')
# get sorted list of parameter names
param_names = sorted(list(inferences[0].get_bootstrap_param_names()))
# get errors and values
errors, values = Inference.get_errors_params_mle(
bootstrap_type=bootstrap_type,
ci_level=ci_level,
confidence_intervals=confidence_intervals,
inferences=inferences,
labels=labels,
param_names=param_names,
point_estimate=point_estimate
)
return Visualization.plot_inferred_parameters(
values=values,
errors=errors,
param_names=param_names,
file=file,
show=show,
title=title,
labels=labels,
scale=scale,
legend=len(labels) > 1,
kwargs_legend=kwargs_legend,
ax=ax,
)
[docs]
@staticmethod
def plot_inferred_parameters_boxplot(
inferences: List['AbstractInference'],
labels: Sequence,
file: str = None,
show: bool = True,
title: str = 'parameter estimates',
**kwargs
) -> 'plt.Axes':
"""
Visualize several discretized DFEs given by the list of inference objects.
Note that the DFE parametrization needs to be the same for all inference objects.
:param inferences: List of inference objects.
:param labels: Unique labels for the DFEs.
:param file: Path to file to save the plot to.
:param show: Whether to show the plot.
:param title: Title of the plot.
:param kwargs: Additional arguments for the plot.
:return: Axes of the plot.
:raises ValueError: If no inference objects are given or no bootstraps are found.
"""
from .visualization import Visualization
if len(inferences) == 0:
raise ValueError('No inference objects given.')
# get sorted list of parameter names
param_names = sorted(list(inferences[0].get_bootstrap_param_names()))
if inferences[0].bootstraps is None:
raise ValueError('No bootstraps found.')
# create dict of dataframes
values = dict((k, inf.bootstraps) for k, inf in zip(labels, inferences))
return Visualization.plot_inferred_parameters_boxplot(
values=values,
param_names=param_names,
file=file,
show=show,
title=title,
)
[docs]
@staticmethod
def get_errors_params_mle(
ci_level: float,
confidence_intervals: bool,
inferences: List['AbstractInference'],
labels: Sequence,
param_names: Sequence,
bootstrap_type: Literal['percentile', 'bca'],
point_estimate: Literal['original', 'mean', 'median'] = 'mean',
) -> Tuple[Dict[str, Tuple[np.ndarray, np.ndarray] | None], Dict[str, np.ndarray]]:
"""
Get errors and values for MLE params of inferences.
:param ci_level: Confidence level for confidence intervals.
:param confidence_intervals: Whether to compute confidence intervals.
:param inferences: List of inference objects.
:param labels: Labels for the inferences.
:param param_names: Names of the parameters to get errors and values for.
:param bootstrap_type: Type of bootstrap to use.
:param point_estimate: Whether to use 'original' MLE values, 'mean' or 'median' of bootstraps as point estimate.
:return: dictionary of errors and dictionary of center values indexed by labels.
"""
errors, values, center = {}, {}, {}
for label, inf in zip(labels, inferences):
values[label] = list(inf.get_bootstrap_params()[k] for k in param_names)
# whether to compute errors
if confidence_intervals and inf.bootstraps is not None:
# compute errors
center[label], errors[label], _ = Bootstrap.get_errors(
values=values[label],
bs=inf.bootstraps[param_names].to_numpy(),
bootstrap_type=bootstrap_type,
ci_level=ci_level,
point_estimate=point_estimate,
)
else:
center[label] = np.array(values[label])
errors[label] = None
return errors, center
[docs]
@staticmethod
def get_discretized(
inferences: List['AbstractInference'],
labels: Sequence,
intervals: np.ndarray = np.array([-np.inf, -100, -10, -1, 0, 1, np.inf]),
confidence_intervals: bool = True,
ci_level: float = 0.05,
bootstrap_type: Literal['percentile', 'bca'] = 'percentile',
point_estimate: Literal['original', 'mean', 'median'] = 'mean'
) -> Tuple[Dict[str, np.ndarray], Dict[str, Optional[np.ndarray]]]:
"""
Get values and errors of discretized DFE.
:param inferences: List of inference objects.
:param labels: Labels for the DFEs.
:param bootstrap_type: Type of bootstrap to use
:param ci_level: Confidence interval level
:param confidence_intervals: Whether to compute confidence intervals
:param intervals: Array of interval boundaries over ``(-inf, inf)`` yielding ``intervals.shape[0] - 1`` bars.
:param point_estimate: Whether to use 'original' MLE values, 'mean' or 'median' of bootstraps as point estimate.
:return: Dictionary of values and dictionary of errors indexed by labels.
"""
values = {}
errors = {}
for label, inf in zip(labels, inferences):
if confidence_intervals and inf.bootstraps is not None:
# get bootstraps and errors if specified
values[label], errors[label], _ = Inference.get_stats_discretized(
params=inf.get_bootstrap_params(),
bootstraps=inf.bootstraps,
model=inf.model,
ci_level=ci_level,
intervals=intervals,
bootstrap_type=bootstrap_type,
point_estimate=point_estimate
)
else:
# otherwise just get discretized values
values[label] = Inference.compute_histogram(
params=inf.get_bootstrap_params(),
model=inf.model,
intervals=intervals
)
errors[label] = None
return values, errors
[docs]
@staticmethod
def get_stats_discretized(
params: dict,
bootstraps: pd.DataFrame,
model: Parametrization | str,
ci_level: float = 0.05,
intervals: np.ndarray = np.array([-np.inf, -100, -10, -1, 0, 1, np.inf]),
bootstrap_type: Literal['percentile', 'bca'] = 'percentile',
point_estimate: Literal['original', 'mean', 'median'] = 'mean'
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Compute errors and confidence interval for a discretized DFE.
:param params: Parameters of the model
:param bootstraps: Bootstrapped samples
:param model: DFE parametrization
:param ci_level: Confidence interval level
:param intervals: Array of interval boundaries yielding ``intervals.shape[0] - 1`` bins.
:param bootstrap_type: Type of bootstrap
:param point_estimate: Whether to use 'original' MLE values, 'mean' or 'median' of bootstraps as point estimate.
:return: Center values, errors around center, and confidence intervals.
"""
# discretize MLE DFE
values = Inference.compute_histogram(model, params, intervals)
# calculate bootstrapped histograms
# get discretized DFE per bootstrap sample
bs = np.array([Inference.compute_histogram(model, dict(r), intervals) for _, r in bootstraps.iterrows()])
return Bootstrap.get_errors(
values=values,
bs=bs,
bootstrap_type=bootstrap_type,
ci_level=ci_level,
point_estimate=point_estimate
)
[docs]
@staticmethod
def compute_histogram(
model: Parametrization | str,
params: dict,
intervals: np.ndarray
) -> np.ndarray:
"""
Discretize the DFE given a DFE parametrization and its parameter values.
:param model: DFE parametrization
:param params: Parameters of the model
:param intervals: Array of interval boundaries yielding ``intervals.shape[0] - 1`` bins.
:return: Discretized DFE
"""
# discrete DFE
y = _from_string(model)._discretize(params, intervals)
# return normalized histogram
return y / y.sum()
class AbstractInference(Serializable, ABC):
"""
Base class for main Inference and polyDFE wrapper.
"""
def __init__(self, **kwargs):
"""
Initialize the inference.
:param kwargs: Keyword arguments
"""
self._logger = logger.getChild(self.__class__.__name__)
self.bootstraps: Optional[pd.DataFrame] = None
self.params_mle: Optional[dict] = None
self.model: Optional[Parametrization] = None
@abstractmethod
def get_bootstrap_params(self) -> Dict[str, float]:
"""
Get the parameters to be included in the bootstraps.
:return: Parameters to be included in the bootstraps
"""
pass
def get_discretized(
self,
intervals: np.ndarray = np.array([-np.inf, -100, -10, -1, 0, 1, np.inf]),
confidence_intervals: bool = True,
ci_level: float = 0.05,
bootstrap_type: Literal['percentile', 'bca'] = 'percentile',
point_estimate: Literal['original', 'mean', 'median'] = 'mean'
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
"""
Get discretized DFE.
:param intervals: Array of interval boundaries over ``(-inf, inf)`` yielding ``intervals.shape[0] - 1`` bins.
:param confidence_intervals: Whether to return confidence intervals
:param ci_level: Confidence interval level
:param bootstrap_type: Type of bootstrap
:param point_estimate: Whether to use 'original' MLE values, 'mean' or 'median' of bootstraps as point estimate.
:return: Array of values and array of deviations
"""
values, errors = Inference.get_discretized(
inferences=[self],
labels=['all'],
intervals=intervals,
confidence_intervals=confidence_intervals,
ci_level=ci_level,
bootstrap_type=bootstrap_type,
point_estimate=point_estimate
)
return values['all'], errors['all']
def plot_discretized(
self,
file: str = None,
show: bool = True,
intervals: np.ndarray = np.array([-np.inf, -100, -10, -1, 0, 1, np.inf]),
confidence_intervals: bool = True,
ci_level: float = 0.05,
bootstrap_type: Literal['percentile', 'bca'] = 'percentile',
point_estimate: Literal['original', 'mean', 'median'] = 'mean',
title: str = 'discretized DFE',
ax: 'plt.Axes' = None,
kwargs_legend: dict = dict(prop=dict(size=8)),
) -> 'plt.Axes':
"""
Plot discretized DFE.
:param file: File to save the plot to
:param show: Whether to show the plot
:param intervals: Array of interval boundaries over ``(-inf, inf)`` yielding ``intervals.shape[0] - 1`` bars.
:param confidence_intervals: Whether to plot confidence intervals
:param ci_level: Confidence interval level
:param bootstrap_type: Type of bootstrap
:param point_estimate: Whether to use 'original' MLE values, 'mean' or 'median' of bootstraps as point estimate.
:param title: Title of the plot
:param ax: Axes to plot on. Only for Python visualization backend.
:param kwargs_legend: Keyword arguments passed to :meth:`plt.legend`. Only for Python visualization backend.
:return: Axes
"""
return Inference.plot_discretized(
inferences=[self],
file=file,
show=show,
intervals=intervals,
confidence_intervals=confidence_intervals,
ci_level=ci_level,
bootstrap_type=bootstrap_type,
point_estimate=point_estimate,
title=title,
kwargs_legend=kwargs_legend,
ax=ax
)
@abstractmethod
def get_bootstrap_param_names(self) -> List[str]:
"""
Get the names of the parameters to be included in the bootstraps.
"""
pass