"""
VCF annotations and an annotator to apply them.
"""
__author__ = "Janek Sendrowski"
__contact__ = "sendrowski.janek@gmail.com"
__date__ = "2023-05-09"
import itertools
import logging
import re
import subprocess
import tempfile
from abc import ABC, abstractmethod
from collections import Counter, defaultdict
from enum import Enum
from functools import cached_property
from io import StringIO
from itertools import product
from typing import List, Optional, Dict, Tuple, Callable, Literal, Iterable, cast, Any, Generator, Union
import Bio.Data.CodonTable
import jsonpickle
import numpy as np
import pandas as pd
from Bio import Phylo
from Bio.Phylo.BaseTree import Clade, Tree
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from scipy.stats import hypergeom
from tqdm import tqdm
from .io_handlers import DummyVariant, MultiHandler, FASTAHandler
from .io_handlers import GFFHandler, get_major_base, get_called_bases
from .optimization import parallelize as parallelize_func, check_bounds
from .settings import Settings
from .spectrum import Spectra
# get logger
logger = logging.getLogger('fastdfe')
# order of the bases important
bases = np.array(['A', 'C', 'G', 'T'])
# base indices
base_indices = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
# codon table
codon_table = Bio.Data.CodonTable.standard_dna_table.forward_table
# stop codons
stop_codons = Bio.Data.CodonTable.standard_dna_table.stop_codons
# start codons
start_codons = ['ATG']
# include stop codons in codon table
for c in stop_codons:
codon_table[c] = 'Σ'
# The degeneracy of the site according to how many unique amino acids
# are coding for when changing the site within the codon.
# We count the third position of the isoleucine codon as 2-fold degenerate.
# This is the only site that would normally have 3-fold degeneracy
# (https://en.wikipedia.org/wiki/Codon_degeneracy)
unique_to_degeneracy = {0: 0, 1: 2, 2: 2, 3: 4}
[docs]
class Annotation(ABC):
"""
Base class for annotations.
"""
[docs]
def __init__(self):
"""
Create a new annotation instance.
"""
#: The logger.
self._logger = logger.getChild(self.__class__.__name__)
#: The annotator.
self._handler: Annotator | None = None
#: The number of annotated sites.
self.n_annotated: int = 0
def _setup(self, handler: MultiHandler):
"""
Provide context by passing the annotator. This should be called before the annotation starts.
:param handler: The handler.
"""
self._handler = handler
def _rewind(self):
"""
Rewind the annotation.
"""
self.n_annotated = 0
def _teardown(self):
"""
Finalize the annotation. Called after all sites have been annotated.
"""
self._logger.info(f'Annotated {self.n_annotated} sites.')
[docs]
@abstractmethod
def annotate_site(self, variant: Union['cyvcf2.Variant', DummyVariant]):
"""
Annotate a single site.
:param variant: The variant to annotate.
:return: The annotated variant.
"""
pass
[docs]
@staticmethod
def count_target_sites(file: str, remove_overlaps: bool = False, contigs: List[str] = None) -> Dict[str, int]:
"""
Count the number of target sites in a GFF file.
:param file: The path to The GFF file path, possibly gzipped or a URL
:param remove_overlaps: Whether to remove overlapping target sites.
:param contigs: The contigs to count the target sites for.
:return: The number of target sites per chromosome/contig.
"""
return GFFHandler(file)._count_target_sites(
remove_overlaps=remove_overlaps,
contigs=contigs
)
[docs]
class DegeneracyAnnotation(Annotation):
"""
Degeneracy annotation. We annotate the degeneracy by looking at each codon for coding variants.
This also annotates mono-allelic sites.
This annotation adds the info fields ``Degeneracy`` and ``Degeneracy_Info``, which hold the degeneracy
of a site (0, 2, 4) and extra information about the degeneracy, respectively. To be used with
:class:`~fastdfe.parser.DegeneracyStratification`.
For this annotation to work, we require a FASTA and GFF file (passed to :class:`~fastdfe.parser.Parser` or
:class:`~fastdfe.annotation.Annotator`).
Example usage:
::
import fastdfe as fd
ann = fd.Annotator(
vcf="http://ftp.1000genomes.ebi.ac.uk/vol1/ftp/data_collections/"
"1000_genomes_project/release/20181203_biallelic_SNV/"
"ALL.chr21.shapeit2_integrated_v1a.GRCh38.20181129.phased.vcf.gz",
fasta="http://ftp.ensembl.org/pub/release-109/fasta/homo_sapiens/"
"dna/Homo_sapiens.GRCh38.dna.chromosome.21.fa.gz",
gff="http://ftp.ensembl.org/pub/release-109/gff3/homo_sapiens/"
"Homo_sapiens.GRCh38.109.chromosome.21.gff3.gz",
output='sapiens.chr21.degeneracy.vcf.gz',
annotations=[fd.DegeneracyAnnotation()],
aliases=dict(chr21=['21'])
)
ann.annotate()
"""
#: The genomic positions for coding sequences that are mocked.
_pos_mock: int = 1e100
[docs]
def __init__(self):
"""
Create a new annotation instance.
"""
Annotation.__init__(self)
#: The current coding sequence or the closest coding sequence downstream.
self._cd: Optional[pd.Series] = None
#: The coding sequence following the current coding sequence.
self._cd_next: Optional[pd.Series] = None
#: The coding sequence preceding the current coding sequence.
self._cd_prev: Optional[pd.Series] = None
#: The current contig.
self._contig: Optional[SeqRecord] = None
#: The variants that could not be annotated correctly.
self.mismatches: List['cyvcf2.Variant'] = []
#: The variant that were skipped because they were not in coding regions.
self.n_skipped: int = 0
#: The variants for which the codon could not be determined.
self.errors: List['cyvcf2.Variant'] = []
def _setup(self, handler: MultiHandler):
"""
Provide context to the annotator.
:param handler: The handler.
"""
# require FASTA and GFF files
handler._require_fasta(self.__class__.__name__)
handler._require_gff(self.__class__.__name__)
# call super
super()._setup(handler)
# touch the cached properties to make for a nicer logging experience
# noinspection PyStatementEffect
self._handler._cds
# noinspection PyStatementEffect
self._handler._ref
handler._reader.add_info_to_header({
'ID': 'Degeneracy',
'Number': '.',
'Type': 'Integer',
'Description': 'n-fold degeneracy'
})
handler._reader.add_info_to_header({
'ID': 'Degeneracy_Info',
'Number': '.',
'Type': 'Integer',
'Description': 'Additional information about degeneracy annotation'
})
def _rewind(self):
"""
Rewind the annotation.
"""
Annotation._rewind(self)
self._cd = None
self._cd_next = None
self._cd_prev = None
self._contig = None
def _parse_codon_forward(self, variant: Union['cyvcf2.Variant', DummyVariant]):
"""
Parse the codon in forward direction.
:param variant: The variant.
:return: Codon, Codon position, Codon start position, Position within codon, and relative position.
"""
# position relative to start of coding sequence
pos_rel = variant.POS - (self._cd.start + int(self._cd.phase))
# position relative to codon
pos_codon = pos_rel % 3
# inclusive codon start, 1-based
codon_start = variant.POS - pos_codon
# the codon positions
codon_pos = [codon_start, codon_start + 1, codon_start + 2]
if (self._cd_prev is None or self._cd_next.start == self._pos_mock) and codon_pos[0] < self._cd.start:
raise IndexError(f'Codon at site {variant.CHROM}:{variant.POS} overlaps with '
f'start position of current CDS and no previous CDS was given.')
# Use final positions from previous coding sequence if current codon
# starts before start position of current coding sequence
if codon_pos[1] == self._cd.start:
codon_pos[0] = self._cd_prev.end if self._cd_prev.strand == '+' else self._cd_prev.start
elif codon_pos[2] == self._cd.start:
codon_pos[1] = self._cd_prev.end if self._cd_prev.strand == '+' else self._cd_prev.start
codon_pos[0] = self._cd_prev.end - 1 if self._cd_prev.strand == '+' else self._cd_prev.start + 1
if (self._cd_next is None or self._cd_next.start == self._pos_mock) and codon_pos[2] > self._cd.end:
raise IndexError(f'Codon at site {variant.CHROM}:{variant.POS} overlaps with '
f'end position of current CDS and no subsequent CDS was given.')
# use initial positions from subsequent coding sequence if current codon
# ends before end position of current coding sequence
if codon_pos[2] == self._cd.end + 1:
codon_pos[2] = self._cd_next.start if self._cd_next.strand == '+' else self._cd_next.end
elif codon_pos[1] == self._cd.end + 1:
codon_pos[1] = self._cd_next.start if self._cd_next.strand == '+' else self._cd_next.end
codon_pos[2] = self._cd_next.start + 1 if self._cd_next.strand == '+' else self._cd_next.end - 1
# seq uses 0-based positions
codon = ''.join([str(self._contig[int(pos - 1)]) for pos in codon_pos]).upper()
return codon, codon_pos, codon_start, pos_codon, pos_rel
def _parse_codon_backward(self, variant: Union['cyvcf2.Variant', DummyVariant]):
"""
Parse the codon in reverse direction.
:param variant: The variant.
:return: Codon, Codon position, Codon start position, Position within codon, and relative position.
"""
# position relative to end of coding sequence
pos_rel = (self._cd.end - int(self._cd.phase)) - variant.POS
# position relative to codon end
pos_codon = pos_rel % 3
# inclusive codon start, 1-based
codon_start = variant.POS + pos_codon
# the codon positions
codon_pos = [codon_start, codon_start - 1, codon_start - 2]
if (self._cd_prev is None or self._cd_next.start == self._pos_mock) and codon_pos[2] < self._cd.start:
raise IndexError(f'Codon at site {variant.CHROM}:{variant.POS} overlaps with '
f'start position of current CDS and no previous CDS was given.')
# Use final positions from previous coding sequence if current codon
# ends before start position of current coding sequence.
if codon_pos[1] == self._cd.start:
codon_pos[2] = self._cd_prev.end if self._cd_prev.strand == '-' else self._cd_prev.start
elif codon_pos[0] == self._cd.start:
codon_pos[1] = self._cd_prev.end if self._cd_prev.strand == '-' else self._cd_prev.start
codon_pos[2] = self._cd_prev.end - 1 if self._cd_prev.strand == '-' else self._cd_prev.start + 1
if (self._cd_next is None or self._cd_next.start == self._pos_mock) and codon_pos[0] > self._cd.end:
raise IndexError(f'Codon at site {variant.CHROM}:{variant.POS} overlaps with '
f'end position of current CDS and no subsequent CDS was given.')
# use initial positions from subsequent coding sequence if current codon
# starts before end position of current coding sequence
if codon_pos[0] == self._cd.end + 1:
codon_pos[0] = self._cd_next.start if self._cd_next.strand == '-' else self._cd_next.end
elif codon_pos[1] == self._cd.end + 1:
codon_pos[1] = self._cd_next.start if self._cd_next.strand == '-' else self._cd_next.end
codon_pos[0] = self._cd_next.start + 1 if self._cd_next.strand == '-' else self._cd_next.end - 1
# we use 0-based positions here
codon = ''.join(str(self._contig[int(pos - 1)]) for pos in codon_pos)
# take complement and convert to uppercase ('n' might be lowercase)
codon = str(Seq(codon).complement()).upper()
return codon, codon_pos, codon_start, pos_codon, pos_rel
def _parse_codon(self, variant: Union['cyvcf2.Variant', DummyVariant]):
"""
Parse the codon for the given variant.
:param variant: The variant to parse the codon for.
:return: Codon, Codon position, Codon start position, Position within codon, and relative position.
"""
if self._cd.strand == '+':
return self._parse_codon_forward(variant)
return self._parse_codon_backward(variant)
@staticmethod
def _get_degeneracy(codon: str, pos: int) -> int:
"""
Translate codon into amino acid.
:param codon: The codon
:param pos: The position of the variant in the codon, 0, 1, or 2
:return: The degeneracy of the codon, 0, 2, or 4
"""
amino_acid = codon_table[codon]
# convert to list of characters
codon = list(codon)
# get the alternative bases
alt = []
for b in bases[bases != codon[pos]]:
codon[pos] = b
alt.append(codon_table[''.join(codon)])
# noinspection PyTypeChecker
return unique_to_degeneracy[sum(amino_acid == np.array(alt))]
@staticmethod
def _get_degeneracy_table() -> Dict[str, str]:
"""
Create codon degeneracy table.
:return: dictionary mapping codons to degeneracy
"""
codon_degeneracy = {}
for codon in product(bases, repeat=3):
codon = ''.join(codon)
codon_degeneracy[codon] = ''.join(
[str(DegeneracyAnnotation._get_degeneracy(codon, pos)) for pos in range(0, 3)]
)
return codon_degeneracy
def _fetch_cds(self, v: Union['cyvcf2.Variant', DummyVariant]):
"""
Fetch the coding sequence for the given variant.
:param v: The variant to fetch the coding sequence for.
:raises LookupError: If no coding sequence was found.
"""
# get the aliases for the current chromosome
aliases = self._handler.get_aliases(v.CHROM)
# only fetch coding sequence if we are on a new chromosome or the
# variant is not within the current coding sequence
if self._cd is None or self._cd.seqid not in aliases or v.POS > self._cd.end:
# reset coding sequences to mocking positions
self._cd_prev = None
self._cd = pd.Series({'seqid': v.CHROM, 'start': self._pos_mock, 'end': self._pos_mock})
self._cd_next = pd.Series({'seqid': v.CHROM, 'start': self._pos_mock, 'end': self._pos_mock})
# filter for the current chromosome
on_contig = self._handler._cds[(self._handler._cds.seqid.isin(aliases))]
# filter for positions ending after the variant
cds = on_contig[(on_contig.end >= v.POS)]
if not cds.empty:
# take the first coding sequence
self._cd = cds.iloc[0]
self._logger.debug(f'Found coding sequence: {self._cd.seqid}:{self._cd.start}-{self._cd.end}, '
f'reminder: {(self._cd.end - self._cd.start + 1) % 3}, '
f'phase: {int(self._cd.phase)}, orientation: {self._cd.strand}, '
f'current position: {v.CHROM}:{v.POS}')
# filter for positions ending after the current coding sequence
cds = on_contig[(on_contig.start > self._cd.end)]
if not cds.empty:
# take the first coding sequence
self._cd_next = cds.iloc[0]
# filter for positions starting before the current coding sequence
cds = on_contig[(on_contig.end < self._cd.start)]
if not cds.empty:
# take the last coding sequence
self._cd_prev = cds.iloc[-1]
if self._cd.start == self._pos_mock and self.n_annotated == 0:
self._logger.warning(f"No coding sequence found on all of contig '{v.CHROM}' and no previous "
f'sites were annotated. Are you sure that this is the correct GFF file '
f'and that the contig names match the chromosome names in the VCF file? '
f'Note that you can also specify aliases for contig names in the VCF file.')
# check if variant is located within coding sequence
if self._cd is None or not (self._cd.start <= v.POS <= self._cd.end):
raise LookupError(f"No coding sequence found, skipping record {v.CHROM}:{v.POS}")
def _fetch_contig(self, v: Union['cyvcf2.Variant', DummyVariant]):
"""
Fetch the contig for the given variant.
:param v: The variant to fetch the contig for.
"""
aliases = self._handler.get_aliases(v.CHROM)
# check if contig is up-to-date
if self._contig is None or self._contig.id not in aliases:
self._logger.debug(f"Fetching contig '{v.CHROM}'.")
# fetch contig
self._contig = self._handler.get_contig(aliases)
def _fetch(self, variant: Union['cyvcf2.Variant', DummyVariant]):
"""
Fetch all required data for the given variant.
:param variant: The variant to fetch the data for.
:raises LookupError: if some data could not be found.
"""
self._fetch_cds(variant)
try:
self._fetch_contig(variant)
except LookupError:
# log error as this should not happen
self._logger.warning(f"Could not fetch contig '{variant.CHROM}'.")
raise
[docs]
def annotate_site(self, v: Union['cyvcf2.Variant', DummyVariant]):
"""
Annotate a single site.
:param v: The variant to annotate.
"""
v.INFO['Degeneracy'] = '.'
try:
self._fetch(v)
except LookupError:
self.n_skipped += 1
return
# skip locus if not a single site
if len(v.REF) != 1:
self.n_skipped += 1
return
# annotate if record is in coding sequence
if self._cd.seqid in self._handler.get_aliases(v.CHROM) and self._cd.start <= v.POS <= self._cd.end:
try:
# parse codon
codon, codon_pos, codon_start, pos_codon, pos_rel = self._parse_codon(v)
except IndexError as e:
# skip site on IndexError
self._logger.warning(e)
self.errors.append(v)
return
# make sure the reference allele matches with the position on the reference genome
if str(self._contig[v.POS - 1]).upper() != v.REF.upper():
self._logger.warning(
f"Reference allele does not match with reference genome at {v.CHROM}:{v.POS}, skipping site."
)
self.mismatches.append(v)
return
degeneracy = '.'
if 'N' not in codon:
degeneracy = self._get_degeneracy(codon, pos_codon)
# increment counter of annotated sites
self.n_annotated += 1
v.INFO['Degeneracy'] = degeneracy
v.INFO['Degeneracy_Info'] = f"{pos_codon},{self._cd.strand},{codon}"
self._logger.debug(f'pos codon: {pos_codon}, pos abs: {v.POS}, '
f'codon start: {codon_start}, codon: {codon}, '
f'strand: {self._cd.strand}, ref allele: {self._contig[v.POS - 1]}, '
f'degeneracy: {degeneracy}, codon pos: {str(codon_pos)}, '
f'ref allele: {v.REF}')
[docs]
class SynonymyAnnotation(DegeneracyAnnotation):
"""
Synonymy annotation. This class annotates a variant with the synonymous/non-synonymous status.
This annotation adds the info fields ``Synonymous`` and ``Synonymous_Info``, which hold
the synonymy (Synonymous [0] or non-synonymous [1]) and the codon information, respectively.
To be used with :class:`~fastdfe.parser.SynonymyStratification`.
For this annotation to work, we require a FASTA and GFF file (passed to :class:`~fastdfe.parser.Parser` or
:class:`~fastdfe.annotation.Annotator`).
This class was tested against `VEP <VEP_>`_ and `SnpEff <SnpEff_>`_ and provides the same annotations in almost
all cases.
.. _VEP: https://www.ensembl.org/info/docs/tools/vep/index.html
.. _SnpEff: https://pcingola.github.io/SnpEff/
.. warning::
Not recommended for use with :class:`~fastdfe.parser.Parser` as we also need to annotate mono-allelic sites.
Consider using :class:`~fastdfe.annotation.DegeneracyAnnotation` and
:class:`~fastdfe.parser.DegeneracyStratification` instead.
"""
[docs]
def __init__(self):
"""
Create a new annotation instance.
"""
super().__init__()
#: The number of sites that did not match with VEP.
self.vep_mismatches: List['cyvcf2.Variant'] = []
#: The number of sites that did not math with the annotation provided by SnpEff
self.snpeff_mismatches: List['cyvcf2.Variant'] = []
#: The number of sites that were concordant with VEP.
self.n_vep_comparisons: int = 0
#: The number of sites that were concordant with SnpEff.
self.n_snpeff_comparisons: int = 0
def _setup(self, handler: MultiHandler):
"""
Provide context to the annotator.
:param handler: The handler.
"""
# require FASTA and GFF files
handler._require_fasta(self.__class__.__name__)
handler._require_gff(self.__class__.__name__)
Annotation._setup(self, handler)
# touch the cached properties to make for a nicer logging experience
# noinspection PyStatementEffect
self._handler._cds
# noinspection PyStatementEffect
self._handler._ref
handler._reader.add_info_to_header({
'ID': 'Synonymy',
'Number': '.',
'Type': 'Integer',
'Description': 'Synonymous (0) or non-synonymous (1)'
})
handler._reader.add_info_to_header({
'ID': 'Synonymy_Info',
'Number': '.',
'Type': 'String',
'Description': 'Alt codon and extra information'
})
def _get_alt_allele(self, variant: Union['cyvcf2.Variant', DummyVariant]) -> str | None:
"""
Get the alternative allele.
:param variant: The variant to get the alternative allele for.
:return: The alternative allele or None if there is no alternative allele.
"""
if len(variant.ALT) > 0:
# assume there is at most one alternative allele
if self._cd.strand == '-':
return Seq(variant.ALT[0]).complement().__str__()
return variant.ALT[0]
[docs]
@staticmethod
def mutate(codon: str, alt: str, pos: int) -> str:
"""
Mutate the codon at the given position to the given alternative allele.
:param codon: The codon to mutate.
:param alt: The alternative allele.
:param pos: The position to mutate.
:return: Mutated codon.
"""
return codon[0:pos] + alt + codon[pos + 1:]
[docs]
@staticmethod
def is_synonymous(codon1: str, codon2: str) -> bool:
"""
Check if two codons are synonymous.
:param codon1: The first codon.
:param codon2: The second codon.
:return: True if the codons are synonymous, False otherwise.
"""
# handle case where there are stop codons
if codon1 in stop_codons or codon2 in stop_codons:
return codon1 in stop_codons and codon2 in stop_codons
return codon_table[codon1] == codon_table[codon2]
def _parse_codons_vep(self, variant: Union['cyvcf2.Variant', DummyVariant]) -> List[str]:
"""
Parse the codons from the VEP annotation if present.
:param variant: The variant.
:return: The codons.
"""
# match codons
match = re.search("([actgACTG]{3})/([actgACTG]{3})", variant.INFO.get('CSQ'))
if match is not None:
if len(match.groups()) != 2:
self._logger.info(f'VEP annotation has more than two codons: {variant.INFO.get("CSQ")}')
return [m.upper() for m in [match[1], match[2]]]
return []
@staticmethod
def _parse_synonymy_snpeff(variant: Union['cyvcf2.Variant', DummyVariant]) -> int | None:
"""
Parse the synonymy from the annotation provided by SnpEff
:param variant: The variant.
:return: The codons.
"""
ann = variant.INFO.get('ANN')
if 'synonymous_variant' in ann:
return 1
if 'missense_variant' in ann:
return 0
def _teardown(self):
"""
Finalize the annotation.
"""
super()._teardown()
if self.n_vep_comparisons != 0:
self._logger.info(f'Number of mismatches with VEP: {len(self.vep_mismatches)}')
if self.n_snpeff_comparisons != 0:
self._logger.info(f'Number of mismatches with SnpEff: {len(self.snpeff_mismatches)}')
[docs]
def annotate_site(self, v: Union['cyvcf2.Variant', DummyVariant]):
"""
Annotate a single site.
:param v: The variant to annotate.
:return: The annotated variant.
"""
v.INFO['Synonymy'] = '.'
if v.is_snp:
try:
self._fetch(v)
except LookupError:
self.n_skipped += 1
return
# annotate if record is in coding sequence
if self._cd.start <= v.POS <= self._cd.end:
try:
# parse codon
codon, codon_pos, codon_start, pos_codon, pos_rel = self._parse_codon(v)
except IndexError as e:
# skip site on IndexError
self._logger.warning(e)
self.errors.append(v)
return
# make sure the reference allele matches with the position in the reference genome
if str(self._contig[v.POS - 1]).upper() != v.REF.upper():
self._logger.warning(
f"Reference allele does not match with reference genome at {v.CHROM}:{v.POS}, skipping site."
)
self.mismatches.append(v)
return
# fetch the alternative allele if present
alt = self._get_alt_allele(v)
info = ''
synonymy = '.'
if alt is not None:
# alternative codon, 'n' might not be uppercase
alt_codon = self.mutate(codon, alt, pos_codon).upper()
# whether the alternative codon is synonymous
if 'N' not in codon and 'N' not in alt_codon:
synonymy = int(self.is_synonymous(codon, alt_codon))
# append alternative codon to info field
info += f'{codon}/{alt_codon}'
# check if the alternative codon is a start codon
if alt_codon in start_codons:
info += ',start_gained'
# check if the alternative codon is a stop codon
if alt_codon in stop_codons:
info += ',stop_gained'
if v.INFO.get('CSQ') is not None:
# fetch the codons from the VEP annotation
codons_vep = self._parse_codons_vep(v)
if len(codons_vep) > 0:
# increase number of comparisons
self.n_vep_comparisons += 1
# make sure the codons determined by VEP are the same as our codons.
if set(codons_vep) != {codon, alt_codon}:
self._logger.warning(f'VEP codons do not match with codons determined by '
f'codon table for {v.CHROM}:{v.POS}')
self.vep_mismatches.append(v)
return
if v.INFO.get('ANN') is not None:
synonymy_snpeff = self._parse_synonymy_snpeff(v)
self.n_snpeff_comparisons += 1
if synonymy_snpeff is not None:
if synonymy_snpeff != synonymy:
self._logger.warning(f'SnpEff annotation does not match with custom '
f'annotation for {v.CHROM}:{v.POS}')
self.snpeff_mismatches.append(v)
return
# increase number of annotated sites
self.n_annotated += 1
# add to info field
v.INFO['Synonymy'] = synonymy
v.INFO['Synonymy_Info'] = info
[docs]
class AncestralAlleleAnnotation(Annotation, ABC):
"""
Base class for ancestral allele annotation.
"""
def _setup(self, handler: MultiHandler):
"""
Add info fields to the header.
:param handler: The handler.
"""
super()._setup(handler)
handler._reader.add_info_to_header({
'ID': self._handler.info_ancestral,
'Number': '.',
'Type': 'Character',
'Description': 'Ancestral Allele'
})
[docs]
class MaximumParsimonyAncestralAnnotation(AncestralAlleleAnnotation):
"""
Annotation of ancestral alleles using maximum parsimony. To be used with
:class:`~fastdfe.parser.AncestralBaseStratification` and :class:`Annotator` or :class:`~fastdfe.parser.Parser`.
Note that maximum parsimony is not a reliable way to determine ancestral alleles, so it is recommended
to use this annotation together with the ancestral misidentification parameter ``eps`` or to fold
spectra altogether. Alternatively, you can use :class:`~fastdfe.filtration.DeviantOutgroupFiltration` to
filter out sites where the major allele among outgroups does not coincide with the major allele among ingroups.
This annotation has the advantage of requiring no outgroup data. If outgroup data is available consider using
:class`MLEAncestralAlleleAnnotation` instead.
"""
[docs]
def __init__(self, samples: List[str] = None):
"""
Create a new ancestral allele annotation instance.
:param samples: The samples to consider when determining the ancestral allele. If ``None``, all samples are
considered.
"""
super().__init__()
#: The samples to consider when determining the ancestral allele.
self.samples: List[str] | None = samples
self.samples_mask: np.ndarray | None = None
def _setup(self, handler: MultiHandler):
"""
Add info fields to the header.
:param handler: The handler.
"""
super()._setup(handler)
# create mask for ingroups
if self.samples is None:
self.samples_mask = np.ones(len(handler._reader.samples)).astype(bool)
else:
self.samples_mask = np.isin(handler._reader.samples, self.samples)
[docs]
def annotate_site(self, variant: Union['cyvcf2.Variant', DummyVariant]):
"""
Annotate a single site.
:param variant: The variant to annotate.
:return: The annotated variant.
"""
# assign the ancestral allele
variant.INFO[self._handler.info_ancestral] = self._get_ancestral(variant, self.samples_mask)
if variant.INFO[self._handler.info_ancestral] != '.':
# increase the number of annotated sites
self.n_annotated += 1
@staticmethod
def _get_ancestral(
variant: Union['cyvcf2.Variant', DummyVariant],
mask: np.ndarray,
) -> str:
"""
Get the ancestral allele for the given variant using maximum parsimony.
:param variant: The variant to annotate.
:param mask: The mask for the ingroups.
:return: The ancestral allele or '.' if it could not be determined.
"""
# take reference allele as ancestral if dummy variant
if isinstance(variant, DummyVariant):
return variant.REF
# take only base to be ancestral if we have a monomorphic snv
if not variant.is_snp and variant.REF in bases:
b = list(set(get_called_bases(variant.gt_bases[mask])))
if len(b) == 1 and b[0] in bases:
return b[0]
# take major base to be ancestral if we have an SNP
if variant.is_snp:
return get_major_base(variant.gt_bases[mask]) or '.'
return '.'
[docs]
class SubstitutionModel(ABC):
"""
Base class for substitution models.
"""
#: The possible transitions
_transitions: np.ndarray = np.array([
(base_indices['A'], base_indices['G']),
(base_indices['G'], base_indices['A']),
(base_indices['C'], base_indices['T']),
(base_indices['T'], base_indices['C'])
])
#: The possible transversions
_transversions: np.ndarray = np.array([
(base_indices['A'], base_indices['C']),
(base_indices['C'], base_indices['A']),
(base_indices['G'], base_indices['T']),
(base_indices['T'], base_indices['G']),
(base_indices['A'], base_indices['T']),
(base_indices['T'], base_indices['A']),
(base_indices['C'], base_indices['G']),
(base_indices['G'], base_indices['C'])
])
[docs]
def __init__(
self,
bounds: Dict[str, Tuple[float, float]] = {},
pool_branch_rates: bool = False,
fixed_params: Dict[str, float] = {}
):
"""
Create a new substitution model instance.
:param bounds: The bounds for the parameters.
:param pool_branch_rates: Whether to pool the branch rates. By default, each branch has its own rate which
is optimized using MLE. If ``True``, the branch rates are pooled and a single rate is optimized. This is
useful if the number of sites used is small.
:param fixed_params: The fixed parameters. Parameters that are not fixed are optimized using MLE.
"""
#: The logger.
self._logger = logging.getLogger(self.__class__.__name__)
# validate bounds
self.validate_bounds(bounds)
#: Whether to pool the branch rates.
self.pool_branch_rates: bool = pool_branch_rates
#: The fixed parameters.
self.fixed_params: Dict[str, float] = fixed_params.copy()
#: Parameter bounds.
self.bounds: Dict[str, Tuple[float, float]] = bounds.copy()
#: Cache for the probabilities.
self._cache: Dict[Tuple[int, int, int], float] | None = None
def _setup(self, ann: 'MaximumLikelihoodAncestralAnnotation'):
"""
Set up the substitution model.
:param ann: The ancestral allele annotation.
"""
pass
[docs]
def cache(self, params: Dict[str, float], n_branches: int):
"""
Cache the probabilities for the given parameters.
:param params: The parameters.
:param n_branches: The number of branches.
"""
self._cache = {}
for (b1, b2, i) in itertools.product(range(0, 4), range(0, 4), range(n_branches)):
self._cache[(b1, b2, i)] = self._get_prob(b1, b2, i, params)
[docs]
@staticmethod
def get_x0(bounds: Dict[str, Tuple[float, float]], rng: np.random.Generator) -> Dict[str, float]:
"""
Get the initial values for the parameters.
:param bounds: The bounds for the parameters.
:param rng: The random number generator.
:return: The initial values.
"""
x0 = {}
# draw initial values from a uniform distribution
for key, (lower, upper) in bounds.items():
x0[key] = rng.uniform(lower, upper)
return x0
[docs]
def get_bounds(self, n_outgroups: int) -> Dict[str, Tuple[float, float]]:
"""
Get the bounds for the parameters.
:param n_outgroups: The number of outgroups.
:return: The bounds.
"""
return self.bounds
[docs]
@staticmethod
def validate_bounds(bounds: Dict[str, Tuple[float, float]]):
"""
Make sure the lower bounds are positive and the upper bounds are larger than the lower bounds.
:param bounds: The bounds to validate
:raises ValueError: If the bounds are invalid
"""
for param, (lower, upper) in bounds.items():
if lower <= 0:
raise ValueError(f'All lower bounds must be positive, got {lower} for {param}.')
if lower > upper:
raise ValueError(f'Lower bounds must be smaller than upper bounds, got {lower} > {upper} for {param}.')
@abstractmethod
def _get_prob(self, b1: int, b2: int, i: int, params: Dict[str, float]) -> float:
"""
Get the probability of a branch using the substitution model.
:param b1: First nucleotide state.
:param b2: Second nucleotide state.
:param i: The index of the branch.
:param params: The parameters for the model.
:return: The probability of the branch.
"""
pass
def _get_cached_prob(self, b1: int, b2: int, i: int, params: Dict[str, float]) -> float:
"""
Get the probability of a branch using the substitution model with caching.
:param b1: First nucleotide state.
:param b2: Second nucleotide state.
:param i: The index of the branch.
:param params: The parameters for the model.
:return: The probability of the branch.
"""
if self._cache is None:
return self._get_prob(b1, b2, i, params)
# return cached value
return self._cache[(b1, b2, i)]
[docs]
class JCSubstitutionModel(SubstitutionModel):
"""
Jukes-Cantor substitution model.
"""
[docs]
def __init__(
self,
bounds: Dict[str, Tuple[float, float]] = {'K': (1e-5, 10)},
pool_branch_rates: bool = False,
fixed_params: Dict[str, float] = {}
):
"""
Create a new substitution model instance.
:param bounds: The bounds for the parameters. K is the branch rate.
:param pool_branch_rates: Whether to pool the branch rates. By default, each branch has its own rate which
is optimized using MLE. If ``True``, the branch rates are pooled and a single rate is optimized. This is
useful if the number of sites used is small. If ``False``, each branch has its own rate denoted by "K{i}",
where i is the branch index. If ``True``, the branch rate is denoted by "K".
:param fixed_params: The fixed parameters. Parameters that are not fixed are optimized using MLE.
"""
super().__init__(
bounds=bounds,
pool_branch_rates=pool_branch_rates,
fixed_params=fixed_params
)
[docs]
def get_bound(self, param: str) -> Tuple[float, float]:
"""
Get the bounds for a parameter.
:param param: The parameter.
:return: The lower and upper bounds.
"""
# check if the parameter is fixed
if param in self.fixed_params:
return self.fixed_params[param], self.fixed_params[param]
# return the bounds if they are defined
if param in self.bounds:
return self.bounds[param]
# attempt to get the bounds for the branch rates by removing the branch index
param_no_index = re.sub(pattern=r'\d', repl='', string=param)
return self.bounds[param_no_index]
[docs]
def get_bounds(self, n_outgroups: int) -> Dict[str, Tuple[float, float]]:
"""
Get the bounds for the parameters.
:param n_outgroups: The number of outgroups.
:return: The lower and upper bounds.
"""
if self.pool_branch_rates:
# pool the branch rates
return {'K': self.get_bound('K')}
# get the bounds for the branch lengths
return {f"K{i}": self.get_bound(f"K{i}") for i in range(2 * n_outgroups - 1)}
def _get_prob(self, b1: int, b2: int, i: int, params: Dict[str, float]) -> float:
"""
Get the probability of a branch using the substitution model.
:param b1: First nucleotide state.
:param b2: Second nucleotide state.
:param i: The index of the branch.
:param params: The parameters for the model.
:return: The probability of the branch.
"""
# evolutionary rate parameter for the branch
K = params['K'] if self.pool_branch_rates else params[f'K{i}']
if b1 == b2:
return np.exp(-K) + (1 / 6) * K ** 2 * np.exp(-K)
return (1 / 3) * K * np.exp(-K) + (1 / 9) * K ** 2 * np.exp(-K)
[docs]
class K2SubstitutionModel(JCSubstitutionModel):
"""
Kimura 2-parameter substitution model.
"""
[docs]
def __init__(
self,
bounds: Dict[str, Tuple[float, float]] = {'K': (1e-5, 10), 'k': (0.1, 10)},
pool_branch_rates: bool = False,
fixed_params: Dict[str, float] = {},
fix_transition_transversion_ratio: bool = False
):
"""
Create a new substitution model instance.
:param bounds: The bounds for the parameters. ``K{i}`` are the branch rates.
``k`` is the transition/transversion ratio.
:param pool_branch_rates: Whether to pool the branch rates. By default, each branch has its own rate which
is optimized using MLE. If ``True``, the branch rates are pooled and a single rate is optimized. This is
useful if the number of sites used is small.
:param fixed_params: The fixed parameters. Parameters that are not fixed are optimized using MLE.
:param fix_transition_transversion_ratio: Whether to fix the transition/transversion ratio to the ratio
observed in the data.
"""
super().__init__(
bounds=bounds,
pool_branch_rates=pool_branch_rates,
fixed_params=fixed_params
)
#: Whether to fix the transition/transversion ratio to the ratio observed in the data.
self.fix_transition_transversion_ratio: bool = fix_transition_transversion_ratio
def _setup(self, ann: 'MaximumLikelihoodAncestralAnnotation'):
"""
Set up the substitution model.
:param ann: The ancestral allele annotation.
"""
if self.fix_transition_transversion_ratio:
# fix the transition/transversion ratio to the ratio observed in the data
self.fixed_params['k'] = ann.get_observed_transition_transversion_ratio()
[docs]
def get_bounds(self, n_outgroups: int) -> Dict[str, Tuple[float, float]]:
"""
Get the bounds for the parameters.
:param n_outgroups: The number of outgroups.
:return: The lower and upper bounds.
"""
bounds = super().get_bounds(n_outgroups)
# add bounds for the transition/transversion ratio
bounds["k"] = self.get_bound("k")
return bounds
def _get_prob(self, b1: int, b2: int, i: int, params: Dict[str, float]) -> float:
"""
Get the probability of a branch using the K2 model.
:param b1: First nucleotide state.
:param b2: Second nucleotide state.
:param i: The index of the branch.
:param params: The parameters for the model.
:return: The probability of the branch.
"""
# evolutionary rate parameter for the branch
K = params['K'] if self.pool_branch_rates else params[f'K{i}']
# transition/transversion ratio
k = params["k"]
# if the ancestral and descendant nucleotide states are the same
if b1 == b2:
return np.exp(-K) * (1 + 0.5 * K ** 2 * (2 + k ** 2) / (k ** 2 + 4 * k + 4))
# if we have a transition
if (b1, b2) in [(0, 2), (2, 0), (1, 3), (3, 1)]:
return K * np.exp(-K) * (k / (k + 2) + K * 1 / (k ** 2 + 4 * k + 4))
# if we have a transversion
return K * np.exp(-K) * (1 / (k + 2) + K * k / (k ** 2 + 4 * k + 4))
[docs]
class SiteConfig:
"""
Ancestral allele site configuration for a single subsample.
"""
[docs]
def __init__(
self,
n_major: int,
major_base: int,
minor_base: int,
outgroup_bases: np.ndarray,
multiplicity: float = 1.0,
sites: np.ndarray = None,
p_minor: float = np.nan,
p_major: float = np.nan
):
"""
Create a new site configuration instance.
"""
#: The number of major alleles.
self.n_major: int = n_major
#: The major allele base index.
self.major_base: int = major_base
#: The minor base index.
self.minor_base: int = minor_base
#: The outgroup base indices.
self.outgroup_bases: np.ndarray = outgroup_bases
#: The multiplicity of the site.
self.multiplicity: float = multiplicity
#: The site indices.
self.sites: np.ndarray = np.array([]) if sites is None else sites
#: The probability of the minor allele.
self.p_minor: float = p_minor
#: The probability of the major allele.
self.p_major: float = p_major
[docs]
class SiteInfo:
"""
Ancestral allele information on a single site.
"""
[docs]
def __init__(
self,
n_major: Dict[int, float],
major_base: str,
minor_base: str,
outgroup_bases: List[str],
p_minor: float = np.nan,
p_major: float = np.nan,
p_major_ancestral: float = np.nan,
major_ancestral: str = '.',
p_bases_first_node: Dict[str, float] = None,
p_first_node_ancestral: float = np.nan,
first_node_ancestral: str = '.',
rate_params: Dict[str, float] = None
):
#: Dictionary mapping number of major alleles to its probability of observation.
self.n_major: Dict[int, float] = n_major
#: The major allele base.
self.major_base: str = major_base
#: The minor base index.
self.minor_base: str = minor_base
#: The outgroup base indices.
self.outgroup_bases: List[str] = outgroup_bases
#: The probability of the minor allele being the ancestral allele (without prior).
self.p_minor: float = p_minor
#: The probability of the major allele being the ancestral allele (without prior).
self.p_major: float = p_major
#: The probability of the major allele being the ancestral allele rather than the minor allele
#: (possibly with prior if specified).
self.p_major_ancestral: float = p_major_ancestral
#: The predicted ancestral base based on comparing major and minor allele.
self.major_ancestral: str = major_ancestral
#: The probability of each base being the ancestral base for the first internal node.
self.p_bases_first_node: Dict[str, float] = {} if p_bases_first_node is None else p_bases_first_node
#: The probability that the mostly likely base for the first internal node is the ancestral base.
self.p_first_node_ancestral: float = p_first_node_ancestral
#: The ancestral base index for the first internal node.
self.first_node_ancestral: str = first_node_ancestral
#: The branch rates.
self.rate_params: Dict[str, float] = {} if rate_params is None else rate_params
[docs]
def plot_tree(
self,
ax: 'plt.Axes' = None,
show: bool = True,
):
"""
Plot the tree for a site. Only Python visualization is supported.
:param self: The site information.
:param ax: Axes to plot on.
:param show: Whether to show the plot.
"""
import matplotlib.pyplot as plt
if ax is None:
ax = plt.gca()
if 'K' in self.rate_params:
branch_lengths = {f'K{i}': self.rate_params['K'] for i in range(len(self.outgroup_bases) * 2 - 1)}
else:
branch_lengths = self.rate_params
n_outgroups = len(self.outgroup_bases)
# Create major, minor, and ingroup clades
major_clade = Clade(name=self.major_base, branch_length=0)
minor_clade = Clade(name=self.minor_base, branch_length=0)
ingroup = Clade(
name="ingroup",
clades=[major_clade, minor_clade],
branch_length=branch_lengths['K0'] if n_outgroups > 0 else 0
)
current = ingroup
# Create and attach outgroup clades to major and minor clades
for i in range(n_outgroups):
# last outgroup has half the branch length to the root
# as we have no internal node
if i < n_outgroups - 1:
outgroup_length = branch_lengths[f"K{2 * i + 1}"]
else:
outgroup_length = branch_lengths[f"K{2 * i}"] / 2
# create outgroup clade
outgroup = Clade(
name=self.outgroup_bases[i],
branch_length=outgroup_length
)
# determine the branch length to the next node
if i < n_outgroups - 2:
node_length = branch_lengths[f"K{2 * i + 2}"]
elif i == n_outgroups - 2:
node_length = branch_lengths[f"K{2 * i + 2}"] / 2
else:
node_length = 0
# create internal node / root
current = Clade(
name=f"internal {i + 1}" if i < n_outgroups - 1 else None,
clades=[outgroup, current],
branch_length=node_length
)
# create a tree object and visualize
tree = Tree(root=current)
Phylo.draw(tree, axes=ax, do_show=False)
# remove Y-axis
ax.axes.get_yaxis().set_visible(False)
# remove frame
for pos in ['top', 'right', 'left']:
ax.spines[pos].set_visible(False)
if show:
plt.show()
class _TooFewIngroupsSiteError(ValueError):
"""
Raised when there are too few ingroups to consider a site for ancestral allele annotation.
"""
pass
class _PolyAllelicSiteError(ValueError):
"""
Raised when a site has more than two alleles.
"""
pass
[docs]
class BaseType(Enum):
"""
The base type, either major or minor.
"""
MINOR: int = 0
MAJOR: int = 1
[docs]
class PolarizationPrior(ABC):
"""
Base class for priors used with :class:MaximumLikelihoodAncestralAnnotation.
These priors incorporate information about the general probability of the major allele being ancestral
across all sites with the same minor allele count. Prior thus take ingroup allele frequencies into account,
when making predictions about the ancestral state of a site. This is useful because it enhances ancestral
allele probability estimates, especially when outgroup information is unavailable for a particular site.
Knowing the likelihood of the major allele being ancestral in general allows for more informed estimations.
"""
[docs]
def __init__(self, allow_divergence: bool = False):
"""
Create a new instance.
:param allow_divergence: Whether to allow divergence. If ``True``, the probability of the minor allele
being ancestral, which is not contained in the ingroup subsample but rather in all specified ingroups
or among the outgroup, is taken to be the same as if it was present in the ingroup subsample with
frequency 1. This is a hack, but allows us to consider alleles that are not present in the ingroup
subsample.
.. warning:: Setting this to ``True`` greatly increases the probability of high-frequency derived alleles
which introduces a strong bias in the distribution of frequency counts, e.g., the SFS. Only use this
if you're interested in handling divergence counts, i.e., sites where the ingroup is mono-allelic.
"""
#: The logger.
self._logger = logger.getChild(self.__class__.__name__)
#: Whether to allow divergence.
self.allow_divergence: bool = allow_divergence
#: The polarization probabilities.
self.probabilities: np.ndarray | None = None
def _add_divergence(self):
"""
Add divergence to the polarization probabilities.
"""
# take divergence probabilities to be the same as alleles of frequency 1
if self.allow_divergence:
self.probabilities[0] = self.probabilities[1]
self.probabilities[-1] = self.probabilities[-2]
else:
# set divergence probabilities to 0
self.probabilities[0] = 1
self.probabilities[-1] = 0
@abstractmethod
def _get_prior(self, configs: pd.DataFrame, n_ingroups: int) -> np.ndarray:
"""
Get the polarization probabilities.
:param configs: The site configurations.
:param n_ingroups: The number of ingroups.
"""
pass
[docs]
def plot(
self,
file: str = None,
show: bool = True,
title: str = 'polarization probabilities',
scale: Literal['lin', 'log'] = 'lin',
ax: 'plt.Axes' = None,
ylabel: str = 'p'
) -> 'plt.Axes':
"""
Visualize the polarization probabilities using a scatter plot.
:param scale: y-scale of the plot.
:param title: Plot title.
:param file: File to save plot to.
:param show: Whether to show plot.
:param ax: Axes to plot on. Only for Python visualization backend.
:param ylabel: y-axis label.
:return: Axes object
"""
from .visualization import Visualization
if self.probabilities is None:
raise ValueError('Polarization probabilities have not been calculated yet.')
return Visualization.plot_scatter(
values=self.probabilities,
file=file,
show=show,
title=title,
scale=scale,
ax=ax,
ylabel=ylabel
)
[docs]
class KingmanPolarizationPrior(PolarizationPrior):
"""
Prior based on the standard Kingman coalescent. To be used with
:class:`MaximumLikelihoodAncestralAnnotation`.
"""
def _get_prior(self, configs: pd.DataFrame, n_ingroups: int) -> np.ndarray:
"""
Get the polarization probabilities.
:param configs: The site configurations.
:param n_ingroups: The number of ingroups.
"""
self.probabilities = np.zeros(n_ingroups + 1)
# calculate polarization probabilities
for i in range(1, n_ingroups):
self.probabilities[i] = 1 / i / (1 / i + 1 / (n_ingroups - i))
# add divergence probabilities
self._add_divergence()
return self.probabilities
[docs]
class AdaptivePolarizationPrior(PolarizationPrior):
"""
Adaptive prior. To be used with :class:`MaximumLikelihoodAncestralAnnotation`. This is the
same prior as used in the EST-SFS paper. This prior is adaptive in the sense that the most likely polarization
probabilities given the site configurations are found. This is the most accurate prior, but requires a lot of
sites in order to work properly. You can check that the polarization probabilities are smooth enough across
frequency counts by calling :meth:`~fastdfe.annotation.PolarizationPrior.plot`. If they are not
smooth enough, you can increase the number of sites, decrease the number of ingroups, or
use :class:`~fastdfe.annotation.KingmanPolarizationPrior` instead.
.. note:: In practice, this prior provides very similar results to :class:`KingmanPolarizationPrior` in most cases.
"""
[docs]
def __init__(
self,
n_runs: int = 1,
parallelize: bool = True,
allow_divergence: bool = False,
seed: int | None = 0
):
"""
Create a new adaptive prior instance.
:param n_runs: The number of runs to perform when determining the polarization parameters. One
run should be sufficient as only one parameter is optimized.
:param parallelize: Whether to parallelize the optimization.
:param allow_divergence: Whether to allow divergence. See :class:`PolarizationPrior` for details.
:param seed: The seed for the random number generator.
"""
super().__init__(allow_divergence=allow_divergence)
#: The number of runs to use for the adaptive prior.
self.n_runs: int = n_runs
#: Whether to parallelize the optimization.
self.parallelize: bool = parallelize
#: The seed for the random number generator.
self.seed: int | None = seed
#: The random number generator.
self.rng: np.random.Generator = np.random.default_rng(seed=self.seed)
def _get_prior(
self,
configs: pd.DataFrame,
n_ingroups: int
) -> np.ndarray:
"""
Get the polarization probabilities.
:param configs: The site configurations.
:param n_ingroups: The number of ingroups.
:return: The polarization probabilities.
"""
from scipy.optimize import minimize, OptimizeResult
# folded frequency bin indices
# if the number of polymorphic bins is odd, the middle bin is fixed
freq_indices = range(1, (n_ingroups + 1) // 2)
# get the likelihood functions
funcs = dict((i, self._get_likelihood(i, configs, n_ingroups)) for i in freq_indices)
def optimize_polarization(args: List[Any]) -> OptimizeResult:
"""
Optimize the likelihood function for a single run.
:param args: The arguments.
:return: The optimization results.
"""
# unpack arguments
i, _, x0 = args
# optimize using scipy
return minimize(
fun=funcs[int(i)],
x0=np.array([x0]),
bounds=[(0, 1)],
method="L-BFGS-B"
)
# prepare arguments
data = np.array(list(itertools.product(freq_indices, range(self.n_runs))))
# get initial values
initial_values = np.array([self.rng.uniform() for _ in range(data.shape[0])])
# add initial values
data = np.hstack((data, initial_values.reshape((-1, 1))))
# run the optimization in parallel for each frequency bin over n_runs
results: np.ndarray = parallelize_func(
func=optimize_polarization,
data=data,
parallelize=self.parallelize,
pbar=True,
desc=f"{self.__class__.__name__}>Optimizing polarization priors",
dtype=object
).reshape(len(freq_indices), self.n_runs)
# get the likelihoods for each run and frequency bin
likelihoods = np.vectorize(lambda r: r.fun)(results)
# choose run with the best likelihood for each frequency bin
i_best = likelihoods.argmin(axis=1)
# check for successful optimization
if not np.all(np.vectorize(lambda r: r.success)(results[:, i_best])):
# get the failure messages
failures = [r.message for r in results[:, i_best] if not r.success]
# raise an error
raise RuntimeError("Polarization probability optimizations failed with messages: " + ", ".join(failures))
# get the probabilities for each frequency bin
# noinspection all
probs = np.array([results[i, j].x[0] for i, j in enumerate(i_best)])
# check for zeros or ones
if np.any(probs == 0) or np.any(probs == 1):
# get the number of bad frequency bins
n_bad = np.sum((probs == 0) | (probs == 1))
self._logger.fatal(f"Polarization probabilities are 0 for {n_bad} frequency bins which "
f"can be a real problem as it means that there are no sites "
f"for those bins. This may be due to ``n_ingroups`` "
f"being too large, or the number of provided sites being very "
f"small. If you can't increase the number of sites or decrease "
f"``n_ingroups``, consider using a the Kingman prior instead.")
# if the number of ingroups is even
if n_ingroups % 2 == 0:
# noinspection all
self.probabilities = np.concatenate(([1], probs, [0.5], 1 - probs[::-1], [0]))
else:
# if the number of ingroups is odd
# noinspection all
self.probabilities = np.concatenate(([1], probs, 1 - probs[::-1], [0]))
# add divergence probabilities
self._add_divergence()
return self.probabilities
@staticmethod
def _get_likelihood(
i: int,
configs: pd.DataFrame,
n_ingroups: int
) -> Callable[[List[float]], float]:
"""
Get the likelihood function.
:param i: The ith frequency bin.
:param configs: The site configurations.
:param n_ingroups: The number of ingroups.
The likelihood function evaluated for the ith frequency bin.
"""
def compute_likelihood(params: List[float]) -> float:
"""
Compute the negative log likelihood of the parameters.
:param params: The probability of polarization.
:return: The negative log likelihood.
"""
# get the probability of polarization for the ith frequency bin
pi = params[0]
# mask for sites that have i minor alleles
i_minor = n_ingroups - configs.n_major == i
# weight the sites by the probability of polarization
p_configs = pi * configs.p_major[i_minor] + (1 - pi) * configs.p_minor[i_minor]
# return the negative log likelihood
return -(np.log(p_configs) * configs.multiplicity[i_minor]).sum()
return compute_likelihood
class _OutgroupAncestralAlleleAnnotation(AncestralAlleleAnnotation, ABC):
"""
Abstract class for annotation of ancestral alleles using outgroup information.
"""
def __init__(
self,
outgroups: List[str],
n_ingroups: int,
ingroups: List[str] | None = None,
exclude: List[str] = [],
seed: int | None = 0,
subsample_mode: Literal['random', 'probabilistic'] = 'random'
):
"""
Create a new ancestral allele annotation instance.
:param outgroups: The outgroup samples to consider when determining the ancestral allele. A list of
sample names as they appear in the VCF file.
:param n_ingroups: The minimum number of ingroups that must be present at a site for it to be considered
for ancestral allele inference.
:param ingroups: The ingroup samples to consider when determining the ancestral allele. A list of
sample names as they appear in the VCF file. If ``None``, all samples except the outgroups are
considered.
:param exclude: Samples to exclude from the ingroup. A list of sample names as they appear in the VCF file.
:param seed: The seed for the random number generator.
:param subsample_mode: The subsampling mode. Either 'random' or 'probabilistic'.
"""
# make sure the number of ingroups is at least 2
if n_ingroups < 2:
raise ValueError("The number of ingroups must be at least 2.")
# check subsample mode
if subsample_mode not in ['random', 'probabilistic']:
raise ValueError(f"Invalid subsample mode: {subsample_mode}")
super().__init__()
#: The ingroup samples to consider when determining the ancestral allele.
self.ingroups: List[str] | None = ingroups
#: The samples excluded from the ingroup.
self.exclude: List[str] = exclude
#: The outgroup samples to consider when determining the ancestral allele.
self.outgroups: List[str] = outgroups
#: The number of ingroups.
self.n_ingroups: int = int(n_ingroups)
#: The number of outgroups.
self.n_outgroups: int = len(outgroups)
#: The seed for the random number generator.
self.seed: int | None = seed
#: The subsampling mode.
self.subsample_mode: Literal['random', 'probabilistic'] = subsample_mode
#: The random number generator.
self.rng: np.random.Generator = np.random.default_rng(seed=self.seed)
#: The outgroup mask.
self._outgroup_mask: np.ndarray | None = None
#: The outgroup indices.
self._outgroup_indices: np.ndarray | None = None
#: The ingroup mask.
self._ingroup_mask: np.ndarray | None = None
#: 1-based positions of lowest and highest site position per contig (only when target_site_counter is used)
# noinspection PyTypeChecker
self._contig_bounds: Dict[str, Tuple[int, int]] = defaultdict(lambda: (np.inf, -np.inf))
def _prepare_masks(self, samples: List[str]):
"""
Prepare the masks for ingroups and outgroups.
:param samples: All samples.
"""
# create mask for ingroups
if self.ingroups is None:
self._ingroup_mask = ~ np.isin(samples, self.outgroups) & ~ np.isin(samples, self.exclude)
else:
self._ingroup_mask = np.isin(samples, self.ingroups) & ~ np.isin(samples, self.exclude)
# create mask for outgroups
self._outgroup_mask = np.isin(samples, self.outgroups)
# make sure all specified outgroups are present
if np.sum(self._outgroup_mask) != len(self.outgroups):
# get missing outgroups
missing = np.array(self.outgroups)[~np.isin(self.outgroups, samples)]
raise ValueError(f"The specified outgroups ({', '.join(missing)}) are not present in the VCF file.")
# outgroup indices
# we ignore the order when using the mask
self._outgroup_indices = np.array([samples.index(outgroup) for outgroup in self.outgroups])
# inform of the number of ingroups
self._logger.info(f"Subsampling {self.n_ingroups} ingroup haplotypes " +
("randomly " if self.subsample_mode == "random" else "probabilistically ") +
f"from {np.sum(self._ingroup_mask)} individuals in total.")
# inform on outgroup samples
self._logger.info(f"Using {np.sum(self._outgroup_mask)} outgroup samples ({', '.join(self.outgroups)}).")
def _setup(self, handler: MultiHandler):
"""
Add info fields to the header.
:param handler: The handler.
"""
super()._setup(handler)
# add AA info field
handler._reader.add_info_to_header({
'ID': self._handler.info_ancestral + '_info',
'Number': '.',
'Type': 'String',
'Description': 'Additional information about the ancestral allele.'
})
# add AA probability field
handler._reader.add_info_to_header({
'ID': self._handler.info_ancestral + '_prob',
'Number': '.',
'Type': 'Float',
'Description': 'Probability that the predicted ancestral allele is correct, as opposed to the other allele.'
})
# set reader
self._reader = self._handler.load_vcf()
# prepare masks
self._prepare_masks(handler._reader.samples)
@staticmethod
def _subsample(
genotypes: np.ndarray,
size: int,
rng: np.random.Generator
) -> np.ndarray:
"""
Subsample a set of bases.
:param genotypes: A list of bases.
:param size: The size of the subsample.
:return: A subsample of the bases.
"""
if genotypes.shape[0] == 0:
return np.array([])
subsamples = rng.choice(
a=genotypes.shape[0],
size=min(size, genotypes.shape[0]),
replace=False
)
return genotypes[subsamples]
@staticmethod
def _get_outgroup_bases(
genotypes: np.ndarray,
n_outgroups: int
) -> np.ndarray:
"""
Get the outgroup bases for a variant.
:param genotypes: The VCF genotype strings.
:param n_outgroups: The number of outgroups.
:return: The outgroup bases.
"""
outgroup_bases = np.full(n_outgroups, '.')
for i, genotype in enumerate(genotypes):
called_bases = get_called_bases([genotype])
if len(called_bases) > 0:
outgroup_bases[i] = called_bases[0]
return outgroup_bases
@classmethod
def _subsample_site(
cls,
mode: Literal['random', 'probabilistic'],
n: int,
samples: np.ndarray,
rng: np.random.Generator
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Subsample a site, either randomly or probabilistically.
:param mode: The subsampling mode.
:param n: The number of ingroups to subsample to.
:param samples: The samples.
:return: Major alleles, major allele counts, and multiplicities,
possibly including zero multiplicities.
"""
if mode == 'random':
# subsample ingroups
samples = cls._subsample(samples, size=n, rng=rng)
# get the major allele count
most_common = Counter(samples).most_common()
if mode == 'random':
major_alleles = [most_common[0][0]]
n_majors = [most_common[0][1]]
m = [1]
# if there is only one allele, probabilistic
# subsampling is trivial
elif len(most_common) < 2:
major_alleles = [most_common[0][0]]
n_majors = [n]
m = [1]
else:
ref_allele = most_common[0][0]
n_ref = most_common[0][1]
alt_allele = most_common[1][0]
# get the major allele counts
n_majors = np.arange(n + 1)
major_alleles = np.full(n + 1, ref_allele)
m = hypergeom.pmf(k=n_majors, M=len(samples), n=n_ref, N=n)
# flip alleles where the ref allele is not the major allele
flip = n_majors < (n + 1) // 2
n_majors[flip] = n - n_majors[flip]
major_alleles[flip] = alt_allele
return major_alleles, n_majors, m
def _parse_variant(self, variant: Union['cyvcf2.Variant', DummyVariant]) -> List[SiteConfig]:
"""
Parse a VCF variant. We only consider sites that are at most bi-allelic in the in- and outgroups.
:param variant: The variant.
:return: List of site configurations containing a single element if subsample_mode is ``random`` or
multiple elements if subsample_mode is ``probabilistic`` or ``None`` if the site is not valid.
:raises _TooFewIngroupsSiteError: If there are too few ingroups to consider a site for ancestral allele
annotation. _PolyAllelicSiteError: If a site has more than two alleles.
"""
# get the called ingroup bases
ingroups = get_called_bases(variant.gt_bases[self._ingroup_mask])
# get the numer of called ingroup and outgroup bases
n_ingroups = len(ingroups)
# make sure we have enough ingroups
if n_ingroups < self.n_ingroups:
raise _TooFewIngroupsSiteError()
# get the called outgroup bases
# the order does not matter here
outgroups = get_called_bases(variant.gt_bases[self._outgroup_mask])
# get total base counts
counts = Counter(np.concatenate((ingroups, outgroups)))
# make sure we have at most two alleles
if len(counts) > 2:
raise _PolyAllelicSiteError()
# get the bases
b: List[str] = list(counts.keys())
# subsample ingroups either randomly or probabilistically
major_alleles, n_majors, multiplicities = self._subsample_site(
mode=self.subsample_mode,
n=self.n_ingroups,
samples=ingroups,
rng=self.rng
)
# Get the outgroup bases.
# The outgroup order is important, so we can't use the mask here.
outgroup_bases = self.get_base_index(self._get_outgroup_bases(
genotypes=np.array([variant.gt_bases[i] for i in self._outgroup_indices]),
n_outgroups=self.n_outgroups
))
# create site configurations
sites = []
for i, (major_allele, n_major, multiplicity) in enumerate(zip(major_alleles, n_majors, multiplicities)):
if multiplicity > 0:
if len(counts) == 2:
# Take the other allele as the minor allele. We keep track of the minor allele
# even if it wasn't contained in the ingroup subsample.
minor_base: str = b[0] if b[0] != major_allele else b[1]
else:
minor_base: str = '.'
# create site configuration
site = SiteConfig(
major_base=base_indices[major_allele],
n_major=n_major,
minor_base=self.get_base_index(minor_base),
outgroup_bases=outgroup_bases,
multiplicity=multiplicity
)
sites.append(site)
return sites
@staticmethod
def get_base_string(indices: int | np.ndarray) -> str | np.ndarray:
"""
Get base string(s) from base index/indices.
:param indices: The base index/indices.
:return: Base string(s).
"""
if isinstance(indices, np.ndarray):
if len(indices) == 0:
return np.array([])
is_valid = indices != -1
base_strings = np.full(indices.shape, '.', dtype=str)
base_strings[is_valid] = bases[indices[is_valid]]
return base_strings
# assume integer
if indices != -1:
return bases[indices]
return '.'
@classmethod
def get_base_index(cls, base_string: str | np.ndarray) -> int | np.ndarray:
"""
Get base index/indices from base string(s).
:param base_string: The base string(s).
:return: Base index/indices.
"""
if isinstance(base_string, np.ndarray):
return np.array([cls.get_base_index(b) for b in base_string], dtype=int)
# assume string
if base_string in bases:
return base_indices[base_string]
return -1
[docs]
class MaximumLikelihoodAncestralAnnotation(_OutgroupAncestralAlleleAnnotation):
"""
Annotation of ancestral alleles following the probabilistic model of EST-SFS
(https://doi.org/10.1534/genetics.118.301120). Note that only bi-allelic SNPs are supported by this model.
By default, the info field ``AA`` (see :attr:`Annotator.info_ancestral`) is added to the VCF file, which holds the
ancestral allele. To be used with :class:`Annotator` or :class:`~fastdfe.parser.Parser`.
The info field ``AA_prob`` holds the probability that the
predicted ancestral allele (``AA`` tag) is correct, as opposed to the other allele. This probability can be also
be used by :class:`~fastdfe.parser.Parser` to polarize the SFS according to the ancestral allele probability.
In addition to the ancestral allele, the info field ``AA_info`` is added, which contains additional information
about the ancestral allele (see :class:`SiteInfo` for an overview of the available information).
This class can also be used independently, see the :meth:`from_dataframe`, :meth:`from_data` and
:meth:`from_est_sfs` methods.
Initially, the branch rates are determined using MLE. Similar to :class:`Parser`, we can also specify the number of
mutational target sites (see the ``n_target_sites`` argument) in case our VCF file does not contain the full set of
monomorphic sites. This is necessary to obtain realistic branch rate estimates. You can also choose a prior for the
polarization probabilities (see :class:`PolarizationPrior`). Eventually, for every site, the probability that the
major allele is ancestral is calculated.
When annotating the variants of a VCF file, we check the most likely ancestral allele against a naive
ad-hoc ancestral allele annotation, and record the sites for which we have disagreement. You might want to
sanity-check the mismatches to make sure the model has been properly specified (see :attr:`mismatches`).
.. note::
* The polarization prior corresponds to the Kingman coalescent probability by default. Using an adaptive prior,
as in the EST-SFS paper, is also possible, but this is only recommended if the number of sites used for the
inference is large (see :attr:`prior`).
* The model can only handle sites that have at most 2 alleles across the in- and outgroups, so sites with more
than 2 alleles are ignored. Only variants that are at most bi-allelic in the provided in- and outgroups are
annotated.
* The model determines the probability of the major allele being ancestral opposed to the minor allele. This can
be problematic if the actual ancestral allele is not contained in the ingroup (possibly due to subsampling).
To avoid this issue, we also keep track of potential minor alleles at frequency 0. If we were to ignore this,
it would be impossible to infer divergence, i.e. fixed derived allele that are no longer observed in the
ingroups (see :attr:`PolarizationPrior.allow_divergence`). That said, divergence counts are not informative
on DFE inference with fastDFE and allow_divergence should not be set to ``True`` if interested in the SFS.
* The model assumes a single coalescent topology for all sites, in which all outgroups coalesce first with the
ingroup and not with each other. It is important to specify the outgroups in order of increasing divergence
and not to select outgroups that are not much more closely related to each other than to the ingroup (as this
would give rise to a different coalescent topology than the one assumed). You can call
:meth:`get_outgroup_divergence` after the inference to check the estimated branch rates for each outgroup.
The assumption of a single fixed topology should be good enough provided that in- and outgroups are
sufficiently diverged.
Example usage:
::
import fastdfe as fd
ann = fd.Annotator(
vcf="https://github.com/Sendrowski/fastDFE/"
"blob/dev/resources/genome/betula/all."
"with_outgroups.subset.10000.vcf.gz?raw=true",
annotations=[fd.MaximumLikelihoodAncestralAnnotation(
outgroups=["ERR2103730"],
n_ingroups=15
)],
output="genome.polarized.vcf.gz"
)
ann.annotate()
"""
#: The data types for the data frame
_dtypes = dict(
n_major=np.int8,
multiplicity=np.float64,
sites=object,
major_base=np.int8,
minor_base=np.int8,
outgroup_bases=object,
p_major_ancestral=np.float64,
p_minor=np.float64,
p_major=np.float64
)
#: The columns to group by.
_group_cols = ['major_base', 'minor_base', 'outgroup_bases', 'n_major']
[docs]
def __init__(
self,
outgroups: List[str],
n_ingroups: int = 11,
ingroups: List[str] | None = None,
exclude: List[str] | None = None,
n_runs: int = 10,
model: SubstitutionModel = None,
parallelize: bool = True,
prior: PolarizationPrior | None = '',
max_sites: int = 10000,
seed: int | None = 0,
confidence_threshold: float = 0,
n_target_sites: int | None = None,
n_samples_target_sites: int | None = 100000,
adjust_target_sites: bool = True,
subsample_mode: Literal['random', 'probabilistic'] = 'probabilistic'
):
"""
Create a new ancestral allele annotation instance.
:param outgroups: The outgroup samples to consider when determining the ancestral allele in the order of
increasing divergence. A list of sample names as they appear in the VCF file. The order of the outgroups
is important as it determines the order of the branches in the tree, whose rates are optimized, and whose
topology is predetermined. The first outgroup is the closest outgroup to the ingroups, and the last
outgroup is the most distant outgroup. More outgroups lead to a more accurate inference of the ancestral
allele, but also increase the computational cost. Using more than 1 outgroup is recommended, but more than
3 is likely not necessary. Sites where these outgroups are not present are not included when optimizing
the rate parameters. Due to assumptions on the tree topology connecting the in- and outgroups, it is
important that the outgroups are not much more closely related to each other than to the ingroups. Ideally,
the optimized branch rates are show markedly different values, and in any case, they should be monotonically
increasing with the outgroups (see :meth:`get_outgroup_divergence`).
:param n_ingroups: The minimum number of ingroups that must be present at a site for it to be considered
for ancestral allele inference. The ingroup subsampling is necessary since our model requires an equal
number of ingroups for all sites. Note that a larger number of ingroups does not necessarily improve
the accuracy of the ancestral allele inference (see ``prior``). A larger number of ingroups can lead
to a large variance in the polarization probabilities, across different frequency counts. ``n_ingroups``
should thus only be large if the number of sites used for the inference is also large. A sensible value
for a reasonably large number of sites (a few thousand) is 10 or perhaps 20 for a larger numbers of sites.
Very small values can lead to the ingroup subsamples not being representative of the actual allele
frequencies at a site, especially when not using probabilistic subsampling (see ``subsample_mode``).
This value also influences the number of frequency bins used for the polarization probabilities, and should
thus not be too small. Note that if ``ingroups`` is an even number, the major allele is chosen arbitrarily
if the number of major alleles is equal to the number of minor alleles. To avoid this, you can use an odd
number of ingroups.
:param ingroups: The ingroup samples to consider when determining the ancestral allele. If ``None``,
all (non-outgroup) samples are considered. A list of sample names as they appear in the VCF file.
Has to be at least as large as ``n_ingroups``.
:param exclude: Samples to exclude from the ingroup. A list of sample names as they appear in the VCF file.
:param n_runs: The number of optimization runs to perform when determining the branch rates. You can
check that the likelihoods of the different runs are similar by calling :meth:`plot_likelihoods`.
:param model: The substitution model to use. By default, :class:`K2SubstitutionModel` is used.
:param parallelize: Whether to parallelize the computation across multiple cores.
:param prior: The prior to use for the polarization probabilities. See
:class:`KingmanPolarizationPrior` and :class:`AdaptivePolarizationPrior` for more information. By default,
:class:`KingmanPolarizationPrior` is used. Use ``None`` for no prior.
:param max_sites: The maximum number of sites to consider. This is useful if the number of sites is very large.
Choosing a reasonably large subset of sites (on the order of a few thousand bi-allelic sites) can speed up
the computation considerably as parsing can be slow. This subset is then used to calibrate the rate
parameters, and possibly the polarization priors.
:param seed: The seed for the random number generator. If ``None``, a random seed is chosen. By default, the
seed is set to 0.
:param confidence_threshold: The confidence threshold for the ancestral allele annotation.
Only if the probability of the major allele being ancestral as opposed to
the minor allele is not within ``((1 - confidence_threshold) / 2, 1 - (1 - confidence_threshold) / 2)``,
the ancestral allele is annotated. This is useful to avoid annotating sites where the ancestral allele
state is not clear. Use values close to ``0`` to annotate as many sites as possible, and values close to
``1`` to annotate only sites where the ancestral allele state is very clear.
.. warning:: This threshold introduces a bias by excluding more sites with high-frequency derived alleles
and should thus be kept at ``0`` if the distribution of frequency counts is important, e.g., if the SFS
is to be determined.
:param n_target_sites: The total number of target sites if this class is used in conjunction with
:class:`Parser` or :class:`Annotator`. This is useful if the provided set of sites only
consists of bi-allelic sites. Specify here the total number of sites underlying the given dataset, i.e.,
both mono- and bi-allelic sites. Ignoring mono-allelic sites will lead to overestimation of the rate
parameters. For this to work, a FASTA file must be provided from which the mono-allelic sites can be
sampled. Sampling takes place between the variants of the last and first site on every contig considered
in the VCF file. Use ``None`` to disable this feature. Note that the number of target sites is automatically
carried over if not specified and this class is used together with :class:`Parser`. In order to use this
feature, you also need to specify a FASTA file to :class:`Parser` or :class:`Annotator`. Also note that
by default we extrapolate the number of mono-allelic sites to be sampled from the FASTA file based on the
ratio of sites with called outgroup bases parsed from the VCF file (``adjust_target_sites``).
:param n_samples_target_sites: The number of sites to sample from the FASTA file when determining the number of
target sites (``n_target_sites``). From this the total number of target sites is extrapolated.
:param adjust_target_sites: Whether to adjust the number of target sites based on the parsed VCF sites relative
to the total number of sites in the VCF. Defaults to ``True``.
:param subsample_mode: The subsampling mode. For ``random``, we draw once without replacement from the set of
all available ingroup genotypes per site. For ``probabilistic``, we integrate over the hypergeometric
distribution when parsing and computing the ancestral probabilities. Probabilistic subsampling requires a
bit more time, but produces much more stable results, while requiring far fewer sites, so it is highly
recommended.
"""
super().__init__(
ingroups=ingroups,
exclude=exclude,
outgroups=outgroups,
n_ingroups=n_ingroups,
seed=seed,
subsample_mode=subsample_mode
)
# check that we have at least one outgroup
if len(outgroups) < 1:
raise ValueError("Must specify at least one outgroup. If you do not have any outgroup "
"information, consider using MaximumParsimonyAncestralAnnotation instead.")
# check that we have enough ingroups specified if specified at all
if ingroups is not None and len(ingroups) * 2 < n_ingroups:
self._logger.warning("The number of specified ingroup samples is smaller than the "
"number of ingroups (assumed diploidy). Please make sure to "
"provide sufficiently many ingroups.")
# raise warning on bias
if confidence_threshold > 0:
self._logger.warning("Please be aware that a confidence threshold of greater than 0 biases the SFS "
"towards fewer high-frequency derived alleles.")
#: Whether to parallelize the computation.
self.parallelize: bool = parallelize
#: Maximum number of sites to consider
self.max_sites: int = max_sites
#: The confidence threshold for the ancestral allele annotation.
self.confidence_threshold: float = confidence_threshold
#: The prior to use for the polarization probabilities.
self.prior: PolarizationPrior | None = KingmanPolarizationPrior() if prior == '' else prior
#: Number of random ML starts when determining the rate parameters
self.n_runs: int = int(n_runs)
#: The substitution model.
self.model: SubstitutionModel = K2SubstitutionModel() if model is None else model
#: The VCF reader.
self._reader: 'cyvcf2.VCF' | None = None
#: The data frame holding all site configurations.
self.configs: pd.DataFrame | None = None
#: The probability of all sites per frequency bin.
self.p_bins: Dict[str, np.ndarray | None] = None
#: The total number of valid sites parsed (including sites not considered for ancestral allele inference).
self.n_sites: int | None = None
#: The parameter names in the order they are passed to the optimizer.
self.param_names: List[str] = list(self.model.get_bounds(self.n_outgroups).keys())
#: The log likelihoods for the different runs when optimizing the rate parameters.
self.likelihoods: np.ndarray | None = None
#: The best log likelihood when optimizing the rate parameters.
self.likelihood: float | None = None
#: Optimization result of the best run.
self.result: Optional['scipy.optimize.OptimizeResult'] = None
#: The MLE parameters.
self.params_mle: Dict[str, float] | None = None
#: The MLE parameters for all runs.
self.params_mle_runs: pd.DataFrame | None = None
#: Mismatches between the most likely ancestral allele and the ad-hoc ancestral allele.
# This is only computed when annotating a VCF file, and only contains the mismatches
# for sites that were actually annotated.
self.mismatches: List[SiteInfo] = []
#: The total number of target sites.
self.n_target_sites: int | None = n_target_sites
#: The number of sites to sample from the FASTA file when determining the number of target sites.
self.n_samples_target_sites: int = n_samples_target_sites
#: Whether to adjust the number of target sites based on the number of sites parsed from the VCF file.
self.adjust_target_sites: bool = adjust_target_sites
#: The monomorphic site counts sampled from the FASTA file.
self._monomorphic_samples: Dict[str, int] = {'A': 0, 'C': 0, 'G': 0, 'T': 0}
def _setup(self, handler: MultiHandler):
"""
Parse the VCF file and perform the optimization.
:param handler: The handler.
"""
from .parser import Parser, TargetSiteCounter
super()._setup(handler)
# try to carry over n_target_sites and fasta file from Parser
if isinstance(handler, Parser) and isinstance(handler.target_site_counter, TargetSiteCounter):
if self.n_target_sites is None:
self.n_target_sites = handler.target_site_counter.n_target_sites
self._logger.debug(f"Using n_target_sites={self.n_target_sites} from Parser.")
if self.n_target_sites is not None:
# check that we have a fasta file if we sample mono-allelic sites
handler._require_fasta(self.__class__.__name__)
# load data
self._parse_vcf()
# sample mono-allelic sites if necessary
if self.n_target_sites is not None:
self._sample_mono_allelic_sites()
# notify on statistics
self._log_stats()
# set up substitution model
self.model._setup(self)
# infer ancestral alleles
self.infer()
def _log_stats(self):
"""
Log statistics about the ancestral allele annotation.
"""
configs = self._get_mle_configs()
n_sites = int(np.round(configs.multiplicity.sum()))
n_monomorphic = int(np.round(configs[configs.minor_base == -1].multiplicity.sum()))
n_polymorphic = n_sites - n_monomorphic
self._logger.info(
f"Included {n_sites} sites for the inference ({n_polymorphic} polymorphic, {n_monomorphic} monomorphic)."
)
def _get_n_target_sites_adjusted(self) -> int:
"""
Get the number of target sites adjusted by the number of sites parsed.
This assumed that the sites have not been sampled yet.
"""
if self.adjust_target_sites:
ratio = self._get_mle_configs().multiplicity.sum() / self.n_sites
else:
ratio = 1
return int(ratio * (self.n_target_sites - self.n_sites))
def _get_n_sites(self) -> int:
"""
Get the number of sites to consider.
"""
return int(self.configs.multiplicity.sum())
def _sample_mono_allelic_sites(self):
"""
Sample mono-allelic sites from the FASTA file.
"""
# inform
self._logger.info(f"Sampling mono-allelic sites.")
if self.n_target_sites < self.n_sites:
raise ValueError(f"The number of target sites ({self.n_target_sites}) must be at least "
f"as large as the number of sites parsed ({self.n_sites}).")
# check that we have enough sites to sample
if self.n_samples_target_sites <= 0 or len(self._contig_bounds) == 0:
self._logger.info("No mono-allelic sites to sample, skipping.")
return
# initialize progress bar
pbar = tqdm(
total=self.n_samples_target_sites,
desc=f'{self.__class__.__name__}>Sampling mono-allelic sites',
disable=Settings.disable_pbar
)
# get array of ranges per contig of parsed variants
ranges = np.array(list(self._contig_bounds.values()))
# get range sizes
range_sizes = ranges[:, 1] - ranges[:, 0]
# determine sampling probabilities
probs = range_sizes / np.sum(range_sizes)
# sample number of sites per contig
sample_counts = self.rng.multinomial(self.n_samples_target_sites, probs)
# sampled bases
samples = dict(A=0, C=0, G=0, T=0)
# iterate over contigs
for contig, bounds, n in zip(self._contig_bounds.keys(), ranges, sample_counts):
# get aliases
aliases = self._handler.get_aliases(contig)
# make sure we have a valid range
if bounds[1] > bounds[0] and n > 0:
self._logger.debug(f"Sampling {n} sites from contig '{contig}'.")
# fetch contig
record = self._handler.get_contig(aliases, notify=False)
# sample sites
i = 0
while i < n:
pos = self.rng.integers(*bounds)
base = record.seq[pos - 1]
if base in bases:
# increase counters
samples[base] += 1
i += 1
pbar.update()
# close progress bar
pbar.close()
# rewind fasta iterator
FASTAHandler._rewind(self._handler)
n_target_sites = self._get_n_target_sites_adjusted()
if self.adjust_target_sites:
self._logger.info(f"Extrapolating to {n_target_sites} mutational target sites "
f"based on the number of sites parsed.")
else:
self._logger.info(f"Extrapolating to {self.n_target_sites} mutational target sites.")
# ratio for extrapolating to the total number of target sites
ratio = n_target_sites / self.n_samples_target_sites
# extrapolate the number of monomorphic sites
vec = np.array([samples[k] for k in bases], dtype=float)
scaled = vec * ratio
floors = np.floor(scaled).astype(int)
remainder = n_target_sites - int(floors.sum())
if remainder > 0:
order = np.argsort(scaled - floors)[::-1]
floors[order[:remainder]] += 1
self._monomorphic_samples = {k: int(v) for k, v in zip(bases, floors)}
# add monomorphic site counts to data frame
self.configs = self._add_monomorphic_sites(self._monomorphic_samples)
# update number of sites
self.n_sites = self._get_n_sites()
def _add_monomorphic_sites(self, samples: Dict[str, int]):
"""
Add monomorphic sites to the data frame holding the site configurations.
:param samples: The samples.
:return: The data frame.
"""
# get indices for new sites
sites = np.concatenate(([self.n_sites], self.n_sites + np.cumsum(list(samples.values()))))
# construct data frame of new sites
df = pd.DataFrame(dict(
n_major=self.n_ingroups,
major_base=base_indices[base],
minor_base=-1,
outgroup_bases=(base_indices[base],) * self.n_outgroups,
multiplicity=count,
sites=list(range(sites[i], sites[i + 1])),
n_outgroups=self.n_outgroups
) for i, (base, count) in enumerate(samples.items()))
# add to data frame
configs = pd.concat((self.configs, df))
# aggregate
return configs.groupby(self._group_cols + ['n_outgroups'], as_index=False, dropna=False).sum()
def _teardown(self):
"""
Teardown the annotation.
"""
super()._teardown()
# inform on mismatches
self._logger.info(f"There were {len(self.mismatches)} mismatches between the most likely "
f"ancestral allele and the ad-hoc ancestral allele annotation.")
@classmethod
def _parse_est_sfs(cls, data: pd.DataFrame) -> pd.DataFrame:
"""
Parse EST-SFS data.
:param data: The data frame.
:return: The site configurations.
"""
# extract the number of outgroups
n_outgroups = data.shape[1] - 1
# retain site index
data['sites'] = data.index
data['sites'] = data.sites.apply(lambda x: [x])
# the first column contains the ingroup counts, split them
ingroup_data = data[0].str.split(',', expand=True).astype(np.int8).to_numpy()
# determine the number of major alleles per site
data['n_major'] = ingroup_data.max(axis=1)
# sort by the number of alleles
data_sorted = ingroup_data.argsort(axis=1)
# determine the number of major alleles per site
data['major_base'] = data_sorted[:, -1]
data['major_base'] = data.major_base.astype(cls._dtypes['major_base'])
# determine the mono-allelic sites
poly_allelic = (ingroup_data > 0).sum(axis=1) > 1
# determine the minor alleles
minor_bases = np.full(data.shape[0], -1, dtype=np.int8)
minor_bases[poly_allelic] = data_sorted[:, -2][poly_allelic]
# assign the minor alleles
data['minor_base'] = minor_bases
# extract outgroup data
outgroup_data = np.full((data.shape[0], n_outgroups), -1, dtype=np.int8)
for i in range(n_outgroups):
# get the genotypes
genotypes = data[i + 1].str.split(',', expand=True).astype(np.int8).to_numpy()
# determine whether the site has an outgroup
has_outgroup = genotypes.sum(axis=1) > 0
# determine the outgroup allele indices provided the site has an outgroup
outgroup_data[has_outgroup, i] = genotypes[has_outgroup].argmax(axis=1)
# assign the outgroup data, convert to tuples for hashing
data['outgroup_bases'] = [tuple(row) for row in outgroup_data]
# return new columns only
return data.drop(range(n_outgroups + 1), axis=1)
[docs]
@classmethod
def from_est_sfs(
cls,
file: str,
prior: PolarizationPrior | None = '',
n_runs: int = 10,
model: SubstitutionModel = None,
parallelize: bool = True,
seed: int = 0,
chunk_size: int = 100000
) -> 'MaximumLikelihoodAncestralAnnotation':
"""
Create instance from EST-SFS input file.
:param file: File containing EST-SFS-formatted input data.
:param prior: The prior to use for the polarization probabilities (see :meth:`__init__`).
:param n_runs: Number of runs for rate estimation (see :meth:`__init__`).
:param model: The substitution model (see :meth:`__init__`).
:param parallelize: Whether to parallelize the runs (see :meth:`__init__`).
:param seed: The seed to use for the random number generator.
:param chunk_size: The chunk size for reading the file.
:return: The instance.
"""
# define an empty dataframe to accumulate the data
data = None
n_ingroups = 0
# iterate over the file in chunks
for i, chunk in enumerate(pd.read_csv(file, sep=r"\s+", header=None, dtype=str, chunksize=chunk_size)):
# determine the number of ingroups
n_ingroups = np.max(np.array(chunk.iloc[0, 0].split(','), dtype=int))
# parse the data
parsed = cls._parse_est_sfs(chunk)
if data is None:
# parse the data
data = parsed
else:
# concatenate with previous data if available
data = pd.concat([data, parsed])
data = data.groupby(cls._group_cols, as_index=False, dropna=False).sum()
# check if there is data
if data is None:
raise ValueError("No data found.")
# determine the multiplicity
data['multiplicity'] = data['sites'].apply(lambda x: len(x))
# create from dataframe
return cls.from_dataframe(
data=data,
n_runs=n_runs,
model=model,
parallelize=parallelize,
prior=prior,
n_ingroups=n_ingroups,
grouped=True,
seed=seed
)
[docs]
def to_est_sfs(self, file: str):
"""
Write the object's state to an EST-SFS formatted file.
:param file: The output file name.
"""
# get config indices for each site
indices = self._get_site_indices()
# remove sites that are not included
indices = indices[indices != -1]
# get the sites
sites = self.configs.iloc[indices]
with open(file, 'w') as f:
# iterate over rows
for i, site in sites.iterrows():
# ingroup counts
ingroups = np.zeros(4, dtype=int)
# major allele count
ingroups[site['major_base']] = site['n_major']
# minor allele count if not mono-allelic
if site['minor_base'] != -1:
ingroups[site['minor_base']] = self.n_ingroups - site['n_major']
# write ingroup counts
outgroups = np.zeros((self.n_outgroups, 4), dtype=int)
# fill outgroup counts
for j, base in enumerate(site['outgroup_bases']):
if base != -1:
outgroups[j, base] = 1
# write line
f.write(
','.join(ingroups.astype(str)) +
'\t' +
'\t'.join([','.join(o) for o in outgroups.astype(str)]) +
'\n'
)
# break if we reached the maximum number of sites
if i + 1 >= self.max_sites:
break
[docs]
def to_file(self, file: str):
"""
Save object to file.
.. note::
References to the handler and the reader are discarded.
:param file: File path.
"""
with open(file, 'w') as fh:
fh.write(self.to_json())
[docs]
@classmethod
def from_file(cls, file: str) -> 'MaximumLikelihoodAncestralAnnotation':
"""
Load object from file.
.. note::
The handler and the reader are not restored, so serialization is mostly useful
for analyzing the results in detail, not for further processing.
:param file: File path.
:return: The object.
"""
with open(file, 'r') as fh:
return cls.from_json(fh.read())
[docs]
def to_json(self) -> str:
"""
Serialize object.
.. note::
References to the handler and the reader are discarded.
:return: JSON string
"""
self._reader = None
self._handler = None
return jsonpickle.encode(self, indent=4, warn=True)
[docs]
@classmethod
def from_json(cls, json: str) -> 'MaximumLikelihoodAncestralAnnotation':
"""
Load object from file.
.. note::
The handler and the reader are not restored, so serialization is mostly useful
for analyzing the results in detail, not for further processing.
:param json: JSON string.
:return: The object.
"""
anc = jsonpickle.decode(json)
# convert index to int if necessary
if 'index' in anc.configs:
anc.configs.index = anc.configs.index.astype(int)
return anc
[docs]
@classmethod
def from_data(
cls,
n_major: Iterable[int],
major_base: Iterable[str | int],
minor_base: Iterable[str | int],
outgroup_bases: Iterable[Iterable[str | int]],
n_ingroups: int,
n_runs: int = 10,
model: SubstitutionModel = None,
parallelize: bool = True,
prior: PolarizationPrior | None = '',
seed: int = 0,
pass_indices: bool = False,
confidence_threshold: float = 0
) -> 'MaximumLikelihoodAncestralAnnotation':
"""
Create an instance by passing the data directly.
:param n_major: The number of major alleles per site. Note that this number has to be lower than ``n_ingroups``,
as we consider the number of major alleles of subsamples of size ``n_ingroups``.
:param major_base: The major allele per site. A string representation of the base or the base index according
to ``['A', 'C', 'G', 'T']`` if ``pass_indices`` is ``True``. Use ``None`` if the base is not defined when
``pass_indices`` is ``False`` and ``-1`` when ``pass_indices`` is ``True``.
:param minor_base: The minor allele per site. A string representation of the base or the base index according
to ``['A', 'C', 'G', 'T']`` if ``pass_indices`` is ``True``. Use ``None`` if the base is not defined when
``pass_indices`` is ``False`` and ``-1`` when ``pass_indices`` is ``True``.
:param outgroup_bases: The outgroup alleles per site. A string representation of the base or the base index
if ``pass_indices`` is ``True``. This should be a list of lists, where the outer list corresponds to the
sites and the inner list to the outgroups per site. All sites are required to have the same number of
outgroups. Use ``None`` if the base is not defined when ``pass_indices`` is ``False`` and ``-1`` when
``pass_indices`` is ``True``.
:param n_ingroups: The number of ingroup samples (see :meth:`__init__`).
:param n_runs: The number of runs for rate estimation (see :meth:`__init__`).
:param model: The substitution model (see :meth:`__init__`).
:param parallelize: Whether to parallelize the runs.
:param prior: The prior to use for the polarization probabilities (see :meth:`__init__`).
:param seed: The seed for the random number generator.
:param pass_indices: Whether to pass the base indices instead of the bases.
:param confidence_threshold: The confidence threshold for the ancestral allele annotation
(see :meth:`__init__`).
:return: The instance.
"""
# convert to numpy arrays
n_major = np.array(list(n_major), dtype=np.int8)
# make sure that the number of major alleles is not larger than the number of ingroups
if np.any(n_major > n_ingroups):
raise ValueError("Major allele counts cannot be larger than the number of ingroups.")
# convert to base indices
if not pass_indices:
major_base = cls.get_base_index(np.array(list(major_base)))
minor_base = cls.get_base_index(np.array(list(minor_base)))
outgroup_bases = cls.get_base_index(np.array(list(outgroup_bases))).reshape(len(major_base), -1)
# create data frame
data = pd.DataFrame({
'n_major': n_major,
'major_base': major_base,
'minor_base': minor_base,
'outgroup_bases': list(outgroup_bases)
})
# create from dataframe
return cls.from_dataframe(
data=data,
n_runs=n_runs,
model=model,
parallelize=parallelize,
prior=prior,
n_ingroups=n_ingroups,
seed=seed,
confidence_threshold=confidence_threshold
)
@classmethod
def _from_vcf(
cls,
file: str,
outgroups: List[str],
n_ingroups: int,
ingroups: List[str] = None,
exclude: List[str] = None,
n_runs: int = 10,
model: SubstitutionModel = K2SubstitutionModel(),
parallelize: bool = True,
prior: PolarizationPrior | None = KingmanPolarizationPrior(),
max_sites: int = np.inf,
seed: int | None = 0,
confidence_threshold: float = 0,
subsample_mode: Literal['random', 'probabilistic'] = 'probabilistic'
) -> 'MaximumLikelihoodAncestralAnnotation':
"""
Create an instance from a VCF file. In most cases, it is recommended to use the :class:`Annotator` or
:class:`~fastdfe.parser.Parser` classes instead.
:param file: The VCF file.
:param outgroups: Same as in :meth:`__init__`.
:param n_ingroups: Same as in :meth:`__init__`.
:param ingroups: Same as in :meth:`__init__`.
:param exclude: Same as in :meth:`__init__`.
:param n_runs: Same as in :meth:`__init__`.
:param model: Same as in :meth:`__init__`.
:param parallelize: Same as in :meth:`__init__`.
:param prior: Same as in :meth:`__init__`.
:param max_sites: Same as in :meth:`__init__`.
:param seed: Same as in :meth:`__init__`.
:param confidence_threshold: Same as in :meth:`__init__`.
:param subsample_mode: Same as in :meth:`__init__`.
:return: The instance.
"""
# create instance
anc = MaximumLikelihoodAncestralAnnotation(
outgroups=outgroups,
n_ingroups=n_ingroups,
ingroups=ingroups,
exclude=exclude,
n_runs=n_runs,
model=model,
parallelize=parallelize,
prior=prior,
max_sites=max_sites,
seed=seed,
confidence_threshold=confidence_threshold,
subsample_mode=subsample_mode
)
# set up the handler
super(cls, anc)._setup(MultiHandler(
vcf=file,
max_sites=max_sites,
seed=seed
))
# parse the variants
anc._parse_vcf()
return anc
[docs]
@classmethod
def from_dataframe(
cls,
data: pd.DataFrame,
n_ingroups: int,
n_runs: int = 10,
model: SubstitutionModel = None,
parallelize: bool = True,
prior: PolarizationPrior | None = '',
seed: int = 0,
grouped: bool = False,
confidence_threshold: float = 0
) -> 'MaximumLikelihoodAncestralAnnotation':
"""
Create an instance from a dataframe.
:param data: Dataframe with the columns: ``major_base``, ``minor_base``, ``outgroup_bases``, ``n_major`` of
type ``int``, ``int``, ``list`` and ``int``, respectively. The outgroup bases should have the same length
for every site.
:param n_ingroups: The number of ingroups (see :meth:`__init__`).
:param n_runs: Number of runs for rate estimation (see :meth:`__init__`).
:param model: The substitution model (see :meth:`__init__`).
:param parallelize: Whether to parallelize computations.
:param prior: The prior to use for the polarization probabilities (see :meth:`__init__`).
:param seed: The seed for the random number generator. If ``None``, a random seed is chosen.
:param grouped: Whether the dataframe is already grouped by all columns (used for internal purposes).
:param confidence_threshold: The confidence threshold for the ancestral allele annotation
(see :meth:`__init__`).
:return: The instance.
"""
# check if dataframe is empty
if data.empty:
raise ValueError("Empty dataframe.")
if not grouped:
# only keep the columns that are needed
data = data[cls._group_cols]
# disable chained assignment warning
with pd.option_context('mode.chained_assignment', None):
# retain site index
data['sites'] = data.index
# convert outgroup bases to tuples
data['outgroup_bases'] = data['outgroup_bases'].apply(tuple)
# group by all columns in the chunk and keep track of the site indices
data = data.groupby(cls._group_cols, as_index=False, dropna=False).agg(list).reset_index(drop=True)
# determine the multiplicity
data['multiplicity'] = data['sites'].apply(lambda x: len(x))
# add missing columns with NA as default value
for col in cls._dtypes:
if col not in data.columns:
data[col] = None
# convert to the correct dtypes
data = data.astype(cls._dtypes)
# determine the number of outgroups
data['n_outgroups'] = np.sum(np.array(data.outgroup_bases.to_list()) != -1, axis=1)
# determine the number of outgroups
n_outgroups = data.n_outgroups.max()
anc = MaximumLikelihoodAncestralAnnotation(
n_runs=n_runs,
model=model,
parallelize=parallelize,
prior=prior,
outgroups=[str(i) for i in range(n_outgroups)], # pseudo names for outgroups
ingroups=[str(i) for i in range(n_ingroups)], # pseudo names for ingroups
n_ingroups=n_ingroups,
seed=seed,
confidence_threshold=confidence_threshold,
subsample_mode='random'
)
# assign data frame
anc.configs = data
# convert outgroup bases to tuples of native integers
anc._convert_outgroup_bases_to_native_types()
# set the number of sites (which coincides with number of sites parsed)
anc.n_sites = anc._get_n_sites()
# notify on statistics
anc._log_stats()
# set up substitution model
anc.model._setup(anc)
return anc
def _convert_outgroup_bases_to_native_types(self):
"""
Convert outgroup bases to tuples of native integers
numpy types cause problems when serializing
"""
self.configs.outgroup_bases = self.configs.outgroup_bases.apply(lambda x: tuple(int(i) for i in x))
def _parse_vcf(self):
"""
Parse variants from VCF file.
"""
# initialize data frame
self.configs = pd.DataFrame(columns=list(self._dtypes.keys()))
self.configs.astype(self._dtypes)
# columns to use as index
index_cols = ['major_base', 'minor_base', 'outgroup_bases', 'n_major']
# set index to initial site configuration
self.configs.set_index(keys=index_cols, inplace=True)
# determine the total number of sites to be parsed
total = min(self._handler.n_sites, self.max_sites)
# initialize counter in case we do not parse any sites
i = -1
# create progress bar
with self._handler.get_pbar(desc=f"{self.__class__.__name__}>Parsing sites", total=total) as pbar:
# iterate over sites
for i, variant in enumerate(self._reader):
# parse the site
try:
configs = self._parse_variant(variant)
except (_PolyAllelicSiteError, _TooFewIngroupsSiteError):
pass
else:
if self.n_target_sites is not None:
# update bounds
low, high = self._contig_bounds[variant.CHROM]
self._contig_bounds[variant.CHROM] = (min(low, variant.POS), max(high, variant.POS))
for config in configs:
index = (
config.major_base,
config.minor_base,
tuple(config.outgroup_bases),
config.n_major
)
if index in self.configs.index:
# get the site data
site_data = self.configs.loc[index].to_dict()
# update the site data
site_data['multiplicity'] += config.multiplicity
site_data['sites'] += [i]
# update the site data
# Note that there were problems updating the data frame directly
self.configs.loc[index] = site_data
else:
self.configs.loc[index] = config.__dict__ | {'sites': [i]}
pbar.update()
# explicitly stopping after ``n`` sites fixes a bug with cyvcf2:
# 'error parsing variant with `htslib::bcf_read` error-code: 0 and ret: -2'
if i + 1 == self._handler.n_sites or i + 1 == self._handler.max_sites or i + 1 == self.max_sites:
break
# reset the index
self.configs.reset_index(inplace=True, names=index_cols)
# create column for number of outgroups
self.configs['n_outgroups'] = None
if len(self.configs) > 0:
# determine number of outgroups
self.configs['n_outgroups'] = np.sum(np.array(self.configs.outgroup_bases.to_list()) != -1, axis=1)
# convert outgroup bases to tuples of native integers
self._convert_outgroup_bases_to_native_types()
# total number of sites considered
self.n_sites = i + 1
[docs]
def infer(self):
"""
Infer the ancestral allele probabilities for the data provided. This method is only supposed to be called
manually if the data is provided directly, e.g. using :meth:`from_data`, :meth:`from_dataframe` or
:meth:`from_est_sfs`. If the data is provided using a VCF file, this method is called automatically.
"""
from scipy.optimize import minimize, OptimizeResult
# get the bounds
bounds = self.model.get_bounds(self.n_outgroups)
# get the likelihood function
# this will raise an error if no data is available
fun = self._get_likelihood()
# log warning if unusually low number of monomorphic sites
if self.configs[self.configs.minor_base == -1].multiplicity.sum() / self.n_sites < 0.95:
self._logger.warning("The number of monomorphic sites is unusually low. Please note that "
"including monomorphic sites is necessary to obtain realistic "
"branch rate estimates.")
if self.n_target_sites is None:
self._logger.warning("If your dataset does not contain any monomorphic sites, consider "
"using the `n_target_sites` argument.")
def optimize_rates(x0: Dict[str, float]) -> OptimizeResult:
"""
Optimize the likelihood function for a single run.
:param x0: The initial values.
:return: The optimization results.
"""
# optimize using scipy
return minimize(
fun=fun,
x0=np.array(list(x0.values())),
bounds=list(bounds.values()),
method="L-BFGS-B"
)
# run the optimization in parallel
results = parallelize_func(
func=optimize_rates,
data=[self.model.get_x0(bounds, self.rng) for _ in range(self.n_runs)],
parallelize=self.parallelize,
pbar=True,
desc=f"{self.__class__.__name__}>Optimizing rates",
dtype=object
)
# get the likelihoods for each run
self.likelihoods = -np.array([result.fun for result in results])
# get the best likelihood
self.likelihood = np.max(self.likelihoods)
# get the MLE parameters for each run
self.params_mle_runs = pd.DataFrame([result.x for result in results], columns=self.param_names)
# get the best result
self.result: OptimizeResult = cast(OptimizeResult, results[np.argmax(self.likelihoods)])
# check if the optimization was successful
if not self.result.success:
raise RuntimeError(f"Optimization failed with message: {self.result.message}")
# get dictionary of MLE parameters
self.params_mle = dict(zip(self.param_names, self.result.x))
# check if the MLE parameters are near the bounds
near_lower, near_upper = check_bounds(
params=self.params_mle,
bounds=bounds,
scale='log',
fixed_params=self.model.fixed_params
)
# warn if the MLE parameters are near the bounds
if len(near_lower | near_upper) > 0:
self._logger.warning(f'The MLE estimate for the rates is near the upper bound for '
f'{near_upper} and lower bound for {near_lower}. (The tuples denote '
f'(lower, value, upper) for every parameter.)')
# check if the outgroup divergence is monotonically increasing
if not self.is_monotonic():
self._logger.warning("The outgroup rates are not monotonically increasing. This might indicate "
"that the outgroups were not specified in the order of increasing divergence. "
f"rates: {dict(zip(self.outgroups, self.get_outgroup_divergence()))}")
# cache the branch probabilities for the MLE parameters
self._renew_cache()
# renew site configuration cache
self._update_configs()
def _update_configs(self):
"""
Renew site configuration cache.
"""
# obtain the probability for each site and minor allele under the MLE rate parameters
self.configs.p_minor = self.get_p_configs(
configs=self.configs,
model=self.model,
base_type=BaseType.MINOR,
params=self.params_mle
)
# obtain the probability for each site and major allele under the MLE rate parameters
self.configs.p_major = self.get_p_configs(
configs=self.configs,
model=self.model,
base_type=BaseType.MAJOR,
params=self.params_mle
)
# calculate the ancestral probabilities, i.e. probability of the major allele being ancestral
# opposed to the minor allele
self.configs.p_major_ancestral = self._calculate_p_major_ancestral(
p_minor=self.configs['p_minor'].values,
p_major=self.configs['p_major'].values,
n_major=self.configs['n_major'].values
)
[docs]
def set_mle_params(self, params: Dict[str, float]):
"""
Set the MLE parameters and update the cache and site configurations. Use this method if you want to
use different parameters for the annotation.
:param params: The new parameters.
"""
# set the parameters
self.params_mle = params
# renew cache
self._renew_cache()
# renew site configuration cache
self._update_configs()
[docs]
def is_monotonic(self) -> bool:
"""
Whether the outgroups are monotonically increasing in divergence.
:return: Whether the outgroups are monotonically increasing in divergence.
"""
# get the outgroup divergence
div = self.get_outgroup_divergence()
# check if the outgroup divergence is monotonically increasing
return all(div[i] <= div[i + 1] for i in range(len(div) - 1))
@cached_property
def p_polarization(self) -> np.ndarray | None:
"""
Get the polarization probabilities or ``None`` if ``prior`` is ``no``.
"""
if isinstance(self.prior, PolarizationPrior):
return self.prior._get_prior(
configs=self.configs,
n_ingroups=self.n_ingroups
)
[docs]
@staticmethod
def get_p_tree(
base: int,
n_outgroups: int,
internal_nodes: List[int] | np.ndarray,
outgroup_bases: List[int] | np.ndarray,
params: Dict[str, float],
model: SubstitutionModel
) -> float:
"""
Get the probability of a tree.
:param base: An observed ingroup base index.
:param n_outgroups: The number of outgroups.
:param internal_nodes: The internal nodes of the tree. We have ``n_outgroups - 1`` internal nodes.
:param outgroup_bases: The observed base indices for the outgroups.
:param params: The parameters of the model.
:param model: The model to use. Either 'K2' or 'JC'.
"""
if n_outgroups < 1:
return 0.0
# get the number of branches
n_branches = 2 * n_outgroups - 1
# the probability for each branch
p_branches = np.zeros(n_branches, dtype=float)
# iterate over the branches
for i in range(n_branches):
# if we are on the first branch
if i == 0:
# combine ingroup base either with only outgroup or with first internal node
b1 = base
b2 = outgroup_bases[0] if n_outgroups == 1 else internal_nodes[0]
# if we are on intermediate branches
elif i < n_branches - 1:
# every internal node that is not the last one combines either
# with the next internal node or with an outgroup
i_internal = (i - 1) // 2
# get internal base
b1 = internal_nodes[i_internal]
# either connect to outgroup or next internal node
b2 = outgroup_bases[i_internal] if i % 2 == 1 else internal_nodes[i_internal + 1]
else:
# last branch connects to last internal node and last outgroup
b1 = internal_nodes[-1]
b2 = outgroup_bases[-1]
# get the probability of the branch
p_branches[i] = model._get_cached_prob(b1, b2, i, params)
# take product of all branch probabilities
prod = p_branches.prod()
return prod
[docs]
@classmethod
def get_p_config(
cls,
config: SiteConfig,
base_type: BaseType,
params: Dict[str, float],
model: SubstitutionModel = K2SubstitutionModel(),
internal: np.ndarray | None = None
) -> float:
"""
Get the probability for a site configuration.
:param config: The site configuration.
:param base_type: The base type.
:param params: The parameters for the substitution model.
:param model: The substitution model to use.
:param internal: Base indices of internal nodes of the tree if fixed. If ``None``, the internal nodes
are considered as free parameters. -1 also indicates a free parameter. The number of internal nodes
is the number of outgroups minus one.
:return: The probability for a site.
"""
n_outgroups = len(config.outgroup_bases)
# get the focal base
base = config.major_base if base_type == BaseType.MAJOR else config.minor_base
# if the focal base is missing we return a probability of 0
if base == -1:
return 0.0
# number of free nodes
n_free = 0
# get internal node possibilities
combs_internal = []
for i in range(n_outgroups - 1):
if internal is not None and internal[i] != -1:
combs_internal.append([internal[i]])
else:
combs_internal.append([0, 1, 2, 3])
n_free += 1
# get outgroup possibilities
combs_outgroup = []
for i in range(n_outgroups):
if config.outgroup_bases[i] != -1:
combs_outgroup.append([config.outgroup_bases[i]])
else:
combs_outgroup.append([0, 1, 2, 3])
n_free += 1
# initialize the probability for each tree
p_trees = np.zeros(4 ** n_free, dtype=float)
# iterator over all possible internal node combinations
for i, nodes in enumerate(itertools.product(*(combs_internal + combs_outgroup))):
# get the probability of the tree
p_trees[i] = cls.get_p_tree(
base=base,
n_outgroups=n_outgroups,
internal_nodes=np.array(nodes[:n_outgroups - 1]),
outgroup_bases=np.array(nodes[n_outgroups - 1:]),
params=params,
model=model
)
return p_trees.sum()
[docs]
@classmethod
def get_p_configs(
cls,
configs: pd.DataFrame,
model: SubstitutionModel,
base_type: BaseType,
params: Dict[str, float]
) -> np.ndarray:
"""
Get the probabilities for each site configuration.
:param configs: The site configurations.
:param model: The substitution model.
:param base_type: The base type.
:param params: A dictionary of the rate parameters.
:return: The probability for each site.
"""
# the probabilities for each site
p_configs = np.zeros(shape=(configs.shape[0]), dtype=float)
# iterate over the sites
for i, config in enumerate(configs.itertuples()):
# get the log likelihood of the site
p_configs[i] = cls.get_p_config(
config=cast(SiteConfig, config),
base_type=base_type,
params=params,
model=model
)
return p_configs
[docs]
def evaluate_likelihood(self, params: Dict[str, float]) -> float:
"""
Evaluate the likelihood function for the rate parameters.
:param params: A dictionary of parameters.
:return: The log likelihood.
"""
# cache the branch probabilities
self._renew_cache(params)
# compute the likelihood
ll = -self._get_likelihood()([params[name] for name in self.param_names])
# restore cached branch probabilities if necessary
if self.params_mle is not None:
self._renew_cache()
return ll
def _renew_cache(self, params: Dict[str, float] = None):
"""
Renew the cache of branch probabilities.
:param params: The model parameters to use for caching. If ``None``, the MLE parameters are used.
"""
# cache the branch probabilities
self.model.cache(params if params is not None else self.params_mle, 2 * self.n_outgroups - 1)
def _get_mle_configs(self) -> pd.DataFrame:
"""
Get the site configurations used for the MLE with only included sites with
the correct number of outgroups.
"""
# only consider sites with the full number of outgroups
return self.configs[self.configs.n_outgroups == self.n_outgroups]
def _get_likelihood(self) -> Callable[[List[float]], float]:
"""
Get the likelihood function for the rate parameters.
:return: The likelihood function.
"""
if self.configs is None:
raise RuntimeError("No sites available. Note that you can't call infer() yourself "
"when using this class with Parser or Annotator.")
# only consider sites with the correct number of outgroups
configs = self._get_mle_configs()
# Set the minor base to -1 if the major allele is fixed.
# We don't want to consider minor allele not present in the subsample
# when optimizing the branch rates.
configs.loc[configs.n_major == self.n_ingroups, 'minor_base'] = -1
# make variables available in the inner function
model = self.model
param_names = self.param_names
n_outgroups = self.n_outgroups
def compute_likelihood(params: List[float]) -> float:
"""
Compute the negative log likelihood of the parameters.
:param params: A list of rate parameters.
:return: The negative log likelihood.
"""
# unpack the parameters
params = dict(zip(param_names, params))
# cache the branch probabilities
model.cache(params, 2 * n_outgroups - 1)
# the likelihood for each site
p_sites = np.zeros(shape=(configs.shape[0], 2), dtype=float)
# get the probability for each site and major allele
p_sites[:, 0] = MaximumLikelihoodAncestralAnnotation.get_p_configs(
configs=configs,
model=model,
base_type=BaseType.MAJOR,
params=params
)
# get the probability for each site and minor allele
p_sites[:, 1] = MaximumLikelihoodAncestralAnnotation.get_p_configs(
configs=configs,
model=model,
base_type=BaseType.MINOR,
params=params
)
# Return the negative log likelihood and take average over major and minor bases
# Also multiply by the multiplicity of each site.
# The final likelihood is the product of the likelihoods for each site.
return -(np.log(p_sites.mean(axis=1)) * configs.multiplicity.values).sum()
return compute_likelihood
def _get_site_indices(self) -> np.ndarray:
"""
Get the list of config indices for each site.
:return: The list of config indices, use -1 for sites that are not included.
"""
indices = np.full(self.n_sites, -1, dtype=int)
for i, config in self.configs.iterrows():
for j in config.sites:
indices[j] = i
return indices
def _get_ancestral_from_prob(
self,
p_major_ancestral: np.ndarray | float,
major_base: np.ndarray | str,
minor_base: np.ndarray | str
) -> np.ndarray | float:
"""
Get the ancestral allele from the probability of the major allele being ancestral.
:param p_major_ancestral: The probabilities of the major allele being ancestral.
:param major_base: The major bases.
:param minor_base: The minor bases.
:return: Array of ancestral alleles.
"""
# make function accept scalars
if isinstance(p_major_ancestral, float):
return self._get_ancestral_from_prob(
np.array([p_major_ancestral]),
np.array([major_base]),
np.array([minor_base])
)[0]
# initialize array
ancestral_bases = np.full(p_major_ancestral.shape, -1, dtype=np.int8)
ancestral_bases[p_major_ancestral >= 0.5] = major_base[p_major_ancestral >= 0.5]
ancestral_bases[p_major_ancestral < 0.5] = minor_base[p_major_ancestral < 0.5]
return ancestral_bases
def _get_internal_prob(
self,
site: SiteConfig,
internal: np.ndarray | None = None
) -> float:
"""
Get the ancestral allele for each site.
:param site: The site configuration.
:param internal: Base indices of internal nodes of the tree if fixed. If ``None``, the internal nodes
are considered as free parameters. -1 also indicates a free parameter.
:return: The ancestral allele, probability for the major being ancestral, the first base being
ancestral, the second base being ancestral.
"""
# get the probability for the major allele
p_minor = self.get_p_config(
config=site,
base_type=BaseType.MINOR,
params=self.params_mle,
model=self.model,
internal=internal
)
# get the probability for the minor allele
p_major = self.get_p_config(
config=site,
base_type=BaseType.MAJOR,
params=self.params_mle,
model=self.model,
internal=internal
)
return p_minor + p_major
def _get_internal_probs(
self,
site: SiteConfig,
i_internal: int,
) -> np.ndarray:
"""
Get the internal probabilities for the sites used to estimate the parameters.
:param site: The site configuration.
:param i_internal: The index of the internal node.
:return: The probabilities for each base and site.
"""
# number of outgroups considered
n_outgroups = len(site.outgroup_bases)
# no internal nodes if there are fewer than two outgroups
if n_outgroups < 2:
return np.full(4, self._get_internal_prob(site))
# initialize internal nodes
internal = np.full(len(site.outgroup_bases), fill_value=-1, dtype=int)
# initialize probabilities
probs = np.zeros(shape=4, dtype=float)
# get the internal node probabilities
for j in range(4):
internal[i_internal] = j
probs[j] = self._get_internal_prob(site, internal=internal)
return probs
[docs]
def get_inferred_site_info(self) -> Generator[SiteInfo, None, None]:
"""
Get the site information for the sites included in the parsing process. The sites are in the same order as
parsed. You can use :meth:`get_site_info` to get the site information for a specific site.
:return: A generator yielding a dictionary with the site information (see :meth:`get_site_info`).
:raises RuntimeError: If the subsample mode is ``probabilistic``.
"""
# check if data is provided using a VCF file
if self.subsample_mode == 'probabilistic':
raise RuntimeError("get_inferred_site_info() not implemented with probabilistic subsampling.")
# get config indices for each site
indices = self._get_site_indices()
# remove sites that are not included
indices = indices[indices != -1]
# get the sites
sites = self.configs.iloc[indices]
# iterate over the sites
for site in sites.itertuples():
yield self.get_site_info(
n_major=site.n_major,
major_base=site.major_base,
minor_base=site.minor_base,
outgroup_bases=site.outgroup_bases,
pass_indices=True
)
def _get_site_info(self, configs: List[SiteConfig]) -> SiteInfo:
"""
Get information on the specified sites using the inferred parameters.
:param configs: The site configurations with differing numbers of major alleles with their multiplicities
summing up to 1.
:return: The site information.
"""
if self.params_mle is None:
raise RuntimeError("No maximum likelihood parameters available.")
# use most likely configuration as reference
i_max = np.argmax([c.multiplicity for c in configs])
ref = configs[i_max]
# get the probability for the minor allele
p_minor = self.get_p_config(
config=ref, # use first config as representative
base_type=BaseType.MINOR,
params=self.params_mle,
model=self.model
)
# get the probability for the major allele
p_major = self.get_p_config(
config=ref, # use first config as representative
base_type=BaseType.MAJOR,
params=self.params_mle,
model=self.model
)
# get the probability that the major allele is ancestral rather than the minor allele
p_major_ancestral_probs = self._calculate_p_major_ancestral(
p_minor=np.array([p_minor if c.minor_base == ref.minor_base else p_major for c in configs]),
p_major=np.array([p_major if c.major_base == ref.major_base else p_minor for c in configs]),
n_major=np.array([c.n_major for c in configs])
)
# configs for which the minor allele turned out to be the major allele in the subsample
alt_config = np.array([c.minor_base != ref.minor_base for c in configs])
p_major_ancestral_probs[alt_config] = 1 - p_major_ancestral_probs[alt_config]
# take the weighted average
weights = np.array([c.multiplicity for c in configs])
p_major_ancestral = (p_major_ancestral_probs * weights).sum()
# get the ancestral alleles using p_major_ancestral
major_ancestral = self.get_base_string(self._get_ancestral_from_prob(
p_major_ancestral=p_major_ancestral,
major_base=ref.major_base,
minor_base=ref.minor_base
))
# ancestral base probabilities for the first node
p_bases_first_node = self._get_internal_probs(site=ref, i_internal=0)
# get the base probabilities for the first node
total = p_bases_first_node.sum()
i_max = np.argmax(p_bases_first_node)
p_first_node_ancestral = p_bases_first_node[i_max] / total if total > 0 else 0
first_node_ancestral = self.get_base_string(i_max)
return SiteInfo(
n_major={config.n_major: config.multiplicity for config in configs},
major_base=self.get_base_string(ref.major_base),
minor_base=self.get_base_string(ref.minor_base),
outgroup_bases=list(self.get_base_string(np.array(ref.outgroup_bases))),
p_minor=p_minor,
p_major=p_major,
p_major_ancestral=p_major_ancestral,
major_ancestral=major_ancestral,
p_bases_first_node=dict(zip(bases, p_bases_first_node)),
p_first_node_ancestral=p_first_node_ancestral,
first_node_ancestral=first_node_ancestral,
rate_params=self.params_mle
)
[docs]
def get_site_info(
self,
n_major: int,
major_base: int | str,
minor_base: int | str,
outgroup_bases: List[int | str] | np.ndarray,
pass_indices: bool = False
) -> SiteInfo:
"""
Get information on the specified sites using the inferred parameters.
:param n_major: The number of copies of the major allele.
:param major_base: The major bases indices or strings.
:param minor_base: The minor bases indices or strings.
:param outgroup_bases: The outgroup base indices or strings.
:param pass_indices: Whether to pass the indices as strings or convert them to integers.
:return: The site information.
"""
if not pass_indices:
major_base = self.get_base_index(major_base)
minor_base = self.get_base_index(minor_base)
outgroup_bases = self.get_base_index(np.array(outgroup_bases))
# initialize site configuration
config = SiteConfig(
n_major=n_major,
major_base=major_base,
minor_base=minor_base,
outgroup_bases=outgroup_bases
)
return self._get_site_info([config])
def _calculate_p_major_ancestral(
self,
p_minor: float | np.ndarray,
p_major: float | np.ndarray,
n_major: int | np.ndarray
) -> float | np.ndarray:
"""
Calculate the probability that the ancestral allele is the major allele.
:param p_minor: The probability or probabilities of the minor allele.
:param p_major: The probability or probabilities of the major allele.
:param n_major: The number or numbers of major alleles.
:return: The probability or probabilities that the ancestral allele is the major allele.
"""
# return empty array if p_minor is empty
if isinstance(p_minor, np.ndarray) and len(p_minor) == 0:
return np.array([])
try:
if self.prior is not None:
# polarization prior for the major allele
pi = self.p_polarization[self.n_ingroups - n_major]
# get the probability that the major allele is ancestral
return pi * p_major / (pi * p_major + (1 - pi) * p_minor)
# get the probability that the major allele is ancestral
return p_major / (p_major + p_minor)
# only occurs when we deal with scalars
except ZeroDivisionError:
return np.nan
@staticmethod
def _is_confident(threshold: float, p: float) -> bool:
"""
Whether we are confident enough about the ancestral allele state.
:param threshold: Confidence threshold.
:param p: Probability of the major allele being ancestral as opposed to the minor allele.
:return: Whether we are confident enough.
"""
return not (1 - threshold) / 2 < p < 1 - (1 - threshold) / 2
[docs]
def annotate_site(self, variant: Union['cyvcf2.Variant', DummyVariant]):
"""
Annotate a single site.
:param variant: The variant to annotate.
:return: The annotated variant.
"""
# set default values
ancestral_base = '.'
ancestral_prob = '.'
# use maximum parsimony if we don't have an SNP
if isinstance(variant, DummyVariant) or not variant.is_snp:
ancestral_base = MaximumParsimonyAncestralAnnotation._get_ancestral(variant, self._ingroup_mask)
ancestral_info = 'monomorphic'
# increase the number of annotated sites
self.n_annotated += 1
else:
try:
configs = self._parse_variant(variant)
except _PolyAllelicSiteError:
ancestral_info = 'polyallelic'
except _TooFewIngroupsSiteError:
ancestral_info = 'too few ingroups'
else:
site = self._get_site_info(configs)
# only proceed if the ancestral allele is known
if site.major_ancestral in bases:
# get site information dictionary
site_dict = site.__dict__
# update info
ancestral_info = str(site_dict)
# only proceed with annotation if the confidence is high enough
if self._is_confident(self.confidence_threshold, site.p_major_ancestral):
# we take most likely configuration as reference
ref = configs[np.argmax([c.multiplicity for c in configs])]
# obtain ad hoc annotation for sanity checking
site_info_ad_hoc = AdHocAncestralAnnotation._get_site_info(ref)
# log warning if ad hoc and maximum likelihood annotation disagree
if site_info_ad_hoc['ancestral_base'] != site.major_ancestral:
self._logger.debug(
"Mismatch with ad hoc ancestral allele annotation: " +
str(dict(
site=f"{variant.CHROM}:{variant.POS}",
ancestral_base_ad_hoc=site_info_ad_hoc['ancestral_base'],
) | site_dict)
)
# append site to mismatches
self.mismatches.append(site)
# update ancestral base
ancestral_base = site.major_ancestral
# update ancestral probability
if site.major_base == site.major_ancestral:
ancestral_prob = site.p_major_ancestral
else:
ancestral_prob = 1 - site.p_major_ancestral
# increase the number of annotated sites
self.n_annotated += 1
else:
ancestral_info = 'below confidence threshold'
else:
ancestral_info = 'invalid or unknown ancestral allele'
# set the ancestral allele
variant.INFO[self._handler.info_ancestral] = ancestral_base
# set the ancestral allele probability
variant.INFO[self._handler.info_ancestral + '_prob'] = ancestral_prob
# set info field
variant.INFO[self._handler.info_ancestral + "_info"] = ancestral_info
[docs]
def plot_likelihoods(
self,
file: str = None,
show: bool = True,
title: str = 'rate likelihoods',
scale: Literal['lin', 'log'] = 'lin',
ax: 'plt.Axes' = None,
ylabel: str = 'lnl'
) -> 'plt.Axes':
"""
Visualize the likelihoods of the rate optimization runs using a scatter plot.
:param scale: y-scale of the plot.
:param title: Plot title.
:param file: File to save plot to.
:param show: Whether to show plot.
:param ax: Axes to plot on. Only for Python visualization backend.
:param ylabel: Label for y-axis.
:return: Axes object
"""
from .visualization import Visualization
return Visualization.plot_scatter(
values=self.likelihoods,
file=file,
show=show,
title=title,
scale=scale,
ax=ax,
ylabel=ylabel,
)
[docs]
def get_folded_spectra(
self,
groups: List[Literal['major_base', 'minor_base', 'outgroup_bases']] = ['major_base'],
) -> Spectra:
"""
Get the folded spectra for the parsed sites (used to estimate the parameters).
:param groups: The groups to group the spectra by.
:return: Spectra object
"""
configs = self._get_mle_configs()
# group by n_major and groups
grouped = configs.groupby(['n_major'] + groups).sum()
if len(groups) == 0:
index = np.arange(self.n_ingroups + 1)
else:
# new index to include all possible values for n_major
index = pd.MultiIndex.from_product(
[np.arange(self.n_ingroups + 1).tolist()] + grouped.index.levels[1:],
names=['n_major'] + groups
)
# reindex
grouped = grouped.reindex(index, fill_value=0)
# if we only group by n_major
if len(groups) == 0:
return Spectra.from_dict(dict(all=grouped.multiplicity[::-1].tolist()))
# iterate over groups
spectra = {}
for i, group in grouped.groupby(level=groups):
if not isinstance(i, tuple):
name = f"{groups[0]}={self.get_base_string(i)}"
else:
name = ", ".join([f"{a}={self.get_base_string(b)}" for a, b in zip(groups, i)])
spectra[name] = group.multiplicity[::-1].tolist()
return Spectra.from_dict(spectra)
@staticmethod
def _get_branch(params: Dict[str, float], i: int) -> float:
"""
Get the branch rate for the given index.
:param params: The parameters.
:param i: The index.
:return: The branch rate.
"""
return params['K'] if 'K' in params else params[f'K{i}']
[docs]
def get_outgroup_divergence(self) -> np.ndarray:
"""
Get the inferred branch rates between the ingroup and outgroups by combining the inferred branch rates.
:return: One rate for each outgroup.
"""
if self.params_mle is None:
raise RuntimeError("No maximum likelihood parameters available.")
# initialize array
rates = np.zeros(self.n_outgroups, dtype=float)
for i in range(self.n_outgroups):
# if it's not the last outgroup
if i < self.n_outgroups - 1:
ingroup = [self._get_branch(self.params_mle, 2 * j) for j in range(i + 1)]
outgroup = self._get_branch(self.params_mle, 2 * i + 1)
else:
ingroup = [self._get_branch(self.params_mle, 2 * j) for j in range(i)]
outgroup = self._get_branch(self.params_mle, 2 * i)
rates[i] = np.sum(ingroup + [outgroup])
return rates
[docs]
def get_observed_transition_transversion_ratio(self) -> float:
"""
Get the observed transition/transversion ratio. Note that this may differ from the estimated ratio for
:class:`K2SubstitutionModel`.
:return: The observed transition/transversion ratio.
"""
configs = self._get_mle_configs()
if len(configs) == 0:
raise RuntimeError("No sites available to calculate the transition/transversion ratio.")
tuples = configs[["minor_base", "major_base"]].apply(tuple, axis=1)
transitions = tuples.isin([tuple(e) for e in SubstitutionModel._transitions])
transversions = tuples.isin([tuple(e) for e in SubstitutionModel._transversions])
n_transitions = configs[transitions].multiplicity.sum()
n_transversions = configs[transversions].multiplicity.sum()
return n_transitions / n_transversions
[docs]
class AdHocAncestralAnnotation(_OutgroupAncestralAlleleAnnotation):
"""
Ad-hoc ancestral allele annotation using simple rules. Used for testing and sanity checking.
"""
@staticmethod
def _get_site_info(config: SiteConfig) -> dict:
"""
Get site information from the site configuration.
:param config: The site configuration.
:return: Dictionary of with the key 'ancestral_base', denoting the ancestral base string.
"""
# get ingroup and outgroup bases
# noinspection PyTypeChecker
bases_combined = np.concatenate(([config.major_base], [config.minor_base], config.outgroup_bases))
# get scores for each base
# noinspection PyTypeChecker
scores = np.concatenate(([1.2], [1], [1 for _ in range(1, len(config.outgroup_bases) + 1)]))
# get valid bases
is_valid = bases_combined != -1
# remove missing bases
valid_bases = bases_combined[is_valid]
# get valid scores
valid_scores = scores[is_valid]
# return missing if no valid bases
if len(valid_bases) == 0:
return dict(
ancestral_base='.'
)
# get sum for each base
score = np.array([np.sum(valid_scores[valid_bases == i]) for i in range(4)])
# take most common base as ancestral
ancestral_base = bases[score.argmax()]
return dict(
ancestral_base=ancestral_base
)
[docs]
def annotate_site(self, variant: Union['cyvcf2.Variant', DummyVariant]):
"""
Annotate a single site. Mono-allelic sites are assigned the major allele as ancestral. Sites with
more than two alleles are ignored.
:param variant: The variant to annotate.
:return: The annotated variant.
"""
ancestral_base = '.'
# use maximum parsimony if we have a mono-allelic site
if isinstance(variant, DummyVariant) or not variant.is_snp:
ancestral_base = MaximumParsimonyAncestralAnnotation._get_ancestral(variant, self._ingroup_mask)
ancestral_info = 'monomorphic'
self.n_annotated += 1
else:
try:
# parse the site
configs = self._parse_variant(variant)
except _PolyAllelicSiteError:
ancestral_info = 'polyallelic'
except _TooFewIngroupsSiteError:
ancestral_info = 'too few ingroups'
else:
# use config with the highest multiplicity
ref = configs[np.argmax([c.multiplicity for c in configs])]
# get site information dictionary
site = self._get_site_info(ref)
# only proceed if the ancestral allele is known
if site['ancestral_base'] in bases:
ancestral_base = site['ancestral_base']
ancestral_info = str(site)
self.n_annotated += 1
else:
ancestral_info = 'invalid or unknown ancestral allele'
# set the ancestral allele
variant.INFO[self._handler.info_ancestral] = ancestral_base
# set info field
variant.INFO[self._handler.info_ancestral + "_info"] = ancestral_info
class _ESTSFSAncestralAnnotation(AncestralAlleleAnnotation): # pragma: no cover
"""
A wrapper around EST-SFS. Used for testing.
"""
def __init__(
self,
anc: MaximumLikelihoodAncestralAnnotation
):
"""
Create a new ESTSFSAncestralAnnotation instance.
:param anc:
"""
super().__init__()
#: The ancestral annotation.
self.anc = anc
#: The likelihoods for each run.
self.likelihoods: np.ndarray | None = None
#: The minimum likelihood.
self.likelihood: float | None = None
#: The MLE parameters.
self.params_mle: Dict[str, float] | None = None
#: The probabilities for each site.
self.probs: pd.DataFrame | None = None
def create_seed_file(self, seed_file: str):
"""
Create the seed file.
:param seed_file: Path to the seed file.
"""
with open(seed_file, 'w') as f:
f.write(str(self.anc.seed))
def create_config_file(self, config_file: str):
"""
Create the config file.
:param config_file: Path to the config file.
"""
models = dict(
JCSubstitutionModel=0,
K2SubstitutionModel=1
)
with open(config_file, 'w') as f:
f.write(f"n_outgroup {self.anc.n_outgroups}\n")
f.write(f"model {models[self.anc.model.__class__.__name__]}\n")
f.write(f"nrandom {self.anc.n_runs}\n")
def infer(
self,
binary: str = 'EST_SFS',
wd: str = None,
execute: Callable = None,
):
"""
Infer the ancestral allele using EST-SFS.
:param binary: The path to the EST-SFS binary.
:param wd: The working directory.
:param execute: The function to execute the bash command.
"""
# define default function for executing command
if execute is None:
def shell(command: str):
"""
Execute shell command.
:param command: Command string
"""
return subprocess.run(command, check=True, cwd=wd, shell=True)
execute = shell
with tempfile.NamedTemporaryFile('w') as sites_file, \
tempfile.NamedTemporaryFile('w') as seed_file, \
tempfile.NamedTemporaryFile('w') as config_file, \
tempfile.NamedTemporaryFile('w') as out_sfs, \
tempfile.NamedTemporaryFile('w') as out_p:
# create the sites file
self.anc.to_est_sfs(sites_file.name)
# create the seed file
self.create_seed_file(seed_file.name)
# create the config file
self.create_config_file(config_file.name)
# construct command string
command = (f"{binary} "
f"{config_file.name} "
f"{sites_file.name} "
f"{seed_file.name} "
f"{out_sfs.name} "
f"{out_p.name} ")
# log command signature
self._logger.info(f"Running: '{command}'")
# execute command
execute(command)
self.parse_est_sfs_output(out_p.name)
def parse_est_sfs_output(self, file: str):
"""
Parse the output of the EST-SFS program containing the site probabilities.
:param file: The file name.
:return: The data frame.
"""
# filter out lines starting with 0
filtered_lines = []
with open(file, 'r') as f:
for i, line in enumerate(f):
# strip line
line = line.strip()
if line.startswith('0'):
if i == 4:
# parse likelihoods
self.likelihoods = np.array(line.split()[2:], dtype=float)
self.likelihood = np.min(self.likelihoods)
if i == 5:
# parse MLE parameters
data = np.array(line.split()[2:])
self.params_mle = dict(zip([d.upper() for d in data[::2]], data[1::2].astype(float)))
if i == 6 and isinstance(self.anc.model, K2SubstitutionModel):
# parse kappa
self.params_mle['k'] = float(line.split()[2])
else:
filtered_lines.append(line.strip())
# read into dataframe
self.probs = pd.read_csv(StringIO('\n'.join(filtered_lines)), sep=" ", header=None)
# drop the first column
self.probs.drop(self.probs.columns[0], axis=1, inplace=True)
# rename columns
self.probs.rename(columns={1: 'config', 2: 'prob'}, inplace=True)
def annotate_site(self, variant: Union['cyvcf2.Variant', DummyVariant]):
"""
Not implemented.
:param variant: The variant to annotate.
:raises: NotImplementedError
"""
raise NotImplementedError
def to_file(self, file: str):
"""
Save object to file (without reference to AncestralAlleleAnnotation object).
:param file: File path.
"""
self.anc = None
with open(file, 'w') as fh:
fh.write(self.to_json())
def to_json(self) -> str:
"""
Serialize object.
:return: JSON string
"""
return jsonpickle.encode(self, indent=4, warn=True)
@classmethod
def from_json(cls, json: str, classes=None) -> 'Self':
"""
Unserialize object.
:param classes: Classes to be used for unserialization
:param json: JSON string
"""
return jsonpickle.decode(json, classes=classes)
@classmethod
def from_file(cls, file: str, classes=None) -> 'Self':
"""
Load object from file.
:param classes: Classes to be used for unserialization
:param file: File to load from
"""
with open(file, 'r') as fh:
return cls.from_json(fh.read(), classes)
[docs]
class Annotator(MultiHandler):
"""
Annotate a VCF file with the given annotations.
Example usage:
::
import fastdfe as fd
ann = fd.Annotator(
vcf="http://ftp.1000genomes.ebi.ac.uk/vol1/ftp/data_collections/"
"1000_genomes_project/release/20181203_biallelic_SNV/"
"ALL.chr21.shapeit2_integrated_v1a.GRCh38.20181129.phased.vcf.gz",
fasta="http://ftp.ensembl.org/pub/release-109/fasta/homo_sapiens/"
"dna/Homo_sapiens.GRCh38.dna.chromosome.21.fa.gz",
gff="http://ftp.ensembl.org/pub/release-109/gff3/homo_sapiens/"
"Homo_sapiens.GRCh38.109.chromosome.21.gff3.gz",
output='sapiens.chr21.degeneracy.vcf.gz',
annotations=[fd.DegeneracyAnnotation()],
aliases=dict(chr21=['21'])
)
ann.annotate()
"""
[docs]
def __init__(
self,
vcf: str,
output: str,
annotations: List[Annotation],
gff: str | None = None,
fasta: str | None = None,
info_ancestral: str = 'AA',
max_sites: int = np.inf,
seed: int | None = 0,
cache: bool = True,
aliases: Dict[str, List[str]] = {},
):
"""
Create a new annotator instance.
:param vcf: The path to the VCF file, can be gzipped, urls are also supported
:param output: The path to the output file
:param annotations: The annotations to apply.
:param gff: The path to the GFF file, can be gzipped, urls are also supported. Required for
annotations that require a GFF file.
:param fasta: The path to the FASTA file, can be gzipped, urls are also supported. Required for
annotations that require a FASTA file.
:param info_ancestral: The tag in the INFO field that contains the ancestral allele
:param max_sites: Maximum number of sites to consider
:param seed: Seed for the random number generator. Use ``None`` for no seed.
:param cache: Whether to cache files downloaded from urls
:param aliases: Dictionary of aliases for the contigs in the VCF file, e.g. ``{'chr1': ['1']}``.
This is used to match the contig names in the VCF file with the contig names in the FASTA file and GFF file.
"""
super().__init__(
vcf=vcf,
gff=gff,
fasta=fasta,
info_ancestral=info_ancestral,
max_sites=max_sites,
seed=seed,
cache=cache,
aliases=aliases
)
#: The path to the output file.
self.output: str = output
#: The annotations to apply.
self.annotations: List[Annotation] = annotations
#: The VCF writer.
self._writer: 'cyvcf2.Writer' | None = None
def _setup(self):
"""
Set up the annotator.
"""
try:
from cyvcf2 import Writer
except ImportError:
raise ImportError(
"VCF support in fastdfe requires the optional 'cyvcf2' package. "
"Please install fastdfe with the 'vcf' extra: pip install fastdfe[vcf]"
)
for annotation in self.annotations:
annotation._setup(self)
# create the writer
self._writer = Writer(self.output, self._reader)
def _teardown(self):
"""
Tear down the annotator.
"""
for annotation in self.annotations:
annotation._teardown()
# close the writer and reader
self._writer.close()
self._reader.close()
[docs]
def annotate(self):
"""
Annotate the VCF file.
"""
self._logger.info('Start annotating')
# set up the annotator
self._setup()
# get progress bar
with self.get_pbar(desc=f"{self.__class__.__name__}>Processing sites") as pbar:
# iterate over the sites
for i, variant in enumerate(self._reader):
# apply annotations
for annotation in self.annotations:
annotation.annotate_site(variant)
# write the variant
self._writer.write_record(variant)
# update the progress bar
pbar.update()
# explicitly stopping after ``n`` sites fixes a bug with cyvcf2:
# 'error parsing variant with `htslib::bcf_read` error-code: 0 and ret: -2'
if i + 1 == self.n_sites or i + 1 == self.max_sites:
break
# tear down the annotator
self._teardown()