Source code for fastdfe.optimization

"""
Optimization module.
"""

__author__ = "Janek Sendrowski"
__contact__ = "sendrowski.janek@gmail.com"
__date__ = "2023-02-26"

import copy
import logging
import math
from dataclasses import dataclass
from typing import Callable, List, Dict, Literal, Tuple, Optional, Sequence

import multiprocess as mp
import numpy as np
import pandas as pd
from numpy.linalg import norm
from numpy.random import Generator
from scipy.stats import loguniform, uniform
from tqdm import tqdm

from .likelihood import Likelihood
from .settings import Settings

# get logger
logger = logging.getLogger('fastdfe').getChild('Optimization')


def parallelize(
        func: Callable,
        data: Sequence,
        parallelize: bool = True,
        pbar: bool = None,
        desc: str = None,
        dtype: type = object,
        wrap_array: bool = True
) -> np.ndarray:
    """
    Parallelize given function or execute sequentially.

    :param parallelize: Whether to parallelize
    :param data: Data to iterate over
    :param func: Function to apply to each element of data
    :param pbar: Whether to show a progress bar
    :param desc: Description for progress bar
    :param dtype: Data type of the returned array
    :param wrap_array: Whether to wrap the result in a numpy array
    :return: List of results
    """
    n = len(data)

    if parallelize and n > 1 and Settings.parallelize is not False:
        # parallelize
        iterator = mp.Pool().imap(func, data)
    else:
        # sequentialize
        iterator = map(func, data)

    # whether to show a progress bar
    if pbar is True or (pbar is None and n > 1):
        iterator = tqdm(iterator, total=n, disable=Settings.disable_pbar, desc=desc)

    if wrap_array:
        return np.array(list(iterator), dtype=dtype)

    return list(iterator)


def flatten_dict(d: dict, separator='.', prefix=''):
    """
    Flatten dictionary.

    :param d: The nested dictionary
    :param separator: The separator character to use in the flattened dictionary keys
    :param prefix: The prefix to use in the flattened dictionary keys
    :return: The flattened dictionary
    """
    res = {}

    for key, value in d.items():
        if isinstance(value, dict):
            # recursive call
            res.update(flatten_dict(value, separator, prefix + key + separator))
        else:
            res[prefix + key] = value

    return res


def unflatten_dict(d: dict, separator='.'):
    """
    Unflatten dictionary.

    :param d: The flattened dictionary
    :param separator: The separator character used in the flattened dictionary keys
    :return: The original nested dictionary
    """
    res = {}

    for key, value in d.items():
        subkeys = key.split(separator)
        subdict = res

        # recursively create nested dictionaries for each subkey
        for subkey in subkeys[:-1]:
            subdict = subdict.setdefault(subkey, {})

        # assign value to the final subkey
        subdict[subkeys[-1]] = value if not isinstance(value, dict) else unflatten_dict(value, separator)

    return res


def unpack_params(x: np.ndarray, original: Dict[str, dict | tuple | float]) -> Dict[str, dict | tuple | float]:
    """
    Unpack params from numpy array. This is the inverse of pack_params and is used
    as scipy.optimize.minimize only accepts numpy arrays as parameters.

    :param x: Numpy array
    :param original: Original dictionary
    :return: Unpacked dictionary
    """
    keys = flatten_dict(original).keys()

    return unflatten_dict(dict(zip(keys, x)))


def pack_params(params: Dict[str, dict | tuple | float]) -> np.ndarray:
    """
    Pack params into numpy array. This is used as scipy.optimize.minimize only accepts
    numpy arrays as parameters. This is the inverse of unpack_params.

    :param params: Dictionary to pack
    :return: numpy array
    """
    flattened = flatten_dict(params)

    return np.array(list(flattened.values()))


def filter_dict(d, keys):
    """
    Recursively filter a dictionary by a list of given keys at the deepest level.

    :param d: The dictionary to filter
    :param keys: The list of keys to keep in the filtered dictionary
    :return: The filtered dictionary
    """
    filtered = {}

    for key, value in d.items():
        if isinstance(value, dict):
            filtered_sub = filter_dict(value, keys)
            if filtered_sub:
                filtered[key] = filtered_sub
        else:
            if key in keys:
                filtered[key] = value

    return filtered


def pack_shared(
        params: Dict[str, Dict[str, float]],
        shared: List['SharedParams'],
        shared_values: Dict[str, float]
) -> Dict[str, Dict[str, float]]:
    """
    Pack shared parameters. Here we extract shared parameters from
    type-specific keys and instead create entries with joint-types
    holding the shared parameters.
    Note that we only delete parameters from marginal types but in some
    cases we would have to delete them from compound types as well if
    a more general shared parameter were to be specified later on.
    We instead rely on the user not to do this.

    :param params: Dictionary of parameters indexed by type
    :param shared: List of shared parameters
    :param shared_values: Dictionary of parameters indexed by joint-type and parameter name, i.e. 'type1:type2.p1'
    :return: Packed dictionary
    """
    packed = copy.deepcopy(params)

    # iterator through shared parameter
    for s in shared:

        # remove shared keys
        for t in s.types:
            for p in s.params:
                if p in packed[t]:
                    packed[t].pop(p)

        # create joint type string
        type_string = ':'.join(s.types)

        # create shared type with appropriate parameters
        packed = merge_dicts(packed, {type_string: dict((p, shared_values[type_string + '.' + p]) for p in s.params)})

    return packed


def unpack_shared(params: dict) -> dict:
    """
    Unpack shared parameters. Here we extract shared parameters from joint-type keys
    and add them tp the targeted types.

    :param params: Dictionary of parameters
    :return: Unpacked dictionary
    """
    unpacked = {}

    for t, v in params.items():
        if ':' in t:
            unpacked = merge_dicts(unpacked, dict((s, v) for s in t.split(':')))
        else:
            unpacked[t] = v

    return unpacked


def expand_shared(params: List['SharedParams'], types: List[str], names: List[str]) -> List['SharedParams']:
    """
    Expand 'all' type for shared parameters.

    :param params: List of shared parameters
    :param types: List of types
    :param names: List of parameter names
    :return: Expanded list of shared parameters
    """
    expanded = []

    for x in params:
        # expand 'all' type
        types = types if x.types == 'all' else x.types
        params = names if x.params == 'all' else x.params

        # noinspection PyTypeChecker
        expanded.append(SharedParams(types=types, params=params))

    return expanded


def expand_fixed(
        fixed_params: Dict[str, Dict[str, float]],
        types: List[str]
) -> Dict[str, Dict[str, float]]:
    """
    Expand 'all' type for shared parameters.

    :param fixed_params: Dictionary of fixed parameters indexed by type and parameter
    :param types: List of types
    :return: Expanded dictionary of fixed parameters
    """
    expanded = {}

    # loop through fixed parameters
    for key_type, params in fixed_params.items():
        # expand 'all' type
        key_types = types if key_type == 'all' else [key_type]

        # loop through types
        for t in key_types:
            if t not in expanded:
                expanded[t] = {}

            if isinstance(params, dict):
                for param, value in params.items():
                    expanded[t][param] = value

    return expanded


def collapse_fixed_to_mean(
        expanded_params: Dict[str, Dict[str, float]],
        types: List[str]
) -> Dict[str, Dict[str, float]]:
    """
    Collapse expanded fixed parameters to 'all' type if all types have the same fixed parameter.
    Take the mean of the fixed parameters.

    :param expanded_params: Expanded dictionary of fixed parameters
    :param types: List of types
    :return: Collapsed dictionary of fixed parameters
    """
    out = {"all": dict(expanded_params.get("all", {}))}

    # collect params appearing in any type
    params = set()
    for t in types:
        params |= expanded_params.get(t, {}).keys()

    for p in params:
        vals = [expanded_params[t][p] for t in types if p in expanded_params.get(t, {})]
        if len(vals) == len(types):
            out["all"][p] = float(np.mean(vals))

    return out

def collapse_fixed(
        expanded_params: Dict[str, Dict[str, float]],
        types: List[str]
) -> Dict[str, Dict[str, float]]:
    """
    Collapse expanded fixed parameters to 'all' type if all types have the same fixed parameter.

    :param expanded_params: Expanded dictionary of fixed parameters
    :param types: List of types
    :return: Collapsed dictionary of fixed parameters
    """
    # copy first
    out = {k: dict(v) for k, v in expanded_params.items()}
    out.setdefault("all", {})

    # find params present in *every* type with identical value
    common = None
    for t in types:
        d = expanded_params.get(t, {})
        keys = set(d.keys())
        common = keys if common is None else (common & keys)

    if not common:
        return out

    for p in list(common):
        vals = [expanded_params[t][p] for t in types]
        if all(v == vals[0] for v in vals):
            out["all"][p] = vals[0]
            for t in types:
                out[t].pop(p, None)

    return out



def merge_dicts(dict1: dict, dict2: dict) -> dict:
    """
    Merge two dictionaries recursively.

    :param dict1: First dictionary
    :param dict2: Second dictionary
    :return: Merged dictionary
    """
    # make a copy of the first dictionary
    result = dict(dict1)

    # loop through the items in the second dictionary
    for key, value in dict2.items():

        # Check if the key already exists in the result dictionary and both the
        # value in the result and dict2 dictionaries are dictionaries.
        if key in result and isinstance(result[key], dict) and isinstance(value, dict):

            # recursively merge the two dictionaries
            result[key] = merge_dicts(result[key], value)
        else:
            # simply assign the value from dict2 to the result dictionary
            result[key] = value

    return result


def correct_values(
        params: Dict[str, float],
        bounds: Dict[str, Tuple[float, float]],
        scales: Dict[str, Literal['lin', 'log', 'symlog']],
        warn: bool = False,
        threshold: float = 1e-6
) -> Dict[str, float]:
    """
    Correct initial values so that they are within the specified bounds.

    :param bounds: Dictionary of bounds
    :param params: Flattened dictionary of parameters
    :param scales: Dictionary of scales
    :param warn: Whether to warn if values are corrected
    :param threshold: Threshold for the error to trigger a warning
    :return: Corrected dictionary
    """
    # create a copy of params
    corrected = params.copy()

    for key, value in params.items():
        # get base name
        name = key.split('.')[-1]

        # get real bounds
        bound = get_real_bounds(bounds[name], scale=scales[name])

        # correct value if outside bounds
        if value < bound[0]:
            corrected[key] = bound[0]
        elif value > bound[1]:
            corrected[key] = bound[1]

    # differences between the original and corrected dictionaries
    differences = {key: (params[key], corrected[key]) for key in params if params[key] != corrected[key]}

    # warn if there are differences that exceed the threshold
    exceeded_threshold = {}
    for key, (old_val, new_val) in differences.items():

        # calculate relative error
        err = np.abs(new_val - old_val)

        # add if it exceeds the relative error
        if err > threshold:
            exceeded_threshold[key] = f"{old_val} -> {new_val}"

    if exceeded_threshold and warn:
        logger.warning(f'Given initial values outside bounds. Adjusting {exceeded_threshold}.')

    return corrected


def get_real_bounds(bounds: Tuple[float, float], scale: Literal['lin', 'log', 'symlog']) -> Tuple[float, float]:
    """
    Get real bounds from the given bounds.

    :param bounds: Bounds of the parameter
    :param scale: Scale of the parameter
    :return:
    """
    if scale == 'symlog':
        return -bounds[1], bounds[1]

    return bounds


def evaluate_counts(get_counts: dict, params: dict):
    """
    Evaluate counts using the given parameters.
    Here we assign the parameters to the appropriate types
    obtaining the counts for each type.

    :param get_counts: Dictionary of functions to evaluate counts for each type
    :param params: Dictionary of parameters
    :return: Dictionary of counts
    """
    counts = {}

    # unpack shared parameters
    unpacked = unpack_shared(params)

    # evaluate counts for each type
    for key in get_counts.keys():
        counts[key] = get_counts[key](unpacked[key])

    return counts


def to_symlog(x: float, linthresh: float = 1e-5) -> float:
    """
    Convert a value to the symlog scale.

    :param x: The input value on the original scale.
    :param linthresh: The positive value that determines the range within which the
        symlog scale is linear. Must be greater than 0.
    :return: The value on the symlog scale.
    """
    sign = np.sign(x)
    abs_x = np.abs(x)
    log_x = np.log10(abs_x + linthresh) - np.log10(linthresh)

    return sign * (abs_x / linthresh if abs_x <= linthresh else log_x)


def from_symlog(y: float, linthresh: float = 1e-5) -> float:
    """
    Convert a value from the symlog scale back to the original scale.

    :param y: The input value on the symlog scale.
    :param linthresh: The positive value that determines the range within which the
        symlog scale is linear. Must be greater than 0.
    :return: The value on the original scale.
    """
    sign = np.sign(y)
    abs_y = np.abs(y)
    exp_y = np.power(10, abs_y + np.log10(linthresh)) - linthresh

    return sign * (abs_y * linthresh if abs_y <= 1 else exp_y)


def scale_bound(bounds: Tuple[float, float], scale: Literal['lin', 'log', 'symlog']):
    """
    Convert a bound to the specified scale. For symlog scale we assume the symmetric bounds,
    so that the upper bound denotes the boundaries and the lower bound the linear threshold.

    :param bounds: The bound to convert
    :param scale: The scale to convert to
    :return: The converted bound
    :raises ValueError: if the scale is unknown
    """
    if scale == 'lin':
        return bounds

    if scale == 'log':
        if bounds[1] < 0:
            return -np.log10(-bounds[0]), -np.log10(-bounds[1])

        if bounds[0] > 0:
            return np.log10(bounds[0]), np.log10(bounds[1])

        raise ValueError('Bounds must not span zero for log scale.')

    if scale == 'symlog':

        if bounds[0] <= 0 or bounds[1] <= 0:
            raise ValueError('Both bounds must be positive for symlog scale.')

        return to_symlog(-bounds[1], linthresh=bounds[0]), to_symlog(bounds[1], linthresh=bounds[0])

    raise ValueError(f'Unknown scale {scale}.')


def unscale_bound(
        scaled_bounds: Tuple[float, float],
        scale: Literal['lin', 'log', 'symlog'],
        linthresh: float = 1e-5
) -> Tuple[float, float]:
    """
    Convert a bound from the specified scale back to the original scale. For symlog scale,
    we assume symmetric bounds, so that the upper bound denotes the boundaries and
    the lower bound the linear threshold, i.e. ``bounds = (-bounds[1], bounds[1])`` and ``linthresh = bounds[0]``.
    Note that we cannot reliably recover negative bounds that were log scaled.

    :param linthresh:
    :param scaled_bounds: The bound to convert
    :param scale: The scale to convert from
    :return: The converted bound
    :raises ValueError: if the scale is unknown
    """
    if scale == 'lin':
        return scaled_bounds

    if scale == 'log':
        return np.power(10, scaled_bounds[0]), np.power(10, scaled_bounds[1])

    if scale == 'symlog':
        upper_bound = from_symlog(scaled_bounds[1], linthresh=linthresh)

        return linthresh, upper_bound

    raise ValueError(f'Unknown scale {scale}.')


def scale_bounds(
        bounds: Dict[str, Tuple[float, float]],
        scales: Dict[str, Literal['lin', 'log', 'symlog']]
) -> Dict[str, Tuple[float, float]]:
    """
    Convert bounds to the specified scale. For symlog scale we assume the symmetric bounds,
    so that the upper bound denotes the boundaries and the lower bound the linear threshold.

    :param bounds: Flattened dictionary of bounds to convert index by type and parameter
    :param scales: Dictionary of scales indexed by parameter
    :return: The converted bounds
    :raises ValueError: if the scale is unknown
    """
    scaled_bounds = {}

    for key, value in bounds.items():
        scaled_bounds[key] = scale_bound(value, scale=scales[get_basename(key)])

    return scaled_bounds


def scale_value(value: float, bounds: Tuple[float, float], scale: Literal['lin', 'log', 'symlog']) -> float:
    """
    Convert a value to the specified scale. For symlog scale, the untransformed bounds are needed,
    so that the upper bound denotes the boundaries and the lower bound the linear threshold.

    :param value: The value to convert.
    :param bounds: The untransformed bounds for the symlog scale.
    :param scale: The scale to convert to.
    :return: The converted value.
    :raises ValueError: if the scale is unknown.
    """
    if scale == 'lin':
        return value

    if scale == 'log':

        if value < 0:
            return -np.log10(-value)

        return np.log10(value)

    if scale == 'symlog':
        return to_symlog(value, linthresh=bounds[0])

    raise ValueError(f'Unknown scale {scale}.')


def unscale_value(scaled_value: float, bounds: Tuple[float, float], scale: Literal['lin', 'log', 'symlog']) -> float:
    """
    Convert a value from the specified scale back to the original scale. For symlog scale,
    the untransformed bounds are needed, so that the upper bound denotes the boundaries
    and the lower bound the linear threshold.

    :param scaled_value: The value to convert.
    :param bounds: The untransformed bounds for the symlog scale.
    :param scale: The scale to convert from.
    :return: The converted value.
    :raises ValueError: if the scale is unknown.
    """
    if scale == 'lin':
        return scaled_value

    if scale == 'log':

        if bounds[1] < 0:
            return -np.power(10, -scaled_value)

        return np.power(10, scaled_value)

    if scale == 'symlog':
        return from_symlog(scaled_value, linthresh=bounds[0])

    raise ValueError(f'Unknown scale {scale}.')


def perturb_value(value: float, bounds: Tuple[float, float], rng: np.random.Generator) -> float:
    """
    Perturb a value within the given bounds using a normal distribution with mean at the value
    and standard deviation equal to the value.

    :param value: The value to perturb.
    :param bounds: The bounds to perturb within.
    :param rng: The random number generator to use.
    :return: The perturbed value.
    """
    std = 2 * abs(value) if value != 0 else 0.1 * abs(bounds[1] - bounds[0])
    perturbed = rng.normal(loc=value, scale=std)

    # ensure within bounds
    perturbed = max(bounds[0], min(bounds[1], perturbed))

    return perturbed


def scale_values(
        params: Dict[str, Dict[str, float]],
        bounds: Dict[str, Tuple[float, float]],
        scales: Dict[str, Literal['lin', 'log', 'symlog']]
) -> Dict[str, Dict[str, float]]:
    """
    Scale values according to the given scales.

    :param params: Nested dictionary of parameters indexed by type and parameter
    :param scales: Dictionary of scales indexed by parameter name
    :param bounds: Dictionary of bounds indexed by parameter name
    :return: Nested dictionary of scaled parameters indexed by type and parameter
    """
    scaled = {}

    for key, value in flatten_dict(params).items():
        # scale value
        scaled[key] = scale_value(value, bounds[get_basename(key)], scales[get_basename(key)])

    return unflatten_dict(scaled)


def unscale_values(
        params: Dict[str, Dict[str, float]],
        bounds: Dict[str, Tuple[float, float]],
        scales: Dict[str, Literal['lin', 'log', 'symlog']]
) -> Dict[str, Dict[str, float]]:
    """
    Unscale values according to the given scales.

    :param params: Nested dictionary of parameters indexed by type and parameter
    :param scales: Dictionary of scales indexed by parameter name
    :param bounds: Dictionary of scales indexed by parameter name
    :return: Nested dictionary of unscaled parameters indexed by type and parameter
    """
    unscaled = {}

    for key, value in flatten_dict(params).items():
        # unscale value
        unscaled[key] = unscale_value(value, bounds[get_basename(key)], scales[get_basename(key)])

    return unflatten_dict(unscaled)


def get_basename(name: str) -> str:
    """
    Get the basename of parameter string, i.e. type.param -> param.

    :param name: The string to get the basename from.
    :return: The basename.
    """
    return name.split('.')[-1]


def check_bounds(
        bounds: Dict[str, Tuple[float, float]],
        params: Dict[str, float],
        fixed_params: Dict[str, float] = {},
        percentile: float = 1,
        scale: Literal['lin', 'log'] = 'lin'
) -> Tuple[Dict[str, Tuple[float, float, float]], Dict[str, Tuple[float, float, float]]]:
    """
    Issue warnings if the passed parameters are close to the specified bounds.

    :param bounds: The bounds to check against.
    :param params: The parameters to check.
    :param fixed_params: The fixed parameters.
    :param percentile: The percentile threshold to consider a parameter close to the bounds.
    :param scale: Scale type: 'lin' for linear and 'log' for logarithmic.
    :return: Tuple of dictionaries of parameters close to the lower and upper bounds, i.e. (lower, value, upper).
    """
    near_lower = {}
    near_upper = {}

    def transform(value: float, to_scale: Literal['lin', 'log']) -> float:
        """
        Transform a value to the specified scale.

        :param value: The value to transform.
        :param to_scale: The scale to transform to.
        :return: The transformed value.
        """
        if to_scale == 'log':
            return math.log(value) if value > 0 else -float('inf')

        return value

    for key, value in params.items():
        # get base name
        name = key.split('.')[-1]

        # get bounds
        lower, upper = bounds[name]

        # transform values
        _lower = transform(lower, scale)
        _upper = transform(upper, scale)
        _value = transform(value, scale)

        if key not in fixed_params:
            if _lower is not None and (_value - _lower) / (_upper - _lower) <= percentile / 100:
                near_lower[key] = (lower, value, upper)

            if _upper is not None and (_upper - _value) / (_upper - _lower) <= percentile / 100:
                near_upper[key] = (lower, value, upper)

    return near_lower, near_upper


[docs] @dataclass class SharedParams: """ Class specifying the sharing of params among types. ``all`` means all available types or params. Example usage: :: import fastdfe as fd # neutral SFS for two types sfs_neut = fd.Spectra(dict( pendula=[177130, 997, 441, 228, 156, 117, 114, 83, 105, 109, 652], pubescens=[172528, 3612, 1359, 790, 584, 427, 325, 234, 166, 76, 31] )) # selected SFS for two types sfs_sel = fd.Spectra(dict( pendula=[797939, 1329, 499, 265, 162, 104, 117, 90, 94, 119, 794], pubescens=[791106, 5326, 1741, 1005, 756, 546, 416, 294, 177, 104, 41] )) # create inference object inf = fd.JointInference( sfs_neut=sfs_neut, sfs_sel=sfs_sel, shared_params=[fd.SharedParams(types=["pendula", "pubescens"], params=["eps", "S_d"])], do_bootstrap=True ) # run inference inf.run() """ #: The params to share params: List[str] | Literal['all'] = 'all' #: The types to share types: List[str] | Literal['all'] = 'all'
[docs] @dataclass class Covariate: """ Class defining a covariate which induces a relationship with one or many parameters. The relationship is defined by a callback function which modifies the parameters. The default callback introduces a linear relationship. Below an example of introducing linear covariates for ``S_d``, the for mean strength of deleterious mutations (cf. :class:`~fastdfe.parametrization.GammaExpParametrization`). Each of the three types is associated with one covariate. This we pass to :class:`~fastdfe.joint_inference.JointInference` together with the stratified spectra: :: import fastdfe as fd cov = fd.Covariate( param='S_d', values=dict(type1=5, type2=3, type3=1) ) """ #: The parameter to modify param: str #: The values of the covariate for each type values: Dict[str, float] #: The callback function to modify the parameters callback: Optional[Callable] = None #: The bounds of the covariate parameter to be estimated bounds: tuple = (1e-4, 1e4) #: The initial value of the covariate x0: float = 0 #: The scale of the bounds. See :func:`scale_value` for details bounds_scale: Literal['lin', 'log', 'symlog'] = 'symlog' def __post_init__(self): """ Cast bounds to tuple and check if an inverse_callback is provided when a custom callback is specified. """ self.bounds = tuple(self.bounds)
[docs] def apply(self, covariate: float, type: str, params: Dict[str, float]) -> Dict[str, float]: """ Apply the custom or default callback to modify the given parameters. :param covariate: The value of the covariate. :param type: The type of the relationship. :param params: The input parameters. :return: Modified parameters. """ # Use custom callback if given else default callback callback = self.apply_default if self.callback is None else self.callback return callback(covariate=covariate, type=type, params=params)
[docs] def apply_default(self, covariate: float, type: str, params: Dict[str, float]) -> Dict[str, float]: """ Modify the given parameters introducing a linear relationship with the given covariate. :param covariate: The value of the covariate. :param type: The type of the relationship. :param params: The input parameters. :return: Modified parameters. """ # create a copy of input parameters modified = params.copy() # introduce linear relationship if self.param in params: modified[self.param] += covariate * self.values[type] return modified
@staticmethod def _apply(covariates: Dict[str, 'Covariate'], params: dict, type: str) -> dict: """ Apply given covariates to given parameters. :param covariates: Dictionary of covariates to add :param params: Dict of parameters :param type: SFS type :return: Dict of parameters with covariates added """ for k, cov in covariates.items(): params = cov.apply( covariate=params[k], type=type, params=params ) return params
class Optimization: """ Class for optimizing the DFE. """ #: Optimization method to use. Use class property for downward compatibility method_mle = 'L-BFGS-B' def __init__( self, bounds: Dict[str, Tuple[float, float]], param_names: List[str], loss_type: Literal['likelihood', 'L2'] = 'likelihood', opts_mle: dict = {}, method_mle: str = 'L-BFGS-B', parallelize: bool = True, fixed_params: Dict[str, Dict[str, float]] = {}, scales: Dict[str, Literal['lin', 'log', 'symlog']] = {}, seed: int = None ): """ Create object. :param parallelize: Whether to parallelize the optimization :param bounds: Dictionary of bounds :param opts_mle: Dictionary of options for the optimizer :param method_mle: Optimization method to use :param loss_type: Type of loss function to use :param fixed_params: Dictionary of fixed parameters :param scales: Dictionary of scales :param param_names: List of parameter names """ #: Parameter bounds self.bounds: Dict[str, Tuple[float, float]] = bounds #: Parameter scales to use self.scales: Dict[str, Literal['lin', 'log', 'symlog']] = scales #: additional options for the optimizer self.opts_mle: dict = opts_mle #: Optimization method to use self.method_mle: str = method_mle #: Type of loss function to use self.loss_type: str = loss_type #: Fixed parameters self.fixed_params: Dict[str, Dict[str, float]] = flatten_dict(fixed_params) # check if fixed parameters are within the specified bounds if correct_values(self.fixed_params, self.bounds, warn=False, scales=scales) != self.fixed_params: raise ValueError('Fixed parameters are outside the specified bounds. ' f'Fixed params: {self.fixed_params}, bounds: {self.bounds}.') #: Parameter names self.param_names: List[str] = param_names #: Whether to parallelize the optimization self.parallelize: bool = parallelize #: Initial values self.x0: Optional[dict] = None #: Number of runs self.n_runs: Optional[int] = None #: DataFrame holding information about all optimization runs self.runs: Optional[pd.DataFrame] = None #: Random generator instance self.rng = np.random.default_rng(seed=seed) def run( self, get_counts: Dict[str, Callable], x0: Dict[str, Dict[str, float]] = {}, scales: Dict[str, Literal['lin', 'log', 'symlog']] = {}, bounds: Dict[str, Tuple[float, float]] = {}, n_runs: int = 1, debug_iterations: bool = True, print_info: bool = True, opts_mle: dict = None, pbar: bool = None, desc: str = 'Inferring DFE', ) -> Tuple['scipy.optimize.OptimizeResult', dict]: """ Perform the optimization procedure. :param scales: Scales of the parameters :param bounds: Bounds of the parameters :param n_runs: Number of independent optimization runs out of which the best one is chosen. The first run will use the initial values if specified. Consider increasing this number if the optimization does not produce good results. :param x0: Dictionary of initial values in the form ``{type: {param: value}}`` :param get_counts: Dictionary of functions to evaluate counts for each type :param debug_iterations: Whether to print debug messages for each iteration :param opts_mle: Dictionary of options for the optimizer :param print_info: Whether to print information about the bounds :param pbar: Whether to show a progress bar :param desc: Description for the progress bar :return: The optimization result and the likelihoods """ from scipy.optimize import minimize, OptimizeResult # number of optimization runs self.n_runs = n_runs # store the scales of the parameters if scales: self.scales = scales # store the bounds of the parameters if bounds: self.bounds = bounds # store the options for the optimizer if opts_mle: self.opts_mle = opts_mle # filter out unneeded values # this also holds the fixed parameters self.x0 = filter_dict(x0, self.param_names) # flatten initial values flattened = flatten_dict(self.x0) # determine parameter names of parameters to be optimized optimized_param_names = list(set(flattened) - set(self.fixed_params)) # issue debug messages logger.debug(f'Performing optimization on {len(flattened)} parameters: {list(flattened.keys())}.') logger.debug(f'Using initial values: {flattened}.') if print_info: logger.info(f"Optimizing {len(optimized_param_names)} parameters: [{', '.join(optimized_param_names)}].") # issue warning when the number of parameters to be optimized is large if len(optimized_param_names) > 10: logger.warning(f'A large number of parameters is optimized jointly. ' f'Please be aware that this makes it harder to find a good optimum.') # correct initial values to be within bounds self.x0 = unflatten_dict(correct_values(flattened, self.bounds, warn=True, scales=self.scales)) # determine parameter bounds bounds = self.get_bounds(flatten_dict(self.x0)) def optimize(x0: Dict[str, Dict[str, float]]) -> OptimizeResult: """ Perform numerical minimization. :param x0: Dictionary of initial values in the form ``{type: {param: value}}`` :return: Optimization result """ logger.debug(f"Initial parameters: {x0}.") return minimize( fun=self.get_loss_function( get_counts=get_counts, print_debug=debug_iterations ), x0=pack_params(self.scale_values(x0)), method=self.method_mle, bounds=pack_params(scale_bounds(bounds, self.scales)), options=self.opts_mle ) # initial parameters for the samples initial_params = [self.x0] + [self.sample_x0(self.x0) for _ in range(int(self.n_runs) - 1)] # parallelize MLE for different initializations results = parallelize(optimize, initial_params, self.parallelize, pbar=pbar, desc=desc) # build a pandas DataFrame of all runs records = [] for i, res in enumerate(results): params = unpack_params(res.x, self.x0) params = unscale_values(params, self.bounds, self.scales) flat = flatten_dict(params) flat['likelihood'] = -res.fun flat['success'] = res.success flat['result'] = str(str) flat['x0'] = str(flatten_dict(initial_params[i])) records.append(flat) # store runs as DataFrame self.runs = pd.DataFrame(records) # get result with the lowest likelihood result = results[np.argmax(self.runs.likelihood)] # unpack MLE params array into a dictionary params_mle = unpack_params(result.x, self.x0) # unscale parameters params_mle = unscale_values(params_mle, self.bounds, self.scales) # check if the MLE reached one of the bounds if print_info: self.check_bounds(flatten_dict(params_mle)) return result, params_mle def scale_values(self, values: Dict[str, Dict[str, float]]) -> Dict[str, Dict[str, float]]: """ Scale the values of the parameters. :param values: Dictionary of initial values in the form ``{type: {param: value}}`` :return: Dictionary of scaled initial values """ return scale_values(values, self.bounds, self.scales) def get_loss_function( self, get_counts: Dict[str, Callable], print_debug: bool = True ) -> Callable: """ Get the loss function. :param get_counts: Dictionary of functions to evaluate counts for each type :param print_debug: Whether to print debug messages :return: The loss function """ def loss(x: np.ndarray) -> float: """ The loss function. :param x: Parameters :return: The loss """ # unpack parameters into dictionary using the keys of self.x0 params = unpack_params(x, self.x0) # unscale parameters params = unscale_values(params, self.bounds, self.scales) # Model SFS from parameters. # Here the order of types does not matter. # We only collect the counts for types that are # given in get_counts. This makes it possible to # avoid specifying type 'all' which is of no use # in joint inference. counts_dict = evaluate_counts(get_counts, params) # flatten and convert to array counts = np.array(list(counts_dict.values())) # reshape and merge counts_modelled, counts_observed = np.stack(counts, axis=1).reshape(2, -1) # use independent Poisson likelihoods LL = Likelihood.log_poisson(mu=counts_modelled, k=counts_observed) # combine likelihoods ll = np.sum(LL) # compute L2 norm L2 = norm(counts_modelled - counts_observed, 2) # information on iteration iter_info = flatten_dict(params) | dict(likelihood=ll, L2=L2) # log likelihood if print_debug: # check likelihood if np.isnan(ll): raise ValueError('Oh boy, likelihood is nan. This is no good...') # log variables logger.debug(iter_info) # return appropriate loss return dict(L2=L2, likelihood=-ll)[self.loss_type] return loss def sample_x0(self, example: dict, random_state: int | Generator = None) -> Dict[str, dict]: """ Sample initial values. :param example: An example dictionary for generating the initial values :param random_state: Random state or seed :return: A dictionary of initial values """ if random_state is None: random_state = self.rng sample = {} for key, value in example.items(): if isinstance(value, dict): sample[key] = self.sample_x0(value, random_state) elif key in self.bounds and key in self.scales: sample[key] = self.sample_value(self.bounds[key], self.scales[key], random_state) return sample @staticmethod def sample_value( bounds: Tuple[float, float], scale: Literal['lin', 'log', 'symlog'], random_state: int | Generator = None ) -> float: """ Sample a value between given bounds using the given scaling. This function works for positive, negative, and mixed bounds. Note that when ``scale == 'symlog'``, ``bounds[0]`` defines the linear threshold and the actual bounds are ``(-bounds[1], bounds[1])``. :param bounds: Tuple of lower and upper bounds :param scale: Scaling of the parameter. :param random_state: Random state or seed :return: Sampled value """ def flip(bounds: Tuple[float, float]) -> Tuple[float, float]: """ Flip the bounds. :param bounds: Tuple of lower and upper bounds :return: Flipped bounds """ return -bounds[1], -bounds[0] def symlog_rvs(lower: float, upper: float, random_state: int | Generator = None) -> float: """ Sample from a symmetric log-uniform distribution. :param lower: Lower bound which is the linear threshold :param upper: Upper bound so that the actual bounds are (-upper, upper) :param random_state: Random state :return: Sampled value """ val = loguniform.rvs(lower, upper, random_state=random_state) # flip sign with 50% probability return val if uniform.rvs() < 0.5 else -val # dictionary of scaling functions scaling_functions = { 'lin': uniform.rvs, 'log': loguniform.rvs, 'symlog': symlog_rvs } # raise an error if the scale is not valid if scale not in scaling_functions: raise ValueError(f"Scale must be one of: {', '.join(scaling_functions.keys())}") # raise an error if bounds span 0 and scale is 'log' if bounds[0] < 0 < bounds[1] and scale == 'log': raise ValueError(f"Log scale not possible for bounds that span 0.") # raise an error if bounds are negative and scale is 'symlog' if bounds[0] < 0 and scale == 'symlog': raise ValueError(f"Symlog scale not possible for negative bounds.") # flip bounds if they are negative flipped = bounds[0] < 0 if flipped: bounds = flip(bounds) # sample a value using the appropriate scaling function sample = scaling_functions[scale](bounds[0], bounds[1] - bounds[0], random_state=random_state) # return the sampled value, flipping back if necessary return -sample if flipped else sample def check_bounds(self, params: Dict[str, float], percentile: float = 1) -> None: """ Check if the given parameters are within the bounds. :param params: Parameters :param percentile: Percentile of the bounds to check :return: Whether the parameters are within the bounds """ # we scale the bounds to obtain more sensible warnings bounds = scale_bounds(self.bounds, self.scales) params_scaled = flatten_dict(self.scale_values(unflatten_dict(params))) # get parameters close to the bounds near_lower, near_upper = check_bounds( params=params_scaled, bounds=bounds, fixed_params=self.fixed_params, percentile=percentile, scale='lin' ) if len(near_lower | near_upper) > 0: def get_values(keys: List[str]) -> Dict[str, Tuple[str, str, str]]: """ Unscale the parameters. :param keys: List of parameter names :return: Unscaled parameters """ unscaled = {} for key in keys: unscaled[key] = ( "{:g}".format(self.bounds[get_basename(key)][0]), "{:.8g}".format(params[key]), "{:g}".format(self.bounds[get_basename(key)][1]) ) return unscaled # string representation of parameters near_lower_unscaled = str(get_values(list(near_lower.keys()))).replace('\'', '') near_upper_unscaled = str(get_values(list(near_upper.keys()))).replace('\'', '') # issue warning logger.warning( f'The MLE estimate is close to the upper bound ' f'for {near_upper_unscaled} and lower bound ' f'for {near_lower_unscaled} [(lower, value, upper)], but ' f'this might be nothing to worry about.' ) def get_bounds(self, x0: Dict[str, float]) -> Dict[str, Tuple[float, float]]: """ Get a nested dictionary of bounds the same structure as the given initial values. :param x0: Flattened dictionary of initial values :return: A dictionary of initial values """ bounds = {} for key, value in x0.items(): # check if the parameter is fixed if key in self.fixed_params: bounds[key] = (self.fixed_params[key], self.fixed_params[key]) else: bounds[key] = self.bounds[key.split('.')[-1]] return bounds def set_fixed_params(self, fixed_params: Dict[str, Dict[str, float]]): """ Set fixed parameters. We flatten the dictionary to make it easier to work with. :param fixed_params: Dictionary of fixed parameters """ self.fixed_params = flatten_dict(fixed_params) @staticmethod def perturb_params( params: Dict[str, Dict[str, float]], bounds: Dict[str, Tuple[float, float]], seed: Optional[int] = None, ) -> Dict[str, Dict[str, float]]: """ Perturb values within the given bounds using a normal distribution with mean at the value and standard deviation equal to the value. :param params: Nested dictionary of parameters indexed by type and parameter :param bounds: Dictionary of bounds indexed by parameter name :param seed: Seed for the random number generator :return: Nested dictionary of perturbed parameters indexed by type and parameter """ rng = np.random.default_rng(seed) perturbed = {} for key, value in flatten_dict(params).items(): # perturb value perturbed[key] = perturb_value(value, bounds[get_basename(key)], rng) return unflatten_dict(perturbed)